[wip] convert insert

Signed-off-by: Anton Nesterov <anton@demiurg.io>
This commit is contained in:
Anton Nesterov 2024-08-09 16:14:42 +02:00
parent ffc8c2b841
commit fde44ce343
No known key found for this signature in database
GPG key ID: 59121E8AE2851FB5
10 changed files with 217 additions and 25 deletions

35
pkg/dal/convert_fields.go Normal file
View file

@ -0,0 +1,35 @@
package dal
import (
"fmt"
"strings"
)
func ConvertFields(ctx Context, fields []Map) (string, error) {
var expressions []string
for _, fieldAssoc := range fields {
for field, as := range fieldAssoc {
asBool, ok := as.(bool)
if ok {
if asBool {
expressions = append(expressions, field)
}
continue
}
asNum, ok := as.(int)
if ok {
if asNum == 1 {
expressions = append(expressions, field)
}
continue
}
asStr, ok := as.(string)
if ok {
expressions = append(expressions, fmt.Sprintf("%s AS %s", field, asStr))
continue
}
return "", fmt.Errorf("invalid field value: %v", as)
}
}
return strings.Join(expressions, ", "), nil
}

View file

@ -0,0 +1,58 @@
package dal
import (
"testing"
filters "l12.xyz/dal/filters"
)
func TestConvertFieldsBool(t *testing.T) {
ctx := filters.SQLiteContext{
TableAlias: "t",
FieldName: "test",
}
result, err := ConvertFields(ctx, []Map{
{"test": true},
{"test2": false},
})
if err != nil {
t.Error(err)
}
if result != `test` {
t.Errorf("Expected test, got %s", result)
}
}
func TestConvertFieldsInt(t *testing.T) {
ctx := filters.SQLiteContext{
TableAlias: "t",
FieldName: "test",
}
result, err := ConvertFields(ctx, []Map{
{"test": 0},
{"test2": 1},
})
if err != nil {
t.Error(err)
}
if result != `test2` {
t.Errorf("Expected test, got %s", result)
}
}
func TestConvertFieldsStr(t *testing.T) {
ctx := filters.SQLiteContext{
TableAlias: "t",
FieldName: "test",
}
result, err := ConvertFields(ctx, []Map{
{"t.test": "Test"},
{"SUM(t.test, t.int)": "Sum"},
})
if err != nil {
t.Error(err)
}
if result != `t.test AS Test, SUM(t.test, t.int) AS Sum` {
t.Errorf("Expected test, got %s", result)
}
}

38
pkg/dal/convert_insert.go Normal file
View file

@ -0,0 +1,38 @@
package dal
import (
"fmt"
"strings"
)
type InsertData struct {
Statement string
Values []interface{}
}
func ConvertInsert(ctx Context, inserts []Map) (InsertData, error) {
keys := AggregateKeys(inserts)
placeholder := make([]string, 0)
for range keys {
placeholder = append(placeholder, "?")
}
values := make([]interface{}, 0)
for _, insert := range inserts {
vals := make([]interface{}, 0)
for _, key := range keys {
vals = append(vals, insert[key])
}
values = append(values, vals)
}
sfmt := fmt.Sprintf(
"INSERT INTO %s (%s) VALUES (%s)", ctx.GetTableName(),
strings.Join(keys, ","),
strings.Join(placeholder, ","),
)
return InsertData{
Statement: sfmt,
Values: values,
}, nil
}

View file

@ -0,0 +1,29 @@
package dal
import (
"fmt"
"testing"
filters "l12.xyz/dal/filters"
)
func TestConvertInsert(t *testing.T) {
ctx := filters.SQLiteContext{
TableName: "test",
TableAlias: "t",
}
insert := []Map{
{"a": "1", "b": 2},
{"b": 2, "a": "1", "c": 3},
}
result, _ := ConvertInsert(ctx, insert)
if result.Statement != `INSERT INTO test (a,b,c) VALUES (?,?,?)` {
t.Errorf(`Expected "INSERT INTO test (a,b,c) VALUES (?,?,?)", got %s`, result.Statement)
}
for _, r := range result.Values {
fmt.Println(r)
}
}

View file

@ -5,6 +5,7 @@ import (
)
type Map = map[string]interface{}
type Fields = Map
type Find = filters.Find
type Query = filters.Find
type Filter = filters.Filter

20
pkg/dal/utils.go Normal file
View file

@ -0,0 +1,20 @@
package dal
import "sort"
func AggregateKeys(maps []Map) []string {
set := make(map[string]int)
keys := make([]string, 0)
for _, item := range maps {
for k := range item {
if set[k] == 1 {
continue
}
keys = append(keys, k)
set[k] = 1
}
}
set = nil
sort.Strings(keys)
return keys
}

View file

@ -1,15 +1,14 @@
package filters
import (
"slices"
"strconv"
"strings"
"unicode"
utils "l12.xyz/dal/utils"
)
type SQLiteContext struct {
TableName string
TableAlias string
FieldName string
}
@ -29,6 +28,10 @@ func (c SQLiteContext) New(opts CtxOpts) Context {
}
}
func (c SQLiteContext) GetTableName() string {
return c.TableName
}
func (c SQLiteContext) GetFieldName() string {
if strings.Contains(c.FieldName, ".") {
return c.FieldName
@ -41,7 +44,7 @@ func (c SQLiteContext) GetFieldName() string {
func (c SQLiteContext) NormalizeValue(value interface{}) interface{} {
str, ok := value.(string)
if isSQLFunction(str) {
if utils.IsSQLFunction(str) {
return str
}
if strings.Contains(str, ".") {
@ -57,26 +60,5 @@ func (c SQLiteContext) NormalizeValue(value interface{}) interface{} {
if err != nil {
return str
}
return "'" + escapeSingleQuote(string(val)) + "'"
}
func isSQLFunction(str string) bool {
stopChars := []string{" ", "_", "-", ".", "("}
isUpper := false
for _, char := range str {
if slices.Contains(stopChars, string(char)) {
break
}
if unicode.IsUpper(char) {
isUpper = true
} else {
isUpper = false
break
}
}
return isUpper
}
func escapeSingleQuote(str string) string {
return strings.ReplaceAll(str, "'", "''")
return "'" + utils.EscapeSingleQuote(string(val)) + "'"
}

View file

@ -3,6 +3,7 @@ package filters
type CtxOpts map[string]string
type Context interface {
New(opts CtxOpts) Context
GetTableName() string
GetFieldName() string
NormalizeValue(interface{}) interface{}
}

28
pkg/utils/sql_format.go Normal file
View file

@ -0,0 +1,28 @@
package utils
import (
"slices"
"strings"
"unicode"
)
func IsSQLFunction(str string) bool {
stopChars := []string{" ", "_", "-", ".", "("}
isUpper := false
for _, char := range str {
if slices.Contains(stopChars, string(char)) {
break
}
if unicode.IsUpper(char) {
isUpper = true
} else {
isUpper = false
break
}
}
return isUpper
}
func EscapeSingleQuote(str string) string {
return strings.ReplaceAll(str, "'", "''")
}