[ref] use sfmt arguments i/of values in find expressions

Signed-off-by: Anton Nesterov <anton@demiurg.io>
This commit is contained in:
Anton Nesterov 2024-08-12 20:21:34 +02:00
parent 0648807cb2
commit 42ab71e964
No known key found for this signature in database
GPG key ID: 59121E8AE2851FB5
27 changed files with 206 additions and 153 deletions

View file

@ -31,11 +31,12 @@ func TestBuilderBasic(t *testing.T) {
t.Fatalf("failed to insert data: %v", err) t.Fatalf("failed to insert data: %v", err)
} }
expr, _ := builder.New(adapter.SQLite{}).In("test t").Find(builder.Find{"name": builder.Is{"$in": []interface{}{"a", 'b'}}}).Sql() expr, values := builder.New(adapter.SQLite{}).In("test t").Find(builder.Find{"name": builder.Is{"$in": []interface{}{"a", 98}}}).Sql()
fmt.Println(expr) fmt.Println(expr)
rows, err := a.Query(adapter.Query{ rows, err := a.Query(adapter.Query{
Db: "file::memory:?cache=shared", Db: "file::memory:?cache=shared",
Expression: expr, Expression: expr,
Data: values,
}) })
if err != nil { if err != nil {
t.Fatalf("failed to query data: %v", err) t.Fatalf("failed to query data: %v", err)

View file

@ -58,7 +58,13 @@ func (c SQLite) GetColumnName(key string) string {
} }
func (c SQLite) NormalizeValue(value interface{}) interface{} { func (c SQLite) NormalizeValue(value interface{}) interface{} {
str, ok := value.(string) str, isStr := value.(string)
if !isStr {
return value
}
if str == "?" {
return str
}
if utils.IsSQLFunction(str) { if utils.IsSQLFunction(str) {
return str return str
} }
@ -68,12 +74,9 @@ func (c SQLite) NormalizeValue(value interface{}) interface{} {
return value return value
} }
} }
if !ok {
return value
}
val, err := utils.EscapeSQL(str) val, err := utils.EscapeSQL(str)
if err != nil { if err != nil {
return str return str
} }
return "'" + utils.EscapeSingleQuote(string(val)) + "'" return string(val)
} }

View file

