[wip] dal golang
Signed-off-by: Anton Nesterov <anton@demiurg.io>
This commit is contained in:
commit
d28d976b8e
37
pkg/dal/builder.go
Normal file
37
pkg/dal/builder.go
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
package dal
|
||||||
|
|
||||||
|
type SQLParts struct {
|
||||||
|
operation string
|
||||||
|
selectExp string
|
||||||
|
fromExp string
|
||||||
|
fiterExp string
|
||||||
|
joinExp []string
|
||||||
|
groupExp string
|
||||||
|
orderExp string
|
||||||
|
limitExp string
|
||||||
|
updateExp string
|
||||||
|
upsertExp string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Builder struct {
|
||||||
|
parts SQLParts
|
||||||
|
}
|
||||||
|
|
||||||
|
func New() *Builder {
|
||||||
|
return &Builder{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Builder) In(selectExp string) *Builder {
|
||||||
|
b.parts.selectExp = selectExp
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Builder) Find(fromExp string) *Builder {
|
||||||
|
b.parts.fromExp = fromExp
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Builder) Join(fiterExp string) *Builder {
|
||||||
|
b.parts.fiterExp = fiterExp
|
||||||
|
return b
|
||||||
|
}
|
7
pkg/dal/convert.go
Normal file
7
pkg/dal/convert.go
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
package dal
|
||||||
|
|
||||||
|
type FindObject map[string]interface{}
|
||||||
|
|
||||||
|
func CovertFind(findobj FindObject) string {
|
||||||
|
return ""
|
||||||
|
}
|
19
pkg/filters/And.go
Normal file
19
pkg/filters/And.go
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
package filters
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type And struct {
|
||||||
|
And []interface{} `json:"$and"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a And) ToSQLPart(ctx Context) string {
|
||||||
|
|
||||||
|
fmt.Println(ctx, a)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a And) FromJSON(data interface{}) Filter {
|
||||||
|
return FromJson[And](data)
|
||||||
|
}
|
25
pkg/filters/Between.go
Normal file
25
pkg/filters/Between.go
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
package filters
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"l12.xyz/dal/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Between struct {
|
||||||
|
Between []interface{} `json:"$between"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Between) FromJSON(data interface{}) Filter {
|
||||||
|
return FromJson[Between](data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Between) ToSQLPart(ctx Context) string {
|
||||||
|
if f.Between == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
name := ctx.GetFieldName()
|
||||||
|
values := utils.Map(f.Between, ctx.NormalizeValue)
|
||||||
|
condition := fmt.Sprintf("%v AND %v", values[0], values[1])
|
||||||
|
return fmt.Sprintf("%s BETWEEN %v", name, condition)
|
||||||
|
}
|
25
pkg/filters/Eq.go
Normal file
25
pkg/filters/Eq.go
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
package filters
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Eq struct {
|
||||||
|
Eq interface{} `json:"$eq"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Eq) FromJSON(data interface{}) Filter {
|
||||||
|
return FromJson[Eq](data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Eq) ToSQLPart(ctx Context) string {
|
||||||
|
if f.Eq == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
name := ctx.GetFieldName()
|
||||||
|
value := ctx.NormalizeValue(f.Eq)
|
||||||
|
if value == "NULL" {
|
||||||
|
return fmt.Sprintf("%s IS NULL", name)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s = %v", name, value)
|
||||||
|
}
|
20
pkg/filters/Glob.go
Normal file
20
pkg/filters/Glob.go
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
package filters
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
type Glob struct {
|
||||||
|
Glob interface{} `json:"$glob"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Glob) FromJSON(data interface{}) Filter {
|
||||||
|
return FromJson[Glob](data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Glob) ToSQLPart(ctx Context) string {
|
||||||
|
if f.Glob == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
name := ctx.GetFieldName()
|
||||||
|
value := ctx.NormalizeValue(f.Glob)
|
||||||
|
return fmt.Sprintf("%s GLOB %v", name, value)
|
||||||
|
}
|
22
pkg/filters/Gt.go
Normal file
22
pkg/filters/Gt.go
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
package filters
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Gt struct {
|
||||||
|
Gt interface{} `json:"$gt"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Gt) FromJSON(data interface{}) Filter {
|
||||||
|
return FromJson[Gt](data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Gt) ToSQLPart(ctx Context) string {
|
||||||
|
if f.Gt == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
name := ctx.GetFieldName()
|
||||||
|
value := ctx.NormalizeValue(f.Gt)
|
||||||
|
return fmt.Sprintf("%s > %v", name, value)
|
||||||
|
}
|
22
pkg/filters/Gte.go
Normal file
22
pkg/filters/Gte.go
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
package filters
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Gte struct {
|
||||||
|
Gte interface{} `json:"$gte"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Gte) FromJSON(data interface{}) Filter {
|
||||||
|
return FromJson[Gte](data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Gte) ToSQLPart(ctx Context) string {
|
||||||
|
if f.Gte == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
name := ctx.GetFieldName()
|
||||||
|
value := ctx.NormalizeValue(f.Gte)
|
||||||
|
return fmt.Sprintf("%s >= %v", name, value)
|
||||||
|
}
|
31
pkg/filters/In.go
Normal file
31
pkg/filters/In.go
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
package filters
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"l12.xyz/dal/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
type In struct {
|
||||||
|
In []interface{} `json:"$in"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f In) FromJSON(data interface{}) Filter {
|
||||||
|
return FromJson[In](data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f In) ToSQLPart(ctx Context) string {
|
||||||
|
if f.In == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
name := ctx.GetFieldName()
|
||||||
|
values := utils.Map(f.In, ctx.NormalizeValue)
|
||||||
|
data := make([]string, len(values))
|
||||||
|
for i, v := range values {
|
||||||
|
data[i] = fmt.Sprintf("%v", v)
|
||||||
|
}
|
||||||
|
value := strings.Join(data, ", ")
|
||||||
|
return fmt.Sprintf("%s IN (%v)", name, value)
|
||||||
|
}
|
22
pkg/filters/Lt.go
Normal file
22
pkg/filters/Lt.go
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
package filters
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Lt struct {
|
||||||
|
Lt interface{} `json:"$lt"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Lt) FromJSON(data interface{}) Filter {
|
||||||
|
return FromJson[Lt](data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Lt) ToSQLPart(ctx Context) string {
|
||||||
|
if f.Lt == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
name := ctx.GetFieldName()
|
||||||
|
value := ctx.NormalizeValue(f.Lt)
|
||||||
|
return fmt.Sprintf("%s < %v", name, value)
|
||||||
|
}
|
22
pkg/filters/Lte.go
Normal file
22
pkg/filters/Lte.go
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
package filters
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Lte struct {
|
||||||
|
Lte interface{} `json:"$lte"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Lte) FromJSON(data interface{}) Filter {
|
||||||
|
return FromJson[Lte](data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Lte) ToSQLPart(ctx Context) string {
|
||||||
|
if f.Lte == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
name := ctx.GetFieldName()
|
||||||
|
value := ctx.NormalizeValue(f.Lte)
|
||||||
|
return fmt.Sprintf("%s <= %v", name, value)
|
||||||
|
}
|
23
pkg/filters/Ne.go
Normal file
23
pkg/filters/Ne.go
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
package filters
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
type Ne struct {
|
||||||
|
Ne interface{} `json:"$ne"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Ne) FromJSON(data interface{}) Filter {
|
||||||
|
return FromJson[Ne](data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Ne) ToSQLPart(ctx Context) string {
|
||||||
|
if f.Ne == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
name := ctx.GetFieldName()
|
||||||
|
value := ctx.NormalizeValue(f.Ne)
|
||||||
|
if value == "NULL" {
|
||||||
|
return fmt.Sprintf("%s IS NOT NULL", name)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s != %v", name, value)
|
||||||
|
}
|
25
pkg/filters/NotBetween.go
Normal file
25
pkg/filters/NotBetween.go
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
package filters
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"l12.xyz/dal/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
type NotBetween struct {
|
||||||
|
NotBetween []interface{} `json:"$nbetween"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f NotBetween) FromJSON(data interface{}) Filter {
|
||||||
|
return FromJson[NotBetween](data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f NotBetween) ToSQLPart(ctx Context) string {
|
||||||
|
if f.NotBetween == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
name := ctx.GetFieldName()
|
||||||
|
values := utils.Map(f.NotBetween, ctx.NormalizeValue)
|
||||||
|
condition := fmt.Sprintf("%v AND %v", values[0], values[1])
|
||||||
|
return fmt.Sprintf("%s NOT BETWEEN %v", name, condition)
|
||||||
|
}
|
67
pkg/filters/context.go
Normal file
67
pkg/filters/context.go
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
package filters
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
|
||||||
|
utils "l12.xyz/dal/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SQLiteContext struct {
|
||||||
|
TableAlias string
|
||||||
|
FieldName string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c SQLiteContext) GetFieldName() string {
|
||||||
|
if strings.Contains(c.FieldName, ".") {
|
||||||
|
return c.FieldName
|
||||||
|
}
|
||||||
|
if c.TableAlias != "" {
|
||||||
|
return c.TableAlias + "." + c.FieldName
|
||||||
|
}
|
||||||
|
return c.FieldName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c SQLiteContext) NormalizeValue(value interface{}) interface{} {
|
||||||
|
str, ok := value.(string)
|
||||||
|
if isSQLFunction(str) {
|
||||||
|
return str
|
||||||
|
}
|
||||||
|
if strings.Contains(str, ".") {
|
||||||
|
_, err := strconv.ParseFloat(str, 64)
|
||||||
|
if err != nil {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
val, err := utils.EscapeSQL(str)
|
||||||
|
if err != nil {
|
||||||
|
return str
|
||||||
|
}
|
||||||
|
return "'" + escapeSingleQuote(string(val)) + "'"
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSQLFunction(str string) bool {
|
||||||
|
stopChars := []string{" ", "_", "-", ".", "("}
|
||||||
|
isUpper := false
|
||||||
|
for _, char := range str {
|
||||||
|
if slices.Contains(stopChars, string(char)) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if unicode.IsUpper(char) {
|
||||||
|
isUpper = true
|
||||||
|
} else {
|
||||||
|
isUpper = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return isUpper
|
||||||
|
}
|
||||||
|
|
||||||
|
func escapeSingleQuote(str string) string {
|
||||||
|
return strings.ReplaceAll(str, "'", "''")
|
||||||
|
}
|
9
pkg/filters/go.mod
Normal file
9
pkg/filters/go.mod
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
module l12.xyz/dal/filters
|
||||||
|
|
||||||
|
go 1.22.6
|
||||||
|
|
||||||
|
require github.com/pkg/errors v0.9.1 // indirect
|
||||||
|
|
||||||
|
require l12.xyz/dal/utils v0.0.0
|
||||||
|
|
||||||
|
replace l12.xyz/dal/utils v0.0.0 => ../utils
|
2
pkg/filters/go.sum
Normal file
2
pkg/filters/go.sum
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||||
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
24
pkg/filters/registry.go
Normal file
24
pkg/filters/registry.go
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
package filters
|
||||||
|
|
||||||
|
var FilterRegistry = map[string]Filter{
|
||||||
|
"Eq": &Eq{},
|
||||||
|
"Ne": &Ne{},
|
||||||
|
"Gt": &Gt{},
|
||||||
|
"Gte": &Gte{},
|
||||||
|
"Lt": &Lt{},
|
||||||
|
"Lte": &Lte{},
|
||||||
|
"In": &In{},
|
||||||
|
"Between": &Between{},
|
||||||
|
"NotBetween": &NotBetween{},
|
||||||
|
"Glob": &Glob{},
|
||||||
|
}
|
||||||
|
|
||||||
|
func Convert(ctx Context, json interface{}) string {
|
||||||
|
for _, t := range FilterRegistry {
|
||||||
|
value := t.FromJSON(json).ToSQLPart(ctx)
|
||||||
|
if value != "" {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
11
pkg/filters/types.go
Normal file
11
pkg/filters/types.go
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
package filters
|
||||||
|
|
||||||
|
type Context interface {
|
||||||
|
GetFieldName() string
|
||||||
|
NormalizeValue(interface{}) interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Filter interface {
|
||||||
|
ToSQLPart(ctx Context) string
|
||||||
|
FromJSON(interface{}) Filter
|
||||||
|
}
|
83
pkg/filters/unit_test.go
Normal file
83
pkg/filters/unit_test.go
Normal file
|
@ -0,0 +1,83 @@
|
||||||
|
package filters
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEq(t *testing.T) {
|
||||||
|
ctx := SQLiteContext{
|
||||||
|
TableAlias: "t",
|
||||||
|
FieldName: "test",
|
||||||
|
}
|
||||||
|
result := Convert(ctx, `{"$eq": "NULL"}`)
|
||||||
|
resultMap := Convert(ctx, map[string]any{"$eq": "NULL"})
|
||||||
|
if result != `t.test IS NULL` {
|
||||||
|
t.Errorf("Expected t.test IS NULL, got %s", result)
|
||||||
|
}
|
||||||
|
if resultMap != result {
|
||||||
|
t.Log(resultMap)
|
||||||
|
t.Errorf("Expected resultMap to be equal to result")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNe(t *testing.T) {
|
||||||
|
ctx := SQLiteContext{
|
||||||
|
FieldName: "test",
|
||||||
|
}
|
||||||
|
result := Convert(ctx, `{"$ne": "1"}`)
|
||||||
|
resultMap := Convert(ctx, map[string]any{"$ne": "1"})
|
||||||
|
if result != `test != '1'` {
|
||||||
|
t.Errorf("Expected test != '1', got %s", result)
|
||||||
|
}
|
||||||
|
if resultMap != result {
|
||||||
|
t.Log(resultMap)
|
||||||
|
t.Errorf("Expected resultMap to be equal to result")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBetween(t *testing.T) {
|
||||||
|
ctx := SQLiteContext{
|
||||||
|
FieldName: "test",
|
||||||
|
}
|
||||||
|
result := Convert(ctx, `{"$between": ["1", "5"]}`)
|
||||||
|
resultMap := Convert(ctx, map[string]any{"$between": []string{"1", "5"}})
|
||||||
|
if result != `test BETWEEN '1' AND '5'` {
|
||||||
|
t.Errorf("Expected test BETWEEN '1' AND '5', got %s", result)
|
||||||
|
}
|
||||||
|
if resultMap != result {
|
||||||
|
t.Log(resultMap)
|
||||||
|
t.Errorf("Expected resultMap to be equal to result")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGlob(t *testing.T) {
|
||||||
|
ctx := SQLiteContext{
|
||||||
|
TableAlias: "t",
|
||||||
|
FieldName: "test",
|
||||||
|
}
|
||||||
|
result := Convert(ctx, `{"$glob": "*son"}`)
|
||||||
|
resultMap := Convert(ctx, map[string]any{"$glob": "*son"})
|
||||||
|
if result != `t.test GLOB '*son'` {
|
||||||
|
t.Errorf("Expected t.test GLOB '*son', got %s", result)
|
||||||
|
}
|
||||||
|
if resultMap != result {
|
||||||
|
t.Log(resultMap)
|
||||||
|
t.Errorf("Expected resultMap to be equal to result")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIn(t *testing.T) {
|
||||||
|
ctx := SQLiteContext{
|
||||||
|
TableAlias: "t",
|
||||||
|
FieldName: "test",
|
||||||
|
}
|
||||||
|
result := Convert(ctx, `{"$in": [1, 2, 3]}`)
|
||||||
|
resultMap := Convert(ctx, map[string]any{"$in": []int{1, 2, 3}})
|
||||||
|
if result != `t.test IN (1, 2, 3)` {
|
||||||
|
t.Errorf("Expected t.test IN (1, 2, 3), got %s", result)
|
||||||
|
}
|
||||||
|
if resultMap != result {
|
||||||
|
t.Log(resultMap)
|
||||||
|
t.Errorf("Expected resultMap to be equal to result")
|
||||||
|
}
|
||||||
|
}
|
29
pkg/filters/utils.go
Normal file
29
pkg/filters/utils.go
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
package filters
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
func FromJson[T Filter](data interface{}) *T {
|
||||||
|
var t T
|
||||||
|
str, ok := data.(string)
|
||||||
|
if ok {
|
||||||
|
err := json.Unmarshal([]byte(str), &t)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m, ok := data.(map[string]interface{})
|
||||||
|
if ok {
|
||||||
|
s, err := json.Marshal(m)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
e := json.Unmarshal(s, &t)
|
||||||
|
if e != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &t
|
||||||
|
}
|
9
pkg/utils/common.go
Normal file
9
pkg/utils/common.go
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
package utils
|
||||||
|
|
||||||
|
func Map[T, U any](ts []T, f func(T) U) []U {
|
||||||
|
us := make([]U, len(ts))
|
||||||
|
for i := range ts {
|
||||||
|
us[i] = f(ts[i])
|
||||||
|
}
|
||||||
|
return us
|
||||||
|
}
|
5
pkg/utils/go.mod
Normal file
5
pkg/utils/go.mod
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
module l12.xyz/dal/utils
|
||||||
|
|
||||||
|
go 1.22.6
|
||||||
|
|
||||||
|
require github.com/pkg/errors v0.9.1
|
2
pkg/utils/go.sum
Normal file
2
pkg/utils/go.sum
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||||
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
236
pkg/utils/sql.go
Normal file
236
pkg/utils/sql.go
Normal file
|
@ -0,0 +1,236 @@
|
||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
func EscapeSQL(sql string, args ...interface{}) ([]byte, error) {
|
||||||
|
buf := make([]byte, 0, len(sql))
|
||||||
|
argPos := 0
|
||||||
|
for i := 0; i < len(sql); i++ {
|
||||||
|
q := strings.IndexByte(sql[i:], '%')
|
||||||
|
if q == -1 {
|
||||||
|
buf = append(buf, sql[i:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
buf = append(buf, sql[i:i+q]...)
|
||||||
|
i += q
|
||||||
|
|
||||||
|
ch := byte(0)
|
||||||
|
if i+1 < len(sql) {
|
||||||
|
ch = sql[i+1] // get the specifier
|
||||||
|
}
|
||||||
|
switch ch {
|
||||||
|
case 'n':
|
||||||
|
if argPos >= len(args) {
|
||||||
|
return nil, errors.Errorf("missing arguments, need %d-th arg, but only got %d args", argPos+1, len(args))
|
||||||
|
}
|
||||||
|
arg := args[argPos]
|
||||||
|
argPos++
|
||||||
|
|
||||||
|
v, ok := arg.(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.Errorf("expect a string identifier, got %v", arg)
|
||||||
|
}
|
||||||
|
buf = append(buf, '`')
|
||||||
|
buf = append(buf, strings.ReplaceAll(v, "`", "``")...)
|
||||||
|
buf = append(buf, '`')
|
||||||
|
i++ // skip specifier
|
||||||
|
case '?':
|
||||||
|
if argPos >= len(args) {
|
||||||
|
return nil, errors.Errorf("missing arguments, need %d-th arg, but only got %d args", argPos+1, len(args))
|
||||||
|
}
|
||||||
|
arg := args[argPos]
|
||||||
|
argPos++
|
||||||
|
|
||||||
|
if arg == nil {
|
||||||
|
buf = append(buf, "NULL"...)
|
||||||
|
} else {
|
||||||
|
switch v := arg.(type) {
|
||||||
|
case int:
|
||||||
|
buf = strconv.AppendInt(buf, int64(v), 10)
|
||||||
|
case int8:
|
||||||
|
buf = strconv.AppendInt(buf, int64(v), 10)
|
||||||
|
case int16:
|
||||||
|
buf = strconv.AppendInt(buf, int64(v), 10)
|
||||||
|
case int32:
|
||||||
|
buf = strconv.AppendInt(buf, int64(v), 10)
|
||||||
|
case int64:
|
||||||
|
buf = strconv.AppendInt(buf, v, 10)
|
||||||
|
case uint:
|
||||||
|
buf = strconv.AppendUint(buf, uint64(v), 10)
|
||||||
|
case uint8:
|
||||||
|
buf = strconv.AppendUint(buf, uint64(v), 10)
|
||||||
|
case uint16:
|
||||||
|
buf = strconv.AppendUint(buf, uint64(v), 10)
|
||||||
|
case uint32:
|
||||||
|
buf = strconv.AppendUint(buf, uint64(v), 10)
|
||||||
|
case uint64:
|
||||||
|
buf = strconv.AppendUint(buf, v, 10)
|
||||||
|
case float32:
|
||||||
|
buf = strconv.AppendFloat(buf, float64(v), 'g', -1, 32)
|
||||||
|
case float64:
|
||||||
|
buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
|
||||||
|
case bool:
|
||||||
|
buf = appendSQLArgBool(buf, v)
|
||||||
|
case time.Time:
|
||||||
|
if v.IsZero() {
|
||||||
|
buf = append(buf, "'0000-00-00'"...)
|
||||||
|
} else {
|
||||||
|
buf = append(buf, '\'')
|
||||||
|
buf = v.AppendFormat(buf, "2006-01-02 15:04:05.999999")
|
||||||
|
buf = append(buf, '\'')
|
||||||
|
}
|
||||||
|
case json.RawMessage:
|
||||||
|
buf = append(buf, '\'')
|
||||||
|
buf = escapeBytesBackslash(buf, v)
|
||||||
|
buf = append(buf, '\'')
|
||||||
|
case []byte:
|
||||||
|
if v == nil {
|
||||||
|
buf = append(buf, "NULL"...)
|
||||||
|
} else {
|
||||||
|
buf = append(buf, "_binary'"...)
|
||||||
|
buf = escapeBytesBackslash(buf, v)
|
||||||
|
buf = append(buf, '\'')
|
||||||
|
}
|
||||||
|
case string:
|
||||||
|
buf = appendSQLArgString(buf, v)
|
||||||
|
case []string:
|
||||||
|
for i, k := range v {
|
||||||
|
if i > 0 {
|
||||||
|
buf = append(buf, ',')
|
||||||
|
}
|
||||||
|
buf = append(buf, '\'')
|
||||||
|
buf = escapeStringBackslash(buf, k)
|
||||||
|
buf = append(buf, '\'')
|
||||||
|
}
|
||||||
|
case []float32:
|
||||||
|
for i, k := range v {
|
||||||
|
if i > 0 {
|
||||||
|
buf = append(buf, ',')
|
||||||
|
}
|
||||||
|
buf = strconv.AppendFloat(buf, float64(k), 'g', -1, 32)
|
||||||
|
}
|
||||||
|
case []float64:
|
||||||
|
for i, k := range v {
|
||||||
|
if i > 0 {
|
||||||
|
buf = append(buf, ',')
|
||||||
|
}
|
||||||
|
buf = strconv.AppendFloat(buf, k, 'g', -1, 64)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// slow path based on reflection
|
||||||
|
reflectTp := reflect.TypeOf(arg)
|
||||||
|
kind := reflectTp.Kind()
|
||||||
|
switch kind {
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
|
buf = strconv.AppendInt(buf, reflect.ValueOf(arg).Int(), 10)
|
||||||
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
|
buf = strconv.AppendUint(buf, reflect.ValueOf(arg).Uint(), 10)
|
||||||
|
case reflect.Float32:
|
||||||
|
buf = strconv.AppendFloat(buf, reflect.ValueOf(arg).Float(), 'g', -1, 32)
|
||||||
|
case reflect.Float64:
|
||||||
|
buf = strconv.AppendFloat(buf, reflect.ValueOf(arg).Float(), 'g', -1, 64)
|
||||||
|
case reflect.Bool:
|
||||||
|
buf = appendSQLArgBool(buf, reflect.ValueOf(arg).Bool())
|
||||||
|
case reflect.String:
|
||||||
|
buf = appendSQLArgString(buf, reflect.ValueOf(arg).String())
|
||||||
|
default:
|
||||||
|
return nil, errors.Errorf("unsupported %d-th argument: %v", argPos, arg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
i++ // skip specifier
|
||||||
|
case '%':
|
||||||
|
buf = append(buf, '%')
|
||||||
|
i++ // skip specifier
|
||||||
|
default:
|
||||||
|
buf = append(buf, '%')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return buf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func EscapeString(s string) string {
|
||||||
|
buf := make([]byte, 0, len(s))
|
||||||
|
return string(escapeStringBackslash(buf, s))
|
||||||
|
}
|
||||||
|
|
||||||
|
func appendSQLArgBool(buf []byte, v bool) []byte {
|
||||||
|
if v {
|
||||||
|
return append(buf, '1')
|
||||||
|
}
|
||||||
|
return append(buf, '0')
|
||||||
|
}
|
||||||
|
|
||||||
|
func appendSQLArgString(buf []byte, s string) []byte {
|
||||||
|
buf = append(buf, '\'')
|
||||||
|
buf = escapeStringBackslash(buf, s)
|
||||||
|
buf = append(buf, '\'')
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func escapeStringBackslash(buf []byte, v string) []byte {
|
||||||
|
return escapeBytesBackslash(buf, unsafe.Slice(unsafe.StringData(v), len(v)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// escapeBytesBackslash will escape []byte into the buffer, with backslash.
|
||||||
|
func escapeBytesBackslash(buf []byte, v []byte) []byte {
|
||||||
|
pos := len(buf)
|
||||||
|
buf = reserveBuffer(buf, len(v)*2)
|
||||||
|
|
||||||
|
for _, c := range v {
|
||||||
|
switch c {
|
||||||
|
case '\x00':
|
||||||
|
buf[pos] = '\\'
|
||||||
|
buf[pos+1] = '0'
|
||||||
|
pos += 2
|
||||||
|
case '\n':
|
||||||
|
buf[pos] = '\\'
|
||||||
|
buf[pos+1] = 'n'
|
||||||
|
pos += 2
|
||||||
|
case '\r':
|
||||||
|
buf[pos] = '\\'
|
||||||
|
buf[pos+1] = 'r'
|
||||||
|
pos += 2
|
||||||
|
case '\x1a':
|
||||||
|
buf[pos] = '\\'
|
||||||
|
buf[pos+1] = 'Z'
|
||||||
|
pos += 2
|
||||||
|
case '\'':
|
||||||
|
buf[pos] = '\\'
|
||||||
|
buf[pos+1] = '\''
|
||||||
|
pos += 2
|
||||||
|
case '"':
|
||||||
|
buf[pos] = '\\'
|
||||||
|
buf[pos+1] = '"'
|
||||||
|
pos += 2
|
||||||
|
case '\\':
|
||||||
|
buf[pos] = '\\'
|
||||||
|
buf[pos+1] = '\\'
|
||||||
|
pos += 2
|
||||||
|
default:
|
||||||
|
buf[pos] = c
|
||||||
|
pos++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf[:pos]
|
||||||
|
}
|
||||||
|
|
||||||
|
func reserveBuffer(buf []byte, appendSize int) []byte {
|
||||||
|
newSize := len(buf) + appendSize
|
||||||
|
if cap(buf) < newSize {
|
||||||
|
newBuf := make([]byte, len(buf)*2+appendSize)
|
||||||
|
copy(newBuf, buf)
|
||||||
|
buf = newBuf
|
||||||
|
}
|
||||||
|
return buf[:newSize]
|
||||||
|
}
|
Loading…
Reference in a new issue