[wip] convert insert
Signed-off-by: Anton Nesterov <anton@demiurg.io>
This commit is contained in:
parent
ffc8c2b841
commit
fde44ce343
35
pkg/dal/convert_fields.go
Normal file
35
pkg/dal/convert_fields.go
Normal file
|
@ -0,0 +1,35 @@
|
|||
package dal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func ConvertFields(ctx Context, fields []Map) (string, error) {
|
||||
var expressions []string
|
||||
for _, fieldAssoc := range fields {
|
||||
for field, as := range fieldAssoc {
|
||||
asBool, ok := as.(bool)
|
||||
if ok {
|
||||
if asBool {
|
||||
expressions = append(expressions, field)
|
||||
}
|
||||
continue
|
||||
}
|
||||
asNum, ok := as.(int)
|
||||
if ok {
|
||||
if asNum == 1 {
|
||||
expressions = append(expressions, field)
|
||||
}
|
||||
continue
|
||||
}
|
||||
asStr, ok := as.(string)
|
||||
if ok {
|
||||
expressions = append(expressions, fmt.Sprintf("%s AS %s", field, asStr))
|
||||
continue
|
||||
}
|
||||
return "", fmt.Errorf("invalid field value: %v", as)
|
||||
}
|
||||
}
|
||||
return strings.Join(expressions, ", "), nil
|
||||
}
|
58
pkg/dal/convert_fields_test.go
Normal file
58
pkg/dal/convert_fields_test.go
Normal file
|
@ -0,0 +1,58 @@
|
|||
package dal
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
filters "l12.xyz/dal/filters"
|
||||
)
|
||||
|
||||
func TestConvertFieldsBool(t *testing.T) {
|
||||
ctx := filters.SQLiteContext{
|
||||
TableAlias: "t",
|
||||
FieldName: "test",
|
||||
}
|
||||
result, err := ConvertFields(ctx, []Map{
|
||||
{"test": true},
|
||||
{"test2": false},
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if result != `test` {
|
||||
t.Errorf("Expected test, got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertFieldsInt(t *testing.T) {
|
||||
ctx := filters.SQLiteContext{
|
||||
TableAlias: "t",
|
||||
FieldName: "test",
|
||||
}
|
||||
result, err := ConvertFields(ctx, []Map{
|
||||
{"test": 0},
|
||||
{"test2": 1},
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if result != `test2` {
|
||||
t.Errorf("Expected test, got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertFieldsStr(t *testing.T) {
|
||||
ctx := filters.SQLiteContext{
|
||||
TableAlias: "t",
|
||||
FieldName: "test",
|
||||
}
|
||||
result, err := ConvertFields(ctx, []Map{
|
||||
{"t.test": "Test"},
|
||||
{"SUM(t.test, t.int)": "Sum"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if result != `t.test AS Test, SUM(t.test, t.int) AS Sum` {
|
||||
t.Errorf("Expected test, got %s", result)
|
||||
}
|
||||
}
|
38
pkg/dal/convert_insert.go
Normal file
38
pkg/dal/convert_insert.go
Normal file
|
@ -0,0 +1,38 @@
|
|||
package dal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type InsertData struct {
|
||||
Statement string
|
||||
Values []interface{}
|
||||
}
|
||||
|
||||
func ConvertInsert(ctx Context, inserts []Map) (InsertData, error) {
|
||||
keys := AggregateKeys(inserts)
|
||||
placeholder := make([]string, 0)
|
||||
for range keys {
|
||||
placeholder = append(placeholder, "?")
|
||||
}
|
||||
|
||||
values := make([]interface{}, 0)
|
||||
for _, insert := range inserts {
|
||||
vals := make([]interface{}, 0)
|
||||
for _, key := range keys {
|
||||
vals = append(vals, insert[key])
|
||||
}
|
||||
values = append(values, vals)
|
||||
}
|
||||
|
||||
sfmt := fmt.Sprintf(
|
||||
"INSERT INTO %s (%s) VALUES (%s)", ctx.GetTableName(),
|
||||
strings.Join(keys, ","),
|
||||
strings.Join(placeholder, ","),
|
||||
)
|
||||
return InsertData{
|
||||
Statement: sfmt,
|
||||
Values: values,
|
||||
}, nil
|
||||
}
|
29
pkg/dal/convert_insert_test.go
Normal file
29
pkg/dal/convert_insert_test.go
Normal file
|
@ -0,0 +1,29 @@
|
|||
package dal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
filters "l12.xyz/dal/filters"
|
||||
)
|
||||
|
||||
func TestConvertInsert(t *testing.T) {
|
||||
ctx := filters.SQLiteContext{
|
||||
TableName: "test",
|
||||
TableAlias: "t",
|
||||
}
|
||||
insert := []Map{
|
||||
{"a": "1", "b": 2},
|
||||
{"b": 2, "a": "1", "c": 3},
|
||||
}
|
||||
result, _ := ConvertInsert(ctx, insert)
|
||||
|
||||
if result.Statement != `INSERT INTO test (a,b,c) VALUES (?,?,?)` {
|
||||
t.Errorf(`Expected "INSERT INTO test (a,b,c) VALUES (?,?,?)", got %s`, result.Statement)
|
||||
}
|
||||
|
||||
for _, r := range result.Values {
|
||||
fmt.Println(r)
|
||||
}
|
||||
|
||||
}
|
|
@ -5,6 +5,7 @@ import (
|
|||
)
|
||||
|
||||
type Map = map[string]interface{}
|
||||
type Fields = Map
|
||||
type Find = filters.Find
|
||||
type Query = filters.Find
|
||||
type Filter = filters.Filter
|
||||
|
|
20
pkg/dal/utils.go
Normal file
20
pkg/dal/utils.go
Normal file
|
@ -0,0 +1,20 @@
|
|||
package dal
|
||||
|
||||
import "sort"
|
||||
|
||||
func AggregateKeys(maps []Map) []string {
|
||||
set := make(map[string]int)
|
||||
keys := make([]string, 0)
|
||||
for _, item := range maps {
|
||||
for k := range item {
|
||||
if set[k] == 1 {
|
||||
continue
|
||||
}
|
||||
keys = append(keys, k)
|
||||
set[k] = 1
|
||||
}
|
||||
}
|
||||
set = nil
|
||||
sort.Strings(keys)
|
||||
return keys
|
||||
}
|
|
@ -1,15 +1,14 @@
|
|||
package filters
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
utils "l12.xyz/dal/utils"
|
||||
)
|
||||
|
||||
type SQLiteContext struct {
|
||||
TableName string
|
||||
TableAlias string
|
||||
FieldName string
|
||||
}
|
||||
|
@ -29,6 +28,10 @@ func (c SQLiteContext) New(opts CtxOpts) Context {
|
|||
}
|
||||
}
|
||||
|
||||
func (c SQLiteContext) GetTableName() string {
|
||||
return c.TableName
|
||||
}
|
||||
|
||||
func (c SQLiteContext) GetFieldName() string {
|
||||
if strings.Contains(c.FieldName, ".") {
|
||||
return c.FieldName
|
||||
|
@ -41,7 +44,7 @@ func (c SQLiteContext) GetFieldName() string {
|
|||
|
||||
func (c SQLiteContext) NormalizeValue(value interface{}) interface{} {
|
||||
str, ok := value.(string)
|
||||
if isSQLFunction(str) {
|
||||
if utils.IsSQLFunction(str) {
|
||||
return str
|
||||
}
|
||||
if strings.Contains(str, ".") {
|
||||
|
@ -57,26 +60,5 @@ func (c SQLiteContext) NormalizeValue(value interface{}) interface{} {
|
|||
if err != nil {
|
||||
return str
|
||||
}
|
||||
return "'" + escapeSingleQuote(string(val)) + "'"
|
||||
}
|
||||
|
||||
func isSQLFunction(str string) bool {
|
||||
stopChars := []string{" ", "_", "-", ".", "("}
|
||||
isUpper := false
|
||||
for _, char := range str {
|
||||
if slices.Contains(stopChars, string(char)) {
|
||||
break
|
||||
}
|
||||
if unicode.IsUpper(char) {
|
||||
isUpper = true
|
||||
} else {
|
||||
isUpper = false
|
||||
break
|
||||
}
|
||||
}
|
||||
return isUpper
|
||||
}
|
||||
|
||||
func escapeSingleQuote(str string) string {
|
||||
return strings.ReplaceAll(str, "'", "''")
|
||||
return "'" + utils.EscapeSingleQuote(string(val)) + "'"
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package filters
|
|||
type CtxOpts map[string]string
|
||||
type Context interface {
|
||||
New(opts CtxOpts) Context
|
||||
GetTableName() string
|
||||
GetFieldName() string
|
||||
NormalizeValue(interface{}) interface{}
|
||||
}
|
||||
|
|
28
pkg/utils/sql_format.go
Normal file
28
pkg/utils/sql_format.go
Normal file
|
@ -0,0 +1,28 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
func IsSQLFunction(str string) bool {
|
||||
stopChars := []string{" ", "_", "-", ".", "("}
|
||||
isUpper := false
|
||||
for _, char := range str {
|
||||
if slices.Contains(stopChars, string(char)) {
|
||||
break
|
||||
}
|
||||
if unicode.IsUpper(char) {
|
||||
isUpper = true
|
||||
} else {
|
||||
isUpper = false
|
||||
break
|
||||
}
|
||||
}
|
||||
return isUpper
|
||||
}
|
||||
|
||||
func EscapeSingleQuote(str string) string {
|
||||
return strings.ReplaceAll(str, "'", "''")
|
||||
}
|
Loading…
Reference in a new issue