@ -24,6 +24,7 @@ type SQLParts struct {
OrderExp string OrderExp string
LimitExp string LimitExp string
OffsetExp string OffsetExp string
Values []interface{}
Insert InsertData Insert InsertData
Update UpdateData Update UpdateData
} }
@ -49,7 +50,7 @@ func (b *Builder) In(table string) *Builder {
} }
func (b *Builder) Find(query Find) *Builder { func (b *Builder) Find(query Find) *Builder {
b.Parts.FiterExp = covertFind( b.Parts.FiterExp, b.Parts.Values = covertFind(
b.Dialect, b.Dialect,
query, query,
) )
@ -79,7 +80,9 @@ func (b *Builder) Fields(fields ...Map) *Builder {
} }
func (b *Builder) Join(joins ...interface{}) *Builder { func (b *Builder) Join(joins ...interface{}) *Builder {
b.Parts.JoinExps = convertJoin(b.Dialect, joins...) exps, vals := convertJoin(b.Dialect, joins...)
b.Parts.JoinExps = append(b.Parts.JoinExps, exps...)
b.Parts.Values = append(b.Parts.Values, vals...)
return b return b
} }
@ -178,7 +181,7 @@ func (b *Builder) Sql() (string, []interface{}) {
b.Parts.OrderExp, b.Parts.OrderExp,
b.Parts.LimitExp, b.Parts.LimitExp,
b.Parts.OffsetExp, b.Parts.OffsetExp,
}, " ")), nil }, " ")), b.Parts.Values
case operation == "DELETE": case operation == "DELETE":
return unspace(strings.Join([]string{ return unspace(strings.Join([]string{
b.Parts.Operation, b.Parts.Operation,
@ -189,7 +192,7 @@ func (b *Builder) Sql() (string, []interface{}) {
b.Parts.OrderExp, b.Parts.OrderExp,
b.Parts.LimitExp, b.Parts.LimitExp,
b.Parts.OffsetExp, b.Parts.OffsetExp,
}, " ")), nil }, " ")), b.Parts.Values
case operation == "INSERT INTO": case operation == "INSERT INTO":
return b.Parts.Insert.Statement, b.Parts.Insert.Values return b.Parts.Insert.Statement, b.Parts.Insert.Values
case operation == "UPDATE": case operation == "UPDATE":

View file

@ -1,6 +1,7 @@
package builder package builder
import ( import (
"fmt"
"testing" "testing"
) )
@ -10,7 +11,7 @@ func TestBuilderFind(t *testing.T) {
"field": "value", "field": "value",
"a": 1, "a": 1,
}) })
expect := "SELECT * FROM table t WHERE t.a = 1 AND t.field = 'value'" expect := "SELECT * FROM table t WHERE t.a = ? AND t.field = ?"
result, _ := db.Sql() result, _ := db.Sql()
if result != expect { if result != expect {
t.Errorf(`Expected: "%s", Got: %s`, expect, result) t.Errorf(`Expected: "%s", Got: %s`, expect, result)
@ -28,7 +29,7 @@ func TestBuilderFields(t *testing.T) {
"t.field": "field", "t.field": "field",
"t.a": 1, "t.a": 1,
}) })
expect := "SELECT t.a, t.field AS field FROM table t WHERE t.a = 1 AND t.field = 'value'" expect := "SELECT t.a, t.field AS field FROM table t WHERE t.a = ? AND t.field = ?"
result, _ := db.Sql() result, _ := db.Sql()
if result != expect { if result != expect {
t.Errorf(`Expected: "%s", Got: %s`, expect, result) t.Errorf(`Expected: "%s", Got: %s`, expect, result)
@ -47,8 +48,9 @@ func TestBuilderGroup(t *testing.T) {
"SUM(t.field)": "field", "SUM(t.field)": "field",
}) })
db.Group("field") db.Group("field")
expect := "SELECT SUM(t.field) AS field FROM table t GROUP BY t.field HAVING t.field > 1" expect := "SELECT SUM(t.field) AS field FROM table t GROUP BY t.field HAVING t.field > ?"
result, _ := db.Sql() result, _ := db.Sql()
fmt.Println(db.Parts.Values)
if result != expect { if result != expect {
t.Errorf(`Expected: "%s", Got: %s`, expect, result) t.Errorf(`Expected: "%s", Got: %s`, expect, result)
} }
@ -67,7 +69,7 @@ func TestBuilderJoin(t *testing.T) {
"t2.field": "t.field", "t2.field": "t.field",
}, },
}) })
expect := "SELECT * FROM table t JOIN table2 t2 ON t2.field = t.field WHERE t.a = 1 AND t.field = 'value'" expect := "SELECT * FROM table t JOIN table2 t2 ON t2.field = t.field WHERE t.a = ? AND t.field = ?"
result, _ := db.Sql() result, _ := db.Sql()
if result != expect { if result != expect {
t.Errorf(`Expected: "%s", Got: %s`, expect, result) t.Errorf(`Expected: "%s", Got: %s`, expect, result)

View file

@ -7,33 +7,39 @@ import (
filters "l12.xyz/dal/filters" filters "l12.xyz/dal/filters"
) )
func covertFind(ctx Dialect, find Find) string { type Values = []interface{}
func covertFind(ctx Dialect, find Find) (string, Values) {
return covert_find(ctx, find, "") return covert_find(ctx, find, "")
} }
func covert_find(ctx Dialect, find Find, join string) string { func covert_find(ctx Dialect, find Find, join string) (string, Values) {
if join == "" { if join == "" {
join = " AND " join = " AND "
} }
keys := aggregateSortedKeys([]Map{find}) keys := aggregateSortedKeys([]Map{find})
expressions := []string{} expressions := []string{}
values := Values{}
for _, key := range keys { for _, key := range keys {
value := find[key] value := find[key]
if strings.Contains(key, "$and") { if strings.Contains(key, "$and") {
v := covert_find(ctx, value.(Find), "") exp, vals := covert_find(ctx, value.(Find), "")
expressions = append(expressions, fmt.Sprintf("(%s)", v)) values = append(values, vals...)
expressions = append(expressions, fmt.Sprintf("(%s)", exp))
continue continue
} }
if strings.Contains(key, "$or") { if strings.Contains(key, "$or") {
v := covert_find(ctx, value.(Find), " OR ") exp, vals := covert_find(ctx, value.(Find), " OR ")
expressions = append(expressions, fmt.Sprintf("(%s)", v)) values = append(values, vals...)
expressions = append(expressions, fmt.Sprintf("(%s)", exp))
continue continue
} }
context := ctx.New(DialectOpts{ context := ctx.New(DialectOpts{
"FieldName": key, "FieldName": key,
}) })
values, _ := filters.Convert(context, value) expr, vals := filters.Convert(context, value)
expressions = append(expressions, values) values = append(values, vals...)
expressions = append(expressions, expr)
} }
return strings.Join(expressions, join) return strings.Join(expressions, join), values
} }

View file

@ -1,6 +1,7 @@
package builder package builder
import ( import (
"fmt"
"testing" "testing"
) )
@ -8,17 +9,20 @@ func TestConvertFind(t *testing.T) {
find := Find{ find := Find{
"impl": "1", "impl": "1",
"exp": Is{ "exp": Is{
"$gt": 1, "$gt": 2,
}, },
} }
ctx := SQLiteContext{ ctx := SQLiteContext{
TableAlias: "t", TableAlias: "t",
} }
result := covertFind(ctx, find) result, values := covertFind(ctx, find)
if result == `t.exp > 1 AND t.impl = '1'` { if values[1] != "1" {
return t.Errorf("Expected '1', got %v", values[1])
} }
if result == `t.impl = '1' AND t.exp > 1` { if values[0].(float64) != 2 {
t.Errorf("Expected 2, got %v", values[0])
}
if result == `t.exp > ? AND t.impl = ?` {
return return
} }
t.Errorf(`Expected "t.impl = '1' AND t.exp = 1", got %s`, result) t.Errorf(`Expected "t.impl = '1' AND t.exp = 1", got %s`, result)
@ -38,14 +42,12 @@ func TestConvertFindAnd(t *testing.T) {
ctx := SQLiteContext{ ctx := SQLiteContext{
TableAlias: "t", TableAlias: "t",
} }
result := covertFind(ctx, find) result, values := covertFind(ctx, find)
if result == `(t.a > 1 AND t.b < 10)` { fmt.Println(values)
if result == `(t.a > ? AND t.b < ?)` {
return return
} }
if result == `(t.b < 10 AND t.a > 1)` { t.Errorf(`Expected "(t.b < ? AND t.a > ?)", got %s`, result)
return
}
t.Errorf(`Expected "(t.b < 10 AND t.a > 1)", got %s`, result)
} }
func TestConvertFindOr(t *testing.T) { func TestConvertFindOr(t *testing.T) {
@ -62,12 +64,10 @@ func TestConvertFindOr(t *testing.T) {
ctx := SQLiteContext{ ctx := SQLiteContext{
TableAlias: "t", TableAlias: "t",
} }
result := covertFind(ctx, find) result, values := covertFind(ctx, find)
if result == `(t.a > 1 OR t.b < 10)` { fmt.Println(values)
if result == `(t.a > ? OR t.b < ?)` {
return return
} }
if result == `(t.b < 10 OR t.a > 1)` { t.Errorf(`Expected "(t.b < ? OR t.a > ?)", got %s`, result)
return
}
t.Errorf(`Expected "(t.b < 10 OR t.a > 1)", got %s`, result)
} }

View file

@ -11,27 +11,30 @@ type Join struct {
As string `json:"$as"` As string `json:"$as"`
} }
func (j Join) Convert(ctx Dialect) string { func (j Join) Convert(ctx Dialect) (string, Values) {
if j.For == "" { if j.For == "" {
return "" return "", nil
} }
filter := covertFind(ctx, j.Do) filter, values := covertFind(ctx, j.Do)
var as string = "" var as string = ""
if j.As != "" { if j.As != "" {
as = fmt.Sprintf("%s ", j.As) as = fmt.Sprintf("%s ", j.As)
} }
return as + fmt.Sprintf("JOIN %s ON %s", j.For, filter) return as + fmt.Sprintf("JOIN %s ON %s", j.For, filter), values
} }
func convertJoin(ctx Dialect, joins ...interface{}) []string { func convertJoin(ctx Dialect, joins ...interface{}) ([]string, Values) {
var result []string var result []string
var values Values
for _, join := range joins { for _, join := range joins {
jstr, ok := join.(string) jstr, ok := join.(string)
if ok { if ok {
jjson := Join{} jjson := Join{}
err := json.Unmarshal([]byte(jstr), &jjson) err := json.Unmarshal([]byte(jstr), &jjson)
if err == nil { if err == nil {
result = append(result, jjson.Convert(ctx)) r, vals := jjson.Convert(ctx)
result = append(result, r)
values = append(values, vals...)
} }
continue continue
} }
@ -44,7 +47,9 @@ func convertJoin(ctx Dialect, joins ...interface{}) []string {
} }
err = json.Unmarshal(jstr, &jjson) err = json.Unmarshal(jstr, &jjson)
if err == nil { if err == nil {
result = append(result, jjson.Convert(ctx)) r, vals := jjson.Convert(ctx)
result = append(result, r)
values = append(values, vals...)
} }
continue continue
} }
@ -52,7 +57,9 @@ func convertJoin(ctx Dialect, joins ...interface{}) []string {
if !ok { if !ok {
continue continue
} }
result = append(result, j.Convert(ctx)) r, vals := j.Convert(ctx)
result = append(result, r)
values = append(values, vals...)
} }
return result return result, values
} }

