[wip] dal golang

Signed-off-by: Anton Nesterov <anton@demiurg.io>
This commit is contained in:
Anton Nesterov 2024-08-07 21:16:40 +02:00
commit d28d976b8e
No known key found for this signature in database
GPG key ID: 59121E8AE2851FB5
25 changed files with 780 additions and 0 deletions

3
go.mod Normal file
View file

@ -0,0 +1,3 @@
module l12.xyz/dal
go 1.22.6

37
pkg/dal/builder.go Normal file
View file

@ -0,0 +1,37 @@
package dal
type SQLParts struct {
operation string
selectExp string
fromExp string
fiterExp string
joinExp []string
groupExp string
orderExp string
limitExp string
updateExp string
upsertExp string
}
type Builder struct {
parts SQLParts
}
func New() *Builder {
return &Builder{}
}
func (b *Builder) In(selectExp string) *Builder {
b.parts.selectExp = selectExp
return b
}
func (b *Builder) Find(fromExp string) *Builder {
b.parts.fromExp = fromExp
return b
}
func (b *Builder) Join(fiterExp string) *Builder {
b.parts.fiterExp = fiterExp
return b
}

7
pkg/dal/convert.go Normal file
View file

@ -0,0 +1,7 @@
package dal
type FindObject map[string]interface{}
func CovertFind(findobj FindObject) string {
return ""
}

19
pkg/filters/And.go Normal file
View file

@ -0,0 +1,19 @@
package filters
import (
"fmt"
)
type And struct {
And []interface{} `json:"$and"`
}
func (a And) ToSQLPart(ctx Context) string {
fmt.Println(ctx, a)
return ""
}
func (a And) FromJSON(data interface{}) Filter {
return FromJson[And](data)
}

25
pkg/filters/Between.go Normal file
View file

@ -0,0 +1,25 @@
package filters
import (
"fmt"
"l12.xyz/dal/utils"
)
type Between struct {
Between []interface{} `json:"$between"`
}
func (f Between) FromJSON(data interface{}) Filter {
return FromJson[Between](data)
}
func (f Between) ToSQLPart(ctx Context) string {
if f.Between == nil {
return ""
}
name := ctx.GetFieldName()
values := utils.Map(f.Between, ctx.NormalizeValue)
condition := fmt.Sprintf("%v AND %v", values[0], values[1])
return fmt.Sprintf("%s BETWEEN %v", name, condition)
}

25
pkg/filters/Eq.go Normal file
View file

@ -0,0 +1,25 @@
package filters
import (
"fmt"
)
type Eq struct {
Eq interface{} `json:"$eq"`
}
func (f Eq) FromJSON(data interface{}) Filter {
return FromJson[Eq](data)
}
func (f Eq) ToSQLPart(ctx Context) string {
if f.Eq == nil {
return ""
}
name := ctx.GetFieldName()
value := ctx.NormalizeValue(f.Eq)
if value == "NULL" {
return fmt.Sprintf("%s IS NULL", name)
}
return fmt.Sprintf("%s = %v", name, value)
}

20
pkg/filters/Glob.go Normal file
View file

@ -0,0 +1,20 @@
package filters
import "fmt"
type Glob struct {
Glob interface{} `json:"$glob"`
}
func (f Glob) FromJSON(data interface{}) Filter {
return FromJson[Glob](data)
}
func (f Glob) ToSQLPart(ctx Context) string {
if f.Glob == nil {
return ""
}
name := ctx.GetFieldName()
value := ctx.NormalizeValue(f.Glob)
return fmt.Sprintf("%s GLOB %v", name, value)
}

22
pkg/filters/Gt.go Normal file
View file

@ -0,0 +1,22 @@
package filters
import (
"fmt"
)
type Gt struct {
Gt interface{} `json:"$gt"`
}
func (f Gt) FromJSON(data interface{}) Filter {
return FromJson[Gt](data)
}
func (f Gt) ToSQLPart(ctx Context) string {
if f.Gt == nil {
return ""
}
name := ctx.GetFieldName()
value := ctx.NormalizeValue(f.Gt)
return fmt.Sprintf("%s > %v", name, value)
}

22
pkg/filters/Gte.go Normal file
View file

