Support create sequence statement with all sequence options (#746)

This commit is contained in:
hichem-fazai
2020-01-17 00:09:06 +01:00
committed by Andrey.Tarashevskiy
parent 39cf32888a
commit bd1e5d3d47
15 changed files with 223 additions and 13 deletions

View File

@@ -150,8 +150,8 @@ fun main() {
Outputs:
```
SQL: CREATE TABLE IF NOT EXISTS Cities (id INT AUTO_INCREMENT NOT NULL, name VARCHAR(50) NOT NULL, CONSTRAINT pk_Cities PRIMARY KEY (id))
SQL: CREATE TABLE IF NOT EXISTS Users (id VARCHAR(10) NOT NULL, name VARCHAR(50) NOT NULL, city_id INT NULL, CONSTRAINT pk_Users PRIMARY KEY (id))
SQL: CREATE TABLE IF NOT EXISTS Cities (id INT AUTO_INCREMENT NOT NULL, name VARCHAR(50) NOT NULL, CONSTRAINT PK_Cities_ID PRIMARY KEY (id))
SQL: CREATE TABLE IF NOT EXISTS Users (id VARCHAR(10) NOT NULL, name VARCHAR(50) NOT NULL, city_id INT NULL, CONSTRAINT PK_User_ID PRIMARY KEY (id))
SQL: ALTER TABLE Users ADD FOREIGN KEY (city_id) REFERENCES Cities(id)
SQL: INSERT INTO Cities (name) VALUES ('St. Petersburg')
SQL: INSERT INTO Cities (name) VALUES ('Munich')

View File

@@ -27,6 +27,12 @@ class LowerCase<T: String?>(val expr: Expression<T>) : Function<T>(VarCharColumn
override fun toQueryBuilder(queryBuilder: QueryBuilder) = queryBuilder { append("LOWER(", expr,")") }
}
class NextVal(val seq: Sequence) : Function<Int>(IntegerColumnType()) {
override fun toQueryBuilder(queryBuilder: QueryBuilder) {
currentDialect.functionProvider.nextVal(seq, queryBuilder)
}
}
class UpperCase<T: String?>(val expr: Expression<T>) : Function<T>(VarCharColumnType()) {
override fun toQueryBuilder(queryBuilder: QueryBuilder) = queryBuilder { append("UPPER(", expr,")") }
}

View File

@@ -38,6 +38,8 @@ fun<T:String?> Expression<T>.trim() : Function<T> = Trim(this)
fun<T:String?> Expression<T>.lowerCase() : Function<T> = LowerCase(this)
fun<T:String?> Expression<T>.upperCase() : Function<T> = UpperCase(this)
fun Sequence.nextVal() : Function<Int> = NextVal(this)
fun<T:Any?> ExpressionWithColumnType<T>.function(functionName: String) : Function<T?> = CustomFunction(functionName, columnType, this)
fun CustomStringFunction(functionName: String, vararg params: Expression<*>) = CustomFunction<String?>(functionName, VarCharColumnType(), *params)
fun CustomLongFunction(functionName: String, vararg params: Expression<*>) = CustomFunction<Long?>(functionName, LongColumnType(), *params)

View File

@@ -86,7 +86,19 @@ object SchemaUtils {
} + alters
}
fun createSequence(name: String) = Seq(name).createStatement()
fun createSequence(vararg seq: Sequence, inBatch: Boolean = false) {
with(TransactionManager.current()) {
val createStatements = seq.flatMap { it.createStatement() }
execStatements(inBatch, createStatements)
}
}
fun dropSequence(vararg seq: Sequence, inBatch: Boolean = false) {
with(TransactionManager.current()) {
val dropStatements = seq.flatMap { it.dropStatement() }
execStatements(inBatch, dropStatements)
}
}
fun createFKey(reference: Column<*>) = ForeignKeyConstraint.from(reference).createStatement()

View File

@@ -0,0 +1,66 @@
package org.jetbrains.exposed.sql
import org.jetbrains.exposed.exceptions.UnsupportedByDialectException
import org.jetbrains.exposed.sql.transactions.TransactionManager
import org.jetbrains.exposed.sql.vendors.currentDialect
import java.lang.StringBuilder
/**
* Sequence : an object that generates a sequence of numeric values.
*
* @param name The name of the sequence
* @param startWith The first sequence number to be generated.
* @param incrementBy The interval between sequence numbers.
* @param minValue The minimum value of the sequence.
* @param maxValue The maximum value of the sequence.
* @param cycle Indicates that the sequence continues to generate values after reaching either its maximum or minimum value.
* @param cache Number of values of the sequence the database preallocates and keeps in memory for faster access.
*/
class Sequence(private val name: String,
val startWith: Int? = null,
val incrementBy: Int? = null,
val minValue: Int? = null,
val maxValue: Int? = null,
val cycle: Boolean? = null,
val cache: Int? = null) {
val identifier get() = TransactionManager.current().db.identifierManager.cutIfNecessaryAndQuote(name)
val ddl: List<String>
get() = createStatement()
fun createStatement(): List<String> {
if (!currentDialect.supportsCreateSequence ) {
throw UnsupportedByDialectException("The current dialect doesn't support create sequence statement", currentDialect)
}
val createTableDDL = buildString {
append("CREATE SEQUENCE ")
if (currentDialect.supportsIfNotExists) {
append("IF NOT EXISTS ")
}
append(identifier)
appendIfNotNull(" START WITH", startWith)
appendIfNotNull(" INCREMENT BY", incrementBy)
appendIfNotNull(" MINVALUE", minValue)
appendIfNotNull(" MAXVALUE", maxValue)
if (cycle == true) {
append(" CYCLE")
}
appendIfNotNull(" CACHE", cache)
}
return listOf(createTableDDL)
}
fun dropStatement() = listOf("DROP SEQUENCE $identifier")
fun StringBuilder.appendIfNotNull(str: String, strToCheck: Any?) = apply {
if (strToCheck != null) {
this.append("$str $strToCheck")
}
}
}

View File

@@ -696,7 +696,7 @@ open class Table(name: String = ""): ColumnSet(), DdlAware {
override fun createStatement(): List<String> {
val seqDDL = autoIncColumn?.autoIncSeqName?.let {
Seq(it).createStatement()
Sequence(it).createStatement()
}.orEmpty()
val addForeignKeysInAlterPart = SchemaUtils.checkCycle(this) && currentDialect !is SQLiteDialect
@@ -768,7 +768,7 @@ open class Table(name: String = ""): ColumnSet(), DdlAware {
}
}
val seqDDL = autoIncColumn?.autoIncSeqName?.let {
Seq(it).dropStatement()
Sequence(it).dropStatement()
}.orEmpty()
return listOf(dropTableDDL) + seqDDL
@@ -792,11 +792,10 @@ open class Table(name: String = ""): ColumnSet(), DdlAware {
}
}
data class Seq(private val name: String) {
private val identifier get() = TransactionManager.current().db.identifierManager.cutIfNecessaryAndQuote(name)
fun createStatement() = listOf("CREATE SEQUENCE $identifier")
fun dropStatement() = listOf("DROP SEQUENCE $identifier")
}
@Deprecated("Use Sequence class instead of Seq class.",
ReplaceWith("org.jetbrains.exposed.sql.Sequence"),
DeprecationLevel.ERROR)
data class Seq(private val name: String)
fun ColumnSet.targetTables(): List<Table> = when (this) {
is Alias<*> -> listOf(this.delegate)

View File

@@ -71,7 +71,7 @@ open class InsertStatement<Key:Any>(val table: Table, val isIgnore: Boolean = fa
}
pairs.forEach { (col, value) ->
if (value != DefaultValueMarker) {
if (col.columnType.isAutoInc)
if (col.columnType.isAutoInc || value is NextVal)
map.getOrPut(col) { value }
else
map[col] = value
@@ -118,8 +118,12 @@ open class InsertStatement<Key:Any>(val table: Table, val isIgnore: Boolean = fa
}
}
protected val autoIncColumns = targets.flatMap { it.columns }.filter {
it.columnType.isAutoInc || (it.columnType is EntityIDColumnType<*> && !currentDialect.supportsOnlyIdentifiersInGeneratedKeys)
protected val autoIncColumns
get() = targets.flatMap { it.columns }.filter { column ->
column.columnType.isAutoInc
|| (column.columnType is EntityIDColumnType<*> && !currentDialect.supportsOnlyIdentifiersInGeneratedKeys)
|| (column in values.filter { it.value is NextVal }.map { it.key })
}
override fun prepared(transaction: Transaction, sql: String): PreparedStatementApi = when {

View File

@@ -3,6 +3,7 @@ package org.jetbrains.exposed.sql.vendors
import org.jetbrains.exposed.exceptions.throwUnsupportedException
import org.jetbrains.exposed.sql.*
import org.jetbrains.exposed.sql.transactions.TransactionManager
import java.lang.StringBuilder
import java.nio.ByteBuffer
import java.util.*
import java.util.concurrent.ConcurrentHashMap
@@ -61,6 +62,9 @@ abstract class FunctionProvider {
append(prefix, "(", expr, ", ", start, ", ", length, ")")
}
open fun nextVal(seq: Sequence, builder: QueryBuilder) = builder {
append(seq.identifier,".NEXTVAL")
}
open fun random(seed: Int?): String = "RANDOM(${seed?.toString().orEmpty()})"
@@ -264,11 +268,43 @@ interface DatabaseDialect {
val supportsOnlyIdentifiersInGeneratedKeys get() = false
val supportsCreateSequence get() = true
// Specific SQL statements
fun createIndex(index: Index): String
fun dropIndex(tableName: String, indexName: String): String
fun modifyColumn(column: Column<*>) : String
fun createSequence(identifier: String,
startWith: Int?,
incrementBy: Int?,
minValue: Int?,
maxValue: Int?,
cycle: Boolean?,
cache: Int?): String = buildString {
append("CREATE SEQUENCE ")
if (currentDialect.supportsIfNotExists) {
append("IF NOT EXISTS ")
}
append(identifier)
appendIfNotNull(" START WITH", startWith)
appendIfNotNull(" INCREMENT BY", incrementBy)
appendIfNotNull(" MINVALUE", minValue)
appendIfNotNull(" MAXVALUE", maxValue)
if (cycle == true) {
append(" CYCLE")
}
appendIfNotNull(" CACHE", cache)
}
fun StringBuilder.appendIfNotNull(str1: String, str2: Any?) = apply {
if (str2 != null) {
this.append("$str1 $str2")
}
}
}
abstract class VendorDialect(override val name: String,

View File

@@ -2,11 +2,16 @@ package org.jetbrains.exposed.sql.vendors
import org.jetbrains.exposed.sql.Expression
import org.jetbrains.exposed.sql.QueryBuilder
import org.jetbrains.exposed.sql.Sequence
internal object MariaDBFunctionProvider : MysqlFunctionProvider() {
override fun <T : String?> regexp(expr1: Expression<T>, pattern: Expression<String>, caseSensitive: Boolean, queryBuilder: QueryBuilder) {
queryBuilder{ append(expr1, " REGEXP ", pattern) }
}
override fun nextVal(seq: Sequence, builder: QueryBuilder) = builder {
append("NEXTVAL(", seq.identifier, ")")
}
}
class MariaDBDialect : MysqlDialect() {

View File

@@ -1,5 +1,6 @@
package org.jetbrains.exposed.sql.vendors
import org.jetbrains.exposed.exceptions.UnsupportedByDialectException
import org.jetbrains.exposed.sql.*
import org.jetbrains.exposed.sql.transactions.TransactionManager
import java.math.BigDecimal
@@ -71,6 +72,7 @@ internal open class MysqlFunctionProvider : FunctionProvider() {
}
open class MysqlDialect : VendorDialect(dialectName, MysqlDataTypeProvider, MysqlFunctionProvider.INSTANSE) {
override val supportsCreateSequence = false
override fun isAllowedAsColumnDefault(e: Expression<*>): Boolean {
if (super.isAllowedAsColumnDefault(e)) return true
@@ -129,6 +131,7 @@ open class MysqlDialect : VendorDialect(dialectName, MysqlDataTypeProvider, Mysq
TransactionManager.current().db.isVersionCovers(BigDecimal("8.0"))
}
companion object {
const val dialectName = "mysql"
}

View File

@@ -119,6 +119,10 @@ internal object PostgreSQLFunctionProvider : FunctionProvider() {
append(expr)
append(")")
}
override fun nextVal(seq: Sequence, builder: QueryBuilder) = builder {
append("NEXTVAL('", seq.identifier, "')")
}
}
open class PostgreSQLDialect : VendorDialect(dialectName, PostgreSQLDataTypeProvider, PostgreSQLFunctionProvider) {

View File

@@ -95,6 +95,10 @@ internal object SQLServerFunctionProvider : FunctionProvider() {
override fun <T> year(expr: Expression<T>, queryBuilder: QueryBuilder) = queryBuilder {
append("DATEPART(YEAR, ", expr, ")")
}
override fun nextVal(seq: Sequence, builder: QueryBuilder) = builder {
append("NEXT VALUE FOR ", seq.identifier)
}
}
open class SQLServerDialect : VendorDialect(dialectName, SQLServerDataTypeProvider, SQLServerFunctionProvider) {

View File

@@ -1,5 +1,6 @@
package org.jetbrains.exposed.sql.vendors
import org.jetbrains.exposed.exceptions.UnsupportedByDialectException
import org.jetbrains.exposed.exceptions.throwUnsupportedException
import org.jetbrains.exposed.sql.*
import org.jetbrains.exposed.sql.transactions.TransactionManager
@@ -98,6 +99,7 @@ internal object SQLiteFunctionProvider : FunctionProvider() {
open class SQLiteDialect : VendorDialect(dialectName, SQLiteDataTypeProvider, SQLiteFunctionProvider) {
override val supportsMultipleGeneratedKeys: Boolean = false
override val supportsCreateSequence = false
override fun isAllowedAsColumnDefault(e: Expression<*>): Boolean = true
override fun createIndex(index: Index): String {

View File

@@ -10,6 +10,7 @@ import java.util.concurrent.ConcurrentHashMap
import kotlin.properties.ReadOnlyProperty
import kotlin.reflect.KClass
import kotlin.reflect.full.primaryConstructor
import kotlin.sequences.Sequence
@Suppress("UNCHECKED_CAST")
abstract class EntityClass<ID : Comparable<ID>, out T: Entity<ID>>(val table: IdTable<ID>, entityType: Class<T>? = null) {

View File

@@ -0,0 +1,66 @@
package org.jetbrains.exposed.sql.tests.shared.ddl
import org.jetbrains.exposed.sql.*
import org.jetbrains.exposed.sql.tests.DatabaseTestsBase
import org.jetbrains.exposed.sql.tests.TestDB
import org.jetbrains.exposed.sql.tests.shared.assertEquals
import org.junit.Test
class CreateSequenceTests : DatabaseTestsBase() {
@Test
fun createSequenceStatementTest() {
withDb(excludeSettings = listOf(TestDB.MYSQL, TestDB.H2_MYSQL, TestDB.SQLITE)) {
val seqDDL = myseq.ddl
assertEquals("CREATE SEQUENCE " + addIfNotExistsIfSupported() + "${myseq.identifier} " +
"START WITH ${myseq.startWith} " +
"INCREMENT BY ${myseq.incrementBy} " +
"MINVALUE ${myseq.minValue} " +
"MAXVALUE ${myseq.maxValue} " +
"CYCLE " +
"CACHE ${myseq.cache}",
seqDDL)
}
}
@Test
fun SequenceNextValTest() {
// Exclude databases that doesn't support create sequence statement(Mysql and SQLite)
withTables(listOf(TestDB.MYSQL, TestDB.H2_MYSQL, TestDB.SQLITE), Developer) {
try {
SchemaUtils.createSequence(myseq)
var developerId = Developer.insert {
it[id] = myseq.nextVal()
it[name] = "Hichem"
} get Developer.id
assertEquals(4, developerId)
developerId = Developer.insert {
it[id] = myseq.nextVal()
it[name] = "Andrey"
} get Developer.id
assertEquals(6, developerId)
} finally {
SchemaUtils.dropSequence(myseq)
}
}
}
object Developer : Table() {
val id = integer("id")
var name = varchar("name", 25)
override val primaryKey = PrimaryKey(id, name)
}
val myseq = Sequence(name= "my_sequence",
startWith= 4,
incrementBy= 2,
minValue= 1,
maxValue= 10,
cycle= true,
cache=20)
}