View file

@ -1,6 +1,7 @@
package builder package builder
import ( import (
"fmt"
"testing" "testing"
adapter "l12.xyz/dal/adapter" adapter "l12.xyz/dal/adapter"
@ -19,7 +20,8 @@ func TestJoin(t *testing.T) {
ctx := SQLiteContext{ ctx := SQLiteContext{
TableAlias: "t", TableAlias: "t",
} }
result := j.Convert(ctx) result, vals := j.Convert(ctx)
fmt.Println("Join:", vals)
if result == `LEFT JOIN artist a ON a.impl = t.impl` { if result == `LEFT JOIN artist a ON a.impl = t.impl` {
return return
} }
@ -39,7 +41,8 @@ func TestConvertJoin(t *testing.T) {
ctx := SQLiteContext{ ctx := SQLiteContext{
TableAlias: "t", TableAlias: "t",
} }
result := convertJoin(ctx, joins...) result, vals := convertJoin(ctx, joins...)
fmt.Println("Join:", vals)
if result[1] != `JOIN artist a ON a.impl = t.impl` { if result[1] != `JOIN artist a ON a.impl = t.impl` {
t.Errorf(`Expected "JOIN artist a ON a.impl = t.impl", got %s`, result[1]) t.Errorf(`Expected "JOIN artist a ON a.impl = t.impl", got %s`, result[1])
} }
@ -56,7 +59,8 @@ func TestConvertMap(t *testing.T) {
ctx := SQLiteContext{ ctx := SQLiteContext{
TableAlias: "t", TableAlias: "t",
} }
result := convertJoin(ctx, joins...) result, vals := convertJoin(ctx, joins...)
fmt.Println("Join:", vals)
if result[0] != `LEFT JOIN artist a ON a.impl = t.impl` { if result[0] != `LEFT JOIN artist a ON a.impl = t.impl` {
t.Errorf(`Expected "LEFT JOIN artist a ON a.impl = t.impl", got %s`, result[0]) t.Errorf(`Expected "LEFT JOIN artist a ON a.impl = t.impl", got %s`, result[0])
} }

View file

@ -9,12 +9,12 @@ type And struct {
And []string `json:"$and"` And []string `json:"$and"`
} }
func (f And) ToSQLPart(ctx Dialect) string { func (f And) ToSQLPart(ctx Dialect) (string, Values) {
if f.And == nil { if f.And == nil {
return "" return "", nil
} }
value := strings.Join(f.And, " AND ") value := strings.Join(f.And, " AND ")
return fmt.Sprintf("(%s)", value) return fmt.Sprintf("(%s)", value), nil
} }
func (a And) FromJSON(data interface{}) IFilter { func (a And) FromJSON(data interface{}) IFilter {

View file

@ -14,12 +14,13 @@ func (f Between) FromJSON(data interface{}) IFilter {
return FromJson[Between](data) return FromJson[Between](data)
} }
func (f Between) ToSQLPart(ctx Dialect) string { func (f Between) ToSQLPart(ctx Dialect) (string, Values) {
if f.Between == nil { if f.Between == nil {
return "" return "", nil
} }
name := ctx.GetFieldName() name := ctx.GetFieldName()
values := utils.Map(f.Between, ctx.NormalizeValue) values := utils.Map(f.Between, ctx.NormalizeValue)
condition := fmt.Sprintf("%v AND %v", values[0], values[1]) placeholders := utils.Map(values, ValueOrPlaceholder)
return fmt.Sprintf("%s BETWEEN %v", name, condition) condition := fmt.Sprintf("%s AND %s", placeholders[0], placeholders[1])
return fmt.Sprintf("%s BETWEEN %v", name, condition), values
} }

View file

@ -12,14 +12,14 @@ func (f Eq) FromJSON(data interface{}) IFilter {
return FromJson[Eq](data) return FromJson[Eq](data)
} }
func (f Eq) ToSQLPart(ctx Dialect) string { func (f Eq) ToSQLPart(ctx Dialect) (string, Values) {
if f.Eq == nil { if f.Eq == nil {
return "" return "", nil
} }
name := ctx.GetFieldName() name := ctx.GetFieldName()
value := ctx.NormalizeValue(f.Eq) value := ctx.NormalizeValue(f.Eq)
if value == "NULL" { if value == "NULL" {
return fmt.Sprintf("%s IS NULL", name) return fmt.Sprintf("%s IS NULL", ValueOrPlaceholder(name)), Values{name}
} }
return fmt.Sprintf("%s = %v", name, value) return FmtCompare("=", name, value)
} }

View file

@ -10,11 +10,11 @@ func (f Glob) FromJSON(data interface{}) IFilter {
return FromJson[Glob](data) return FromJson[Glob](data)
} }
func (f Glob) ToSQLPart(ctx Dialect) string { func (f Glob) ToSQLPart(ctx Dialect) (string, Values) {
if f.Glob == nil { if f.Glob == nil {
return "" return "", nil
} }
name := ctx.GetFieldName() name := ctx.GetFieldName()
value := ctx.NormalizeValue(f.Glob) value := ctx.NormalizeValue(f.Glob)
return fmt.Sprintf("%s GLOB %v", name, value) return fmt.Sprintf("%s GLOB ?", name), Values{value}
} }

View file

@ -1,9 +1,5 @@
package filters package filters
import (
"fmt"
)
type Gt struct { type Gt struct {
Gt interface{} `json:"$gt"` Gt interface{} `json:"$gt"`
} }
@ -12,11 +8,11 @@ func (f Gt) FromJSON(data interface{}) IFilter {
return FromJson[Gt](data) return FromJson[Gt](data)
} }
func (f Gt) ToSQLPart(ctx Dialect) string { func (f Gt) ToSQLPart(ctx Dialect) (string, Values) {
if f.Gt == nil { if f.Gt == nil {
return "" return "", nil
} }
name := ctx.GetFieldName() name := ctx.GetFieldName()
value := ctx.NormalizeValue(f.Gt) value := ctx.NormalizeValue(f.Gt)
return fmt.Sprintf("%s > %v", name, value) return FmtCompare(">", name, value)
} }

View file

@ -1,9 +1,5 @@
package filters package filters
import (
"fmt"
)
type Gte struct { type Gte struct {
Gte interface{} `json:"$gte"` Gte interface{} `json:"$gte"`
} }
@ -12,11 +8,11 @@ func (f Gte) FromJSON(data interface{}) IFilter {
return FromJson[Gte](data) return FromJson[Gte](data)
} }
func (f Gte) ToSQLPart(ctx Dialect) string { func (f Gte) ToSQLPart(ctx Dialect) (string, Values) {
if f.Gte == nil { if f.Gte == nil {
return "" return "", nil
} }
name := ctx.GetFieldName() name := ctx.GetFieldName()
value := ctx.NormalizeValue(f.Gte) value := ctx.NormalizeValue(f.Gte)
return fmt.Sprintf("%s >= %v", name, value) return FmtCompare(">=", name, value)
} }

View file

@ -15,17 +15,22 @@ func (f In) FromJSON(data interface{}) IFilter {
return FromJson[In](data) return FromJson[In](data)
} }
func (f In) ToSQLPart(ctx Dialect) string { func (f In) ToSQLPart(ctx Dialect) (string, Values) {
if f.In == nil { if f.In == nil {
return "" return "", nil
} }
name := ctx.GetFieldName() name := ctx.GetFieldName()
values := utils.Map(f.In, ctx.NormalizeValue) values := utils.Map(f.In, ctx.NormalizeValue)
returnValues := make(Values, 0)
data := make([]string, len(values)) data := make([]string, len(values))
for i, v := range values { for i, value := range values {
data[i] = fmt.Sprintf("%v", v) val := ValueOrPlaceholder(value).(string)
data[i] = val
if val == "?" {
returnValues = append(returnValues, value)
}
} }
value := strings.Join(data, ", ") value := strings.Join(data, ", ")
return fmt.Sprintf("%s IN (%v)", name, value) return fmt.Sprintf("%s IN (%v)", name, value), returnValues
} }

