Add encryption option
This commit is contained in:
2
.idea/kotlinc.xml
generated
2
.idea/kotlinc.xml
generated
@@ -1,6 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="KotlinJpsPluginSettings">
|
||||
<option name="version" value="2.0.0-RC2" />
|
||||
<option name="version" value="2.0.0" />
|
||||
</component>
|
||||
</project>
|
||||
@@ -1,13 +1,13 @@
|
||||
import org.gradle.model.internal.core.ModelNodes.withType
|
||||
|
||||
plugins {
|
||||
kotlin("jvm") version "2.0.0-RC2"
|
||||
kotlin("jvm") version "2.0.0"
|
||||
id("maven-publish")
|
||||
id("signing")
|
||||
}
|
||||
|
||||
group = "nl.astraeus"
|
||||
version = "1.0.0-SNAPSHOT"
|
||||
version = "1.1.2-SNAPSHOT"
|
||||
|
||||
repositories {
|
||||
mavenCentral()
|
||||
|
||||
@@ -5,6 +5,7 @@ import java.io.ObjectInputStream
|
||||
import java.io.ObjectOutputStream
|
||||
import java.io.Serializable
|
||||
import java.text.DecimalFormat
|
||||
import java.util.*
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import java.util.concurrent.atomic.AtomicLong
|
||||
import kotlin.reflect.KClass
|
||||
@@ -39,6 +40,8 @@ class Action(
|
||||
class Datastore(
|
||||
private val directory: File,
|
||||
val enableOptimisticLocking: Boolean = false,
|
||||
val decryptionKey: String? = null,
|
||||
val encryptionKey: String? = null,
|
||||
indexes: Array<PersistableIndex> = arrayOf(),
|
||||
) {
|
||||
private val fileManager = FileManager(directory)
|
||||
@@ -109,7 +112,7 @@ class Datastore(
|
||||
val (lastSnapshot, lastSnapshotFile) = fileManager.findLastSnapshot()
|
||||
|
||||
if (lastSnapshotFile != null) {
|
||||
ObjectInputStream(lastSnapshotFile.inputStream()).use { ois ->
|
||||
ObjectInputStream(DecryptingInputStream(lastSnapshotFile.inputStream(), decryptionKey)).use { ois ->
|
||||
readSnapshot(ois)
|
||||
}
|
||||
}
|
||||
@@ -117,7 +120,7 @@ class Datastore(
|
||||
val transactions = fileManager.findTransactionsAfter(lastSnapshot ?: 0L)
|
||||
|
||||
transactions?.forEach { file ->
|
||||
ObjectInputStream(file.inputStream()).use { ois ->
|
||||
ObjectInputStream(DecryptingInputStream(file.inputStream(), decryptionKey)).use { ois ->
|
||||
readTransaction(ois)
|
||||
}
|
||||
}
|
||||
@@ -126,7 +129,7 @@ class Datastore(
|
||||
Logger.debug("Loaded transactions in %6.3fms", ((System.nanoTime() - start) / 1_000_000f))
|
||||
}
|
||||
|
||||
private fun execute(actions: Set<Action>) {
|
||||
private fun execute(actions: List<Action>) {
|
||||
synchronized(this) {
|
||||
if (enableOptimisticLocking) {
|
||||
for (action in actions) {
|
||||
@@ -211,7 +214,7 @@ class Datastore(
|
||||
return indexes[kClass.java]?.get(indexName)
|
||||
}
|
||||
|
||||
fun storeAndExecute(actions: Set<Action>) {
|
||||
fun storeAndExecute(actions: List<Action>) {
|
||||
if (actions.isNotEmpty()) {
|
||||
synchronized(this) {
|
||||
execute(actions)
|
||||
@@ -225,14 +228,22 @@ class Datastore(
|
||||
check(versionNumber == 1) { "Unsupported version number: $versionNumber" }
|
||||
val transactionNumber = ois.readLong()
|
||||
nextTransactionNumber = transactionNumber + 1
|
||||
val actions = ois.readObject() as Set<Action>
|
||||
execute(actions)
|
||||
val actions = ois.readObject()
|
||||
when(actions) {
|
||||
is Set<*> -> {
|
||||
val list = LinkedList(actions as Set<Action>)
|
||||
execute(list)
|
||||
}
|
||||
is List<*> -> {
|
||||
execute(actions as List<Action>)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun writeTransaction(actions: Set<Action>) {
|
||||
private fun writeTransaction(actions: List<Action>) {
|
||||
val number = transactionFormatter.format(nextTransactionNumber)
|
||||
val file = File(directory, "transaction-$number.trn-tmp")
|
||||
ObjectOutputStream(file.outputStream()).use { oos ->
|
||||
ObjectOutputStream(EncryptingOutputStream(file.outputStream(), encryptionKey)).use { oos ->
|
||||
// version number
|
||||
oos.writeInt(1)
|
||||
oos.writeLong(nextTransactionNumber++)
|
||||
@@ -246,7 +257,7 @@ class Datastore(
|
||||
synchronized(this) {
|
||||
val number = transactionFormatter.format(nextTransactionNumber)
|
||||
val file = File(directory, "transaction-$number.snp-tmp")
|
||||
ObjectOutputStream(file.outputStream()).use { oos ->
|
||||
ObjectOutputStream(EncryptingOutputStream(file.outputStream(), encryptionKey)).use { oos ->
|
||||
// version number
|
||||
oos.writeInt(1)
|
||||
oos.writeLong(nextTransactionNumber++)
|
||||
@@ -270,7 +281,8 @@ class Datastore(
|
||||
val versionNumber = ois.readInt()
|
||||
check(versionNumber == 1) { "Unsupported version number: $versionNumber" }
|
||||
nextTransactionNumber = ois.readLong() + 1
|
||||
data = ois.readObject() as MutableMap<Class<*>, TypeData>
|
||||
val dataObj = ois.readObject()
|
||||
data = dataObj as ConcurrentHashMap<Class<*>, TypeData>
|
||||
|
||||
val foundIndexes = mutableMapOf<Class<*>, MutableList<String>>()
|
||||
val numberOfClassesWithIndex = ois.readInt()
|
||||
|
||||
150
src/main/kotlin/nl/astraeus/persistence/Encryption.kt
Normal file
150
src/main/kotlin/nl/astraeus/persistence/Encryption.kt
Normal file
@@ -0,0 +1,150 @@
|
||||
@file:OptIn(ExperimentalEncodingApi::class, ExperimentalEncodingApi::class)
|
||||
|
||||
package nl.astraeus.persistence
|
||||
|
||||
import java.io.ByteArrayOutputStream
|
||||
import java.io.InputStream
|
||||
import java.io.OutputStream
|
||||
import java.security.SecureRandom
|
||||
import javax.crypto.Cipher
|
||||
import javax.crypto.KeyGenerator
|
||||
import javax.crypto.SecretKey
|
||||
import javax.crypto.spec.IvParameterSpec
|
||||
import javax.crypto.spec.SecretKeySpec
|
||||
import kotlin.io.encoding.Base64
|
||||
import kotlin.io.encoding.ExperimentalEncodingApi
|
||||
|
||||
class Encryptor(
|
||||
base64EncryptionKey: String?,
|
||||
base64DecryptionKey: String?,
|
||||
) {
|
||||
private var decryptionKey: SecretKey? = null
|
||||
private var encryptionKey: SecretKey? = null
|
||||
|
||||
init {
|
||||
if (base64EncryptionKey?.isNotEmpty() == true) {
|
||||
encryptionKey = SecretKeySpec(Base64.UrlSafe.decode(base64EncryptionKey), "AES")
|
||||
}
|
||||
if (base64DecryptionKey?.isNotEmpty() == true) {
|
||||
decryptionKey = SecretKeySpec(Base64.UrlSafe.decode(base64DecryptionKey), "AES")
|
||||
}
|
||||
}
|
||||
|
||||
fun encrypt(data: ByteArray): ByteArray {
|
||||
if (encryptionKey == null) {
|
||||
return data
|
||||
}
|
||||
|
||||
val prePaddedData = ByteArray(16) + data
|
||||
val cipher = Cipher.getInstance("AES/CBC/PKCS5Padding")
|
||||
|
||||
// Generate a new IV (Initialization Vector)
|
||||
val secureRandom = SecureRandom()
|
||||
val iv = ByteArray(cipher.blockSize)
|
||||
secureRandom.nextBytes(iv)
|
||||
val ivParams = IvParameterSpec(iv)
|
||||
|
||||
cipher.init(Cipher.ENCRYPT_MODE, encryptionKey, ivParams)
|
||||
return cipher.doFinal(prePaddedData)
|
||||
}
|
||||
|
||||
fun decrypt(data: ByteArray): ByteArray {
|
||||
if (decryptionKey == null) {
|
||||
return data
|
||||
}
|
||||
|
||||
val cipher = Cipher.getInstance("AES/CBC/PKCS5Padding")
|
||||
val secureRandom = SecureRandom()
|
||||
val iv = ByteArray(cipher.blockSize)
|
||||
secureRandom.nextBytes(iv)
|
||||
cipher.init(Cipher.DECRYPT_MODE, decryptionKey, IvParameterSpec(iv))
|
||||
val completeData = cipher.doFinal(data)
|
||||
|
||||
return completeData.sliceArray(16 until completeData.size)
|
||||
}
|
||||
}
|
||||
|
||||
fun generateBase64Key(): String {
|
||||
val keyGen: KeyGenerator = KeyGenerator.getInstance("AES")
|
||||
keyGen.init(256) // for AES-256
|
||||
val secretKey: SecretKey = keyGen.generateKey()
|
||||
|
||||
return Base64.UrlSafe.encode(secretKey.encoded)
|
||||
}
|
||||
|
||||
/*
|
||||
object Encryption {
|
||||
var encryptor = Encryptor(
|
||||
System.getenv().getOrDefault("SPK_ENCRYPTION_KEY", ""),
|
||||
System.getenv().getOrDefault("SPK_DECRYPTION_KEY", "")
|
||||
)
|
||||
}
|
||||
*/
|
||||
|
||||
class DecryptingInputStream(
|
||||
val input: InputStream,
|
||||
val base64DecryptionKey: String?
|
||||
) : InputStream() {
|
||||
val bytes: ByteArray
|
||||
var index = 0
|
||||
|
||||
init {
|
||||
val encryptedBytes = input.readAllBytes()
|
||||
if (base64DecryptionKey?.isBlank() == true) {
|
||||
bytes = encryptedBytes
|
||||
} else {
|
||||
val encryptor = Encryptor(
|
||||
base64EncryptionKey = null,
|
||||
base64DecryptionKey = base64DecryptionKey
|
||||
)
|
||||
bytes = encryptor.decrypt(encryptedBytes)
|
||||
}
|
||||
}
|
||||
|
||||
/* override fun readAllBytes(): ByteArray {
|
||||
index = bytes.size
|
||||
return bytes
|
||||
}*/
|
||||
|
||||
override fun read(): Int {
|
||||
return if (index < bytes.size) {
|
||||
bytes[index++].toUByte().toInt()
|
||||
} else {
|
||||
-1
|
||||
}
|
||||
}
|
||||
|
||||
override fun close() {
|
||||
input.close()
|
||||
}
|
||||
}
|
||||
|
||||
class EncryptingOutputStream(
|
||||
val output: OutputStream,
|
||||
val base64EncryptionKey: String?
|
||||
) : OutputStream() {
|
||||
val baos = ByteArrayOutputStream()
|
||||
|
||||
override fun write(b: Int) {
|
||||
baos.write(b)
|
||||
}
|
||||
|
||||
override fun close() {
|
||||
if (base64EncryptionKey?.isBlank() == true) {
|
||||
output.write(baos.toByteArray())
|
||||
} else {
|
||||
val encryptor = Encryptor(
|
||||
base64EncryptionKey = base64EncryptionKey,
|
||||
base64DecryptionKey = null
|
||||
)
|
||||
val encryptedBytes = encryptor.encrypt(baos.toByteArray())
|
||||
output.write(encryptedBytes)
|
||||
}
|
||||
output.flush()
|
||||
output.close()
|
||||
}
|
||||
|
||||
override fun flush() {
|
||||
// no flush
|
||||
}
|
||||
}
|
||||
@@ -10,10 +10,18 @@ fun currentTransaction(): Transaction? {
|
||||
|
||||
class Persistent(
|
||||
directory: File,
|
||||
indexes: Array<PersistableIndex> = arrayOf(),
|
||||
enableOptimisticLocking: Boolean = false,
|
||||
decryptionKey: String? = null,
|
||||
encryptionKey: String? = null,
|
||||
indexes: Array<PersistableIndex> = arrayOf(),
|
||||
) {
|
||||
val datastore: Datastore = Datastore(directory, enableOptimisticLocking, indexes)
|
||||
val datastore: Datastore = Datastore(
|
||||
directory,
|
||||
enableOptimisticLocking,
|
||||
decryptionKey,
|
||||
encryptionKey,
|
||||
indexes
|
||||
)
|
||||
|
||||
fun <T> query(block: Query.() -> T): T {
|
||||
var cleanup = false
|
||||
|
||||
@@ -99,7 +99,7 @@ open class Query(
|
||||
class Transaction(
|
||||
persistent: Persistent,
|
||||
) : Query(persistent), Serializable {
|
||||
private val actions = mutableSetOf<Action>()
|
||||
private val actions = ArrayList<Action>()
|
||||
|
||||
fun store(obj: Persistable) {
|
||||
if (obj.id == 0L) {
|
||||
|
||||
@@ -2,9 +2,11 @@ package nl.astraeus.persistence
|
||||
|
||||
import java.io.File
|
||||
import java.io.ObjectInputStream
|
||||
import java.util.*
|
||||
|
||||
class TransactionLog(
|
||||
directory: File,
|
||||
val decryptionKey: String? = null,
|
||||
) {
|
||||
val fileManager = FileManager(directory)
|
||||
|
||||
@@ -14,14 +16,15 @@ class TransactionLog(
|
||||
|
||||
printer("Snapshot:")
|
||||
snapshot?.inputStream()?.use { input ->
|
||||
ObjectInputStream(input).use { ois ->
|
||||
ObjectInputStream(DecryptingInputStream(input, decryptionKey)).use { ois ->
|
||||
val versionNumber = ois.readInt()
|
||||
check(versionNumber == 1) {
|
||||
"Unsupported version number: $versionNumber"
|
||||
}
|
||||
val transactionNumber = ois.readLong()
|
||||
printer("[$versionNumber] $transactionNumber")
|
||||
val data = ois.readObject() as MutableMap<Class<*>, TypeData>
|
||||
val dataObj = ois.readObject()
|
||||
val data = dataObj as MutableMap<Class<*>, TypeData>
|
||||
printer("Data:")
|
||||
printer("\tClasses:")
|
||||
for ((cls, entries) in data.entries) {
|
||||
@@ -35,15 +38,26 @@ class TransactionLog(
|
||||
printer("Transactions:")
|
||||
transactions?.forEach { transaction ->
|
||||
transaction.inputStream().use { input ->
|
||||
ObjectInputStream(input).use { ois ->
|
||||
ObjectInputStream(DecryptingInputStream(input, decryptionKey)).use { ois ->
|
||||
val versionNumber = ois.readInt()
|
||||
check(versionNumber == 1) {
|
||||
"Unsupported version number: $versionNumber"
|
||||
}
|
||||
val transactionNumber = ois.readLong()
|
||||
val actions = ois.readObject() as Set<Action>
|
||||
val actions = ois.readObject()
|
||||
val actionList = when(actions) {
|
||||
is Set<*> -> {
|
||||
LinkedList(actions as Set<Action>)
|
||||
}
|
||||
is List<*> -> {
|
||||
actions as List<Action>
|
||||
}
|
||||
else -> {
|
||||
emptyList()
|
||||
}
|
||||
}
|
||||
printer("\t[$transactionNumber]")
|
||||
for (action in actions) {
|
||||
for (action in actionList) {
|
||||
printer("\t\t- $action")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ import org.junit.jupiter.api.Test;
|
||||
import java.io.File;
|
||||
import java.util.List;
|
||||
|
||||
|
||||
public class TestPersistenceJava {
|
||||
|
||||
static class Person extends AbstractPersistable {
|
||||
@@ -64,14 +63,16 @@ public class TestPersistenceJava {
|
||||
|
||||
Persistent persistent = new Persistent(
|
||||
new File("data", "java-test"),
|
||||
false,
|
||||
null,
|
||||
null,
|
||||
new Index[] {
|
||||
new Index<>(
|
||||
Person.class,
|
||||
"name",
|
||||
(p) -> ((Person)p).getName()
|
||||
)
|
||||
},
|
||||
false
|
||||
}
|
||||
);
|
||||
|
||||
persistent.transaction((t) -> {
|
||||
|
||||
54
src/test/kotlin/nl/astraeus/persistence/EncryptionTest.kt
Normal file
54
src/test/kotlin/nl/astraeus/persistence/EncryptionTest.kt
Normal file
@@ -0,0 +1,54 @@
|
||||
package nl.astraeus.persistence
|
||||
|
||||
import org.junit.jupiter.api.Test
|
||||
|
||||
import org.junit.jupiter.api.Assertions.*
|
||||
import java.io.ByteArrayInputStream
|
||||
import java.io.ByteArrayOutputStream
|
||||
import java.security.SecureRandom
|
||||
|
||||
class EncryptionTest {
|
||||
|
||||
@Test
|
||||
fun testKeyGen() {
|
||||
println(generateBase64Key())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testEncryptDecrypt() {
|
||||
val random = SecureRandom()
|
||||
val randomBytes = ByteArray(random.nextInt(10000))
|
||||
random.nextBytes(randomBytes)
|
||||
val base64Key = generateBase64Key()
|
||||
val encryptor = Encryptor(
|
||||
base64Key,
|
||||
base64Key,
|
||||
)
|
||||
|
||||
val encrypted = encryptor.encrypt(randomBytes)
|
||||
val decrypted = encryptor.decrypt(encrypted)
|
||||
|
||||
assertArrayEquals(randomBytes, decrypted)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `test encryption-decryption streams`() {
|
||||
val random = SecureRandom()
|
||||
val key = generateBase64Key()
|
||||
val baos = ByteArrayOutputStream()
|
||||
val encryptionStream = EncryptingOutputStream(baos, key)
|
||||
val bytes = ByteArray(random.nextInt(10000))
|
||||
random.nextBytes(bytes)
|
||||
|
||||
encryptionStream.use {
|
||||
it.write(bytes)
|
||||
}
|
||||
|
||||
val bais = ByteArrayInputStream(baos.toByteArray())
|
||||
val decryptingStream = DecryptingInputStream(bais, key)
|
||||
|
||||
val decryptedBytes = decryptingStream.readAllBytes()
|
||||
|
||||
assertArrayEquals(bytes, decryptedBytes)
|
||||
}
|
||||
}
|
||||
@@ -21,10 +21,12 @@ class TestOptimisticLocking {
|
||||
|
||||
val pst = Persistent(
|
||||
directory = File("data", "test-locking"),
|
||||
true,
|
||||
null,
|
||||
null,
|
||||
arrayOf(
|
||||
index<Person>("name") { p -> (p as? Person)?.name ?: "" },
|
||||
),
|
||||
true
|
||||
)
|
||||
|
||||
pst.transaction {
|
||||
|
||||
@@ -22,7 +22,7 @@ class TestPersistence {
|
||||
|
||||
val pst = Persistent(
|
||||
directory = File("data", "test-persistence"),
|
||||
arrayOf(
|
||||
indexes = arrayOf(
|
||||
index<Person>("name") { p -> (p as? Person)?.name ?: "" },
|
||||
index<Person>("age") { p -> (p as? Person)?.age ?: -1 },
|
||||
index<Person>("ageGt20") { p -> ((p as? Person)?.age ?: 0) > 20 },
|
||||
@@ -131,7 +131,7 @@ class TestPersistence {
|
||||
}
|
||||
}
|
||||
|
||||
//pst.snapshot()
|
||||
pst.snapshot()
|
||||
|
||||
pst.transaction {
|
||||
store(
|
||||
@@ -179,6 +179,6 @@ class TestPersistence {
|
||||
}
|
||||
|
||||
pst.datastore.printStatus()
|
||||
//pst.removeOldFiles()
|
||||
pst.removeOldFiles()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,8 +3,8 @@ package nl.astraeus.persistence
|
||||
import org.junit.jupiter.api.Test
|
||||
import java.io.File
|
||||
|
||||
|
||||
class TestPersistenceJavaInKotlin {
|
||||
|
||||
internal class Person(
|
||||
var name: String,
|
||||
var age: Int
|
||||
@@ -23,19 +23,21 @@ class TestPersistenceJavaInKotlin {
|
||||
|
||||
val persistent = Persistent(
|
||||
File("data", "java-kotlin-test"),
|
||||
arrayOf(
|
||||
enableOptimisticLocking = false,
|
||||
indexes = arrayOf(
|
||||
Index(
|
||||
Person::class,
|
||||
"name"
|
||||
) { p -> (p as Person).name }
|
||||
),
|
||||
false
|
||||
)
|
||||
|
||||
persistent.transaction {
|
||||
val person = find(Person::class.java, 1L)
|
||||
|
||||
if (person != null) {
|
||||
println("Person: ${person.name} is ${person.age} years old."
|
||||
println(
|
||||
"Person: ${person.name} is ${person.age} years old."
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ class TestThreaded {
|
||||
|
||||
val pst = Persistent(
|
||||
directory = File("data", "test-threaded"),
|
||||
arrayOf(
|
||||
indexes = arrayOf(
|
||||
index<Person>("name") { p -> (p as? Person)?.name ?: "" },
|
||||
index<Person>("age") { p -> (p as? Person)?.age ?: -1 },
|
||||
index<Person>("ageGt20") { p -> ((p as? Person)?.age ?: 0) > 20 },
|
||||
|
||||
Reference in New Issue
Block a user