From ea0d46164fbd6db5e2d67bd1cad34df6da2f874c Mon Sep 17 00:00:00 2001 From: rnentjes Date: Sun, 4 Aug 2024 12:13:31 +0200 Subject: [PATCH] Add encryption option --- .idea/kotlinc.xml | 2 +- build.gradle.kts | 4 +- .../nl/astraeus/persistence/Datastore.kt | 32 ++-- .../nl/astraeus/persistence/Encryption.kt | 150 ++++++++++++++++++ .../nl/astraeus/persistence/Persistent.kt | 12 +- .../nl/astraeus/persistence/Transaction.kt | 2 +- .../nl/astraeus/persistence/TransactionLog.kt | 24 ++- .../persistence/TestPersistenceJava.java | 7 +- .../nl/astraeus/persistence/EncryptionTest.kt | 54 +++++++ .../persistence/TestOptimisticLocking.kt | 4 +- .../astraeus/persistence/TestPersistence.kt | 6 +- .../TestPersistenceJavaInKotlin.kt | 10 +- .../nl/astraeus/persistence/TestThreaded.kt | 2 +- 13 files changed, 276 insertions(+), 33 deletions(-) create mode 100644 src/main/kotlin/nl/astraeus/persistence/Encryption.kt create mode 100644 src/test/kotlin/nl/astraeus/persistence/EncryptionTest.kt diff --git a/.idea/kotlinc.xml b/.idea/kotlinc.xml index 53bf319..6d0ee1c 100644 --- a/.idea/kotlinc.xml +++ b/.idea/kotlinc.xml @@ -1,6 +1,6 @@ - \ No newline at end of file diff --git a/build.gradle.kts b/build.gradle.kts index 59ab646..9e8ca1d 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -1,13 +1,13 @@ import org.gradle.model.internal.core.ModelNodes.withType plugins { - kotlin("jvm") version "2.0.0-RC2" + kotlin("jvm") version "2.0.0" id("maven-publish") id("signing") } group = "nl.astraeus" -version = "1.0.0-SNAPSHOT" +version = "1.1.2-SNAPSHOT" repositories { mavenCentral() diff --git a/src/main/kotlin/nl/astraeus/persistence/Datastore.kt b/src/main/kotlin/nl/astraeus/persistence/Datastore.kt index fdf0dc5..39b3a92 100644 --- a/src/main/kotlin/nl/astraeus/persistence/Datastore.kt +++ b/src/main/kotlin/nl/astraeus/persistence/Datastore.kt @@ -5,6 +5,7 @@ import java.io.ObjectInputStream import java.io.ObjectOutputStream import java.io.Serializable import java.text.DecimalFormat +import java.util.* import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicLong import kotlin.reflect.KClass @@ -39,6 +40,8 @@ class Action( class Datastore( private val directory: File, val enableOptimisticLocking: Boolean = false, + val decryptionKey: String? = null, + val encryptionKey: String? = null, indexes: Array = arrayOf(), ) { private val fileManager = FileManager(directory) @@ -109,7 +112,7 @@ class Datastore( val (lastSnapshot, lastSnapshotFile) = fileManager.findLastSnapshot() if (lastSnapshotFile != null) { - ObjectInputStream(lastSnapshotFile.inputStream()).use { ois -> + ObjectInputStream(DecryptingInputStream(lastSnapshotFile.inputStream(), decryptionKey)).use { ois -> readSnapshot(ois) } } @@ -117,7 +120,7 @@ class Datastore( val transactions = fileManager.findTransactionsAfter(lastSnapshot ?: 0L) transactions?.forEach { file -> - ObjectInputStream(file.inputStream()).use { ois -> + ObjectInputStream(DecryptingInputStream(file.inputStream(), decryptionKey)).use { ois -> readTransaction(ois) } } @@ -126,7 +129,7 @@ class Datastore( Logger.debug("Loaded transactions in %6.3fms", ((System.nanoTime() - start) / 1_000_000f)) } - private fun execute(actions: Set) { + private fun execute(actions: List) { synchronized(this) { if (enableOptimisticLocking) { for (action in actions) { @@ -211,7 +214,7 @@ class Datastore( return indexes[kClass.java]?.get(indexName) } - fun storeAndExecute(actions: Set) { + fun storeAndExecute(actions: List) { if (actions.isNotEmpty()) { synchronized(this) { execute(actions) @@ -225,14 +228,22 @@ class Datastore( check(versionNumber == 1) { "Unsupported version number: $versionNumber" } val transactionNumber = ois.readLong() nextTransactionNumber = transactionNumber + 1 - val actions = ois.readObject() as Set - execute(actions) + val actions = ois.readObject() + when(actions) { + is Set<*> -> { + val list = LinkedList(actions as Set) + execute(list) + } + is List<*> -> { + execute(actions as List) + } + } } - private fun writeTransaction(actions: Set) { + private fun writeTransaction(actions: List) { val number = transactionFormatter.format(nextTransactionNumber) val file = File(directory, "transaction-$number.trn-tmp") - ObjectOutputStream(file.outputStream()).use { oos -> + ObjectOutputStream(EncryptingOutputStream(file.outputStream(), encryptionKey)).use { oos -> // version number oos.writeInt(1) oos.writeLong(nextTransactionNumber++) @@ -246,7 +257,7 @@ class Datastore( synchronized(this) { val number = transactionFormatter.format(nextTransactionNumber) val file = File(directory, "transaction-$number.snp-tmp") - ObjectOutputStream(file.outputStream()).use { oos -> + ObjectOutputStream(EncryptingOutputStream(file.outputStream(), encryptionKey)).use { oos -> // version number oos.writeInt(1) oos.writeLong(nextTransactionNumber++) @@ -270,7 +281,8 @@ class Datastore( val versionNumber = ois.readInt() check(versionNumber == 1) { "Unsupported version number: $versionNumber" } nextTransactionNumber = ois.readLong() + 1 - data = ois.readObject() as MutableMap, TypeData> + val dataObj = ois.readObject() + data = dataObj as ConcurrentHashMap, TypeData> val foundIndexes = mutableMapOf, MutableList>() val numberOfClassesWithIndex = ois.readInt() diff --git a/src/main/kotlin/nl/astraeus/persistence/Encryption.kt b/src/main/kotlin/nl/astraeus/persistence/Encryption.kt new file mode 100644 index 0000000..fcfc319 --- /dev/null +++ b/src/main/kotlin/nl/astraeus/persistence/Encryption.kt @@ -0,0 +1,150 @@ +@file:OptIn(ExperimentalEncodingApi::class, ExperimentalEncodingApi::class) + +package nl.astraeus.persistence + +import java.io.ByteArrayOutputStream +import java.io.InputStream +import java.io.OutputStream +import java.security.SecureRandom +import javax.crypto.Cipher +import javax.crypto.KeyGenerator +import javax.crypto.SecretKey +import javax.crypto.spec.IvParameterSpec +import javax.crypto.spec.SecretKeySpec +import kotlin.io.encoding.Base64 +import kotlin.io.encoding.ExperimentalEncodingApi + +class Encryptor( + base64EncryptionKey: String?, + base64DecryptionKey: String?, +) { + private var decryptionKey: SecretKey? = null + private var encryptionKey: SecretKey? = null + + init { + if (base64EncryptionKey?.isNotEmpty() == true) { + encryptionKey = SecretKeySpec(Base64.UrlSafe.decode(base64EncryptionKey), "AES") + } + if (base64DecryptionKey?.isNotEmpty() == true) { + decryptionKey = SecretKeySpec(Base64.UrlSafe.decode(base64DecryptionKey), "AES") + } + } + + fun encrypt(data: ByteArray): ByteArray { + if (encryptionKey == null) { + return data + } + + val prePaddedData = ByteArray(16) + data + val cipher = Cipher.getInstance("AES/CBC/PKCS5Padding") + + // Generate a new IV (Initialization Vector) + val secureRandom = SecureRandom() + val iv = ByteArray(cipher.blockSize) + secureRandom.nextBytes(iv) + val ivParams = IvParameterSpec(iv) + + cipher.init(Cipher.ENCRYPT_MODE, encryptionKey, ivParams) + return cipher.doFinal(prePaddedData) + } + + fun decrypt(data: ByteArray): ByteArray { + if (decryptionKey == null) { + return data + } + + val cipher = Cipher.getInstance("AES/CBC/PKCS5Padding") + val secureRandom = SecureRandom() + val iv = ByteArray(cipher.blockSize) + secureRandom.nextBytes(iv) + cipher.init(Cipher.DECRYPT_MODE, decryptionKey, IvParameterSpec(iv)) + val completeData = cipher.doFinal(data) + + return completeData.sliceArray(16 until completeData.size) + } +} + +fun generateBase64Key(): String { + val keyGen: KeyGenerator = KeyGenerator.getInstance("AES") + keyGen.init(256) // for AES-256 + val secretKey: SecretKey = keyGen.generateKey() + + return Base64.UrlSafe.encode(secretKey.encoded) +} + +/* +object Encryption { + var encryptor = Encryptor( + System.getenv().getOrDefault("SPK_ENCRYPTION_KEY", ""), + System.getenv().getOrDefault("SPK_DECRYPTION_KEY", "") + ) +} +*/ + +class DecryptingInputStream( + val input: InputStream, + val base64DecryptionKey: String? +) : InputStream() { + val bytes: ByteArray + var index = 0 + + init { + val encryptedBytes = input.readAllBytes() + if (base64DecryptionKey?.isBlank() == true) { + bytes = encryptedBytes + } else { + val encryptor = Encryptor( + base64EncryptionKey = null, + base64DecryptionKey = base64DecryptionKey + ) + bytes = encryptor.decrypt(encryptedBytes) + } + } + +/* override fun readAllBytes(): ByteArray { + index = bytes.size + return bytes + }*/ + + override fun read(): Int { + return if (index < bytes.size) { + bytes[index++].toUByte().toInt() + } else { + -1 + } + } + + override fun close() { + input.close() + } +} + +class EncryptingOutputStream( + val output: OutputStream, + val base64EncryptionKey: String? +) : OutputStream() { + val baos = ByteArrayOutputStream() + + override fun write(b: Int) { + baos.write(b) + } + + override fun close() { + if (base64EncryptionKey?.isBlank() == true) { + output.write(baos.toByteArray()) + } else { + val encryptor = Encryptor( + base64EncryptionKey = base64EncryptionKey, + base64DecryptionKey = null + ) + val encryptedBytes = encryptor.encrypt(baos.toByteArray()) + output.write(encryptedBytes) + } + output.flush() + output.close() + } + + override fun flush() { + // no flush + } +} diff --git a/src/main/kotlin/nl/astraeus/persistence/Persistent.kt b/src/main/kotlin/nl/astraeus/persistence/Persistent.kt index 578bc25..0dadcb7 100644 --- a/src/main/kotlin/nl/astraeus/persistence/Persistent.kt +++ b/src/main/kotlin/nl/astraeus/persistence/Persistent.kt @@ -10,10 +10,18 @@ fun currentTransaction(): Transaction? { class Persistent( directory: File, - indexes: Array = arrayOf(), enableOptimisticLocking: Boolean = false, + decryptionKey: String? = null, + encryptionKey: String? = null, + indexes: Array = arrayOf(), ) { - val datastore: Datastore = Datastore(directory, enableOptimisticLocking, indexes) + val datastore: Datastore = Datastore( + directory, + enableOptimisticLocking, + decryptionKey, + encryptionKey, + indexes + ) fun query(block: Query.() -> T): T { var cleanup = false diff --git a/src/main/kotlin/nl/astraeus/persistence/Transaction.kt b/src/main/kotlin/nl/astraeus/persistence/Transaction.kt index 0c8c209..4eabc43 100644 --- a/src/main/kotlin/nl/astraeus/persistence/Transaction.kt +++ b/src/main/kotlin/nl/astraeus/persistence/Transaction.kt @@ -99,7 +99,7 @@ open class Query( class Transaction( persistent: Persistent, ) : Query(persistent), Serializable { - private val actions = mutableSetOf() + private val actions = ArrayList() fun store(obj: Persistable) { if (obj.id == 0L) { diff --git a/src/main/kotlin/nl/astraeus/persistence/TransactionLog.kt b/src/main/kotlin/nl/astraeus/persistence/TransactionLog.kt index 49dd071..95509ad 100644 --- a/src/main/kotlin/nl/astraeus/persistence/TransactionLog.kt +++ b/src/main/kotlin/nl/astraeus/persistence/TransactionLog.kt @@ -2,9 +2,11 @@ package nl.astraeus.persistence import java.io.File import java.io.ObjectInputStream +import java.util.* class TransactionLog( directory: File, + val decryptionKey: String? = null, ) { val fileManager = FileManager(directory) @@ -14,14 +16,15 @@ class TransactionLog( printer("Snapshot:") snapshot?.inputStream()?.use { input -> - ObjectInputStream(input).use { ois -> + ObjectInputStream(DecryptingInputStream(input, decryptionKey)).use { ois -> val versionNumber = ois.readInt() check(versionNumber == 1) { "Unsupported version number: $versionNumber" } val transactionNumber = ois.readLong() printer("[$versionNumber] $transactionNumber") - val data = ois.readObject() as MutableMap, TypeData> + val dataObj = ois.readObject() + val data = dataObj as MutableMap, TypeData> printer("Data:") printer("\tClasses:") for ((cls, entries) in data.entries) { @@ -35,15 +38,26 @@ class TransactionLog( printer("Transactions:") transactions?.forEach { transaction -> transaction.inputStream().use { input -> - ObjectInputStream(input).use { ois -> + ObjectInputStream(DecryptingInputStream(input, decryptionKey)).use { ois -> val versionNumber = ois.readInt() check(versionNumber == 1) { "Unsupported version number: $versionNumber" } val transactionNumber = ois.readLong() - val actions = ois.readObject() as Set + val actions = ois.readObject() + val actionList = when(actions) { + is Set<*> -> { + LinkedList(actions as Set) + } + is List<*> -> { + actions as List + } + else -> { + emptyList() + } + } printer("\t[$transactionNumber]") - for (action in actions) { + for (action in actionList) { printer("\t\t- $action") } } diff --git a/src/test/java/nl/astraeus/persistence/TestPersistenceJava.java b/src/test/java/nl/astraeus/persistence/TestPersistenceJava.java index cdd1516..f0e055c 100644 --- a/src/test/java/nl/astraeus/persistence/TestPersistenceJava.java +++ b/src/test/java/nl/astraeus/persistence/TestPersistenceJava.java @@ -5,7 +5,6 @@ import org.junit.jupiter.api.Test; import java.io.File; import java.util.List; - public class TestPersistenceJava { static class Person extends AbstractPersistable { @@ -64,14 +63,16 @@ public class TestPersistenceJava { Persistent persistent = new Persistent( new File("data", "java-test"), + false, + null, + null, new Index[] { new Index<>( Person.class, "name", (p) -> ((Person)p).getName() ) - }, - false + } ); persistent.transaction((t) -> { diff --git a/src/test/kotlin/nl/astraeus/persistence/EncryptionTest.kt b/src/test/kotlin/nl/astraeus/persistence/EncryptionTest.kt new file mode 100644 index 0000000..6dcbcd8 --- /dev/null +++ b/src/test/kotlin/nl/astraeus/persistence/EncryptionTest.kt @@ -0,0 +1,54 @@ +package nl.astraeus.persistence + +import org.junit.jupiter.api.Test + +import org.junit.jupiter.api.Assertions.* +import java.io.ByteArrayInputStream +import java.io.ByteArrayOutputStream +import java.security.SecureRandom + +class EncryptionTest { + + @Test + fun testKeyGen() { + println(generateBase64Key()) + } + + @Test + fun testEncryptDecrypt() { + val random = SecureRandom() + val randomBytes = ByteArray(random.nextInt(10000)) + random.nextBytes(randomBytes) + val base64Key = generateBase64Key() + val encryptor = Encryptor( + base64Key, + base64Key, + ) + + val encrypted = encryptor.encrypt(randomBytes) + val decrypted = encryptor.decrypt(encrypted) + + assertArrayEquals(randomBytes, decrypted) + } + + @Test + fun `test encryption-decryption streams`() { + val random = SecureRandom() + val key = generateBase64Key() + val baos = ByteArrayOutputStream() + val encryptionStream = EncryptingOutputStream(baos, key) + val bytes = ByteArray(random.nextInt(10000)) + random.nextBytes(bytes) + + encryptionStream.use { + it.write(bytes) + } + + val bais = ByteArrayInputStream(baos.toByteArray()) + val decryptingStream = DecryptingInputStream(bais, key) + + val decryptedBytes = decryptingStream.readAllBytes() + + assertArrayEquals(bytes, decryptedBytes) + } +} diff --git a/src/test/kotlin/nl/astraeus/persistence/TestOptimisticLocking.kt b/src/test/kotlin/nl/astraeus/persistence/TestOptimisticLocking.kt index 7349239..597b93e 100644 --- a/src/test/kotlin/nl/astraeus/persistence/TestOptimisticLocking.kt +++ b/src/test/kotlin/nl/astraeus/persistence/TestOptimisticLocking.kt @@ -21,10 +21,12 @@ class TestOptimisticLocking { val pst = Persistent( directory = File("data", "test-locking"), + true, + null, + null, arrayOf( index("name") { p -> (p as? Person)?.name ?: "" }, ), - true ) pst.transaction { diff --git a/src/test/kotlin/nl/astraeus/persistence/TestPersistence.kt b/src/test/kotlin/nl/astraeus/persistence/TestPersistence.kt index d1d6e11..32e736f 100644 --- a/src/test/kotlin/nl/astraeus/persistence/TestPersistence.kt +++ b/src/test/kotlin/nl/astraeus/persistence/TestPersistence.kt @@ -22,7 +22,7 @@ class TestPersistence { val pst = Persistent( directory = File("data", "test-persistence"), - arrayOf( + indexes = arrayOf( index("name") { p -> (p as? Person)?.name ?: "" }, index("age") { p -> (p as? Person)?.age ?: -1 }, index("ageGt20") { p -> ((p as? Person)?.age ?: 0) > 20 }, @@ -131,7 +131,7 @@ class TestPersistence { } } - //pst.snapshot() + pst.snapshot() pst.transaction { store( @@ -179,6 +179,6 @@ class TestPersistence { } pst.datastore.printStatus() - //pst.removeOldFiles() + pst.removeOldFiles() } } diff --git a/src/test/kotlin/nl/astraeus/persistence/TestPersistenceJavaInKotlin.kt b/src/test/kotlin/nl/astraeus/persistence/TestPersistenceJavaInKotlin.kt index 2fb072d..a07df50 100644 --- a/src/test/kotlin/nl/astraeus/persistence/TestPersistenceJavaInKotlin.kt +++ b/src/test/kotlin/nl/astraeus/persistence/TestPersistenceJavaInKotlin.kt @@ -3,8 +3,8 @@ package nl.astraeus.persistence import org.junit.jupiter.api.Test import java.io.File - class TestPersistenceJavaInKotlin { + internal class Person( var name: String, var age: Int @@ -23,19 +23,21 @@ class TestPersistenceJavaInKotlin { val persistent = Persistent( File("data", "java-kotlin-test"), - arrayOf( + enableOptimisticLocking = false, + indexes = arrayOf( Index( Person::class, "name" ) { p -> (p as Person).name } ), - false ) persistent.transaction { val person = find(Person::class.java, 1L) + if (person != null) { - println("Person: ${person.name} is ${person.age} years old." + println( + "Person: ${person.name} is ${person.age} years old." ) } } diff --git a/src/test/kotlin/nl/astraeus/persistence/TestThreaded.kt b/src/test/kotlin/nl/astraeus/persistence/TestThreaded.kt index d059932..28af980 100644 --- a/src/test/kotlin/nl/astraeus/persistence/TestThreaded.kt +++ b/src/test/kotlin/nl/astraeus/persistence/TestThreaded.kt @@ -22,7 +22,7 @@ class TestThreaded { val pst = Persistent( directory = File("data", "test-threaded"), - arrayOf( + indexes = arrayOf( index("name") { p -> (p as? Person)?.name ?: "" }, index("age") { p -> (p as? Person)?.age ?: -1 }, index("ageGt20") { p -> ((p as? Person)?.age ?: 0) > 20 },