View file

@ -10,11 +10,11 @@ func (f Like) FromJSON(data interface{}) IFilter {
return FromJson[Like](data) return FromJson[Like](data)
} }
func (f Like) ToSQLPart(ctx Dialect) string { func (f Like) ToSQLPart(ctx Dialect) (string, Values) {
if f.Like == nil { if f.Like == nil {
return "" return "", nil
} }
name := ctx.GetFieldName() name := ctx.GetFieldName()
value := ctx.NormalizeValue(f.Like) value := ctx.NormalizeValue(f.Like)
return fmt.Sprintf("%s LIKE %v ESCAPE '\\'", name, value) return fmt.Sprintf("%s LIKE ? ESCAPE '\\'", name), Values{value}
} }

View file

@ -1,9 +1,5 @@
package filters package filters
import (
"fmt"
)
type Lt struct { type Lt struct {
Lt interface{} `json:"$lt"` Lt interface{} `json:"$lt"`
} }
@ -12,11 +8,11 @@ func (f Lt) FromJSON(data interface{}) IFilter {
return FromJson[Lt](data) return FromJson[Lt](data)
} }
func (f Lt) ToSQLPart(ctx Dialect) string { func (f Lt) ToSQLPart(ctx Dialect) (string, Values) {
if f.Lt == nil { if f.Lt == nil {
return "" return "", nil
} }
name := ctx.GetFieldName() name := ctx.GetFieldName()
value := ctx.NormalizeValue(f.Lt) value := ctx.NormalizeValue(f.Lt)
return fmt.Sprintf("%s < %v", name, value) return FmtCompare("<", name, value)
} }