@ -0,0 +1,22 @@
package filters
import (
"fmt"
)
type Gte struct {
Gte interface{} `json:"$gte"`
}
func (f Gte) FromJSON(data interface{}) Filter {
return FromJson[Gte](data)
}
func (f Gte) ToSQLPart(ctx Context) string {
if f.Gte == nil {
return ""
}
name := ctx.GetFieldName()
value := ctx.NormalizeValue(f.Gte)
return fmt.Sprintf("%s >= %v", name, value)
}

31
pkg/filters/In.go Normal file
View file

@ -0,0 +1,31 @@
package filters
import (
"fmt"
"strings"
"l12.xyz/dal/utils"
)
type In struct {
In []interface{} `json:"$in"`
}
func (f In) FromJSON(data interface{}) Filter {
return FromJson[In](data)
}
func (f In) ToSQLPart(ctx Context) string {
if f.In == nil {
return ""
}
name := ctx.GetFieldName()
values := utils.Map(f.In, ctx.NormalizeValue)
data := make([]string, len(values))
for i, v := range values {
data[i] = fmt.Sprintf("%v", v)
}
value := strings.Join(data, ", ")
return fmt.Sprintf("%s IN (%v)", name, value)
}

22
pkg/filters/Lt.go Normal file
View file

@ -0,0 +1,22 @@
package filters
import (
"fmt"
)
type Lt struct {
Lt interface{} `json:"$lt"`
}
func (f Lt) FromJSON(data interface{}) Filter {
return FromJson[Lt](data)
}
func (f Lt) ToSQLPart(ctx Context) string {
if f.Lt == nil {
return ""
}
name := ctx.GetFieldName()
value := ctx.NormalizeValue(f.Lt)
return fmt.Sprintf("%s < %v", name, value)
}

22
pkg/filters/Lte.go Normal file
View file

@ -0,0 +1,22 @@
package filters
import (
"fmt"
)
type Lte struct {
Lte interface{} `json:"$lte"`
}
func (f Lte) FromJSON(data interface{}) Filter {
return FromJson[Lte](data)
}
func (f Lte) ToSQLPart(ctx Context) string {
if f.Lte == nil {
return ""
}
name := ctx.GetFieldName()
value := ctx.NormalizeValue(f.Lte)
return fmt.Sprintf("%s <= %v", name, value)
}

23
pkg/filters/Ne.go Normal file
View file

@ -0,0 +1,23 @@
package filters
import "fmt"
type Ne struct {
Ne interface{} `json:"$ne"`
}
func (f Ne) FromJSON(data interface{}) Filter {
return FromJson[Ne](data)
}
func (f Ne) ToSQLPart(ctx Context) string {
if f.Ne == nil {
return ""
}
name := ctx.GetFieldName()
value := ctx.NormalizeValue(f.Ne)
if value == "NULL" {
return fmt.Sprintf("%s IS NOT NULL", name)
}
return fmt.Sprintf("%s != %v", name, value)
}

25
pkg/filters/NotBetween.go Normal file
View file

@ -0,0 +1,25 @@
package filters
import (
"fmt"
"l12.xyz/dal/utils"
)
type NotBetween struct {
NotBetween []interface{} `json:"$nbetween"`
}
func (f NotBetween) FromJSON(data interface{}) Filter {
return FromJson[NotBetween](data)
}
func (f NotBetween) ToSQLPart(ctx Context) string {
if f.NotBetween == nil {
return ""
}
name := ctx.GetFieldName()
values := utils.Map(f.NotBetween, ctx.NormalizeValue)
condition := fmt.Sprintf("%v AND %v", values[0], values[1])
return fmt.Sprintf("%s NOT BETWEEN %v", name, condition)
}

67
pkg/filters/context.go Normal file
View file

