diff --git a/pkg/dal/builder.go b/pkg/dal/Builder.go similarity index 100% rename from pkg/dal/builder.go rename to pkg/dal/Builder.go diff --git a/pkg/dal/convert_fields.go b/pkg/dal/convert_fields.go new file mode 100644 index 0000000..a325bfb --- /dev/null +++ b/pkg/dal/convert_fields.go @@ -0,0 +1,35 @@ +package dal + +import ( + "fmt" + "strings" +) + +func ConvertFields(ctx Context, fields []Map) (string, error) { + var expressions []string + for _, fieldAssoc := range fields { + for field, as := range fieldAssoc { + asBool, ok := as.(bool) + if ok { + if asBool { + expressions = append(expressions, field) + } + continue + } + asNum, ok := as.(int) + if ok { + if asNum == 1 { + expressions = append(expressions, field) + } + continue + } + asStr, ok := as.(string) + if ok { + expressions = append(expressions, fmt.Sprintf("%s AS %s", field, asStr)) + continue + } + return "", fmt.Errorf("invalid field value: %v", as) + } + } + return strings.Join(expressions, ", "), nil +} diff --git a/pkg/dal/convert_fields_test.go b/pkg/dal/convert_fields_test.go new file mode 100644 index 0000000..405e1b7 --- /dev/null +++ b/pkg/dal/convert_fields_test.go @@ -0,0 +1,58 @@ +package dal + +import ( + "testing" + + filters "l12.xyz/dal/filters" +) + +func TestConvertFieldsBool(t *testing.T) { + ctx := filters.SQLiteContext{ + TableAlias: "t", + FieldName: "test", + } + result, err := ConvertFields(ctx, []Map{ + {"test": true}, + {"test2": false}, + }) + if err != nil { + t.Error(err) + } + if result != `test` { + t.Errorf("Expected test, got %s", result) + } +} + +func TestConvertFieldsInt(t *testing.T) { + ctx := filters.SQLiteContext{ + TableAlias: "t", + FieldName: "test", + } + result, err := ConvertFields(ctx, []Map{ + {"test": 0}, + {"test2": 1}, + }) + if err != nil { + t.Error(err) + } + if result != `test2` { + t.Errorf("Expected test, got %s", result) + } +} + +func TestConvertFieldsStr(t *testing.T) { + ctx := filters.SQLiteContext{ + TableAlias: "t", + FieldName: "test", + } + result, err := ConvertFields(ctx, []Map{ + {"t.test": "Test"}, + {"SUM(t.test, t.int)": "Sum"}, + }) + if err != nil { + t.Error(err) + } + if result != `t.test AS Test, SUM(t.test, t.int) AS Sum` { + t.Errorf("Expected test, got %s", result) + } +} diff --git a/pkg/dal/convert_insert.go b/pkg/dal/convert_insert.go new file mode 100644 index 0000000..8c17a54 --- /dev/null +++ b/pkg/dal/convert_insert.go @@ -0,0 +1,38 @@ +package dal + +import ( + "fmt" + "strings" +) + +type InsertData struct { + Statement string + Values []interface{} +} + +func ConvertInsert(ctx Context, inserts []Map) (InsertData, error) { + keys := AggregateKeys(inserts) + placeholder := make([]string, 0) + for range keys { + placeholder = append(placeholder, "?") + } + + values := make([]interface{}, 0) + for _, insert := range inserts { + vals := make([]interface{}, 0) + for _, key := range keys { + vals = append(vals, insert[key]) + } + values = append(values, vals) + } + + sfmt := fmt.Sprintf( + "INSERT INTO %s (%s) VALUES (%s)", ctx.GetTableName(), + strings.Join(keys, ","), + strings.Join(placeholder, ","), + ) + return InsertData{ + Statement: sfmt, + Values: values, + }, nil +} diff --git a/pkg/dal/convert_insert_test.go b/pkg/dal/convert_insert_test.go new file mode 100644 index 0000000..c816851 --- /dev/null +++ b/pkg/dal/convert_insert_test.go @@ -0,0 +1,29 @@ +package dal + +import ( + "fmt" + "testing" + + filters "l12.xyz/dal/filters" +) + +func TestConvertInsert(t *testing.T) { + ctx := filters.SQLiteContext{ + TableName: "test", + TableAlias: "t", + } + insert := []Map{ + {"a": "1", "b": 2}, + {"b": 2, "a": "1", "c": 3}, + } + result, _ := ConvertInsert(ctx, insert) + + if result.Statement != `INSERT INTO test (a,b,c) VALUES (?,?,?)` { + t.Errorf(`Expected "INSERT INTO test (a,b,c) VALUES (?,?,?)", got %s`, result.Statement) + } + + for _, r := range result.Values { + fmt.Println(r) + } + +} diff --git a/pkg/dal/types.go b/pkg/dal/types.go index 7dd2de1..8fa500e 100644 --- a/pkg/dal/types.go +++ b/pkg/dal/types.go @@ -5,6 +5,7 @@ import ( ) type Map = map[string]interface{} +type Fields = Map type Find = filters.Find type Query = filters.Find type Filter = filters.Filter diff --git a/pkg/dal/utils.go b/pkg/dal/utils.go new file mode 100644 index 0000000..1218150 --- /dev/null +++ b/pkg/dal/utils.go @@ -0,0 +1,20 @@ +package dal + +import "sort" + +func AggregateKeys(maps []Map) []string { + set := make(map[string]int) + keys := make([]string, 0) + for _, item := range maps { + for k := range item { + if set[k] == 1 { + continue + } + keys = append(keys, k) + set[k] = 1 + } + } + set = nil + sort.Strings(keys) + return keys +} diff --git a/pkg/filters/context.go b/pkg/filters/context.go index d4e0f83..03cb338 100644 --- a/pkg/filters/context.go +++ b/pkg/filters/context.go @@ -1,15 +1,14 @@ package filters import ( - "slices" "strconv" "strings" - "unicode" utils "l12.xyz/dal/utils" ) type SQLiteContext struct { + TableName string TableAlias string FieldName string } @@ -29,6 +28,10 @@ func (c SQLiteContext) New(opts CtxOpts) Context { } } +func (c SQLiteContext) GetTableName() string { + return c.TableName +} + func (c SQLiteContext) GetFieldName() string { if strings.Contains(c.FieldName, ".") { return c.FieldName @@ -41,7 +44,7 @@ func (c SQLiteContext) GetFieldName() string { func (c SQLiteContext) NormalizeValue(value interface{}) interface{} { str, ok := value.(string) - if isSQLFunction(str) { + if utils.IsSQLFunction(str) { return str } if strings.Contains(str, ".") { @@ -57,26 +60,5 @@ func (c SQLiteContext) NormalizeValue(value interface{}) interface{} { if err != nil { return str } - return "'" + escapeSingleQuote(string(val)) + "'" -} - -func isSQLFunction(str string) bool { - stopChars := []string{" ", "_", "-", ".", "("} - isUpper := false - for _, char := range str { - if slices.Contains(stopChars, string(char)) { - break - } - if unicode.IsUpper(char) { - isUpper = true - } else { - isUpper = false - break - } - } - return isUpper -} - -func escapeSingleQuote(str string) string { - return strings.ReplaceAll(str, "'", "''") + return "'" + utils.EscapeSingleQuote(string(val)) + "'" } diff --git a/pkg/filters/types.go b/pkg/filters/types.go index ff99d6f..dc40f79 100644 --- a/pkg/filters/types.go +++ b/pkg/filters/types.go @@ -3,6 +3,7 @@ package filters type CtxOpts map[string]string type Context interface { New(opts CtxOpts) Context + GetTableName() string GetFieldName() string NormalizeValue(interface{}) interface{} } diff --git a/pkg/utils/sql_format.go b/pkg/utils/sql_format.go new file mode 100644 index 0000000..ce3b069 --- /dev/null +++ b/pkg/utils/sql_format.go @@ -0,0 +1,28 @@ +package utils + +import ( + "slices" + "strings" + "unicode" +) + +func IsSQLFunction(str string) bool { + stopChars := []string{" ", "_", "-", ".", "("} + isUpper := false + for _, char := range str { + if slices.Contains(stopChars, string(char)) { + break + } + if unicode.IsUpper(char) { + isUpper = true + } else { + isUpper = false + break + } + } + return isUpper +} + +func EscapeSingleQuote(str string) string { + return strings.ReplaceAll(str, "'", "''") +}