diff --git a/pkg/__test__/builder_test.go b/pkg/__test__/builder_test.go index 5108013..db84773 100644 --- a/pkg/__test__/builder_test.go +++ b/pkg/__test__/builder_test.go @@ -31,11 +31,12 @@ func TestBuilderBasic(t *testing.T) { t.Fatalf("failed to insert data: %v", err) } - expr, _ := builder.New(adapter.SQLite{}).In("test t").Find(builder.Find{"name": builder.Is{"$in": []interface{}{"a", 'b'}}}).Sql() + expr, values := builder.New(adapter.SQLite{}).In("test t").Find(builder.Find{"name": builder.Is{"$in": []interface{}{"a", 98}}}).Sql() fmt.Println(expr) rows, err := a.Query(adapter.Query{ Db: "file::memory:?cache=shared", Expression: expr, + Data: values, }) if err != nil { t.Fatalf("failed to query data: %v", err) diff --git a/pkg/adapter/SQLite.go b/pkg/adapter/SQLite.go index 150dece..4863c9f 100644 --- a/pkg/adapter/SQLite.go +++ b/pkg/adapter/SQLite.go @@ -58,7 +58,13 @@ func (c SQLite) GetColumnName(key string) string { } func (c SQLite) NormalizeValue(value interface{}) interface{} { - str, ok := value.(string) + str, isStr := value.(string) + if !isStr { + return value + } + if str == "?" { + return str + } if utils.IsSQLFunction(str) { return str } @@ -68,12 +74,9 @@ func (c SQLite) NormalizeValue(value interface{}) interface{} { return value } } - if !ok { - return value - } val, err := utils.EscapeSQL(str) if err != nil { return str } - return "'" + utils.EscapeSingleQuote(string(val)) + "'" + return string(val) } diff --git a/pkg/builder/Builder.go b/pkg/builder/Builder.go index 4fc3da6..c8e0214 100644 --- a/pkg/builder/Builder.go +++ b/pkg/builder/Builder.go @@ -24,6 +24,7 @@ type SQLParts struct { OrderExp string LimitExp string OffsetExp string + Values []interface{} Insert InsertData Update UpdateData } @@ -49,7 +50,7 @@ func (b *Builder) In(table string) *Builder { } func (b *Builder) Find(query Find) *Builder { - b.Parts.FiterExp = covertFind( + b.Parts.FiterExp, b.Parts.Values = covertFind( b.Dialect, query, ) @@ -79,7 +80,9 @@ func (b *Builder) Fields(fields ...Map) *Builder { } func (b *Builder) Join(joins ...interface{}) *Builder { - b.Parts.JoinExps = convertJoin(b.Dialect, joins...) + exps, vals := convertJoin(b.Dialect, joins...) + b.Parts.JoinExps = append(b.Parts.JoinExps, exps...) + b.Parts.Values = append(b.Parts.Values, vals...) return b } @@ -178,7 +181,7 @@ func (b *Builder) Sql() (string, []interface{}) { b.Parts.OrderExp, b.Parts.LimitExp, b.Parts.OffsetExp, - }, " ")), nil + }, " ")), b.Parts.Values case operation == "DELETE": return unspace(strings.Join([]string{ b.Parts.Operation, @@ -189,7 +192,7 @@ func (b *Builder) Sql() (string, []interface{}) { b.Parts.OrderExp, b.Parts.LimitExp, b.Parts.OffsetExp, - }, " ")), nil + }, " ")), b.Parts.Values case operation == "INSERT INTO": return b.Parts.Insert.Statement, b.Parts.Insert.Values case operation == "UPDATE": diff --git a/pkg/builder/Builder_test.go b/pkg/builder/Builder_test.go index f7da7b1..6b9bae8 100644 --- a/pkg/builder/Builder_test.go +++ b/pkg/builder/Builder_test.go @@ -1,6 +1,7 @@ package builder import ( + "fmt" "testing" ) @@ -10,7 +11,7 @@ func TestBuilderFind(t *testing.T) { "field": "value", "a": 1, }) - expect := "SELECT * FROM table t WHERE t.a = 1 AND t.field = 'value'" + expect := "SELECT * FROM table t WHERE t.a = ? AND t.field = ?" result, _ := db.Sql() if result != expect { t.Errorf(`Expected: "%s", Got: %s`, expect, result) @@ -28,7 +29,7 @@ func TestBuilderFields(t *testing.T) { "t.field": "field", "t.a": 1, }) - expect := "SELECT t.a, t.field AS field FROM table t WHERE t.a = 1 AND t.field = 'value'" + expect := "SELECT t.a, t.field AS field FROM table t WHERE t.a = ? AND t.field = ?" result, _ := db.Sql() if result != expect { t.Errorf(`Expected: "%s", Got: %s`, expect, result) @@ -47,8 +48,9 @@ func TestBuilderGroup(t *testing.T) { "SUM(t.field)": "field", }) db.Group("field") - expect := "SELECT SUM(t.field) AS field FROM table t GROUP BY t.field HAVING t.field > 1" + expect := "SELECT SUM(t.field) AS field FROM table t GROUP BY t.field HAVING t.field > ?" result, _ := db.Sql() + fmt.Println(db.Parts.Values) if result != expect { t.Errorf(`Expected: "%s", Got: %s`, expect, result) } @@ -67,7 +69,7 @@ func TestBuilderJoin(t *testing.T) { "t2.field": "t.field", }, }) - expect := "SELECT * FROM table t JOIN table2 t2 ON t2.field = t.field WHERE t.a = 1 AND t.field = 'value'" + expect := "SELECT * FROM table t JOIN table2 t2 ON t2.field = t.field WHERE t.a = ? AND t.field = ?" result, _ := db.Sql() if result != expect { t.Errorf(`Expected: "%s", Got: %s`, expect, result) diff --git a/pkg/builder/convert_find.go b/pkg/builder/convert_find.go index 9e43ffd..4ea47b6 100644 --- a/pkg/builder/convert_find.go +++ b/pkg/builder/convert_find.go @@ -7,33 +7,39 @@ import ( filters "l12.xyz/dal/filters" ) -func covertFind(ctx Dialect, find Find) string { +type Values = []interface{} + +func covertFind(ctx Dialect, find Find) (string, Values) { return covert_find(ctx, find, "") } -func covert_find(ctx Dialect, find Find, join string) string { +func covert_find(ctx Dialect, find Find, join string) (string, Values) { if join == "" { join = " AND " } keys := aggregateSortedKeys([]Map{find}) expressions := []string{} + values := Values{} for _, key := range keys { value := find[key] if strings.Contains(key, "$and") { - v := covert_find(ctx, value.(Find), "") - expressions = append(expressions, fmt.Sprintf("(%s)", v)) + exp, vals := covert_find(ctx, value.(Find), "") + values = append(values, vals...) + expressions = append(expressions, fmt.Sprintf("(%s)", exp)) continue } if strings.Contains(key, "$or") { - v := covert_find(ctx, value.(Find), " OR ") - expressions = append(expressions, fmt.Sprintf("(%s)", v)) + exp, vals := covert_find(ctx, value.(Find), " OR ") + values = append(values, vals...) + expressions = append(expressions, fmt.Sprintf("(%s)", exp)) continue } context := ctx.New(DialectOpts{ "FieldName": key, }) - values, _ := filters.Convert(context, value) - expressions = append(expressions, values) + expr, vals := filters.Convert(context, value) + values = append(values, vals...) + expressions = append(expressions, expr) } - return strings.Join(expressions, join) + return strings.Join(expressions, join), values } diff --git a/pkg/builder/convert_find_test.go b/pkg/builder/convert_find_test.go index 80db309..4a737ac 100644 --- a/pkg/builder/convert_find_test.go +++ b/pkg/builder/convert_find_test.go @@ -1,6 +1,7 @@ package builder import ( + "fmt" "testing" ) @@ -8,17 +9,20 @@ func TestConvertFind(t *testing.T) { find := Find{ "impl": "1", "exp": Is{ - "$gt": 1, + "$gt": 2, }, } ctx := SQLiteContext{ TableAlias: "t", } - result := covertFind(ctx, find) - if result == `t.exp > 1 AND t.impl = '1'` { - return + result, values := covertFind(ctx, find) + if values[1] != "1" { + t.Errorf("Expected '1', got %v", values[1]) } - if result == `t.impl = '1' AND t.exp > 1` { + if values[0].(float64) != 2 { + t.Errorf("Expected 2, got %v", values[0]) + } + if result == `t.exp > ? AND t.impl = ?` { return } t.Errorf(`Expected "t.impl = '1' AND t.exp = 1", got %s`, result) @@ -38,14 +42,12 @@ func TestConvertFindAnd(t *testing.T) { ctx := SQLiteContext{ TableAlias: "t", } - result := covertFind(ctx, find) - if result == `(t.a > 1 AND t.b < 10)` { + result, values := covertFind(ctx, find) + fmt.Println(values) + if result == `(t.a > ? AND t.b < ?)` { return } - if result == `(t.b < 10 AND t.a > 1)` { - return - } - t.Errorf(`Expected "(t.b < 10 AND t.a > 1)", got %s`, result) + t.Errorf(`Expected "(t.b < ? AND t.a > ?)", got %s`, result) } func TestConvertFindOr(t *testing.T) { @@ -62,12 +64,10 @@ func TestConvertFindOr(t *testing.T) { ctx := SQLiteContext{ TableAlias: "t", } - result := covertFind(ctx, find) - if result == `(t.a > 1 OR t.b < 10)` { + result, values := covertFind(ctx, find) + fmt.Println(values) + if result == `(t.a > ? OR t.b < ?)` { return } - if result == `(t.b < 10 OR t.a > 1)` { - return - } - t.Errorf(`Expected "(t.b < 10 OR t.a > 1)", got %s`, result) + t.Errorf(`Expected "(t.b < ? OR t.a > ?)", got %s`, result) } diff --git a/pkg/builder/convert_join.go b/pkg/builder/convert_join.go index 7bb4955..cd200c1 100644 --- a/pkg/builder/convert_join.go +++ b/pkg/builder/convert_join.go @@ -11,27 +11,30 @@ type Join struct { As string `json:"$as"` } -func (j Join) Convert(ctx Dialect) string { +func (j Join) Convert(ctx Dialect) (string, Values) { if j.For == "" { - return "" + return "", nil } - filter := covertFind(ctx, j.Do) + filter, values := covertFind(ctx, j.Do) var as string = "" if j.As != "" { as = fmt.Sprintf("%s ", j.As) } - return as + fmt.Sprintf("JOIN %s ON %s", j.For, filter) + return as + fmt.Sprintf("JOIN %s ON %s", j.For, filter), values } -func convertJoin(ctx Dialect, joins ...interface{}) []string { +func convertJoin(ctx Dialect, joins ...interface{}) ([]string, Values) { var result []string + var values Values for _, join := range joins { jstr, ok := join.(string) if ok { jjson := Join{} err := json.Unmarshal([]byte(jstr), &jjson) if err == nil { - result = append(result, jjson.Convert(ctx)) + r, vals := jjson.Convert(ctx) + result = append(result, r) + values = append(values, vals...) } continue } @@ -44,7 +47,9 @@ func convertJoin(ctx Dialect, joins ...interface{}) []string { } err = json.Unmarshal(jstr, &jjson) if err == nil { - result = append(result, jjson.Convert(ctx)) + r, vals := jjson.Convert(ctx) + result = append(result, r) + values = append(values, vals...) } continue } @@ -52,7 +57,9 @@ func convertJoin(ctx Dialect, joins ...interface{}) []string { if !ok { continue } - result = append(result, j.Convert(ctx)) + r, vals := j.Convert(ctx) + result = append(result, r) + values = append(values, vals...) } - return result + return result, values } diff --git a/pkg/builder/convert_join_test.go b/pkg/builder/convert_join_test.go index c32601c..11255a3 100644 --- a/pkg/builder/convert_join_test.go +++ b/pkg/builder/convert_join_test.go @@ -1,6 +1,7 @@ package builder import ( + "fmt" "testing" adapter "l12.xyz/dal/adapter" @@ -19,7 +20,8 @@ func TestJoin(t *testing.T) { ctx := SQLiteContext{ TableAlias: "t", } - result := j.Convert(ctx) + result, vals := j.Convert(ctx) + fmt.Println("Join:", vals) if result == `LEFT JOIN artist a ON a.impl = t.impl` { return } @@ -39,7 +41,8 @@ func TestConvertJoin(t *testing.T) { ctx := SQLiteContext{ TableAlias: "t", } - result := convertJoin(ctx, joins...) + result, vals := convertJoin(ctx, joins...) + fmt.Println("Join:", vals) if result[1] != `JOIN artist a ON a.impl = t.impl` { t.Errorf(`Expected "JOIN artist a ON a.impl = t.impl", got %s`, result[1]) } @@ -56,7 +59,8 @@ func TestConvertMap(t *testing.T) { ctx := SQLiteContext{ TableAlias: "t", } - result := convertJoin(ctx, joins...) + result, vals := convertJoin(ctx, joins...) + fmt.Println("Join:", vals) if result[0] != `LEFT JOIN artist a ON a.impl = t.impl` { t.Errorf(`Expected "LEFT JOIN artist a ON a.impl = t.impl", got %s`, result[0]) } diff --git a/pkg/filters/And.go b/pkg/filters/And.go index 95ef9c3..4da7646 100644 --- a/pkg/filters/And.go +++ b/pkg/filters/And.go @@ -9,12 +9,12 @@ type And struct { And []string `json:"$and"` } -func (f And) ToSQLPart(ctx Dialect) string { +func (f And) ToSQLPart(ctx Dialect) (string, Values) { if f.And == nil { - return "" + return "", nil } value := strings.Join(f.And, " AND ") - return fmt.Sprintf("(%s)", value) + return fmt.Sprintf("(%s)", value), nil } func (a And) FromJSON(data interface{}) IFilter { diff --git a/pkg/filters/Between.go b/pkg/filters/Between.go index cc0a6bb..a969066 100644 --- a/pkg/filters/Between.go +++ b/pkg/filters/Between.go @@ -14,12 +14,13 @@ func (f Between) FromJSON(data interface{}) IFilter { return FromJson[Between](data) } -func (f Between) ToSQLPart(ctx Dialect) string { +func (f Between) ToSQLPart(ctx Dialect) (string, Values) { if f.Between == nil { - return "" + return "", nil } name := ctx.GetFieldName() values := utils.Map(f.Between, ctx.NormalizeValue) - condition := fmt.Sprintf("%v AND %v", values[0], values[1]) - return fmt.Sprintf("%s BETWEEN %v", name, condition) + placeholders := utils.Map(values, ValueOrPlaceholder) + condition := fmt.Sprintf("%s AND %s", placeholders[0], placeholders[1]) + return fmt.Sprintf("%s BETWEEN %v", name, condition), values } diff --git a/pkg/filters/Eq.go b/pkg/filters/Eq.go index ff18f99..e9bc4d6 100644 --- a/pkg/filters/Eq.go +++ b/pkg/filters/Eq.go @@ -12,14 +12,14 @@ func (f Eq) FromJSON(data interface{}) IFilter { return FromJson[Eq](data) } -func (f Eq) ToSQLPart(ctx Dialect) string { +func (f Eq) ToSQLPart(ctx Dialect) (string, Values) { if f.Eq == nil { - return "" + return "", nil } name := ctx.GetFieldName() value := ctx.NormalizeValue(f.Eq) if value == "NULL" { - return fmt.Sprintf("%s IS NULL", name) + return fmt.Sprintf("%s IS NULL", ValueOrPlaceholder(name)), Values{name} } - return fmt.Sprintf("%s = %v", name, value) + return FmtCompare("=", name, value) } diff --git a/pkg/filters/Glob.go b/pkg/filters/Glob.go index e542d2a..7c97c8b 100644 --- a/pkg/filters/Glob.go +++ b/pkg/filters/Glob.go @@ -10,11 +10,11 @@ func (f Glob) FromJSON(data interface{}) IFilter { return FromJson[Glob](data) } -func (f Glob) ToSQLPart(ctx Dialect) string { +func (f Glob) ToSQLPart(ctx Dialect) (string, Values) { if f.Glob == nil { - return "" + return "", nil } name := ctx.GetFieldName() value := ctx.NormalizeValue(f.Glob) - return fmt.Sprintf("%s GLOB %v", name, value) + return fmt.Sprintf("%s GLOB ?", name), Values{value} } diff --git a/pkg/filters/Gt.go b/pkg/filters/Gt.go index 6b5eec8..0ea6253 100644 --- a/pkg/filters/Gt.go +++ b/pkg/filters/Gt.go @@ -1,9 +1,5 @@ package filters -import ( - "fmt" -) - type Gt struct { Gt interface{} `json:"$gt"` } @@ -12,11 +8,11 @@ func (f Gt) FromJSON(data interface{}) IFilter { return FromJson[Gt](data) } -func (f Gt) ToSQLPart(ctx Dialect) string { +func (f Gt) ToSQLPart(ctx Dialect) (string, Values) { if f.Gt == nil { - return "" + return "", nil } name := ctx.GetFieldName() value := ctx.NormalizeValue(f.Gt) - return fmt.Sprintf("%s > %v", name, value) + return FmtCompare(">", name, value) } diff --git a/pkg/filters/Gte.go b/pkg/filters/Gte.go index a6fceb5..afbcb53 100644 --- a/pkg/filters/Gte.go +++ b/pkg/filters/Gte.go @@ -1,9 +1,5 @@ package filters -import ( - "fmt" -) - type Gte struct { Gte interface{} `json:"$gte"` } @@ -12,11 +8,11 @@ func (f Gte) FromJSON(data interface{}) IFilter { return FromJson[Gte](data) } -func (f Gte) ToSQLPart(ctx Dialect) string { +func (f Gte) ToSQLPart(ctx Dialect) (string, Values) { if f.Gte == nil { - return "" + return "", nil } name := ctx.GetFieldName() value := ctx.NormalizeValue(f.Gte) - return fmt.Sprintf("%s >= %v", name, value) + return FmtCompare(">=", name, value) } diff --git a/pkg/filters/In.go b/pkg/filters/In.go index 7acbef4..1c4d5c5 100644 --- a/pkg/filters/In.go +++ b/pkg/filters/In.go @@ -15,17 +15,22 @@ func (f In) FromJSON(data interface{}) IFilter { return FromJson[In](data) } -func (f In) ToSQLPart(ctx Dialect) string { +func (f In) ToSQLPart(ctx Dialect) (string, Values) { if f.In == nil { - return "" + return "", nil } name := ctx.GetFieldName() values := utils.Map(f.In, ctx.NormalizeValue) + returnValues := make(Values, 0) data := make([]string, len(values)) - for i, v := range values { - data[i] = fmt.Sprintf("%v", v) + for i, value := range values { + val := ValueOrPlaceholder(value).(string) + data[i] = val + if val == "?" { + returnValues = append(returnValues, value) + } } value := strings.Join(data, ", ") - return fmt.Sprintf("%s IN (%v)", name, value) + return fmt.Sprintf("%s IN (%v)", name, value), returnValues } diff --git a/pkg/filters/Like.go b/pkg/filters/Like.go index 13bf787..b940b24 100644 --- a/pkg/filters/Like.go +++ b/pkg/filters/Like.go @@ -10,11 +10,11 @@ func (f Like) FromJSON(data interface{}) IFilter { return FromJson[Like](data) } -func (f Like) ToSQLPart(ctx Dialect) string { +func (f Like) ToSQLPart(ctx Dialect) (string, Values) { if f.Like == nil { - return "" + return "", nil } name := ctx.GetFieldName() value := ctx.NormalizeValue(f.Like) - return fmt.Sprintf("%s LIKE %v ESCAPE '\\'", name, value) + return fmt.Sprintf("%s LIKE ? ESCAPE '\\'", name), Values{value} } diff --git a/pkg/filters/Lt.go b/pkg/filters/Lt.go index cc79a8d..e63c547 100644 --- a/pkg/filters/Lt.go +++ b/pkg/filters/Lt.go @@ -1,9 +1,5 @@ package filters -import ( - "fmt" -) - type Lt struct { Lt interface{} `json:"$lt"` } @@ -12,11 +8,11 @@ func (f Lt) FromJSON(data interface{}) IFilter { return FromJson[Lt](data) } -func (f Lt) ToSQLPart(ctx Dialect) string { +func (f Lt) ToSQLPart(ctx Dialect) (string, Values) { if f.Lt == nil { - return "" + return "", nil } name := ctx.GetFieldName() value := ctx.NormalizeValue(f.Lt) - return fmt.Sprintf("%s < %v", name, value) + return FmtCompare("<", name, value) } diff --git a/pkg/filters/Lte.go b/pkg/filters/Lte.go index dd04a7a..893e2bf 100644 --- a/pkg/filters/Lte.go +++ b/pkg/filters/Lte.go @@ -1,9 +1,5 @@ package filters -import ( - "fmt" -) - type Lte struct { Lte interface{} `json:"$lte"` } @@ -12,11 +8,11 @@ func (f Lte) FromJSON(data interface{}) IFilter { return FromJson[Lte](data) } -func (f Lte) ToSQLPart(ctx Dialect) string { +func (f Lte) ToSQLPart(ctx Dialect) (string, Values) { if f.Lte == nil { - return "" + return "", nil } name := ctx.GetFieldName() value := ctx.NormalizeValue(f.Lte) - return fmt.Sprintf("%s <= %v", name, value) + return FmtCompare("<=", name, value) } diff --git a/pkg/filters/Ne.go b/pkg/filters/Ne.go index 067d612..34234b9 100644 --- a/pkg/filters/Ne.go +++ b/pkg/filters/Ne.go @@ -1,7 +1,5 @@ package filters -import "fmt" - type Ne struct { Ne interface{} `json:"$ne"` } @@ -10,14 +8,11 @@ func (f Ne) FromJSON(data interface{}) IFilter { return FromJson[Ne](data) } -func (f Ne) ToSQLPart(ctx Dialect) string { +func (f Ne) ToSQLPart(ctx Dialect) (string, Values) { if f.Ne == nil { - return "" + return "", nil } name := ctx.GetFieldName() value := ctx.NormalizeValue(f.Ne) - if value == "NULL" { - return fmt.Sprintf("%s IS NOT NULL", name) - } - return fmt.Sprintf("%s != %v", name, value) + return FmtCompare("!=", name, value) } diff --git a/pkg/filters/NotBetween.go b/pkg/filters/NotBetween.go index d2568df..85a42b9 100644 --- a/pkg/filters/NotBetween.go +++ b/pkg/filters/NotBetween.go @@ -14,12 +14,13 @@ func (f NotBetween) FromJSON(data interface{}) IFilter { return FromJson[NotBetween](data) } -func (f NotBetween) ToSQLPart(ctx Dialect) string { +func (f NotBetween) ToSQLPart(ctx Dialect) (string, Values) { if f.NotBetween == nil { - return "" + return "", nil } name := ctx.GetFieldName() values := utils.Map(f.NotBetween, ctx.NormalizeValue) - condition := fmt.Sprintf("%v AND %v", values[0], values[1]) - return fmt.Sprintf("%s NOT BETWEEN %v", name, condition) + placeholders := utils.Map(values, ValueOrPlaceholder) + condition := fmt.Sprintf("%s AND %s", placeholders[0], placeholders[1]) + return fmt.Sprintf("%s NOT BETWEEN %v", name, condition), values } diff --git a/pkg/filters/NotIn.go b/pkg/filters/NotIn.go index 358bbe7..41d22f6 100644 --- a/pkg/filters/NotIn.go +++ b/pkg/filters/NotIn.go @@ -15,17 +15,22 @@ func (f NotIn) FromJSON(data interface{}) IFilter { return FromJson[NotIn](data) } -func (f NotIn) ToSQLPart(ctx Dialect) string { +func (f NotIn) ToSQLPart(ctx Dialect) (string, Values) { if f.NotIn == nil { - return "" + return "", nil } name := ctx.GetFieldName() values := utils.Map(f.NotIn, ctx.NormalizeValue) + returnValues := make(Values, 0) data := make([]string, len(values)) - for i, v := range values { - data[i] = fmt.Sprintf("%v", v) + for i, value := range values { + val := ValueOrPlaceholder(value).(string) + data[i] = val + if val == "?" { + returnValues = append(returnValues, value) + } } value := strings.Join(data, ", ") - return fmt.Sprintf("%s NOT IN (%v)", name, value) + return fmt.Sprintf("%s NOT IN (%v)", name, value), returnValues } diff --git a/pkg/filters/NotLike.go b/pkg/filters/NotLike.go index 6cc2c59..e0fd6f2 100644 --- a/pkg/filters/NotLike.go +++ b/pkg/filters/NotLike.go @@ -10,11 +10,11 @@ func (f NotLike) FromJSON(data interface{}) IFilter { return FromJson[NotLike](data) } -func (f NotLike) ToSQLPart(ctx Dialect) string { +func (f NotLike) ToSQLPart(ctx Dialect) (string, Values) { if f.NotLike == nil { - return "" + return "", nil } name := ctx.GetFieldName() value := ctx.NormalizeValue(f.NotLike) - return fmt.Sprintf("%s NOT LIKE %v ESCAPE '\\'", name, value) + return fmt.Sprintf("%s NOT LIKE ? ESCAPE '\\'", name), Values{value} } diff --git a/pkg/filters/Or.go b/pkg/filters/Or.go index fed000a..d10aa3b 100644 --- a/pkg/filters/Or.go +++ b/pkg/filters/Or.go @@ -9,12 +9,12 @@ type Or struct { Or []string `json:"$or"` } -func (f Or) ToSQLPart(ctx Dialect) string { +func (f Or) ToSQLPart(ctx Dialect) (string, Values) { if f.Or == nil { - return "" + return "", nil } value := strings.Join(f.Or, " OR ") - return fmt.Sprintf("(%s)", value) + return fmt.Sprintf("(%s)", value), nil } func (a Or) FromJSON(data interface{}) IFilter { diff --git a/pkg/filters/registry.go b/pkg/filters/registry.go index 6e9d167..03c3644 100644 --- a/pkg/filters/registry.go +++ b/pkg/filters/registry.go @@ -22,17 +22,16 @@ var FilterRegistry = map[string]IFilter{ "NotLike": &NotLike{}, } -func Convert(ctx Dialect, data interface{}) (string, error) { +func Convert(ctx Dialect, data interface{}) (string, []interface{}) { for _, impl := range FilterRegistry { filter := impl.FromJSON(data) if reflect.DeepEqual(impl, filter) { continue } - value := filter.ToSQLPart(ctx) - if value != "" { - return value, nil + sfmt, values := filter.ToSQLPart(ctx) + if sfmt != "" { + return sfmt, values } } - value := Eq{Eq: data}.ToSQLPart(ctx) - return value, nil + return Eq{Eq: data}.ToSQLPart(ctx) } diff --git a/pkg/filters/types.go b/pkg/filters/types.go index 13f244d..ea3aa77 100644 --- a/pkg/filters/types.go +++ b/pkg/filters/types.go @@ -4,9 +4,9 @@ import "l12.xyz/dal/adapter" type DialectOpts = adapter.DialectOpts type Dialect = adapter.Dialect - +type Values = []interface{} type IFilter interface { - ToSQLPart(ctx Dialect) string + ToSQLPart(ctx Dialect) (string, Values) FromJSON(interface{}) IFilter } diff --git a/pkg/filters/unit_test.go b/pkg/filters/unit_test.go index b377921..d776ec8 100644 --- a/pkg/filters/unit_test.go +++ b/pkg/filters/unit_test.go @@ -1,6 +1,7 @@ package filters import ( + "fmt" "testing" adapter "l12.xyz/dal/adapter" @@ -29,10 +30,13 @@ func TestGte(t *testing.T) { TableAlias: "t", FieldName: "test", } - result, _ := Convert(ctx, `{"$gte": 1}`) + result, vals := Convert(ctx, `{"$gte": 1}`) resultMap, _ := Convert(ctx, Filter{"$gte": 1}) - if result != `t.test >= 1` { - t.Errorf("Expected t.test >= 1, got %s", result) + if vals[0].(float64) != 1 { + t.Errorf("Expected 1, got %v", vals[0]) + } + if result != `t.test >= ?` { + t.Errorf("Expected t.test >= ?, got %s", result) } if resultMap != result { t.Log(resultMap) @@ -46,8 +50,8 @@ func TestNe(t *testing.T) { } result, _ := Convert(ctx, `{"$ne": "1"}`) resultMap, _ := Convert(ctx, Filter{"$ne": "1"}) - if result != `test != '1'` { - t.Errorf("Expected test != '1', got %s", result) + if result != `test != ?` { + t.Errorf("Expected test != ?, got %s", result) } if resultMap != result { t.Log(resultMap) @@ -59,10 +63,11 @@ func TestBetween(t *testing.T) { ctx := SQLiteContext{ FieldName: "test", } - result, _ := Convert(ctx, `{"$between": ["1", "5"]}`) + result, vals := Convert(ctx, `{"$between": ["1", "5"]}`) + fmt.Println(vals) resultMap, _ := Convert(ctx, Filter{"$between": []string{"1", "5"}}) - if result != `test BETWEEN '1' AND '5'` { - t.Errorf("Expected test BETWEEN '1' AND '5', got %s", result) + if result != `test BETWEEN ? AND ?` { + t.Errorf("Expected test BETWEEN ? AND ?, got %s", result) } if resultMap != result { t.Log(resultMap) @@ -76,8 +81,8 @@ func TestNotBetween(t *testing.T) { } result, _ := Convert(ctx, `{"$nbetween": ["1", "5"]}`) resultMap, _ := Convert(ctx, Filter{"$nbetween": []string{"1", "5"}}) - if result != `test NOT BETWEEN '1' AND '5'` { - t.Errorf("Expected test BETWEEN '1' AND '5', got %s", result) + if result != `test NOT BETWEEN ? AND ?` { + t.Errorf("Expected test NOT BETWEEN ? AND ?, got %s", result) } if resultMap != result { t.Log(resultMap) @@ -90,9 +95,12 @@ func TestGlob(t *testing.T) { TableAlias: "t", FieldName: "test", } - result, _ := Convert(ctx, `{"$glob": "*son"}`) + result, vals := Convert(ctx, `{"$glob": "*son"}`) resultMap, _ := Convert(ctx, Filter{"$glob": "*son"}) - if result != `t.test GLOB '*son'` { + if vals[0].(string) != "*son" { + t.Errorf("Expected *son, got %v", vals[0]) + } + if result != `t.test GLOB ?` { t.Errorf("Expected t.test GLOB '*son', got %s", result) } if resultMap != result { @@ -108,8 +116,8 @@ func TestIn(t *testing.T) { } result, _ := Convert(ctx, `{"$in": [1, 2, 3]}`) resultMap, _ := Convert(ctx, Filter{"$in": []int{1, 2, 3}}) - if result != `t.test IN (1, 2, 3)` { - t.Errorf("Expected t.test IN (1, 2, 3), got %s", result) + if result != `t.test IN (?, ?, ?)` { + t.Errorf("Expected t.test IN (?, ?, ?), got %s", result) } if resultMap != result { t.Log(resultMap) @@ -122,10 +130,13 @@ func TestNotIn(t *testing.T) { TableAlias: "t", FieldName: "test", } - result, _ := Convert(ctx, `{"$nin": [1, 2, 3]}`) + result, vals := Convert(ctx, `{"$nin": [1, 2, 3]}`) resultMap, _ := Convert(ctx, Filter{"$nin": []int{1, 2, 3}}) - if result != `t.test NOT IN (1, 2, 3)` { - t.Errorf("Expected t.test NOT IN (1, 2, 3), got %s", result) + if vals[1].(float64) != 2 { + t.Errorf("Expected 1, got %v", vals[1]) + } + if result != `t.test NOT IN (?, ?, ?)` { + t.Errorf("Expected t.test NOT IN (?, ?, ?), got %s", result) } if resultMap != result { t.Log(resultMap) @@ -138,10 +149,13 @@ func TestLike(t *testing.T) { TableAlias: "t", FieldName: "test", } - result, _ := Convert(ctx, `{"$like": "199_"}`) + result, vals := Convert(ctx, `{"$like": "199_"}`) resultMap, _ := Convert(ctx, Filter{"$like": "199_"}) - if result != `t.test LIKE '199_' ESCAPE '\'` { - t.Errorf("Expected t.test LIKE '199_' ESCAPE '\\', got %s", result) + if vals[0].(string) != "199_" { + t.Errorf("Expected 199_, got %v", vals[0]) + } + if result != `t.test LIKE ? ESCAPE '\'` { + t.Errorf("Expected t.test LIKE ? ESCAPE '\\', got %s", result) } if resultMap != result { t.Log(resultMap) diff --git a/pkg/filters/utils.go b/pkg/filters/utils.go index a2e58a8..cdb6c93 100644 --- a/pkg/filters/utils.go +++ b/pkg/filters/utils.go @@ -2,6 +2,8 @@ package filters import ( "encoding/json" + "fmt" + "strings" ) func FromJson[T IFilter](data interface{}) *T { @@ -27,3 +29,24 @@ func FromJson[T IFilter](data interface{}) *T { } return &t } + +func ValueOrPlaceholder(value interface{}) interface{} { + if value == nil { + return "NULL" + } + val, ok := value.(string) + if !ok { + return "?" + } + if strings.Contains(val, ".") { + return value + } + return "?" +} + +func FmtCompare(operator string, a interface{}, b interface{}) (string, Values) { + if ValueOrPlaceholder(b) == "?" { + return fmt.Sprintf("%s %s ?", a, operator), Values{b} + } + return fmt.Sprintf("%s %s %s", a, operator, ValueOrPlaceholder(b)), nil +}