@ -0,0 +1,67 @@
package filters
import (
"slices"
"strconv"
"strings"
"unicode"
utils "l12.xyz/dal/utils"
)
type SQLiteContext struct {
TableAlias string
FieldName string
}
func (c SQLiteContext) GetFieldName() string {
if strings.Contains(c.FieldName, ".") {
return c.FieldName
}
if c.TableAlias != "" {
return c.TableAlias + "." + c.FieldName
}
return c.FieldName
}
func (c SQLiteContext) NormalizeValue(value interface{}) interface{} {
str, ok := value.(string)
if isSQLFunction(str) {
return str
}
if strings.Contains(str, ".") {
_, err := strconv.ParseFloat(str, 64)
if err != nil {
return value
}
}
if !ok {
return value
}
val, err := utils.EscapeSQL(str)
if err != nil {
return str
}
return "'" + escapeSingleQuote(string(val)) + "'"
}
func isSQLFunction(str string) bool {
stopChars := []string{" ", "_", "-", ".", "("}
isUpper := false
for _, char := range str {
if slices.Contains(stopChars, string(char)) {
break
}
if unicode.IsUpper(char) {
isUpper = true
} else {
isUpper = false
break
}
}
return isUpper
}
func escapeSingleQuote(str string) string {
return strings.ReplaceAll(str, "'", "''")
}

9
pkg/filters/go.mod Normal file
View file

@ -0,0 +1,9 @@
module l12.xyz/dal/filters
go 1.22.6
require github.com/pkg/errors v0.9.1 // indirect
require l12.xyz/dal/utils v0.0.0
replace l12.xyz/dal/utils v0.0.0 => ../utils

2
pkg/filters/go.sum Normal file
View file

@ -0,0 +1,2 @@
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=

24
pkg/filters/registry.go Normal file
View file

@ -0,0 +1,24 @@
package filters
var FilterRegistry = map[string]Filter{
"Eq": &Eq{},
"Ne": &Ne{},
"Gt": &Gt{},
"Gte": &Gte{},
"Lt": &Lt{},
"Lte": &Lte{},
"In": &In{},
"Between": &Between{},
"NotBetween": &NotBetween{},
"Glob": &Glob{},
}
func Convert(ctx Context, json interface{}) string {
for _, t := range FilterRegistry {
value := t.FromJSON(json).ToSQLPart(ctx)
if value != "" {
return value
}
}
return ""
}

11
pkg/filters/types.go Normal file
View file

@ -0,0 +1,11 @@
package filters
type Context interface {
GetFieldName() string
NormalizeValue(interface{}) interface{}
}
type Filter interface {
ToSQLPart(ctx Context) string
FromJSON(interface{}) Filter
}

83
pkg/filters/unit_test.go Normal file
View file

@ -0,0 +1,83 @@
package filters
import (
"testing"
)
func TestEq(t *testing.T) {
ctx := SQLiteContext{
TableAlias: "t",
FieldName: "test",
}
result := Convert(ctx, `{"$eq": "NULL"}`)
resultMap := Convert(ctx, map[string]any{"$eq": "NULL"})
if result != `t.test IS NULL` {
t.Errorf("Expected t.test IS NULL, got %s", result)
}
if resultMap != result {
t.Log(resultMap)
t.Errorf("Expected resultMap to be equal to result")
}
}
func TestNe(t *testing.T) {
ctx := SQLiteContext{
FieldName: "test",
}
result := Convert(ctx, `{"$ne": "1"}`)
resultMap := Convert(ctx, map[string]any{"$ne": "1"})
if result != `test != '1'` {
t.Errorf("Expected test != '1', got %s", result)
}
if resultMap != result {
t.Log(resultMap)
t.Errorf("Expected resultMap to be equal to result")
}
}
func TestBetween(t *testing.T) {
ctx := SQLiteContext{
FieldName: "test",
}
result := Convert(ctx, `{"$between": ["1", "5"]}`)
resultMap := Convert(ctx, map[string]any{"$between": []string{"1", "5"}})
if result != `test BETWEEN '1' AND '5'` {
t.Errorf("Expected test BETWEEN '1' AND '5', got %s", result)
}
if resultMap != result {
t.Log(resultMap)
t.Errorf("Expected resultMap to be equal to result")
}
}
func TestGlob(t *testing.T) {
ctx := SQLiteContext{
TableAlias: "t",
FieldName: "test",
}
result := Convert(ctx, `{"$glob": "*son"}`)
resultMap := Convert(ctx, map[string]any{"$glob": "*son"})
if result != `t.test GLOB '*son'` {
t.Errorf("Expected t.test GLOB '*son', got %s", result)
}
if resultMap != result {
t.Log(resultMap)
t.Errorf("Expected resultMap to be equal to result")
}
}
func TestIn(t *testing.T) {
ctx := SQLiteContext{
TableAlias: "t",
FieldName: "test",
}
result := Convert(ctx, `{"$in": [1, 2, 3]}`)
resultMap := Convert(ctx, map[string]any{"$in": []int{1, 2, 3}})
if result != `t.test IN (1, 2, 3)` {
t.Errorf("Expected t.test IN (1, 2, 3), got %s", result)
}
if resultMap != result {
t.Log(resultMap)
t.Errorf("Expected resultMap to be equal to result")
}
}

