diff --git a/src/main/kotlin/nl/astraeus/persistence/Datastore.kt b/src/main/kotlin/nl/astraeus/persistence/Datastore.kt index 399baa7..6f3459b 100644 --- a/src/main/kotlin/nl/astraeus/persistence/Datastore.kt +++ b/src/main/kotlin/nl/astraeus/persistence/Datastore.kt @@ -17,16 +17,29 @@ enum class ActionType { class TypeData( var nextId: AtomicLong = AtomicLong(1L), val data: MutableMap = ConcurrentHashMap(), -) : Serializable +) : Serializable { + companion object { + private const val serialVersionUID: Long = 1L + } +} class Action( val type: ActionType, val obj: Persistable -) : Serializable +) : Serializable { + override fun toString(): String { + return "Action(type=$type, obj=$obj)" + } + + companion object { + private const val serialVersionUID: Long = 1L + } +} class Datastore( private val directory: File, indexes: Array = arrayOf(), + val enableOptimisticLocking: Boolean = false, ) { private val fileManager = FileManager(directory) private val transactionFormatter = DecimalFormat("#") @@ -58,7 +71,8 @@ class Datastore( fun setMaxId(javaClass: Class, id: Long) { val nextId = data.getOrPut(javaClass) { TypeData() }.nextId - if (nextId.get() <= id) nextId.set(id + 1) + val current = nextId.get() + if (current <= id) nextId.addAndGet(id - current) } override fun toString(): String { @@ -97,15 +111,24 @@ class Datastore( } } - Logger.debug("Loaded transactions in ${(System.nanoTime() - start) / 1_000_000}ms") + Logger.debug("Loaded transactions in %6.3fms", ((System.nanoTime() - start) / 1_000_000f)) } - private fun getTrnx(file: File): Long { - return file.name.substringAfterLast('/').substringAfter("transaction-").substringBefore(".").toLong() - } - - fun execute(actions: MutableList) { + private fun execute(actions: Set) { synchronized(this) { + if (enableOptimisticLocking) { + for (action in actions) { + val typeData = data.getOrPut(action.obj::class.java) { + TypeData() + } + if (action.type == ActionType.STORE) { + if ((typeData.data[action.obj.id]?.version ?: -1L) >= action.obj.version) { + throw OptimisticLockingException(action.obj) + } + } + } + } + for (action in actions) { val typeData = data.getOrPut(action.obj::class.java) { TypeData() @@ -137,7 +160,6 @@ class Datastore( } } - fun count(clazz: KClass): Int { val typeData = data.getOrPut(clazz.java) { TypeData() @@ -176,25 +198,25 @@ class Datastore( return indexes[kClass.java]?.get(indexName) } - fun storeAndExecute(actions: MutableList) { + fun storeAndExecute(actions: Set) { if (actions.isNotEmpty()) { synchronized(this) { - writeTransaction(actions) execute(actions) + writeTransaction(actions) } } } private fun readTransaction(ois: ObjectInputStream) { val versionNumber = ois.readInt() - check (versionNumber == 1) { "Unsupported version number: $versionNumber" } + check(versionNumber == 1) { "Unsupported version number: $versionNumber" } val transactionNumber = ois.readLong() nextTransactionNumber = transactionNumber + 1 - val actions = ois.readObject() as MutableList + val actions = ois.readObject() as Set execute(actions) } - private fun writeTransaction(actions: MutableList) { + private fun writeTransaction(actions: Set) { val number = transactionFormatter.format(nextTransactionNumber) val file = File(directory, "transaction-$number.trn") ObjectOutputStream(file.outputStream()).use { oos -> @@ -226,12 +248,12 @@ class Datastore( } } } - Logger.debug("Snapshot in ${(System.nanoTime() - start) / 1_000_000}ms") + Logger.debug("Snapshot in %6.3fms", ((System.nanoTime() - start) / 1_000_000f)) } private fun readSnapshot(ois: ObjectInputStream) { val versionNumber = ois.readInt() - check (versionNumber == 1) { "Unsupported version number: $versionNumber" } + check(versionNumber == 1) { "Unsupported version number: $versionNumber" } nextTransactionNumber = ois.readLong() + 1 data.clear() data.putAll(ois.readObject() as MutableMap, TypeData>) diff --git a/src/main/kotlin/nl/astraeus/persistence/Logger.kt b/src/main/kotlin/nl/astraeus/persistence/Logger.kt index aebf8ec..c6ff1c3 100644 --- a/src/main/kotlin/nl/astraeus/persistence/Logger.kt +++ b/src/main/kotlin/nl/astraeus/persistence/Logger.kt @@ -1,8 +1,62 @@ package nl.astraeus.nl.astraeus.persistence +enum class LogLevel { + TRACE, + DEBUG, + INFO, + WARN, + ERROR +} + object Logger { - var debug: (String) -> Unit = { println("DEBUG: $it") } - var info: (String) -> Unit = { println("INFO: $it") } - var warn: (String) -> Unit = { println("WARN: $it") } - var error: (String) -> Unit = { println("ERROR: $it") } -} \ No newline at end of file + var level: LogLevel = LogLevel.DEBUG + + var tracePrinter: (String) -> Unit = { println(it) } + var debugPrinter: (String) -> Unit = { println(it) } + var infoPrinter: (String) -> Unit = { println(it) } + var warnPrinter: (String) -> Unit = { println(it) } + var errorPrinter: (String) -> Unit = { System.err.println(it) } + + fun trace(message: String, vararg parameters: Any?) { + if (level <= LogLevel.TRACE) { + writeLogMessage(LogLevel.TRACE, message, *parameters) + } + } + + fun debug(message: String, vararg parameters: Any?) { + if (level <= LogLevel.DEBUG) { + writeLogMessage(LogLevel.DEBUG, message, *parameters) + } + } + + fun info(message: String, vararg parameters: Any?) { + if (level <= LogLevel.INFO) { + writeLogMessage(LogLevel.INFO, message, *parameters) + } + } + + fun warn(message: String, vararg parameters: Any?) { + if (level <= LogLevel.DEBUG) { + writeLogMessage(LogLevel.DEBUG, message, *parameters) + } + } + + fun error(message: String, vararg parameters: Any?) { + if (level <= LogLevel.ERROR) { + writeLogMessage(LogLevel.ERROR, message, *parameters) + } + } + + private fun writeLogMessage(level: LogLevel, message: String, vararg parameters: Any?) { + val formattedMessage = "[${level}] - ${message.format(*parameters)}" + + when (level) { + LogLevel.TRACE -> tracePrinter(formattedMessage) + LogLevel.DEBUG -> debugPrinter(formattedMessage) + LogLevel.INFO -> infoPrinter(formattedMessage) + LogLevel.WARN -> warnPrinter(formattedMessage) + LogLevel.ERROR -> errorPrinter(formattedMessage) + } + + } +} diff --git a/src/main/kotlin/nl/astraeus/persistence/OptimisticLockingException.kt b/src/main/kotlin/nl/astraeus/persistence/OptimisticLockingException.kt new file mode 100644 index 0000000..45d0e71 --- /dev/null +++ b/src/main/kotlin/nl/astraeus/persistence/OptimisticLockingException.kt @@ -0,0 +1,18 @@ +package nl.astraeus.nl.astraeus.persistence + +class OptimisticLockingException : Exception { + constructor( + obj: Persistable + ) : this("Optimistic locking failed for ${obj.javaClass.simpleName} with id ${obj.id}, version ${obj.version}") + + constructor() : super() + constructor(message: String?) : super(message) + constructor(message: String?, cause: Throwable?) : super(message, cause) + constructor(cause: Throwable?) : super(cause) + constructor(message: String?, cause: Throwable?, enableSuppression: Boolean, writableStackTrace: Boolean) : super( + message, + cause, + enableSuppression, + writableStackTrace + ) +} diff --git a/src/main/kotlin/nl/astraeus/persistence/Persistable.kt b/src/main/kotlin/nl/astraeus/persistence/Persistable.kt index dcdb282..f7b6088 100644 --- a/src/main/kotlin/nl/astraeus/persistence/Persistable.kt +++ b/src/main/kotlin/nl/astraeus/persistence/Persistable.kt @@ -17,7 +17,9 @@ interface Persistable : Serializable { } ByteArrayInputStream(baos.toByteArray()).use { bais -> ObjectInputStream(bais).use { ois -> - return ois.readObject() as Persistable + val result = ois.readObject() as Persistable + result.version++ + return result } } } diff --git a/src/main/kotlin/nl/astraeus/persistence/Persistent.kt b/src/main/kotlin/nl/astraeus/persistence/Persistent.kt index a0e59fd..adc7949 100644 --- a/src/main/kotlin/nl/astraeus/persistence/Persistent.kt +++ b/src/main/kotlin/nl/astraeus/persistence/Persistent.kt @@ -11,8 +11,9 @@ fun currentTransaction(): Transaction? { class Persistent( directory: File, indexes: Array = arrayOf(), + enableOptimisticLocking: Boolean = false, ) { - val datastore: Datastore = Datastore(directory, indexes) + val datastore: Datastore = Datastore(directory, indexes, enableOptimisticLocking) fun query(block: Query.() -> T): T { return block(Query(this)) diff --git a/src/main/kotlin/nl/astraeus/persistence/Transaction.kt b/src/main/kotlin/nl/astraeus/persistence/Transaction.kt index 51e3099..113301a 100644 --- a/src/main/kotlin/nl/astraeus/persistence/Transaction.kt +++ b/src/main/kotlin/nl/astraeus/persistence/Transaction.kt @@ -95,11 +95,13 @@ open class Query( class Transaction( persistent: Persistent, ) : Query(persistent), Serializable { - private val actions = mutableListOf() + private val actions = mutableSetOf() fun store(obj: Persistable) { if (obj.id == 0L) { obj.id = persistent.datastore.getNextId(obj.javaClass) + } else if (obj.id > persistent.datastore.getNextId(obj.javaClass)) { + persistent.datastore.setMaxId(obj.javaClass, obj.id + 1) } actions.add(Action(ActionType.STORE, obj)) diff --git a/src/main/kotlin/nl/astraeus/persistence/TransactionLog.kt b/src/main/kotlin/nl/astraeus/persistence/TransactionLog.kt new file mode 100644 index 0000000..89e35f9 --- /dev/null +++ b/src/main/kotlin/nl/astraeus/persistence/TransactionLog.kt @@ -0,0 +1,33 @@ +package nl.astraeus.nl.astraeus.persistence + +import java.io.File +import java.io.ObjectInputStream + +class TransactionLog( + val directory: File, +) { + val fileManager = FileManager(directory) + + fun showTransactions() { + fileManager.findLastSnapshot().let { (after, snapshot) -> + println("Last snapshot: $snapshot") + + val transactions = fileManager.findTransactionsAfter(after ?: 0L) + + println("Transactions:") + transactions?.forEach { transaction -> + transaction.inputStream().use { input -> + ObjectInputStream(input).use { ois -> + val versionNumber = ois.readInt() + check(versionNumber == 1) { "Unsupported version number: $versionNumber" } + val transactionNumber = ois.readLong() + val actions = ois.readObject() as Set + println("[$versionNumber] $transactionNumber - ${actions.joinToString(",")}") + } + } + } + } + + } + +} \ No newline at end of file diff --git a/src/test/java/nl/astraeus/persistence/TestPersistenceJava.java b/src/test/java/nl/astraeus/persistence/TestPersistenceJava.java index e6579de..533816f 100644 --- a/src/test/java/nl/astraeus/persistence/TestPersistenceJava.java +++ b/src/test/java/nl/astraeus/persistence/TestPersistenceJava.java @@ -74,7 +74,8 @@ public class TestPersistenceJava { "name", (p) -> ((Person)p).getName() ) - } + }, + false ); persistent.transaction((t) -> { diff --git a/src/test/kotlin/nl/astraeus/persistence/TestOptimisticLocking.kt b/src/test/kotlin/nl/astraeus/persistence/TestOptimisticLocking.kt new file mode 100644 index 0000000..2feebc9 --- /dev/null +++ b/src/test/kotlin/nl/astraeus/persistence/TestOptimisticLocking.kt @@ -0,0 +1,114 @@ +package nl.astraeus.persistence + +import nl.astraeus.nl.astraeus.persistence.OptimisticLockingException +import nl.astraeus.nl.astraeus.persistence.Persistable +import nl.astraeus.nl.astraeus.persistence.Persistent +import nl.astraeus.nl.astraeus.persistence.TransactionLog +import nl.astraeus.nl.astraeus.persistence.find +import nl.astraeus.nl.astraeus.persistence.findByIndex +import nl.astraeus.nl.astraeus.persistence.index +import org.junit.jupiter.api.Assertions.assertNotNull +import org.junit.jupiter.api.assertThrows +import java.io.File +import kotlin.test.Test + +class TestOptimisticLocking { + + class Person( + override var id: Long = 0, + override var version: Long = 0, + val name: String, + val age: Int, + ) : Persistable, Cloneable { + companion object { + private const val serialVersionUID: Long = 1L + } + + override fun toString(): String { + return "Person(id=$id, version=$version, name='$name', age=$age)" + } + } + + @Test + fun showTransactions() { + val log = TransactionLog(File("data", "test-locking")) + + log.showTransactions() + } + + @Test + fun testOptimisticLocking() { + println("Test locking") + + val pst = Persistent( + directory = File("data", "test-locking"), + arrayOf( + index("name") { p -> (p as? Person)?.name ?: "" }, + ), + true + ) + + pst.transaction { + val person = find(1L) ?: Person( + id = 1L, + name = "John Doe", + age = 25 + ) + + store(person) + + findByIndex("name", "John Doe").forEach { p -> + println("Found person by name: ${p.name} - ${p.age}") + } + } + + pst.query { + val person = find(1L) + + assertNotNull(person) + } + + val threads = Array(2) { index -> + Thread { + println("Start thread $index") + var person: Person? = null + + pst.transaction { + Thread.sleep(10L) + person = find(1L) + println("Thread $index -> ${person?.version}") + } + + if (person != null) { + Thread.sleep((index + 1) * 10L) + + if (index == 1) { + assertThrows { + println("Store thread $index -> ${person!!.version}") + pst.transaction { + store(person!!) + } + } + } else { + println("Store thread $index -> ${person!!.version}") + pst.transaction { + store(person!!) + } + } + } + } + } + + for (thread in threads) { + thread.start() + } + + for (thread in threads) { + thread.join() + } + + pst.datastore.printStatus() + //pst.snapshot() + pst.removeOldFiles() + } +} diff --git a/src/test/kotlin/nl/astraeus/persistence/TestPersistence.kt b/src/test/kotlin/nl/astraeus/persistence/TestPersistence.kt index e89f137..1d1cf46 100644 --- a/src/test/kotlin/nl/astraeus/persistence/TestPersistence.kt +++ b/src/test/kotlin/nl/astraeus/persistence/TestPersistence.kt @@ -127,6 +127,14 @@ class TestPersistence { assertNotNull(person) } + pst.transaction { + val p1 = find(1L) + val p2 = find(1L) + + store(p2!!) + store(p1!!) + } + pst.transaction { val person = find(Person::class, 1L)