View file

@ -1,9 +1,5 @@
package filters package filters
import (
"fmt"
)
type Lte struct { type Lte struct {
Lte interface{} `json:"$lte"` Lte interface{} `json:"$lte"`
} }
@ -12,11 +8,11 @@ func (f Lte) FromJSON(data interface{}) IFilter {
return FromJson[Lte](data) return FromJson[Lte](data)
} }
func (f Lte) ToSQLPart(ctx Dialect) string { func (f Lte) ToSQLPart(ctx Dialect) (string, Values) {
if f.Lte == nil { if f.Lte == nil {
return "" return "", nil
} }
name := ctx.GetFieldName() name := ctx.GetFieldName()
value := ctx.NormalizeValue(f.Lte) value := ctx.NormalizeValue(f.Lte)
return fmt.Sprintf("%s <= %v", name, value) return FmtCompare("<=", name, value)
} }

View file

@ -1,7 +1,5 @@
package filters package filters
import "fmt"
type Ne struct { type Ne struct {
Ne interface{} `json:"$ne"` Ne interface{} `json:"$ne"`
} }
@ -10,14 +8,11 @@ func (f Ne) FromJSON(data interface{}) IFilter {
return FromJson[Ne](data) return FromJson[Ne](data)
} }
func (f Ne) ToSQLPart(ctx Dialect) string { func (f Ne) ToSQLPart(ctx Dialect) (string, Values) {
if f.Ne == nil { if f.Ne == nil {
return "" return "", nil
} }
name := ctx.GetFieldName() name := ctx.GetFieldName()
value := ctx.NormalizeValue(f.Ne) value := ctx.NormalizeValue(f.Ne)
if value == "NULL" { return FmtCompare("!=", name, value)
return fmt.Sprintf("%s IS NOT NULL", name)
}
return fmt.Sprintf("%s != %v", name, value)
} }