29
pkg/filters/utils.go Normal file
View file

@ -0,0 +1,29 @@
package filters
import (
"encoding/json"
)
func FromJson[T Filter](data interface{}) *T {
var t T
str, ok := data.(string)
if ok {
err := json.Unmarshal([]byte(str), &t)
if err != nil {
return nil
}
}
m, ok := data.(map[string]interface{})
if ok {
s, err := json.Marshal(m)
if err != nil {
return nil
}
e := json.Unmarshal(s, &t)
if e != nil {
return nil
}
}
return &t
}

9
pkg/utils/common.go Normal file
View file

@ -0,0 +1,9 @@
package utils
func Map[T, U any](ts []T, f func(T) U) []U {
us := make([]U, len(ts))
for i := range ts {
us[i] = f(ts[i])
}
return us
}

5
pkg/utils/go.mod Normal file
View file

@ -0,0 +1,5 @@
module l12.xyz/dal/utils
go 1.22.6
require github.com/pkg/errors v0.9.1

2
pkg/utils/go.sum Normal file
View file

@ -0,0 +1,2 @@
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=

236
pkg/utils/sql.go Normal file
View file

@ -0,0 +1,236 @@
package utils
import (
"encoding/json"
"reflect"
"strconv"
"strings"
"time"
"unsafe"
"github.com/pkg/errors"
)
func EscapeSQL(sql string, args ...interface{}) ([]byte, error) {
buf := make([]byte, 0, len(sql))
argPos := 0
for i := 0; i < len(sql); i++ {
q := strings.IndexByte(sql[i:], '%')
if q == -1 {
buf = append(buf, sql[i:]...)
break
}
buf = append(buf, sql[i:i+q]...)
i += q
ch := byte(0)
if i+1 < len(sql) {
ch = sql[i+1] // get the specifier
}
switch ch {
case 'n':
if argPos >= len(args) {
return nil, errors.Errorf("missing arguments, need %d-th arg, but only got %d args", argPos+1, len(args))
}
arg := args[argPos]
argPos++
v, ok := arg.(string)
if !ok {
return nil, errors.Errorf("expect a string identifier, got %v", arg)
}
buf = append(buf, '`')
buf = append(buf, strings.ReplaceAll(v, "`", "``")...)
buf = append(buf, '`')
i++ // skip specifier
case '?':
if argPos >= len(args) {
return nil, errors.Errorf("missing arguments, need %d-th arg, but only got %d args", argPos+1, len(args))
}
arg := args[argPos]
argPos++
if arg == nil {
buf = append(buf, "NULL"...)
} else {
switch v := arg.(type) {
case int:
buf = strconv.AppendInt(buf, int64(v), 10)
case int8:
buf = strconv.AppendInt(buf, int64(v), 10)
case int16:
buf = strconv.AppendInt(buf, int64(v), 10)
case int32:
buf = strconv.AppendInt(buf, int64(v), 10)
case int64:
buf = strconv.AppendInt(buf, v, 10)
case uint:
buf = strconv.AppendUint(buf, uint64(v), 10)
case uint8:
buf = strconv.AppendUint(buf, uint64(v), 10)
case uint16:
buf = strconv.AppendUint(buf, uint64(v), 10)
case uint32:
buf = strconv.AppendUint(buf, uint64(v), 10)
case uint64:
buf = strconv.AppendUint(buf, v, 10)
case float32:
buf = strconv.AppendFloat(buf, float64(v), 'g', -1, 32)
case float64:
buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
case bool:
buf = appendSQLArgBool(buf, v)
case time.Time:
if v.IsZero() {
buf = append(buf, "'0000-00-00'"...)
} else {
buf = append(buf, '\'')
buf = v.AppendFormat(buf, "2006-01-02 15:04:05.999999")
buf = append(buf, '\'')
}
case json.RawMessage:
buf = append(buf, '\'')
buf = escapeBytesBackslash(buf, v)
buf = append(buf, '\'')
case []byte:
if v == nil {
buf = append(buf, "NULL"...)
} else {
buf = append(buf, "_binary'"...)
buf = escapeBytesBackslash(buf, v)
buf = append(buf, '\'')
}
case string:
buf = appendSQLArgString(buf, v)
case []string:
for i, k := range v {
if i > 0 {
buf = append(buf, ',')
}
buf = append(buf, '\'')
buf = escapeStringBackslash(buf, k)
buf = append(buf, '\'')
}
case []float32:
for i, k := range v {
if i > 0 {
buf = append(buf, ',')
}
buf = strconv.AppendFloat(buf, float64(k), 'g', -1, 32)
}
case []float64:
for i, k := range v {
if i > 0 {
buf = append(buf, ',')
}
buf = strconv.AppendFloat(buf, k, 'g', -1, 64)
}
default:
// slow path based on reflection
reflectTp := reflect.TypeOf(arg)
kind := reflectTp.Kind()
switch kind {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
buf = strconv.AppendInt(buf, reflect.ValueOf(arg).Int(), 10)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
buf = strconv.AppendUint(buf, reflect.ValueOf(arg).Uint(), 10)
case reflect.Float32:
buf = strconv.AppendFloat(buf, reflect.ValueOf(arg).Float(), 'g', -1, 32)
case reflect.Float64:
buf = strconv.AppendFloat(buf, reflect.ValueOf(arg).Float(), 'g', -1, 64)
case reflect.Bool:
buf = appendSQLArgBool(buf, reflect.ValueOf(arg).Bool())
case reflect.String:
buf = appendSQLArgString(buf, reflect.ValueOf(arg).String())
default:
return nil, errors.Errorf("unsupported %d-th argument: %v", argPos, arg)
}
}
}
i++ // skip specifier
case '%':
buf = append(buf, '%')
i++ // skip specifier
default:
buf = append(buf, '%')
}
}
return buf, nil
}
func EscapeString(s string) string {
buf := make([]byte, 0, len(s))
return string(escapeStringBackslash(buf, s))
}
func appendSQLArgBool(buf []byte, v bool) []byte {
if v {
return append(buf, '1')
}
return append(buf, '0')
}
func appendSQLArgString(buf []byte, s string) []byte {
buf = append(buf, '\'')
buf = escapeStringBackslash(buf, s)
buf = append(buf, '\'')
return buf
}
func escapeStringBackslash(buf []byte, v string) []byte {
return escapeBytesBackslash(buf, unsafe.Slice(unsafe.StringData(v), len(v)))
}
// escapeBytesBackslash will escape []byte into the buffer, with backslash.
func escapeBytesBackslash(buf []byte, v []byte) []byte {
pos := len(buf)
buf = reserveBuffer(buf, len(v)*2)
for _, c := range v {
switch c {
case '\x00':
buf[pos] = '\\'
buf[pos+1] = '0'
pos += 2
case '\n':
buf[pos] = '\\'
buf[pos+1] = 'n'
pos += 2
case '\r':
buf[pos] = '\\'
buf[pos+1] = 'r'
pos += 2
case '\x1a':
buf[pos] = '\\'
buf[pos+1] = 'Z'
pos += 2
case '\'':
buf[pos] = '\\'
buf[pos+1] = '\''
pos += 2
case '"':
buf[pos] = '\\'
buf[pos+1] = '"'
pos += 2
case '\\':
buf[pos] = '\\'
buf[pos+1] = '\\'
pos += 2
default:
buf[pos] = c
pos++
}
}
return buf[:pos]
}
func reserveBuffer(buf []byte, appendSize int) []byte {
newSize := len(buf) + appendSize
if cap(buf) < newSize {
newBuf := make([]byte, len(buf)*2+appendSize)
copy(newBuf, buf)
buf = newBuf
}
return buf[:newSize]
}