Add encryption option

This commit is contained in:
2024-08-04 12:13:31 +02:00
parent c6f84224b1
commit ea0d46164f
13 changed files with 276 additions and 33 deletions

2
.idea/kotlinc.xml generated
View File

@@ -1,6 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="KotlinJpsPluginSettings">
<option name="version" value="2.0.0-RC2" />
<option name="version" value="2.0.0" />
</component>
</project>

View File

@@ -1,13 +1,13 @@
import org.gradle.model.internal.core.ModelNodes.withType
plugins {
kotlin("jvm") version "2.0.0-RC2"
kotlin("jvm") version "2.0.0"
id("maven-publish")
id("signing")
}
group = "nl.astraeus"
version = "1.0.0-SNAPSHOT"
version = "1.1.2-SNAPSHOT"
repositories {
mavenCentral()

View File

@@ -5,6 +5,7 @@ import java.io.ObjectInputStream
import java.io.ObjectOutputStream
import java.io.Serializable
import java.text.DecimalFormat
import java.util.*
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicLong
import kotlin.reflect.KClass
@@ -39,6 +40,8 @@ class Action(
class Datastore(
private val directory: File,
val enableOptimisticLocking: Boolean = false,
val decryptionKey: String? = null,
val encryptionKey: String? = null,
indexes: Array<PersistableIndex> = arrayOf(),
) {
private val fileManager = FileManager(directory)
@@ -109,7 +112,7 @@ class Datastore(
val (lastSnapshot, lastSnapshotFile) = fileManager.findLastSnapshot()
if (lastSnapshotFile != null) {
ObjectInputStream(lastSnapshotFile.inputStream()).use { ois ->
ObjectInputStream(DecryptingInputStream(lastSnapshotFile.inputStream(), decryptionKey)).use { ois ->
readSnapshot(ois)
}
}
@@ -117,7 +120,7 @@ class Datastore(
val transactions = fileManager.findTransactionsAfter(lastSnapshot ?: 0L)
transactions?.forEach { file ->
ObjectInputStream(file.inputStream()).use { ois ->
ObjectInputStream(DecryptingInputStream(file.inputStream(), decryptionKey)).use { ois ->
readTransaction(ois)
}
}
@@ -126,7 +129,7 @@ class Datastore(
Logger.debug("Loaded transactions in %6.3fms", ((System.nanoTime() - start) / 1_000_000f))
}
private fun execute(actions: Set<Action>) {
private fun execute(actions: List<Action>) {
synchronized(this) {
if (enableOptimisticLocking) {
for (action in actions) {
@@ -211,7 +214,7 @@ class Datastore(
return indexes[kClass.java]?.get(indexName)
}
fun storeAndExecute(actions: Set<Action>) {
fun storeAndExecute(actions: List<Action>) {
if (actions.isNotEmpty()) {
synchronized(this) {
execute(actions)
@@ -225,14 +228,22 @@ class Datastore(
check(versionNumber == 1) { "Unsupported version number: $versionNumber" }
val transactionNumber = ois.readLong()
nextTransactionNumber = transactionNumber + 1
val actions = ois.readObject() as Set<Action>
execute(actions)
val actions = ois.readObject()
when(actions) {
is Set<*> -> {
val list = LinkedList(actions as Set<Action>)
execute(list)
}
is List<*> -> {
execute(actions as List<Action>)
}
}
}
private fun writeTransaction(actions: Set<Action>) {
private fun writeTransaction(actions: List<Action>) {
val number = transactionFormatter.format(nextTransactionNumber)
val file = File(directory, "transaction-$number.trn-tmp")
ObjectOutputStream(file.outputStream()).use { oos ->
ObjectOutputStream(EncryptingOutputStream(file.outputStream(), encryptionKey)).use { oos ->
// version number
oos.writeInt(1)
oos.writeLong(nextTransactionNumber++)
@@ -246,7 +257,7 @@ class Datastore(
synchronized(this) {
val number = transactionFormatter.format(nextTransactionNumber)
val file = File(directory, "transaction-$number.snp-tmp")
ObjectOutputStream(file.outputStream()).use { oos ->
ObjectOutputStream(EncryptingOutputStream(file.outputStream(), encryptionKey)).use { oos ->
// version number
oos.writeInt(1)
oos.writeLong(nextTransactionNumber++)
@@ -270,7 +281,8 @@ class Datastore(
val versionNumber = ois.readInt()
check(versionNumber == 1) { "Unsupported version number: $versionNumber" }
nextTransactionNumber = ois.readLong() + 1
data = ois.readObject() as MutableMap<Class<*>, TypeData>
val dataObj = ois.readObject()
data = dataObj as ConcurrentHashMap<Class<*>, TypeData>
val foundIndexes = mutableMapOf<Class<*>, MutableList<String>>()
val numberOfClassesWithIndex = ois.readInt()

View File

@@ -0,0 +1,150 @@
@file:OptIn(ExperimentalEncodingApi::class, ExperimentalEncodingApi::class)
package nl.astraeus.persistence
import java.io.ByteArrayOutputStream
import java.io.InputStream
import java.io.OutputStream
import java.security.SecureRandom
import javax.crypto.Cipher
import javax.crypto.KeyGenerator
import javax.crypto.SecretKey
import javax.crypto.spec.IvParameterSpec
import javax.crypto.spec.SecretKeySpec
import kotlin.io.encoding.Base64
import kotlin.io.encoding.ExperimentalEncodingApi
class Encryptor(
base64EncryptionKey: String?,
base64DecryptionKey: String?,
) {
private var decryptionKey: SecretKey? = null
private var encryptionKey: SecretKey? = null
init {
if (base64EncryptionKey?.isNotEmpty() == true) {
encryptionKey = SecretKeySpec(Base64.UrlSafe.decode(base64EncryptionKey), "AES")
}
if (base64DecryptionKey?.isNotEmpty() == true) {
decryptionKey = SecretKeySpec(Base64.UrlSafe.decode(base64DecryptionKey), "AES")
}
}
fun encrypt(data: ByteArray): ByteArray {
if (encryptionKey == null) {
return data
}
val prePaddedData = ByteArray(16) + data
val cipher = Cipher.getInstance("AES/CBC/PKCS5Padding")
// Generate a new IV (Initialization Vector)
val secureRandom = SecureRandom()
val iv = ByteArray(cipher.blockSize)
secureRandom.nextBytes(iv)
val ivParams = IvParameterSpec(iv)
cipher.init(Cipher.ENCRYPT_MODE, encryptionKey, ivParams)
return cipher.doFinal(prePaddedData)
}
fun decrypt(data: ByteArray): ByteArray {
if (decryptionKey == null) {
return data
}
val cipher = Cipher.getInstance("AES/CBC/PKCS5Padding")
val secureRandom = SecureRandom()
val iv = ByteArray(cipher.blockSize)
secureRandom.nextBytes(iv)
cipher.init(Cipher.DECRYPT_MODE, decryptionKey, IvParameterSpec(iv))
val completeData = cipher.doFinal(data)
return completeData.sliceArray(16 until completeData.size)
}
}
fun generateBase64Key(): String {
val keyGen: KeyGenerator = KeyGenerator.getInstance("AES")
keyGen.init(256) // for AES-256
val secretKey: SecretKey = keyGen.generateKey()
return Base64.UrlSafe.encode(secretKey.encoded)
}
/*
object Encryption {
var encryptor = Encryptor(
System.getenv().getOrDefault("SPK_ENCRYPTION_KEY", ""),
System.getenv().getOrDefault("SPK_DECRYPTION_KEY", "")
)
}
*/
class DecryptingInputStream(
val input: InputStream,
val base64DecryptionKey: String?
) : InputStream() {
val bytes: ByteArray
var index = 0
init {
val encryptedBytes = input.readAllBytes()
if (base64DecryptionKey?.isBlank() == true) {
bytes = encryptedBytes
} else {
val encryptor = Encryptor(
base64EncryptionKey = null,
base64DecryptionKey = base64DecryptionKey
)
bytes = encryptor.decrypt(encryptedBytes)
}
}
/* override fun readAllBytes(): ByteArray {
index = bytes.size
return bytes
}*/
override fun read(): Int {
return if (index < bytes.size) {
bytes[index++].toUByte().toInt()
} else {
-1
}
}
override fun close() {
input.close()
}
}
class EncryptingOutputStream(
val output: OutputStream,
val base64EncryptionKey: String?
) : OutputStream() {
val baos = ByteArrayOutputStream()
override fun write(b: Int) {
baos.write(b)
}
override fun close() {
if (base64EncryptionKey?.isBlank() == true) {
output.write(baos.toByteArray())
} else {
val encryptor = Encryptor(
base64EncryptionKey = base64EncryptionKey,
base64DecryptionKey = null
)
val encryptedBytes = encryptor.encrypt(baos.toByteArray())
output.write(encryptedBytes)
}
output.flush()
output.close()
}
override fun flush() {
// no flush
}
}

View File

@@ -10,10 +10,18 @@ fun currentTransaction(): Transaction? {
class Persistent(
directory: File,
indexes: Array<PersistableIndex> = arrayOf(),
enableOptimisticLocking: Boolean = false,
decryptionKey: String? = null,
encryptionKey: String? = null,
indexes: Array<PersistableIndex> = arrayOf(),
) {
val datastore: Datastore = Datastore(directory, enableOptimisticLocking, indexes)
val datastore: Datastore = Datastore(
directory,
enableOptimisticLocking,
decryptionKey,
encryptionKey,
indexes
)
fun <T> query(block: Query.() -> T): T {
var cleanup = false

View File

@@ -99,7 +99,7 @@ open class Query(
class Transaction(
persistent: Persistent,
) : Query(persistent), Serializable {
private val actions = mutableSetOf<Action>()
private val actions = ArrayList<Action>()
fun store(obj: Persistable) {
if (obj.id == 0L) {

View File

@@ -2,9 +2,11 @@ package nl.astraeus.persistence
import java.io.File
import java.io.ObjectInputStream
import java.util.*
class TransactionLog(
directory: File,
val decryptionKey: String? = null,
) {
val fileManager = FileManager(directory)
@@ -14,14 +16,15 @@ class TransactionLog(
printer("Snapshot:")
snapshot?.inputStream()?.use { input ->
ObjectInputStream(input).use { ois ->
ObjectInputStream(DecryptingInputStream(input, decryptionKey)).use { ois ->
val versionNumber = ois.readInt()
check(versionNumber == 1) {
"Unsupported version number: $versionNumber"
}
val transactionNumber = ois.readLong()
printer("[$versionNumber] $transactionNumber")
val data = ois.readObject() as MutableMap<Class<*>, TypeData>
val dataObj = ois.readObject()
val data = dataObj as MutableMap<Class<*>, TypeData>
printer("Data:")
printer("\tClasses:")
for ((cls, entries) in data.entries) {
@@ -35,15 +38,26 @@ class TransactionLog(
printer("Transactions:")
transactions?.forEach { transaction ->
transaction.inputStream().use { input ->
ObjectInputStream(input).use { ois ->
ObjectInputStream(DecryptingInputStream(input, decryptionKey)).use { ois ->
val versionNumber = ois.readInt()
check(versionNumber == 1) {
"Unsupported version number: $versionNumber"
}
val transactionNumber = ois.readLong()
val actions = ois.readObject() as Set<Action>
val actions = ois.readObject()
val actionList = when(actions) {
is Set<*> -> {
LinkedList(actions as Set<Action>)
}
is List<*> -> {
actions as List<Action>
}
else -> {
emptyList()
}
}
printer("\t[$transactionNumber]")
for (action in actions) {
for (action in actionList) {
printer("\t\t- $action")
}
}

View File

@@ -5,7 +5,6 @@ import org.junit.jupiter.api.Test;
import java.io.File;
import java.util.List;
public class TestPersistenceJava {
static class Person extends AbstractPersistable {
@@ -64,14 +63,16 @@ public class TestPersistenceJava {
Persistent persistent = new Persistent(
new File("data", "java-test"),
false,
null,
null,
new Index[] {
new Index<>(
Person.class,
"name",
(p) -> ((Person)p).getName()
)
},
false
}
);
persistent.transaction((t) -> {

View File

@@ -0,0 +1,54 @@
package nl.astraeus.persistence
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.Assertions.*
import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
import java.security.SecureRandom
class EncryptionTest {
@Test
fun testKeyGen() {
println(generateBase64Key())
}
@Test
fun testEncryptDecrypt() {
val random = SecureRandom()
val randomBytes = ByteArray(random.nextInt(10000))
random.nextBytes(randomBytes)
val base64Key = generateBase64Key()
val encryptor = Encryptor(
base64Key,
base64Key,
)
val encrypted = encryptor.encrypt(randomBytes)
val decrypted = encryptor.decrypt(encrypted)
assertArrayEquals(randomBytes, decrypted)
}
@Test
fun `test encryption-decryption streams`() {
val random = SecureRandom()
val key = generateBase64Key()
val baos = ByteArrayOutputStream()
val encryptionStream = EncryptingOutputStream(baos, key)
val bytes = ByteArray(random.nextInt(10000))
random.nextBytes(bytes)
encryptionStream.use {
it.write(bytes)
}
val bais = ByteArrayInputStream(baos.toByteArray())
val decryptingStream = DecryptingInputStream(bais, key)
val decryptedBytes = decryptingStream.readAllBytes()
assertArrayEquals(bytes, decryptedBytes)
}
}

View File

@@ -21,10 +21,12 @@ class TestOptimisticLocking {
val pst = Persistent(
directory = File("data", "test-locking"),
true,
null,
null,
arrayOf(
index<Person>("name") { p -> (p as? Person)?.name ?: "" },
),
true
)
pst.transaction {

View File

@@ -22,7 +22,7 @@ class TestPersistence {
val pst = Persistent(
directory = File("data", "test-persistence"),
arrayOf(
indexes = arrayOf(
index<Person>("name") { p -> (p as? Person)?.name ?: "" },
index<Person>("age") { p -> (p as? Person)?.age ?: -1 },
index<Person>("ageGt20") { p -> ((p as? Person)?.age ?: 0) > 20 },
@@ -131,7 +131,7 @@ class TestPersistence {
}
}
//pst.snapshot()
pst.snapshot()
pst.transaction {
store(
@@ -179,6 +179,6 @@ class TestPersistence {
}
pst.datastore.printStatus()
//pst.removeOldFiles()
pst.removeOldFiles()
}
}

View File

@@ -3,8 +3,8 @@ package nl.astraeus.persistence
import org.junit.jupiter.api.Test
import java.io.File
class TestPersistenceJavaInKotlin {
internal class Person(
var name: String,
var age: Int
@@ -23,19 +23,21 @@ class TestPersistenceJavaInKotlin {
val persistent = Persistent(
File("data", "java-kotlin-test"),
arrayOf(
enableOptimisticLocking = false,
indexes = arrayOf(
Index(
Person::class,
"name"
) { p -> (p as Person).name }
),
false
)
persistent.transaction {
val person = find(Person::class.java, 1L)
if (person != null) {
println("Person: ${person.name} is ${person.age} years old."
println(
"Person: ${person.name} is ${person.age} years old."
)
}
}

View File

@@ -22,7 +22,7 @@ class TestThreaded {
val pst = Persistent(
directory = File("data", "test-threaded"),
arrayOf(
indexes = arrayOf(
index<Person>("name") { p -> (p as? Person)?.name ?: "" },
index<Person>("age") { p -> (p as? Person)?.age ?: -1 },
index<Person>("ageGt20") { p -> ((p as? Person)?.age ?: 0) > 20 },