View file

@ -14,12 +14,13 @@ func (f NotBetween) FromJSON(data interface{}) IFilter {
return FromJson[NotBetween](data) return FromJson[NotBetween](data)
} }
func (f NotBetween) ToSQLPart(ctx Dialect) string { func (f NotBetween) ToSQLPart(ctx Dialect) (string, Values) {
if f.NotBetween == nil { if f.NotBetween == nil {
return "" return "", nil
} }
name := ctx.GetFieldName() name := ctx.GetFieldName()
values := utils.Map(f.NotBetween, ctx.NormalizeValue) values := utils.Map(f.NotBetween, ctx.NormalizeValue)
condition := fmt.Sprintf("%v AND %v", values[0], values[1]) placeholders := utils.Map(values, ValueOrPlaceholder)
return fmt.Sprintf("%s NOT BETWEEN %v", name, condition) condition := fmt.Sprintf("%s AND %s", placeholders[0], placeholders[1])
return fmt.Sprintf("%s NOT BETWEEN %v", name, condition), values
} }

View file

@ -15,17 +15,22 @@ func (f NotIn) FromJSON(data interface{}) IFilter {
return FromJson[NotIn](data) return FromJson[NotIn](data)
} }
func (f NotIn) ToSQLPart(ctx Dialect) string { func (f NotIn) ToSQLPart(ctx Dialect) (string, Values) {
if f.NotIn == nil { if f.NotIn == nil {
return "" return "", nil
} }
name := ctx.GetFieldName() name := ctx.GetFieldName()
values := utils.Map(f.NotIn, ctx.NormalizeValue) values := utils.Map(f.NotIn, ctx.NormalizeValue)
returnValues := make(Values, 0)
data := make([]string, len(values)) data := make([]string, len(values))
for i, v := range values { for i, value := range values {
data[i] = fmt.Sprintf("%v", v) val := ValueOrPlaceholder(value).(string)
data[i] = val
if val == "?" {
returnValues = append(returnValues, value)
}
} }
value := strings.Join(data, ", ") value := strings.Join(data, ", ")
return fmt.Sprintf("%s NOT IN (%v)", name, value) return fmt.Sprintf("%s NOT IN (%v)", name, value), returnValues
} }

View file

@ -10,11 +10,11 @@ func (f NotLike) FromJSON(data interface{}) IFilter {
return FromJson[NotLike](data) return FromJson[NotLike](data)
} }
func (f NotLike) ToSQLPart(ctx Dialect) string { func (f NotLike) ToSQLPart(ctx Dialect) (string, Values) {
if f.NotLike == nil { if f.NotLike == nil {
return "" return "", nil
} }
name := ctx.GetFieldName() name := ctx.GetFieldName()
value := ctx.NormalizeValue(f.NotLike) value := ctx.NormalizeValue(f.NotLike)
return fmt.Sprintf("%s NOT LIKE %v ESCAPE '\\'", name, value) return fmt.Sprintf("%s NOT LIKE ? ESCAPE '\\'", name), Values{value}
} }

View file

