commit d28d976b8eac589f2662cbe4f21ce69fd1a0c989 Author: Anton Nesterov Date: Wed Aug 7 21:16:40 2024 +0200 [wip] dal golang Signed-off-by: Anton Nesterov diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..1f74aa8 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module l12.xyz/dal + +go 1.22.6 diff --git a/pkg/dal/builder.go b/pkg/dal/builder.go new file mode 100644 index 0000000..273d475 --- /dev/null +++ b/pkg/dal/builder.go @@ -0,0 +1,37 @@ +package dal + +type SQLParts struct { + operation string + selectExp string + fromExp string + fiterExp string + joinExp []string + groupExp string + orderExp string + limitExp string + updateExp string + upsertExp string +} + +type Builder struct { + parts SQLParts +} + +func New() *Builder { + return &Builder{} +} + +func (b *Builder) In(selectExp string) *Builder { + b.parts.selectExp = selectExp + return b +} + +func (b *Builder) Find(fromExp string) *Builder { + b.parts.fromExp = fromExp + return b +} + +func (b *Builder) Join(fiterExp string) *Builder { + b.parts.fiterExp = fiterExp + return b +} diff --git a/pkg/dal/convert.go b/pkg/dal/convert.go new file mode 100644 index 0000000..69794f1 --- /dev/null +++ b/pkg/dal/convert.go @@ -0,0 +1,7 @@ +package dal + +type FindObject map[string]interface{} + +func CovertFind(findobj FindObject) string { + return "" +} diff --git a/pkg/filters/And.go b/pkg/filters/And.go new file mode 100644 index 0000000..cde9d11 --- /dev/null +++ b/pkg/filters/And.go @@ -0,0 +1,19 @@ +package filters + +import ( + "fmt" +) + +type And struct { + And []interface{} `json:"$and"` +} + +func (a And) ToSQLPart(ctx Context) string { + + fmt.Println(ctx, a) + return "" +} + +func (a And) FromJSON(data interface{}) Filter { + return FromJson[And](data) +} diff --git a/pkg/filters/Between.go b/pkg/filters/Between.go new file mode 100644 index 0000000..6b53be0 --- /dev/null +++ b/pkg/filters/Between.go @@ -0,0 +1,25 @@ +package filters + +import ( + "fmt" + + "l12.xyz/dal/utils" +) + +type Between struct { + Between []interface{} `json:"$between"` +} + +func (f Between) FromJSON(data interface{}) Filter { + return FromJson[Between](data) +} + +func (f Between) ToSQLPart(ctx Context) string { + if f.Between == nil { + return "" + } + name := ctx.GetFieldName() + values := utils.Map(f.Between, ctx.NormalizeValue) + condition := fmt.Sprintf("%v AND %v", values[0], values[1]) + return fmt.Sprintf("%s BETWEEN %v", name, condition) +} diff --git a/pkg/filters/Eq.go b/pkg/filters/Eq.go new file mode 100644 index 0000000..57436b4 --- /dev/null +++ b/pkg/filters/Eq.go @@ -0,0 +1,25 @@ +package filters + +import ( + "fmt" +) + +type Eq struct { + Eq interface{} `json:"$eq"` +} + +func (f Eq) FromJSON(data interface{}) Filter { + return FromJson[Eq](data) +} + +func (f Eq) ToSQLPart(ctx Context) string { + if f.Eq == nil { + return "" + } + name := ctx.GetFieldName() + value := ctx.NormalizeValue(f.Eq) + if value == "NULL" { + return fmt.Sprintf("%s IS NULL", name) + } + return fmt.Sprintf("%s = %v", name, value) +} diff --git a/pkg/filters/Glob.go b/pkg/filters/Glob.go new file mode 100644 index 0000000..fdca1da --- /dev/null +++ b/pkg/filters/Glob.go @@ -0,0 +1,20 @@ +package filters + +import "fmt" + +type Glob struct { + Glob interface{} `json:"$glob"` +} + +func (f Glob) FromJSON(data interface{}) Filter { + return FromJson[Glob](data) +} + +func (f Glob) ToSQLPart(ctx Context) string { + if f.Glob == nil { + return "" + } + name := ctx.GetFieldName() + value := ctx.NormalizeValue(f.Glob) + return fmt.Sprintf("%s GLOB %v", name, value) +} diff --git a/pkg/filters/Gt.go b/pkg/filters/Gt.go new file mode 100644 index 0000000..eb003d9 --- /dev/null +++ b/pkg/filters/Gt.go @@ -0,0 +1,22 @@ +package filters + +import ( + "fmt" +) + +type Gt struct { + Gt interface{} `json:"$gt"` +} + +func (f Gt) FromJSON(data interface{}) Filter { + return FromJson[Gt](data) +} + +func (f Gt) ToSQLPart(ctx Context) string { + if f.Gt == nil { + return "" + } + name := ctx.GetFieldName() + value := ctx.NormalizeValue(f.Gt) + return fmt.Sprintf("%s > %v", name, value) +} diff --git a/pkg/filters/Gte.go b/pkg/filters/Gte.go new file mode 100644 index 0000000..11a8ff2 --- /dev/null +++ b/pkg/filters/Gte.go @@ -0,0 +1,22 @@ +package filters + +import ( + "fmt" +) + +type Gte struct { + Gte interface{} `json:"$gte"` +} + +func (f Gte) FromJSON(data interface{}) Filter { + return FromJson[Gte](data) +} + +func (f Gte) ToSQLPart(ctx Context) string { + if f.Gte == nil { + return "" + } + name := ctx.GetFieldName() + value := ctx.NormalizeValue(f.Gte) + return fmt.Sprintf("%s >= %v", name, value) +} diff --git a/pkg/filters/In.go b/pkg/filters/In.go new file mode 100644 index 0000000..139ac08 --- /dev/null +++ b/pkg/filters/In.go @@ -0,0 +1,31 @@ +package filters + +import ( + "fmt" + "strings" + + "l12.xyz/dal/utils" +) + +type In struct { + In []interface{} `json:"$in"` +} + +func (f In) FromJSON(data interface{}) Filter { + return FromJson[In](data) +} + +func (f In) ToSQLPart(ctx Context) string { + if f.In == nil { + return "" + } + + name := ctx.GetFieldName() + values := utils.Map(f.In, ctx.NormalizeValue) + data := make([]string, len(values)) + for i, v := range values { + data[i] = fmt.Sprintf("%v", v) + } + value := strings.Join(data, ", ") + return fmt.Sprintf("%s IN (%v)", name, value) +} diff --git a/pkg/filters/Lt.go b/pkg/filters/Lt.go new file mode 100644 index 0000000..71b8503 --- /dev/null +++ b/pkg/filters/Lt.go @@ -0,0 +1,22 @@ +package filters + +import ( + "fmt" +) + +type Lt struct { + Lt interface{} `json:"$lt"` +} + +func (f Lt) FromJSON(data interface{}) Filter { + return FromJson[Lt](data) +} + +func (f Lt) ToSQLPart(ctx Context) string { + if f.Lt == nil { + return "" + } + name := ctx.GetFieldName() + value := ctx.NormalizeValue(f.Lt) + return fmt.Sprintf("%s < %v", name, value) +} diff --git a/pkg/filters/Lte.go b/pkg/filters/Lte.go new file mode 100644 index 0000000..62557a6 --- /dev/null +++ b/pkg/filters/Lte.go @@ -0,0 +1,22 @@ +package filters + +import ( + "fmt" +) + +type Lte struct { + Lte interface{} `json:"$lte"` +} + +func (f Lte) FromJSON(data interface{}) Filter { + return FromJson[Lte](data) +} + +func (f Lte) ToSQLPart(ctx Context) string { + if f.Lte == nil { + return "" + } + name := ctx.GetFieldName() + value := ctx.NormalizeValue(f.Lte) + return fmt.Sprintf("%s <= %v", name, value) +} diff --git a/pkg/filters/Ne.go b/pkg/filters/Ne.go new file mode 100644 index 0000000..b32a6c2 --- /dev/null +++ b/pkg/filters/Ne.go @@ -0,0 +1,23 @@ +package filters + +import "fmt" + +type Ne struct { + Ne interface{} `json:"$ne"` +} + +func (f Ne) FromJSON(data interface{}) Filter { + return FromJson[Ne](data) +} + +func (f Ne) ToSQLPart(ctx Context) string { + if f.Ne == nil { + return "" + } + name := ctx.GetFieldName() + value := ctx.NormalizeValue(f.Ne) + if value == "NULL" { + return fmt.Sprintf("%s IS NOT NULL", name) + } + return fmt.Sprintf("%s != %v", name, value) +} diff --git a/pkg/filters/NotBetween.go b/pkg/filters/NotBetween.go new file mode 100644 index 0000000..2a2eb7b --- /dev/null +++ b/pkg/filters/NotBetween.go @@ -0,0 +1,25 @@ +package filters + +import ( + "fmt" + + "l12.xyz/dal/utils" +) + +type NotBetween struct { + NotBetween []interface{} `json:"$nbetween"` +} + +func (f NotBetween) FromJSON(data interface{}) Filter { + return FromJson[NotBetween](data) +} + +func (f NotBetween) ToSQLPart(ctx Context) string { + if f.NotBetween == nil { + return "" + } + name := ctx.GetFieldName() + values := utils.Map(f.NotBetween, ctx.NormalizeValue) + condition := fmt.Sprintf("%v AND %v", values[0], values[1]) + return fmt.Sprintf("%s NOT BETWEEN %v", name, condition) +} diff --git a/pkg/filters/context.go b/pkg/filters/context.go new file mode 100644 index 0000000..cc1db54 --- /dev/null +++ b/pkg/filters/context.go @@ -0,0 +1,67 @@ +package filters + +import ( + "slices" + "strconv" + "strings" + "unicode" + + utils "l12.xyz/dal/utils" +) + +type SQLiteContext struct { + TableAlias string + FieldName string +} + +func (c SQLiteContext) GetFieldName() string { + if strings.Contains(c.FieldName, ".") { + return c.FieldName + } + if c.TableAlias != "" { + return c.TableAlias + "." + c.FieldName + } + return c.FieldName +} + +func (c SQLiteContext) NormalizeValue(value interface{}) interface{} { + str, ok := value.(string) + if isSQLFunction(str) { + return str + } + if strings.Contains(str, ".") { + _, err := strconv.ParseFloat(str, 64) + if err != nil { + return value + } + } + if !ok { + return value + } + val, err := utils.EscapeSQL(str) + if err != nil { + return str + } + return "'" + escapeSingleQuote(string(val)) + "'" +} + +func isSQLFunction(str string) bool { + stopChars := []string{" ", "_", "-", ".", "("} + isUpper := false + for _, char := range str { + if slices.Contains(stopChars, string(char)) { + break + } + if unicode.IsUpper(char) { + isUpper = true + } else { + isUpper = false + break + } + } + return isUpper +} + +func escapeSingleQuote(str string) string { + return strings.ReplaceAll(str, "'", "''") +} diff --git a/pkg/filters/go.mod b/pkg/filters/go.mod new file mode 100644 index 0000000..2aeb161 --- /dev/null +++ b/pkg/filters/go.mod @@ -0,0 +1,9 @@ +module l12.xyz/dal/filters + +go 1.22.6 + +require github.com/pkg/errors v0.9.1 // indirect + +require l12.xyz/dal/utils v0.0.0 + +replace l12.xyz/dal/utils v0.0.0 => ../utils diff --git a/pkg/filters/go.sum b/pkg/filters/go.sum new file mode 100644 index 0000000..7c401c3 --- /dev/null +++ b/pkg/filters/go.sum @@ -0,0 +1,2 @@ +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= diff --git a/pkg/filters/registry.go b/pkg/filters/registry.go new file mode 100644 index 0000000..b64981c --- /dev/null +++ b/pkg/filters/registry.go @@ -0,0 +1,24 @@ +package filters + +var FilterRegistry = map[string]Filter{ + "Eq": &Eq{}, + "Ne": &Ne{}, + "Gt": &Gt{}, + "Gte": &Gte{}, + "Lt": &Lt{}, + "Lte": &Lte{}, + "In": &In{}, + "Between": &Between{}, + "NotBetween": &NotBetween{}, + "Glob": &Glob{}, +} + +func Convert(ctx Context, json interface{}) string { + for _, t := range FilterRegistry { + value := t.FromJSON(json).ToSQLPart(ctx) + if value != "" { + return value + } + } + return "" +} diff --git a/pkg/filters/types.go b/pkg/filters/types.go new file mode 100644 index 0000000..c20fae2 --- /dev/null +++ b/pkg/filters/types.go @@ -0,0 +1,11 @@ +package filters + +type Context interface { + GetFieldName() string + NormalizeValue(interface{}) interface{} +} + +type Filter interface { + ToSQLPart(ctx Context) string + FromJSON(interface{}) Filter +} diff --git a/pkg/filters/unit_test.go b/pkg/filters/unit_test.go new file mode 100644 index 0000000..b8ed979 --- /dev/null +++ b/pkg/filters/unit_test.go @@ -0,0 +1,83 @@ +package filters + +import ( + "testing" +) + +func TestEq(t *testing.T) { + ctx := SQLiteContext{ + TableAlias: "t", + FieldName: "test", + } + result := Convert(ctx, `{"$eq": "NULL"}`) + resultMap := Convert(ctx, map[string]any{"$eq": "NULL"}) + if result != `t.test IS NULL` { + t.Errorf("Expected t.test IS NULL, got %s", result) + } + if resultMap != result { + t.Log(resultMap) + t.Errorf("Expected resultMap to be equal to result") + } +} + +func TestNe(t *testing.T) { + ctx := SQLiteContext{ + FieldName: "test", + } + result := Convert(ctx, `{"$ne": "1"}`) + resultMap := Convert(ctx, map[string]any{"$ne": "1"}) + if result != `test != '1'` { + t.Errorf("Expected test != '1', got %s", result) + } + if resultMap != result { + t.Log(resultMap) + t.Errorf("Expected resultMap to be equal to result") + } +} + +func TestBetween(t *testing.T) { + ctx := SQLiteContext{ + FieldName: "test", + } + result := Convert(ctx, `{"$between": ["1", "5"]}`) + resultMap := Convert(ctx, map[string]any{"$between": []string{"1", "5"}}) + if result != `test BETWEEN '1' AND '5'` { + t.Errorf("Expected test BETWEEN '1' AND '5', got %s", result) + } + if resultMap != result { + t.Log(resultMap) + t.Errorf("Expected resultMap to be equal to result") + } +} + +func TestGlob(t *testing.T) { + ctx := SQLiteContext{ + TableAlias: "t", + FieldName: "test", + } + result := Convert(ctx, `{"$glob": "*son"}`) + resultMap := Convert(ctx, map[string]any{"$glob": "*son"}) + if result != `t.test GLOB '*son'` { + t.Errorf("Expected t.test GLOB '*son', got %s", result) + } + if resultMap != result { + t.Log(resultMap) + t.Errorf("Expected resultMap to be equal to result") + } +} + +func TestIn(t *testing.T) { + ctx := SQLiteContext{ + TableAlias: "t", + FieldName: "test", + } + result := Convert(ctx, `{"$in": [1, 2, 3]}`) + resultMap := Convert(ctx, map[string]any{"$in": []int{1, 2, 3}}) + if result != `t.test IN (1, 2, 3)` { + t.Errorf("Expected t.test IN (1, 2, 3), got %s", result) + } + if resultMap != result { + t.Log(resultMap) + t.Errorf("Expected resultMap to be equal to result") + } +} diff --git a/pkg/filters/utils.go b/pkg/filters/utils.go new file mode 100644 index 0000000..add3835 --- /dev/null +++ b/pkg/filters/utils.go @@ -0,0 +1,29 @@ +package filters + +import ( + "encoding/json" +) + +func FromJson[T Filter](data interface{}) *T { + var t T + str, ok := data.(string) + if ok { + err := json.Unmarshal([]byte(str), &t) + if err != nil { + return nil + } + } + m, ok := data.(map[string]interface{}) + if ok { + s, err := json.Marshal(m) + if err != nil { + return nil + } + + e := json.Unmarshal(s, &t) + if e != nil { + return nil + } + } + return &t +} diff --git a/pkg/utils/common.go b/pkg/utils/common.go new file mode 100644 index 0000000..27ce748 --- /dev/null +++ b/pkg/utils/common.go @@ -0,0 +1,9 @@ +package utils + +func Map[T, U any](ts []T, f func(T) U) []U { + us := make([]U, len(ts)) + for i := range ts { + us[i] = f(ts[i]) + } + return us +} diff --git a/pkg/utils/go.mod b/pkg/utils/go.mod new file mode 100644 index 0000000..e31eed9 --- /dev/null +++ b/pkg/utils/go.mod @@ -0,0 +1,5 @@ +module l12.xyz/dal/utils + +go 1.22.6 + +require github.com/pkg/errors v0.9.1 diff --git a/pkg/utils/go.sum b/pkg/utils/go.sum new file mode 100644 index 0000000..7c401c3 --- /dev/null +++ b/pkg/utils/go.sum @@ -0,0 +1,2 @@ +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= diff --git a/pkg/utils/sql.go b/pkg/utils/sql.go new file mode 100644 index 0000000..41fdaaa --- /dev/null +++ b/pkg/utils/sql.go @@ -0,0 +1,236 @@ +package utils + +import ( + "encoding/json" + "reflect" + "strconv" + "strings" + "time" + "unsafe" + + "github.com/pkg/errors" +) + +func EscapeSQL(sql string, args ...interface{}) ([]byte, error) { + buf := make([]byte, 0, len(sql)) + argPos := 0 + for i := 0; i < len(sql); i++ { + q := strings.IndexByte(sql[i:], '%') + if q == -1 { + buf = append(buf, sql[i:]...) + break + } + buf = append(buf, sql[i:i+q]...) + i += q + + ch := byte(0) + if i+1 < len(sql) { + ch = sql[i+1] // get the specifier + } + switch ch { + case 'n': + if argPos >= len(args) { + return nil, errors.Errorf("missing arguments, need %d-th arg, but only got %d args", argPos+1, len(args)) + } + arg := args[argPos] + argPos++ + + v, ok := arg.(string) + if !ok { + return nil, errors.Errorf("expect a string identifier, got %v", arg) + } + buf = append(buf, '`') + buf = append(buf, strings.ReplaceAll(v, "`", "``")...) + buf = append(buf, '`') + i++ // skip specifier + case '?': + if argPos >= len(args) { + return nil, errors.Errorf("missing arguments, need %d-th arg, but only got %d args", argPos+1, len(args)) + } + arg := args[argPos] + argPos++ + + if arg == nil { + buf = append(buf, "NULL"...) + } else { + switch v := arg.(type) { + case int: + buf = strconv.AppendInt(buf, int64(v), 10) + case int8: + buf = strconv.AppendInt(buf, int64(v), 10) + case int16: + buf = strconv.AppendInt(buf, int64(v), 10) + case int32: + buf = strconv.AppendInt(buf, int64(v), 10) + case int64: + buf = strconv.AppendInt(buf, v, 10) + case uint: + buf = strconv.AppendUint(buf, uint64(v), 10) + case uint8: + buf = strconv.AppendUint(buf, uint64(v), 10) + case uint16: + buf = strconv.AppendUint(buf, uint64(v), 10) + case uint32: + buf = strconv.AppendUint(buf, uint64(v), 10) + case uint64: + buf = strconv.AppendUint(buf, v, 10) + case float32: + buf = strconv.AppendFloat(buf, float64(v), 'g', -1, 32) + case float64: + buf = strconv.AppendFloat(buf, v, 'g', -1, 64) + case bool: + buf = appendSQLArgBool(buf, v) + case time.Time: + if v.IsZero() { + buf = append(buf, "'0000-00-00'"...) + } else { + buf = append(buf, '\'') + buf = v.AppendFormat(buf, "2006-01-02 15:04:05.999999") + buf = append(buf, '\'') + } + case json.RawMessage: + buf = append(buf, '\'') + buf = escapeBytesBackslash(buf, v) + buf = append(buf, '\'') + case []byte: + if v == nil { + buf = append(buf, "NULL"...) + } else { + buf = append(buf, "_binary'"...) + buf = escapeBytesBackslash(buf, v) + buf = append(buf, '\'') + } + case string: + buf = appendSQLArgString(buf, v) + case []string: + for i, k := range v { + if i > 0 { + buf = append(buf, ',') + } + buf = append(buf, '\'') + buf = escapeStringBackslash(buf, k) + buf = append(buf, '\'') + } + case []float32: + for i, k := range v { + if i > 0 { + buf = append(buf, ',') + } + buf = strconv.AppendFloat(buf, float64(k), 'g', -1, 32) + } + case []float64: + for i, k := range v { + if i > 0 { + buf = append(buf, ',') + } + buf = strconv.AppendFloat(buf, k, 'g', -1, 64) + } + default: + // slow path based on reflection + reflectTp := reflect.TypeOf(arg) + kind := reflectTp.Kind() + switch kind { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + buf = strconv.AppendInt(buf, reflect.ValueOf(arg).Int(), 10) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + buf = strconv.AppendUint(buf, reflect.ValueOf(arg).Uint(), 10) + case reflect.Float32: + buf = strconv.AppendFloat(buf, reflect.ValueOf(arg).Float(), 'g', -1, 32) + case reflect.Float64: + buf = strconv.AppendFloat(buf, reflect.ValueOf(arg).Float(), 'g', -1, 64) + case reflect.Bool: + buf = appendSQLArgBool(buf, reflect.ValueOf(arg).Bool()) + case reflect.String: + buf = appendSQLArgString(buf, reflect.ValueOf(arg).String()) + default: + return nil, errors.Errorf("unsupported %d-th argument: %v", argPos, arg) + } + } + } + i++ // skip specifier + case '%': + buf = append(buf, '%') + i++ // skip specifier + default: + buf = append(buf, '%') + } + } + return buf, nil +} + +func EscapeString(s string) string { + buf := make([]byte, 0, len(s)) + return string(escapeStringBackslash(buf, s)) +} + +func appendSQLArgBool(buf []byte, v bool) []byte { + if v { + return append(buf, '1') + } + return append(buf, '0') +} + +func appendSQLArgString(buf []byte, s string) []byte { + buf = append(buf, '\'') + buf = escapeStringBackslash(buf, s) + buf = append(buf, '\'') + return buf +} + +func escapeStringBackslash(buf []byte, v string) []byte { + return escapeBytesBackslash(buf, unsafe.Slice(unsafe.StringData(v), len(v))) +} + +// escapeBytesBackslash will escape []byte into the buffer, with backslash. +func escapeBytesBackslash(buf []byte, v []byte) []byte { + pos := len(buf) + buf = reserveBuffer(buf, len(v)*2) + + for _, c := range v { + switch c { + case '\x00': + buf[pos] = '\\' + buf[pos+1] = '0' + pos += 2 + case '\n': + buf[pos] = '\\' + buf[pos+1] = 'n' + pos += 2 + case '\r': + buf[pos] = '\\' + buf[pos+1] = 'r' + pos += 2 + case '\x1a': + buf[pos] = '\\' + buf[pos+1] = 'Z' + pos += 2 + case '\'': + buf[pos] = '\\' + buf[pos+1] = '\'' + pos += 2 + case '"': + buf[pos] = '\\' + buf[pos+1] = '"' + pos += 2 + case '\\': + buf[pos] = '\\' + buf[pos+1] = '\\' + pos += 2 + default: + buf[pos] = c + pos++ + } + } + + return buf[:pos] +} + +func reserveBuffer(buf []byte, appendSize int) []byte { + newSize := len(buf) + appendSize + if cap(buf) < newSize { + newBuf := make([]byte, len(buf)*2+appendSize) + copy(newBuf, buf) + buf = newBuf + } + return buf[:newSize] +}