diff --git a/.idea/kotlinc.xml b/.idea/kotlinc.xml
index 53bf319..6d0ee1c 100644
--- a/.idea/kotlinc.xml
+++ b/.idea/kotlinc.xml
@@ -1,6 +1,6 @@
-
+
\ No newline at end of file
diff --git a/build.gradle.kts b/build.gradle.kts
index 59ab646..9e8ca1d 100644
--- a/build.gradle.kts
+++ b/build.gradle.kts
@@ -1,13 +1,13 @@
import org.gradle.model.internal.core.ModelNodes.withType
plugins {
- kotlin("jvm") version "2.0.0-RC2"
+ kotlin("jvm") version "2.0.0"
id("maven-publish")
id("signing")
}
group = "nl.astraeus"
-version = "1.0.0-SNAPSHOT"
+version = "1.1.2-SNAPSHOT"
repositories {
mavenCentral()
diff --git a/src/main/kotlin/nl/astraeus/persistence/Datastore.kt b/src/main/kotlin/nl/astraeus/persistence/Datastore.kt
index fdf0dc5..39b3a92 100644
--- a/src/main/kotlin/nl/astraeus/persistence/Datastore.kt
+++ b/src/main/kotlin/nl/astraeus/persistence/Datastore.kt
@@ -5,6 +5,7 @@ import java.io.ObjectInputStream
import java.io.ObjectOutputStream
import java.io.Serializable
import java.text.DecimalFormat
+import java.util.*
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicLong
import kotlin.reflect.KClass
@@ -39,6 +40,8 @@ class Action(
class Datastore(
private val directory: File,
val enableOptimisticLocking: Boolean = false,
+ val decryptionKey: String? = null,
+ val encryptionKey: String? = null,
indexes: Array = arrayOf(),
) {
private val fileManager = FileManager(directory)
@@ -109,7 +112,7 @@ class Datastore(
val (lastSnapshot, lastSnapshotFile) = fileManager.findLastSnapshot()
if (lastSnapshotFile != null) {
- ObjectInputStream(lastSnapshotFile.inputStream()).use { ois ->
+ ObjectInputStream(DecryptingInputStream(lastSnapshotFile.inputStream(), decryptionKey)).use { ois ->
readSnapshot(ois)
}
}
@@ -117,7 +120,7 @@ class Datastore(
val transactions = fileManager.findTransactionsAfter(lastSnapshot ?: 0L)
transactions?.forEach { file ->
- ObjectInputStream(file.inputStream()).use { ois ->
+ ObjectInputStream(DecryptingInputStream(file.inputStream(), decryptionKey)).use { ois ->
readTransaction(ois)
}
}
@@ -126,7 +129,7 @@ class Datastore(
Logger.debug("Loaded transactions in %6.3fms", ((System.nanoTime() - start) / 1_000_000f))
}
- private fun execute(actions: Set) {
+ private fun execute(actions: List) {
synchronized(this) {
if (enableOptimisticLocking) {
for (action in actions) {
@@ -211,7 +214,7 @@ class Datastore(
return indexes[kClass.java]?.get(indexName)
}
- fun storeAndExecute(actions: Set) {
+ fun storeAndExecute(actions: List) {
if (actions.isNotEmpty()) {
synchronized(this) {
execute(actions)
@@ -225,14 +228,22 @@ class Datastore(
check(versionNumber == 1) { "Unsupported version number: $versionNumber" }
val transactionNumber = ois.readLong()
nextTransactionNumber = transactionNumber + 1
- val actions = ois.readObject() as Set
- execute(actions)
+ val actions = ois.readObject()
+ when(actions) {
+ is Set<*> -> {
+ val list = LinkedList(actions as Set)
+ execute(list)
+ }
+ is List<*> -> {
+ execute(actions as List)
+ }
+ }
}
- private fun writeTransaction(actions: Set) {
+ private fun writeTransaction(actions: List) {
val number = transactionFormatter.format(nextTransactionNumber)
val file = File(directory, "transaction-$number.trn-tmp")
- ObjectOutputStream(file.outputStream()).use { oos ->
+ ObjectOutputStream(EncryptingOutputStream(file.outputStream(), encryptionKey)).use { oos ->
// version number
oos.writeInt(1)
oos.writeLong(nextTransactionNumber++)
@@ -246,7 +257,7 @@ class Datastore(
synchronized(this) {
val number = transactionFormatter.format(nextTransactionNumber)
val file = File(directory, "transaction-$number.snp-tmp")
- ObjectOutputStream(file.outputStream()).use { oos ->
+ ObjectOutputStream(EncryptingOutputStream(file.outputStream(), encryptionKey)).use { oos ->
// version number
oos.writeInt(1)
oos.writeLong(nextTransactionNumber++)
@@ -270,7 +281,8 @@ class Datastore(
val versionNumber = ois.readInt()
check(versionNumber == 1) { "Unsupported version number: $versionNumber" }
nextTransactionNumber = ois.readLong() + 1
- data = ois.readObject() as MutableMap, TypeData>
+ val dataObj = ois.readObject()
+ data = dataObj as ConcurrentHashMap, TypeData>
val foundIndexes = mutableMapOf, MutableList>()
val numberOfClassesWithIndex = ois.readInt()
diff --git a/src/main/kotlin/nl/astraeus/persistence/Encryption.kt b/src/main/kotlin/nl/astraeus/persistence/Encryption.kt
new file mode 100644
index 0000000..fcfc319
--- /dev/null
+++ b/src/main/kotlin/nl/astraeus/persistence/Encryption.kt
@@ -0,0 +1,150 @@
+@file:OptIn(ExperimentalEncodingApi::class, ExperimentalEncodingApi::class)
+
+package nl.astraeus.persistence
+
+import java.io.ByteArrayOutputStream
+import java.io.InputStream
+import java.io.OutputStream
+import java.security.SecureRandom
+import javax.crypto.Cipher
+import javax.crypto.KeyGenerator
+import javax.crypto.SecretKey
+import javax.crypto.spec.IvParameterSpec
+import javax.crypto.spec.SecretKeySpec
+import kotlin.io.encoding.Base64
+import kotlin.io.encoding.ExperimentalEncodingApi
+
+class Encryptor(
+ base64EncryptionKey: String?,
+ base64DecryptionKey: String?,
+) {
+ private var decryptionKey: SecretKey? = null
+ private var encryptionKey: SecretKey? = null
+
+ init {
+ if (base64EncryptionKey?.isNotEmpty() == true) {
+ encryptionKey = SecretKeySpec(Base64.UrlSafe.decode(base64EncryptionKey), "AES")
+ }
+ if (base64DecryptionKey?.isNotEmpty() == true) {
+ decryptionKey = SecretKeySpec(Base64.UrlSafe.decode(base64DecryptionKey), "AES")
+ }
+ }
+
+ fun encrypt(data: ByteArray): ByteArray {
+ if (encryptionKey == null) {
+ return data
+ }
+
+ val prePaddedData = ByteArray(16) + data
+ val cipher = Cipher.getInstance("AES/CBC/PKCS5Padding")
+
+ // Generate a new IV (Initialization Vector)
+ val secureRandom = SecureRandom()
+ val iv = ByteArray(cipher.blockSize)
+ secureRandom.nextBytes(iv)
+ val ivParams = IvParameterSpec(iv)
+
+ cipher.init(Cipher.ENCRYPT_MODE, encryptionKey, ivParams)
+ return cipher.doFinal(prePaddedData)
+ }
+
+ fun decrypt(data: ByteArray): ByteArray {
+ if (decryptionKey == null) {
+ return data
+ }
+
+ val cipher = Cipher.getInstance("AES/CBC/PKCS5Padding")
+ val secureRandom = SecureRandom()
+ val iv = ByteArray(cipher.blockSize)
+ secureRandom.nextBytes(iv)
+ cipher.init(Cipher.DECRYPT_MODE, decryptionKey, IvParameterSpec(iv))
+ val completeData = cipher.doFinal(data)
+
+ return completeData.sliceArray(16 until completeData.size)
+ }
+}
+
+fun generateBase64Key(): String {
+ val keyGen: KeyGenerator = KeyGenerator.getInstance("AES")
+ keyGen.init(256) // for AES-256
+ val secretKey: SecretKey = keyGen.generateKey()
+
+ return Base64.UrlSafe.encode(secretKey.encoded)
+}
+
+/*
+object Encryption {
+ var encryptor = Encryptor(
+ System.getenv().getOrDefault("SPK_ENCRYPTION_KEY", ""),
+ System.getenv().getOrDefault("SPK_DECRYPTION_KEY", "")
+ )
+}
+*/
+
+class DecryptingInputStream(
+ val input: InputStream,
+ val base64DecryptionKey: String?
+) : InputStream() {
+ val bytes: ByteArray
+ var index = 0
+
+ init {
+ val encryptedBytes = input.readAllBytes()
+ if (base64DecryptionKey?.isBlank() == true) {
+ bytes = encryptedBytes
+ } else {
+ val encryptor = Encryptor(
+ base64EncryptionKey = null,
+ base64DecryptionKey = base64DecryptionKey
+ )
+ bytes = encryptor.decrypt(encryptedBytes)
+ }
+ }
+
+/* override fun readAllBytes(): ByteArray {
+ index = bytes.size
+ return bytes
+ }*/
+
+ override fun read(): Int {
+ return if (index < bytes.size) {
+ bytes[index++].toUByte().toInt()
+ } else {
+ -1
+ }
+ }
+
+ override fun close() {
+ input.close()
+ }
+}
+
+class EncryptingOutputStream(
+ val output: OutputStream,
+ val base64EncryptionKey: String?
+) : OutputStream() {
+ val baos = ByteArrayOutputStream()
+
+ override fun write(b: Int) {
+ baos.write(b)
+ }
+
+ override fun close() {
+ if (base64EncryptionKey?.isBlank() == true) {
+ output.write(baos.toByteArray())
+ } else {
+ val encryptor = Encryptor(
+ base64EncryptionKey = base64EncryptionKey,
+ base64DecryptionKey = null
+ )
+ val encryptedBytes = encryptor.encrypt(baos.toByteArray())
+ output.write(encryptedBytes)
+ }
+ output.flush()
+ output.close()
+ }
+
+ override fun flush() {
+ // no flush
+ }
+}
diff --git a/src/main/kotlin/nl/astraeus/persistence/Persistent.kt b/src/main/kotlin/nl/astraeus/persistence/Persistent.kt
index 578bc25..0dadcb7 100644
--- a/src/main/kotlin/nl/astraeus/persistence/Persistent.kt
+++ b/src/main/kotlin/nl/astraeus/persistence/Persistent.kt
@@ -10,10 +10,18 @@ fun currentTransaction(): Transaction? {
class Persistent(
directory: File,
- indexes: Array = arrayOf(),
enableOptimisticLocking: Boolean = false,
+ decryptionKey: String? = null,
+ encryptionKey: String? = null,
+ indexes: Array = arrayOf(),
) {
- val datastore: Datastore = Datastore(directory, enableOptimisticLocking, indexes)
+ val datastore: Datastore = Datastore(
+ directory,
+ enableOptimisticLocking,
+ decryptionKey,
+ encryptionKey,
+ indexes
+ )
fun query(block: Query.() -> T): T {
var cleanup = false
diff --git a/src/main/kotlin/nl/astraeus/persistence/Transaction.kt b/src/main/kotlin/nl/astraeus/persistence/Transaction.kt
index 0c8c209..4eabc43 100644
--- a/src/main/kotlin/nl/astraeus/persistence/Transaction.kt
+++ b/src/main/kotlin/nl/astraeus/persistence/Transaction.kt
@@ -99,7 +99,7 @@ open class Query(
class Transaction(
persistent: Persistent,
) : Query(persistent), Serializable {
- private val actions = mutableSetOf()
+ private val actions = ArrayList()
fun store(obj: Persistable) {
if (obj.id == 0L) {
diff --git a/src/main/kotlin/nl/astraeus/persistence/TransactionLog.kt b/src/main/kotlin/nl/astraeus/persistence/TransactionLog.kt
index 49dd071..95509ad 100644
--- a/src/main/kotlin/nl/astraeus/persistence/TransactionLog.kt
+++ b/src/main/kotlin/nl/astraeus/persistence/TransactionLog.kt
@@ -2,9 +2,11 @@ package nl.astraeus.persistence
import java.io.File
import java.io.ObjectInputStream
+import java.util.*
class TransactionLog(
directory: File,
+ val decryptionKey: String? = null,
) {
val fileManager = FileManager(directory)
@@ -14,14 +16,15 @@ class TransactionLog(
printer("Snapshot:")
snapshot?.inputStream()?.use { input ->
- ObjectInputStream(input).use { ois ->
+ ObjectInputStream(DecryptingInputStream(input, decryptionKey)).use { ois ->
val versionNumber = ois.readInt()
check(versionNumber == 1) {
"Unsupported version number: $versionNumber"
}
val transactionNumber = ois.readLong()
printer("[$versionNumber] $transactionNumber")
- val data = ois.readObject() as MutableMap, TypeData>
+ val dataObj = ois.readObject()
+ val data = dataObj as MutableMap, TypeData>
printer("Data:")
printer("\tClasses:")
for ((cls, entries) in data.entries) {
@@ -35,15 +38,26 @@ class TransactionLog(
printer("Transactions:")
transactions?.forEach { transaction ->
transaction.inputStream().use { input ->
- ObjectInputStream(input).use { ois ->
+ ObjectInputStream(DecryptingInputStream(input, decryptionKey)).use { ois ->
val versionNumber = ois.readInt()
check(versionNumber == 1) {
"Unsupported version number: $versionNumber"
}
val transactionNumber = ois.readLong()
- val actions = ois.readObject() as Set
+ val actions = ois.readObject()
+ val actionList = when(actions) {
+ is Set<*> -> {
+ LinkedList(actions as Set)
+ }
+ is List<*> -> {
+ actions as List
+ }
+ else -> {
+ emptyList()
+ }
+ }
printer("\t[$transactionNumber]")
- for (action in actions) {
+ for (action in actionList) {
printer("\t\t- $action")
}
}
diff --git a/src/test/java/nl/astraeus/persistence/TestPersistenceJava.java b/src/test/java/nl/astraeus/persistence/TestPersistenceJava.java
index cdd1516..f0e055c 100644
--- a/src/test/java/nl/astraeus/persistence/TestPersistenceJava.java
+++ b/src/test/java/nl/astraeus/persistence/TestPersistenceJava.java
@@ -5,7 +5,6 @@ import org.junit.jupiter.api.Test;
import java.io.File;
import java.util.List;
-
public class TestPersistenceJava {
static class Person extends AbstractPersistable {
@@ -64,14 +63,16 @@ public class TestPersistenceJava {
Persistent persistent = new Persistent(
new File("data", "java-test"),
+ false,
+ null,
+ null,
new Index[] {
new Index<>(
Person.class,
"name",
(p) -> ((Person)p).getName()
)
- },
- false
+ }
);
persistent.transaction((t) -> {
diff --git a/src/test/kotlin/nl/astraeus/persistence/EncryptionTest.kt b/src/test/kotlin/nl/astraeus/persistence/EncryptionTest.kt
new file mode 100644
index 0000000..6dcbcd8
--- /dev/null
+++ b/src/test/kotlin/nl/astraeus/persistence/EncryptionTest.kt
@@ -0,0 +1,54 @@
+package nl.astraeus.persistence
+
+import org.junit.jupiter.api.Test
+
+import org.junit.jupiter.api.Assertions.*
+import java.io.ByteArrayInputStream
+import java.io.ByteArrayOutputStream
+import java.security.SecureRandom
+
+class EncryptionTest {
+
+ @Test
+ fun testKeyGen() {
+ println(generateBase64Key())
+ }
+
+ @Test
+ fun testEncryptDecrypt() {
+ val random = SecureRandom()
+ val randomBytes = ByteArray(random.nextInt(10000))
+ random.nextBytes(randomBytes)
+ val base64Key = generateBase64Key()
+ val encryptor = Encryptor(
+ base64Key,
+ base64Key,
+ )
+
+ val encrypted = encryptor.encrypt(randomBytes)
+ val decrypted = encryptor.decrypt(encrypted)
+
+ assertArrayEquals(randomBytes, decrypted)
+ }
+
+ @Test
+ fun `test encryption-decryption streams`() {
+ val random = SecureRandom()
+ val key = generateBase64Key()
+ val baos = ByteArrayOutputStream()
+ val encryptionStream = EncryptingOutputStream(baos, key)
+ val bytes = ByteArray(random.nextInt(10000))
+ random.nextBytes(bytes)
+
+ encryptionStream.use {
+ it.write(bytes)
+ }
+
+ val bais = ByteArrayInputStream(baos.toByteArray())
+ val decryptingStream = DecryptingInputStream(bais, key)
+
+ val decryptedBytes = decryptingStream.readAllBytes()
+
+ assertArrayEquals(bytes, decryptedBytes)
+ }
+}
diff --git a/src/test/kotlin/nl/astraeus/persistence/TestOptimisticLocking.kt b/src/test/kotlin/nl/astraeus/persistence/TestOptimisticLocking.kt
index 7349239..597b93e 100644
--- a/src/test/kotlin/nl/astraeus/persistence/TestOptimisticLocking.kt
+++ b/src/test/kotlin/nl/astraeus/persistence/TestOptimisticLocking.kt
@@ -21,10 +21,12 @@ class TestOptimisticLocking {
val pst = Persistent(
directory = File("data", "test-locking"),
+ true,
+ null,
+ null,
arrayOf(
index("name") { p -> (p as? Person)?.name ?: "" },
),
- true
)
pst.transaction {
diff --git a/src/test/kotlin/nl/astraeus/persistence/TestPersistence.kt b/src/test/kotlin/nl/astraeus/persistence/TestPersistence.kt
index d1d6e11..32e736f 100644
--- a/src/test/kotlin/nl/astraeus/persistence/TestPersistence.kt
+++ b/src/test/kotlin/nl/astraeus/persistence/TestPersistence.kt
@@ -22,7 +22,7 @@ class TestPersistence {
val pst = Persistent(
directory = File("data", "test-persistence"),
- arrayOf(
+ indexes = arrayOf(
index("name") { p -> (p as? Person)?.name ?: "" },
index("age") { p -> (p as? Person)?.age ?: -1 },
index("ageGt20") { p -> ((p as? Person)?.age ?: 0) > 20 },
@@ -131,7 +131,7 @@ class TestPersistence {
}
}
- //pst.snapshot()
+ pst.snapshot()
pst.transaction {
store(
@@ -179,6 +179,6 @@ class TestPersistence {
}
pst.datastore.printStatus()
- //pst.removeOldFiles()
+ pst.removeOldFiles()
}
}
diff --git a/src/test/kotlin/nl/astraeus/persistence/TestPersistenceJavaInKotlin.kt b/src/test/kotlin/nl/astraeus/persistence/TestPersistenceJavaInKotlin.kt
index 2fb072d..a07df50 100644
--- a/src/test/kotlin/nl/astraeus/persistence/TestPersistenceJavaInKotlin.kt
+++ b/src/test/kotlin/nl/astraeus/persistence/TestPersistenceJavaInKotlin.kt
@@ -3,8 +3,8 @@ package nl.astraeus.persistence
import org.junit.jupiter.api.Test
import java.io.File
-
class TestPersistenceJavaInKotlin {
+
internal class Person(
var name: String,
var age: Int
@@ -23,19 +23,21 @@ class TestPersistenceJavaInKotlin {
val persistent = Persistent(
File("data", "java-kotlin-test"),
- arrayOf(
+ enableOptimisticLocking = false,
+ indexes = arrayOf(
Index(
Person::class,
"name"
) { p -> (p as Person).name }
),
- false
)
persistent.transaction {
val person = find(Person::class.java, 1L)
+
if (person != null) {
- println("Person: ${person.name} is ${person.age} years old."
+ println(
+ "Person: ${person.name} is ${person.age} years old."
)
}
}
diff --git a/src/test/kotlin/nl/astraeus/persistence/TestThreaded.kt b/src/test/kotlin/nl/astraeus/persistence/TestThreaded.kt
index d059932..28af980 100644
--- a/src/test/kotlin/nl/astraeus/persistence/TestThreaded.kt
+++ b/src/test/kotlin/nl/astraeus/persistence/TestThreaded.kt
@@ -22,7 +22,7 @@ class TestThreaded {
val pst = Persistent(
directory = File("data", "test-threaded"),
- arrayOf(
+ indexes = arrayOf(
index("name") { p -> (p as? Person)?.name ?: "" },
index("age") { p -> (p as? Person)?.age ?: -1 },
index("ageGt20") { p -> ((p as? Person)?.age ?: 0) > 20 },