@ -9,12 +9,12 @@ type Or struct {
Or []string `json:"$or"` Or []string `json:"$or"`
} }
func (f Or) ToSQLPart(ctx Dialect) string { func (f Or) ToSQLPart(ctx Dialect) (string, Values) {
if f.Or == nil { if f.Or == nil {
return "" return "", nil
} }
value := strings.Join(f.Or, " OR ") value := strings.Join(f.Or, " OR ")
return fmt.Sprintf("(%s)", value) return fmt.Sprintf("(%s)", value), nil
} }
func (a Or) FromJSON(data interface{}) IFilter { func (a Or) FromJSON(data interface{}) IFilter {

View file

@ -22,17 +22,16 @@ var FilterRegistry = map[string]IFilter{
"NotLike": &NotLike{}, "NotLike": &NotLike{},
} }
func Convert(ctx Dialect, data interface{}) (string, error) { func Convert(ctx Dialect, data interface{}) (string, []interface{}) {
for _, impl := range FilterRegistry { for _, impl := range FilterRegistry {
filter := impl.FromJSON(data) filter := impl.FromJSON(data)
if reflect.DeepEqual(impl, filter) { if reflect.DeepEqual(impl, filter) {
continue continue
} }
value := filter.ToSQLPart(ctx) sfmt, values := filter.ToSQLPart(ctx)
if value != "" { if sfmt != "" {
return value, nil return sfmt, values
} }
} }
value := Eq{Eq: data}.ToSQLPart(ctx) return Eq{Eq: data}.ToSQLPart(ctx)
return value, nil
} }

View file

@ -4,9 +4,9 @@ import "l12.xyz/dal/adapter"
type DialectOpts = adapter.DialectOpts type DialectOpts = adapter.DialectOpts
type Dialect = adapter.Dialect type Dialect = adapter.Dialect
type Values = []interface{}
type IFilter interface { type IFilter interface {
ToSQLPart(ctx Dialect) string ToSQLPart(ctx Dialect) (string, Values)
FromJSON(interface{}) IFilter FromJSON(interface{}) IFilter
} }

View file

@ -1,6 +1,7 @@
package filters package filters
import ( import (
"fmt"
"testing" "testing"
adapter "l12.xyz/dal/adapter" adapter "l12.xyz/dal/adapter"
@ -29,10 +30,13 @@ func TestGte(t *testing.T) {
TableAlias: "t", TableAlias: "t",
FieldName: "test", FieldName: "test",
} }
result, _ := Convert(ctx, `{"$gte": 1}`) result, vals := Convert(ctx, `{"$gte": 1}`)
resultMap, _ := Convert(ctx, Filter{"$gte": 1}) resultMap, _ := Convert(ctx, Filter{"$gte": 1})
if result != `t.test >= 1` { if vals[0].(float64) != 1 {
t.Errorf("Expected t.test >= 1, got %s", result) t.Errorf("Expected 1, got %v", vals[0])
}
if result != `t.test >= ?` {
t.Errorf("Expected t.test >= ?, got %s", result)
} }
if resultMap != result { if resultMap != result {
t.Log(resultMap) t.Log(resultMap)
@ -46,8 +50,8 @@ func TestNe(t *testing.T) {
} }
result, _ := Convert(ctx, `{"$ne": "1"}`) result, _ := Convert(ctx, `{"$ne": "1"}`)
resultMap, _ := Convert(ctx, Filter{"$ne": "1"}) resultMap, _ := Convert(ctx, Filter{"$ne": "1"})
if result != `test != '1'` { if result != `test != ?` {
t.Errorf("Expected test != '1', got %s", result) t.Errorf("Expected test != ?, got %s", result)
} }
if resultMap != result { if resultMap != result {
t.Log(resultMap) t.Log(resultMap)
@ -59,10 +63,11 @@ func TestBetween(t *testing.T) {
ctx := SQLiteContext{ ctx := SQLiteContext{
FieldName: "test", FieldName: "test",
} }
result, _ := Convert(ctx, `{"$between": ["1", "5"]}`) result, vals := Convert(ctx, `{"$between": ["1", "5"]}`)
fmt.Println(vals)
resultMap, _ := Convert(ctx, Filter{"$between": []string{"1", "5"}}) resultMap, _ := Convert(ctx, Filter{"$between": []string{"1", "5"}})
if result != `test BETWEEN '1' AND '5'` { if result != `test BETWEEN ? AND ?` {
t.Errorf("Expected test BETWEEN '1' AND '5', got %s", result) t.Errorf("Expected test BETWEEN ? AND ?, got %s", result)
} }
if resultMap != result { if resultMap != result {
t.Log(resultMap) t.Log(resultMap)
@ -76,8 +81,8 @@ func TestNotBetween(t *testing.T) {
} }
result, _ := Convert(ctx, `{"$nbetween": ["1", "5"]}`) result, _ := Convert(ctx, `{"$nbetween": ["1", "5"]}`)
resultMap, _ := Convert(ctx, Filter{"$nbetween": []string{"1", "5"}}) resultMap, _ := Convert(ctx, Filter{"$nbetween": []string{"1", "5"}})
if result != `test NOT BETWEEN '1' AND '5'` { if result != `test NOT BETWEEN ? AND ?` {
t.Errorf("Expected test BETWEEN '1' AND '5', got %s", result) t.Errorf("Expected test NOT BETWEEN ? AND ?, got %s", result)
} }
if resultMap != result { if resultMap != result {
t.Log(resultMap) t.Log(resultMap)
@ -90,9 +95,12 @@ func TestGlob(t *testing.T) {
TableAlias: "t", TableAlias: "t",
FieldName: "test", FieldName: "test",
} }
result, _ := Convert(ctx, `{"$glob": "*son"}`) result, vals := Convert(ctx, `{"$glob": "*son"}`)
resultMap, _ := Convert(ctx, Filter{"$glob": "*son"}) resultMap, _ := Convert(ctx, Filter{"$glob": "*son"})
if result != `t.test GLOB '*son'` { if vals[0].(string) != "*son" {
t.Errorf("Expected *son, got %v", vals[0])
}
if result != `t.test GLOB ?` {
t.Errorf("Expected t.test GLOB '*son', got %s", result) t.Errorf("Expected t.test GLOB '*son', got %s", result)
} }
if resultMap != result { if resultMap != result {
@ -108,8 +116,8 @@ func TestIn(t *testing.T) {
} }
result, _ := Convert(ctx, `{"$in": [1, 2, 3]}`) result, _ := Convert(ctx, `{"$in": [1, 2, 3]}`)
resultMap, _ := Convert(ctx, Filter{"$in": []int{1, 2, 3}}) resultMap, _ := Convert(ctx, Filter{"$in": []int{1, 2, 3}})
if result != `t.test IN (1, 2, 3)` { if result != `t.test IN (?, ?, ?)` {
t.Errorf("Expected t.test IN (1, 2, 3), got %s", result) t.Errorf("Expected t.test IN (?, ?, ?), got %s", result)
} }
if resultMap != result { if resultMap != result {
t.Log(resultMap) t.Log(resultMap)
@ -122,10 +130,13 @@ func TestNotIn(t *testing.T) {
TableAlias: "t", TableAlias: "t",
FieldName: "test", FieldName: "test",
} }
result, _ := Convert(ctx, `{"$nin": [1, 2, 3]}`) result, vals := Convert(ctx, `{"$nin": [1, 2, 3]}`)
resultMap, _ := Convert(ctx, Filter{"$nin": []int{1, 2, 3}}) resultMap, _ := Convert(ctx, Filter{"$nin": []int{1, 2, 3}})
if result != `t.test NOT IN (1, 2, 3)` { if vals[1].(float64) != 2 {
t.Errorf("Expected t.test NOT IN (1, 2, 3), got %s", result) t.Errorf("Expected 1, got %v", vals[1])
}
if result != `t.test NOT IN (?, ?, ?)` {
t.Errorf("Expected t.test NOT IN (?, ?, ?), got %s", result)
} }
if resultMap != result { if resultMap != result {
t.Log(resultMap) t.Log(resultMap)
@ -138,10 +149,13 @@ func TestLike(t *testing.T) {
TableAlias: "t", TableAlias: "t",
FieldName: "test", FieldName: "test",
} }
result, _ := Convert(ctx, `{"$like": "199_"}`) result, vals := Convert(ctx, `{"$like": "199_"}`)
resultMap, _ := Convert(ctx, Filter{"$like": "199_"}) resultMap, _ := Convert(ctx, Filter{"$like": "199_"})
if result != `t.test LIKE '199_' ESCAPE '\'` { if vals[0].(string) != "199_" {
t.Errorf("Expected t.test LIKE '199_' ESCAPE '\\', got %s", result) t.Errorf("Expected 199_, got %v", vals[0])
}
if result != `t.test LIKE ? ESCAPE '\'` {
t.Errorf("Expected t.test LIKE ? ESCAPE '\\', got %s", result)
} }
if resultMap != result { if resultMap != result {
t.Log(resultMap) t.Log(resultMap)

View file

@ -2,6 +2,8 @@ package filters
import ( import (
"encoding/json" "encoding/json"
"fmt"
"strings"
) )
func FromJson[T IFilter](data interface{}) *T { func FromJson[T IFilter](data interface{}) *T {
@ -27,3 +29,24 @@ func FromJson[T IFilter](data interface{}) *T {
} }
return &t return &t
} }
func ValueOrPlaceholder(value interface{}) interface{} {
if value == nil {
return "NULL"
}
val, ok := value.(string)
if !ok {
return "?"
}
if strings.Contains(val, ".") {
return value
}
return "?"
}
func FmtCompare(operator string, a interface{}, b interface{}) (string, Values) {
if ValueOrPlaceholder(b) == "?" {
return fmt.Sprintf("%s %s ?", a, operator), Values{b}
}
return fmt.Sprintf("%s %s %s", a, operator, ValueOrPlaceholder(b)), nil
}