[wip] convert insert
Signed-off-by: Anton Nesterov <anton@demiurg.io>
This commit is contained in:
parent
ffc8c2b841
commit
fde44ce343
35
pkg/dal/convert_fields.go
Normal file
35
pkg/dal/convert_fields.go
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
package dal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ConvertFields(ctx Context, fields []Map) (string, error) {
|
||||||
|
var expressions []string
|
||||||
|
for _, fieldAssoc := range fields {
|
||||||
|
for field, as := range fieldAssoc {
|
||||||
|
asBool, ok := as.(bool)
|
||||||
|
if ok {
|
||||||
|
if asBool {
|
||||||
|
expressions = append(expressions, field)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
asNum, ok := as.(int)
|
||||||
|
if ok {
|
||||||
|
if asNum == 1 {
|
||||||
|
expressions = append(expressions, field)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
asStr, ok := as.(string)
|
||||||
|
if ok {
|
||||||
|
expressions = append(expressions, fmt.Sprintf("%s AS %s", field, asStr))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("invalid field value: %v", as)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.Join(expressions, ", "), nil
|
||||||
|
}
|
58
pkg/dal/convert_fields_test.go
Normal file
58
pkg/dal/convert_fields_test.go
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
package dal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
filters "l12.xyz/dal/filters"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConvertFieldsBool(t *testing.T) {
|
||||||
|
ctx := filters.SQLiteContext{
|
||||||
|
TableAlias: "t",
|
||||||
|
FieldName: "test",
|
||||||
|
}
|
||||||
|
result, err := ConvertFields(ctx, []Map{
|
||||||
|
{"test": true},
|
||||||
|
{"test2": false},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if result != `test` {
|
||||||
|
t.Errorf("Expected test, got %s", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertFieldsInt(t *testing.T) {
|
||||||
|
ctx := filters.SQLiteContext{
|
||||||
|
TableAlias: "t",
|
||||||
|
FieldName: "test",
|
||||||
|
}
|
||||||
|
result, err := ConvertFields(ctx, []Map{
|
||||||
|
{"test": 0},
|
||||||
|
{"test2": 1},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if result != `test2` {
|
||||||
|
t.Errorf("Expected test, got %s", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertFieldsStr(t *testing.T) {
|
||||||
|
ctx := filters.SQLiteContext{
|
||||||
|
TableAlias: "t",
|
||||||
|
FieldName: "test",
|
||||||
|
}
|
||||||
|
result, err := ConvertFields(ctx, []Map{
|
||||||
|
{"t.test": "Test"},
|
||||||
|
{"SUM(t.test, t.int)": "Sum"},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if result != `t.test AS Test, SUM(t.test, t.int) AS Sum` {
|
||||||
|
t.Errorf("Expected test, got %s", result)
|
||||||
|
}
|
||||||
|
}
|
38
pkg/dal/convert_insert.go
Normal file
38
pkg/dal/convert_insert.go
Normal file
|
@ -0,0 +1,38 @@
|
||||||
|
package dal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type InsertData struct {
|
||||||
|
Statement string
|
||||||
|
Values []interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ConvertInsert(ctx Context, inserts []Map) (InsertData, error) {
|
||||||
|
keys := AggregateKeys(inserts)
|
||||||
|
placeholder := make([]string, 0)
|
||||||
|
for range keys {
|
||||||
|
placeholder = append(placeholder, "?")
|
||||||
|
}
|
||||||
|
|
||||||
|
values := make([]interface{}, 0)
|
||||||
|
for _, insert := range inserts {
|
||||||
|
vals := make([]interface{}, 0)
|
||||||
|
for _, key := range keys {
|
||||||
|
vals = append(vals, insert[key])
|
||||||
|
}
|
||||||
|
values = append(values, vals)
|
||||||
|
}
|
||||||
|
|
||||||
|
sfmt := fmt.Sprintf(
|
||||||
|
"INSERT INTO %s (%s) VALUES (%s)", ctx.GetTableName(),
|
||||||
|
strings.Join(keys, ","),
|
||||||
|
strings.Join(placeholder, ","),
|
||||||
|
)
|
||||||
|
return InsertData{
|
||||||
|
Statement: sfmt,
|
||||||
|
Values: values,
|
||||||
|
}, nil
|
||||||
|
}
|
29
pkg/dal/convert_insert_test.go
Normal file
29
pkg/dal/convert_insert_test.go
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
package dal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
filters "l12.xyz/dal/filters"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConvertInsert(t *testing.T) {
|
||||||
|
ctx := filters.SQLiteContext{
|
||||||
|
TableName: "test",
|
||||||
|
TableAlias: "t",
|
||||||
|
}
|
||||||
|
insert := []Map{
|
||||||
|
{"a": "1", "b": 2},
|
||||||
|
{"b": 2, "a": "1", "c": 3},
|
||||||
|
}
|
||||||
|
result, _ := ConvertInsert(ctx, insert)
|
||||||
|
|
||||||
|
if result.Statement != `INSERT INTO test (a,b,c) VALUES (?,?,?)` {
|
||||||
|
t.Errorf(`Expected "INSERT INTO test (a,b,c) VALUES (?,?,?)", got %s`, result.Statement)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, r := range result.Values {
|
||||||
|
fmt.Println(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -5,6 +5,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Map = map[string]interface{}
|
type Map = map[string]interface{}
|
||||||
|
type Fields = Map
|
||||||
type Find = filters.Find
|
type Find = filters.Find
|
||||||
type Query = filters.Find
|
type Query = filters.Find
|
||||||
type Filter = filters.Filter
|
type Filter = filters.Filter
|
||||||
|
|
20
pkg/dal/utils.go
Normal file
20
pkg/dal/utils.go
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
package dal
|
||||||
|
|
||||||
|
import "sort"
|
||||||
|
|
||||||
|
func AggregateKeys(maps []Map) []string {
|
||||||
|
set := make(map[string]int)
|
||||||
|
keys := make([]string, 0)
|
||||||
|
for _, item := range maps {
|
||||||
|
for k := range item {
|
||||||
|
if set[k] == 1 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
keys = append(keys, k)
|
||||||
|
set[k] = 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
set = nil
|
||||||
|
sort.Strings(keys)
|
||||||
|
return keys
|
||||||
|
}
|
|
@ -1,15 +1,14 @@
|
||||||
package filters
|
package filters
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"slices"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"unicode"
|
|
||||||
|
|
||||||
utils "l12.xyz/dal/utils"
|
utils "l12.xyz/dal/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
type SQLiteContext struct {
|
type SQLiteContext struct {
|
||||||
|
TableName string
|
||||||
TableAlias string
|
TableAlias string
|
||||||
FieldName string
|
FieldName string
|
||||||
}
|
}
|
||||||
|
@ -29,6 +28,10 @@ func (c SQLiteContext) New(opts CtxOpts) Context {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c SQLiteContext) GetTableName() string {
|
||||||
|
return c.TableName
|
||||||
|
}
|
||||||
|
|
||||||
func (c SQLiteContext) GetFieldName() string {
|
func (c SQLiteContext) GetFieldName() string {
|
||||||
if strings.Contains(c.FieldName, ".") {
|
if strings.Contains(c.FieldName, ".") {
|
||||||
return c.FieldName
|
return c.FieldName
|
||||||
|
@ -41,7 +44,7 @@ func (c SQLiteContext) GetFieldName() string {
|
||||||
|
|
||||||
func (c SQLiteContext) NormalizeValue(value interface{}) interface{} {
|
func (c SQLiteContext) NormalizeValue(value interface{}) interface{} {
|
||||||
str, ok := value.(string)
|
str, ok := value.(string)
|
||||||
if isSQLFunction(str) {
|
if utils.IsSQLFunction(str) {
|
||||||
return str
|
return str
|
||||||
}
|
}
|
||||||
if strings.Contains(str, ".") {
|
if strings.Contains(str, ".") {
|
||||||
|
@ -57,26 +60,5 @@ func (c SQLiteContext) NormalizeValue(value interface{}) interface{} {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return str
|
return str
|
||||||
}
|
}
|
||||||
return "'" + escapeSingleQuote(string(val)) + "'"
|
return "'" + utils.EscapeSingleQuote(string(val)) + "'"
|
||||||
}
|
|
||||||
|
|
||||||
func isSQLFunction(str string) bool {
|
|
||||||
stopChars := []string{" ", "_", "-", ".", "("}
|
|
||||||
isUpper := false
|
|
||||||
for _, char := range str {
|
|
||||||
if slices.Contains(stopChars, string(char)) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if unicode.IsUpper(char) {
|
|
||||||
isUpper = true
|
|
||||||
} else {
|
|
||||||
isUpper = false
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return isUpper
|
|
||||||
}
|
|
||||||
|
|
||||||
func escapeSingleQuote(str string) string {
|
|
||||||
return strings.ReplaceAll(str, "'", "''")
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package filters
|
||||||
type CtxOpts map[string]string
|
type CtxOpts map[string]string
|
||||||
type Context interface {
|
type Context interface {
|
||||||
New(opts CtxOpts) Context
|
New(opts CtxOpts) Context
|
||||||
|
GetTableName() string
|
||||||
GetFieldName() string
|
GetFieldName() string
|
||||||
NormalizeValue(interface{}) interface{}
|
NormalizeValue(interface{}) interface{}
|
||||||
}
|
}
|
||||||
|
|
28
pkg/utils/sql_format.go
Normal file
28
pkg/utils/sql_format.go
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
)
|
||||||
|
|
||||||
|
func IsSQLFunction(str string) bool {
|
||||||
|
stopChars := []string{" ", "_", "-", ".", "("}
|
||||||
|
isUpper := false
|
||||||
|
for _, char := range str {
|
||||||
|
if slices.Contains(stopChars, string(char)) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if unicode.IsUpper(char) {
|
||||||
|
isUpper = true
|
||||||
|
} else {
|
||||||
|
isUpper = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return isUpper
|
||||||
|
}
|
||||||
|
|
||||||
|
func EscapeSingleQuote(str string) string {
|
||||||
|
return strings.ReplaceAll(str, "'", "''")
|
||||||
|
}
|
Loading…
Reference in a new issue