[wip] dal golang
Signed-off-by: Anton Nesterov <anton@demiurg.io>
This commit is contained in:
commit
d28d976b8e
37
pkg/dal/builder.go
Normal file
37
pkg/dal/builder.go
Normal file
|
@ -0,0 +1,37 @@
|
|||
package dal
|
||||
|
||||
type SQLParts struct {
|
||||
operation string
|
||||
selectExp string
|
||||
fromExp string
|
||||
fiterExp string
|
||||
joinExp []string
|
||||
groupExp string
|
||||
orderExp string
|
||||
limitExp string
|
||||
updateExp string
|
||||
upsertExp string
|
||||
}
|
||||
|
||||
type Builder struct {
|
||||
parts SQLParts
|
||||
}
|
||||
|
||||
func New() *Builder {
|
||||
return &Builder{}
|
||||
}
|
||||
|
||||
func (b *Builder) In(selectExp string) *Builder {
|
||||
b.parts.selectExp = selectExp
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *Builder) Find(fromExp string) *Builder {
|
||||
b.parts.fromExp = fromExp
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *Builder) Join(fiterExp string) *Builder {
|
||||
b.parts.fiterExp = fiterExp
|
||||
return b
|
||||
}
|
7
pkg/dal/convert.go
Normal file
7
pkg/dal/convert.go
Normal file
|
@ -0,0 +1,7 @@
|
|||
package dal
|
||||
|
||||
type FindObject map[string]interface{}
|
||||
|
||||
func CovertFind(findobj FindObject) string {
|
||||
return ""
|
||||
}
|
19
pkg/filters/And.go
Normal file
19
pkg/filters/And.go
Normal file
|
@ -0,0 +1,19 @@
|
|||
package filters
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type And struct {
|
||||
And []interface{} `json:"$and"`
|
||||
}
|
||||
|
||||
func (a And) ToSQLPart(ctx Context) string {
|
||||
|
||||
fmt.Println(ctx, a)
|
||||
return ""
|
||||
}
|
||||
|
||||
func (a And) FromJSON(data interface{}) Filter {
|
||||
return FromJson[And](data)
|
||||
}
|
25
pkg/filters/Between.go
Normal file
25
pkg/filters/Between.go
Normal file
|
@ -0,0 +1,25 @@
|
|||
package filters
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"l12.xyz/dal/utils"
|
||||
)
|
||||
|
||||
type Between struct {
|
||||
Between []interface{} `json:"$between"`
|
||||
}
|
||||
|
||||
func (f Between) FromJSON(data interface{}) Filter {
|
||||
return FromJson[Between](data)
|
||||
}
|
||||
|
||||
func (f Between) ToSQLPart(ctx Context) string {
|
||||
if f.Between == nil {
|
||||
return ""
|
||||
}
|
||||
name := ctx.GetFieldName()
|
||||
values := utils.Map(f.Between, ctx.NormalizeValue)
|
||||
condition := fmt.Sprintf("%v AND %v", values[0], values[1])
|
||||
return fmt.Sprintf("%s BETWEEN %v", name, condition)
|
||||
}
|
25
pkg/filters/Eq.go
Normal file
25
pkg/filters/Eq.go
Normal file
|
@ -0,0 +1,25 @@
|
|||
package filters
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type Eq struct {
|
||||
Eq interface{} `json:"$eq"`
|
||||
}
|
||||
|
||||
func (f Eq) FromJSON(data interface{}) Filter {
|
||||
return FromJson[Eq](data)
|
||||
}
|
||||
|
||||
func (f Eq) ToSQLPart(ctx Context) string {
|
||||
if f.Eq == nil {
|
||||
return ""
|
||||
}
|
||||
name := ctx.GetFieldName()
|
||||
value := ctx.NormalizeValue(f.Eq)
|
||||
if value == "NULL" {
|
||||
return fmt.Sprintf("%s IS NULL", name)
|
||||
}
|
||||
return fmt.Sprintf("%s = %v", name, value)
|
||||
}
|
20
pkg/filters/Glob.go
Normal file
20
pkg/filters/Glob.go
Normal file
|
@ -0,0 +1,20 @@
|
|||
package filters
|
||||
|
||||
import "fmt"
|
||||
|
||||
type Glob struct {
|
||||
Glob interface{} `json:"$glob"`
|
||||
}
|
||||
|
||||
func (f Glob) FromJSON(data interface{}) Filter {
|
||||
return FromJson[Glob](data)
|
||||
}
|
||||
|
||||
func (f Glob) ToSQLPart(ctx Context) string {
|
||||
if f.Glob == nil {
|
||||
return ""
|
||||
}
|
||||
name := ctx.GetFieldName()
|
||||
value := ctx.NormalizeValue(f.Glob)
|
||||
return fmt.Sprintf("%s GLOB %v", name, value)
|
||||
}
|
22
pkg/filters/Gt.go
Normal file
22
pkg/filters/Gt.go
Normal file
|
@ -0,0 +1,22 @@
|
|||
package filters
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type Gt struct {
|
||||
Gt interface{} `json:"$gt"`
|
||||
}
|
||||
|
||||
func (f Gt) FromJSON(data interface{}) Filter {
|
||||
return FromJson[Gt](data)
|
||||
}
|
||||
|
||||
func (f Gt) ToSQLPart(ctx Context) string {
|
||||
if f.Gt == nil {
|
||||
return ""
|
||||
}
|
||||
name := ctx.GetFieldName()
|
||||
value := ctx.NormalizeValue(f.Gt)
|
||||
return fmt.Sprintf("%s > %v", name, value)
|
||||
}
|
22
pkg/filters/Gte.go
Normal file
22
pkg/filters/Gte.go
Normal file
|
@ -0,0 +1,22 @@
|
|||
package filters
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type Gte struct {
|
||||
Gte interface{} `json:"$gte"`
|
||||
}
|
||||
|
||||
func (f Gte) FromJSON(data interface{}) Filter {
|
||||
return FromJson[Gte](data)
|
||||
}
|
||||
|
||||
func (f Gte) ToSQLPart(ctx Context) string {
|
||||
if f.Gte == nil {
|
||||
return ""
|
||||
}
|
||||
name := ctx.GetFieldName()
|
||||
value := ctx.NormalizeValue(f.Gte)
|
||||
return fmt.Sprintf("%s >= %v", name, value)
|
||||
}
|
31
pkg/filters/In.go
Normal file
31
pkg/filters/In.go
Normal file
|
@ -0,0 +1,31 @@
|
|||
package filters
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"l12.xyz/dal/utils"
|
||||
)
|
||||
|
||||
type In struct {
|
||||
In []interface{} `json:"$in"`
|
||||
}
|
||||
|
||||
func (f In) FromJSON(data interface{}) Filter {
|
||||
return FromJson[In](data)
|
||||
}
|
||||
|
||||
func (f In) ToSQLPart(ctx Context) string {
|
||||
if f.In == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
name := ctx.GetFieldName()
|
||||
values := utils.Map(f.In, ctx.NormalizeValue)
|
||||
data := make([]string, len(values))
|
||||
for i, v := range values {
|
||||
data[i] = fmt.Sprintf("%v", v)
|
||||
}
|
||||
value := strings.Join(data, ", ")
|
||||
return fmt.Sprintf("%s IN (%v)", name, value)
|
||||
}
|
22
pkg/filters/Lt.go
Normal file
22
pkg/filters/Lt.go
Normal file
|
@ -0,0 +1,22 @@
|
|||
package filters
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type Lt struct {
|
||||
Lt interface{} `json:"$lt"`
|
||||
}
|
||||
|
||||
func (f Lt) FromJSON(data interface{}) Filter {
|
||||
return FromJson[Lt](data)
|
||||
}
|
||||
|
||||
func (f Lt) ToSQLPart(ctx Context) string {
|
||||
if f.Lt == nil {
|
||||
return ""
|
||||
}
|
||||
name := ctx.GetFieldName()
|
||||
value := ctx.NormalizeValue(f.Lt)
|
||||
return fmt.Sprintf("%s < %v", name, value)
|
||||
}
|
22
pkg/filters/Lte.go
Normal file
22
pkg/filters/Lte.go
Normal file
|
@ -0,0 +1,22 @@
|
|||
package filters
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type Lte struct {
|
||||
Lte interface{} `json:"$lte"`
|
||||
}
|
||||
|
||||
func (f Lte) FromJSON(data interface{}) Filter {
|
||||
return FromJson[Lte](data)
|
||||
}
|
||||
|
||||
func (f Lte) ToSQLPart(ctx Context) string {
|
||||
if f.Lte == nil {
|
||||
return ""
|
||||
}
|
||||
name := ctx.GetFieldName()
|
||||
value := ctx.NormalizeValue(f.Lte)
|
||||
return fmt.Sprintf("%s <= %v", name, value)
|
||||
}
|
23
pkg/filters/Ne.go
Normal file
23
pkg/filters/Ne.go
Normal file
|
@ -0,0 +1,23 @@
|
|||
package filters
|
||||
|
||||
import "fmt"
|
||||
|
||||
type Ne struct {
|
||||
Ne interface{} `json:"$ne"`
|
||||
}
|
||||
|
||||
func (f Ne) FromJSON(data interface{}) Filter {
|
||||
return FromJson[Ne](data)
|
||||
}
|
||||
|
||||
func (f Ne) ToSQLPart(ctx Context) string {
|
||||
if f.Ne == nil {
|
||||
return ""
|
||||
}
|
||||
name := ctx.GetFieldName()
|
||||
value := ctx.NormalizeValue(f.Ne)
|
||||
if value == "NULL" {
|
||||
return fmt.Sprintf("%s IS NOT NULL", name)
|
||||
}
|
||||
return fmt.Sprintf("%s != %v", name, value)
|
||||
}
|
25
pkg/filters/NotBetween.go
Normal file
25
pkg/filters/NotBetween.go
Normal file
|
@ -0,0 +1,25 @@
|
|||
package filters
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"l12.xyz/dal/utils"
|
||||
)
|
||||
|
||||
type NotBetween struct {
|
||||
NotBetween []interface{} `json:"$nbetween"`
|
||||
}
|
||||
|
||||
func (f NotBetween) FromJSON(data interface{}) Filter {
|
||||
return FromJson[NotBetween](data)
|
||||
}
|
||||
|
||||
func (f NotBetween) ToSQLPart(ctx Context) string {
|
||||
if f.NotBetween == nil {
|
||||
return ""
|
||||
}
|
||||
name := ctx.GetFieldName()
|
||||
values := utils.Map(f.NotBetween, ctx.NormalizeValue)
|
||||
condition := fmt.Sprintf("%v AND %v", values[0], values[1])
|
||||
return fmt.Sprintf("%s NOT BETWEEN %v", name, condition)
|
||||
}
|
67
pkg/filters/context.go
Normal file
67
pkg/filters/context.go
Normal file
|
@ -0,0 +1,67 @@
|
|||
package filters
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
utils "l12.xyz/dal/utils"
|
||||
)
|
||||
|
||||
type SQLiteContext struct {
|
||||
TableAlias string
|
||||
FieldName string
|
||||
}
|
||||
|
||||
func (c SQLiteContext) GetFieldName() string {
|
||||
if strings.Contains(c.FieldName, ".") {
|
||||
return c.FieldName
|
||||
}
|
||||
if c.TableAlias != "" {
|
||||
return c.TableAlias + "." + c.FieldName
|
||||
}
|
||||
return c.FieldName
|
||||
}
|
||||
|
||||
func (c SQLiteContext) NormalizeValue(value interface{}) interface{} {
|
||||
str, ok := value.(string)
|
||||
if isSQLFunction(str) {
|
||||
return str
|
||||
}
|
||||
if strings.Contains(str, ".") {
|
||||
_, err := strconv.ParseFloat(str, 64)
|
||||
if err != nil {
|
||||
return value
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
return value
|
||||
}
|
||||
val, err := utils.EscapeSQL(str)
|
||||
if err != nil {
|
||||
return str
|
||||
}
|
||||
return "'" + escapeSingleQuote(string(val)) + "'"
|
||||
}
|
||||
|
||||
func isSQLFunction(str string) bool {
|
||||
stopChars := []string{" ", "_", "-", ".", "("}
|
||||
isUpper := false
|
||||
for _, char := range str {
|
||||
if slices.Contains(stopChars, string(char)) {
|
||||
break
|
||||
}
|
||||
if unicode.IsUpper(char) {
|
||||
isUpper = true
|
||||
} else {
|
||||
isUpper = false
|
||||
break
|
||||
}
|
||||
}
|
||||
return isUpper
|
||||
}
|
||||
|
||||
func escapeSingleQuote(str string) string {
|
||||
return strings.ReplaceAll(str, "'", "''")
|
||||
}
|
9
pkg/filters/go.mod
Normal file
9
pkg/filters/go.mod
Normal file
|
@ -0,0 +1,9 @@
|
|||
module l12.xyz/dal/filters
|
||||
|
||||
go 1.22.6
|
||||
|
||||
require github.com/pkg/errors v0.9.1 // indirect
|
||||
|
||||
require l12.xyz/dal/utils v0.0.0
|
||||
|
||||
replace l12.xyz/dal/utils v0.0.0 => ../utils
|
2
pkg/filters/go.sum
Normal file
2
pkg/filters/go.sum
Normal file
|
@ -0,0 +1,2 @@
|
|||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
24
pkg/filters/registry.go
Normal file
24
pkg/filters/registry.go
Normal file
|
@ -0,0 +1,24 @@
|
|||
package filters
|
||||
|
||||
var FilterRegistry = map[string]Filter{
|
||||
"Eq": &Eq{},
|
||||
"Ne": &Ne{},
|
||||
"Gt": &Gt{},
|
||||
"Gte": &Gte{},
|
||||
"Lt": &Lt{},
|
||||
"Lte": &Lte{},
|
||||
"In": &In{},
|
||||
"Between": &Between{},
|
||||
"NotBetween": &NotBetween{},
|
||||
"Glob": &Glob{},
|
||||
}
|
||||
|
||||
func Convert(ctx Context, json interface{}) string {
|
||||
for _, t := range FilterRegistry {
|
||||
value := t.FromJSON(json).ToSQLPart(ctx)
|
||||
if value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
11
pkg/filters/types.go
Normal file
11
pkg/filters/types.go
Normal file
|
@ -0,0 +1,11 @@
|
|||
package filters
|
||||
|
||||
type Context interface {
|
||||
GetFieldName() string
|
||||
NormalizeValue(interface{}) interface{}
|
||||
}
|
||||
|
||||
type Filter interface {
|
||||
ToSQLPart(ctx Context) string
|
||||
FromJSON(interface{}) Filter
|
||||
}
|
83
pkg/filters/unit_test.go
Normal file
83
pkg/filters/unit_test.go
Normal file
|
@ -0,0 +1,83 @@
|
|||
package filters
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEq(t *testing.T) {
|
||||
ctx := SQLiteContext{
|
||||
TableAlias: "t",
|
||||
FieldName: "test",
|
||||
}
|
||||
result := Convert(ctx, `{"$eq": "NULL"}`)
|
||||
resultMap := Convert(ctx, map[string]any{"$eq": "NULL"})
|
||||
if result != `t.test IS NULL` {
|
||||
t.Errorf("Expected t.test IS NULL, got %s", result)
|
||||
}
|
||||
if resultMap != result {
|
||||
t.Log(resultMap)
|
||||
t.Errorf("Expected resultMap to be equal to result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNe(t *testing.T) {
|
||||
ctx := SQLiteContext{
|
||||
FieldName: "test",
|
||||
}
|
||||
result := Convert(ctx, `{"$ne": "1"}`)
|
||||
resultMap := Convert(ctx, map[string]any{"$ne": "1"})
|
||||
if result != `test != '1'` {
|
||||
t.Errorf("Expected test != '1', got %s", result)
|
||||
}
|
||||
if resultMap != result {
|
||||
t.Log(resultMap)
|
||||
t.Errorf("Expected resultMap to be equal to result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBetween(t *testing.T) {
|
||||
ctx := SQLiteContext{
|
||||
FieldName: "test",
|
||||
}
|
||||
result := Convert(ctx, `{"$between": ["1", "5"]}`)
|
||||
resultMap := Convert(ctx, map[string]any{"$between": []string{"1", "5"}})
|
||||
if result != `test BETWEEN '1' AND '5'` {
|
||||
t.Errorf("Expected test BETWEEN '1' AND '5', got %s", result)
|
||||
}
|
||||
if resultMap != result {
|
||||
t.Log(resultMap)
|
||||
t.Errorf("Expected resultMap to be equal to result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGlob(t *testing.T) {
|
||||
ctx := SQLiteContext{
|
||||
TableAlias: "t",
|
||||
FieldName: "test",
|
||||
}
|
||||
result := Convert(ctx, `{"$glob": "*son"}`)
|
||||
resultMap := Convert(ctx, map[string]any{"$glob": "*son"})
|
||||
if result != `t.test GLOB '*son'` {
|
||||
t.Errorf("Expected t.test GLOB '*son', got %s", result)
|
||||
}
|
||||
if resultMap != result {
|
||||
t.Log(resultMap)
|
||||
t.Errorf("Expected resultMap to be equal to result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIn(t *testing.T) {
|
||||
ctx := SQLiteContext{
|
||||
TableAlias: "t",
|
||||
FieldName: "test",
|
||||
}
|
||||
result := Convert(ctx, `{"$in": [1, 2, 3]}`)
|
||||
resultMap := Convert(ctx, map[string]any{"$in": []int{1, 2, 3}})
|
||||
if result != `t.test IN (1, 2, 3)` {
|
||||
t.Errorf("Expected t.test IN (1, 2, 3), got %s", result)
|
||||
}
|
||||
if resultMap != result {
|
||||
t.Log(resultMap)
|
||||
t.Errorf("Expected resultMap to be equal to result")
|
||||
}
|
||||
}
|
29
pkg/filters/utils.go
Normal file
29
pkg/filters/utils.go
Normal file
|
@ -0,0 +1,29 @@
|
|||
package filters
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
func FromJson[T Filter](data interface{}) *T {
|
||||
var t T
|
||||
str, ok := data.(string)
|
||||
if ok {
|
||||
err := json.Unmarshal([]byte(str), &t)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
m, ok := data.(map[string]interface{})
|
||||
if ok {
|
||||
s, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
e := json.Unmarshal(s, &t)
|
||||
if e != nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return &t
|
||||
}
|
9
pkg/utils/common.go
Normal file
9
pkg/utils/common.go
Normal file
|
@ -0,0 +1,9 @@
|
|||
package utils
|
||||
|
||||
func Map[T, U any](ts []T, f func(T) U) []U {
|
||||
us := make([]U, len(ts))
|
||||
for i := range ts {
|
||||
us[i] = f(ts[i])
|
||||
}
|
||||
return us
|
||||
}
|
5
pkg/utils/go.mod
Normal file
5
pkg/utils/go.mod
Normal file
|
@ -0,0 +1,5 @@
|
|||
module l12.xyz/dal/utils
|
||||
|
||||
go 1.22.6
|
||||
|
||||
require github.com/pkg/errors v0.9.1
|
2
pkg/utils/go.sum
Normal file
2
pkg/utils/go.sum
Normal file
|
@ -0,0 +1,2 @@
|
|||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
236
pkg/utils/sql.go
Normal file
236
pkg/utils/sql.go
Normal file
|
@ -0,0 +1,236 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func EscapeSQL(sql string, args ...interface{}) ([]byte, error) {
|
||||
buf := make([]byte, 0, len(sql))
|
||||
argPos := 0
|
||||
for i := 0; i < len(sql); i++ {
|
||||
q := strings.IndexByte(sql[i:], '%')
|
||||
if q == -1 {
|
||||
buf = append(buf, sql[i:]...)
|
||||
break
|
||||
}
|
||||
buf = append(buf, sql[i:i+q]...)
|
||||
i += q
|
||||
|
||||
ch := byte(0)
|
||||
if i+1 < len(sql) {
|
||||
ch = sql[i+1] // get the specifier
|
||||
}
|
||||
switch ch {
|
||||
case 'n':
|
||||
if argPos >= len(args) {
|
||||
return nil, errors.Errorf("missing arguments, need %d-th arg, but only got %d args", argPos+1, len(args))
|
||||
}
|
||||
arg := args[argPos]
|
||||
argPos++
|
||||
|
||||
v, ok := arg.(string)
|
||||
if !ok {
|
||||
return nil, errors.Errorf("expect a string identifier, got %v", arg)
|
||||
}
|
||||
buf = append(buf, '`')
|
||||
buf = append(buf, strings.ReplaceAll(v, "`", "``")...)
|
||||
buf = append(buf, '`')
|
||||
i++ // skip specifier
|
||||
case '?':
|
||||
if argPos >= len(args) {
|
||||
return nil, errors.Errorf("missing arguments, need %d-th arg, but only got %d args", argPos+1, len(args))
|
||||
}
|
||||
arg := args[argPos]
|
||||
argPos++
|
||||
|
||||
if arg == nil {
|
||||
buf = append(buf, "NULL"...)
|
||||
} else {
|
||||
switch v := arg.(type) {
|
||||
case int:
|
||||
buf = strconv.AppendInt(buf, int64(v), 10)
|
||||
case int8:
|
||||
buf = strconv.AppendInt(buf, int64(v), 10)
|
||||
case int16:
|
||||
buf = strconv.AppendInt(buf, int64(v), 10)
|
||||
case int32:
|
||||
buf = strconv.AppendInt(buf, int64(v), 10)
|
||||
case int64:
|
||||
buf = strconv.AppendInt(buf, v, 10)
|
||||
case uint:
|
||||
buf = strconv.AppendUint(buf, uint64(v), 10)
|
||||
case uint8:
|
||||
buf = strconv.AppendUint(buf, uint64(v), 10)
|
||||
case uint16:
|
||||
buf = strconv.AppendUint(buf, uint64(v), 10)
|
||||
case uint32:
|
||||
buf = strconv.AppendUint(buf, uint64(v), 10)
|
||||
case uint64:
|
||||
buf = strconv.AppendUint(buf, v, 10)
|
||||
case float32:
|
||||
buf = strconv.AppendFloat(buf, float64(v), 'g', -1, 32)
|
||||
case float64:
|
||||
buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
|
||||
case bool:
|
||||
buf = appendSQLArgBool(buf, v)
|
||||
case time.Time:
|
||||
if v.IsZero() {
|
||||
buf = append(buf, "'0000-00-00'"...)
|
||||
} else {
|
||||
buf = append(buf, '\'')
|
||||
buf = v.AppendFormat(buf, "2006-01-02 15:04:05.999999")
|
||||
buf = append(buf, '\'')
|
||||
}
|
||||
case json.RawMessage:
|
||||
buf = append(buf, '\'')
|
||||
buf = escapeBytesBackslash(buf, v)
|
||||
buf = append(buf, '\'')
|
||||
case []byte:
|
||||
if v == nil {
|
||||
buf = append(buf, "NULL"...)
|
||||
} else {
|
||||
buf = append(buf, "_binary'"...)
|
||||
buf = escapeBytesBackslash(buf, v)
|
||||
buf = append(buf, '\'')
|
||||
}
|
||||
case string:
|
||||
buf = appendSQLArgString(buf, v)
|
||||
case []string:
|
||||
for i, k := range v {
|
||||
if i > 0 {
|
||||
buf = append(buf, ',')
|
||||
}
|
||||
buf = append(buf, '\'')
|
||||
buf = escapeStringBackslash(buf, k)
|
||||
buf = append(buf, '\'')
|
||||
}
|
||||
case []float32:
|
||||
for i, k := range v {
|
||||
if i > 0 {
|
||||
buf = append(buf, ',')
|
||||
}
|
||||
buf = strconv.AppendFloat(buf, float64(k), 'g', -1, 32)
|
||||
}
|
||||
case []float64:
|
||||
for i, k := range v {
|
||||
if i > 0 {
|
||||
buf = append(buf, ',')
|
||||
}
|
||||
buf = strconv.AppendFloat(buf, k, 'g', -1, 64)
|
||||
}
|
||||
default:
|
||||
// slow path based on reflection
|
||||
reflectTp := reflect.TypeOf(arg)
|
||||
kind := reflectTp.Kind()
|
||||
switch kind {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
buf = strconv.AppendInt(buf, reflect.ValueOf(arg).Int(), 10)
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
buf = strconv.AppendUint(buf, reflect.ValueOf(arg).Uint(), 10)
|
||||
case reflect.Float32:
|
||||
buf = strconv.AppendFloat(buf, reflect.ValueOf(arg).Float(), 'g', -1, 32)
|
||||
case reflect.Float64:
|
||||
buf = strconv.AppendFloat(buf, reflect.ValueOf(arg).Float(), 'g', -1, 64)
|
||||
case reflect.Bool:
|
||||
buf = appendSQLArgBool(buf, reflect.ValueOf(arg).Bool())
|
||||
case reflect.String:
|
||||
buf = appendSQLArgString(buf, reflect.ValueOf(arg).String())
|
||||
default:
|
||||
return nil, errors.Errorf("unsupported %d-th argument: %v", argPos, arg)
|
||||
}
|
||||
}
|
||||
}
|
||||
i++ // skip specifier
|
||||
case '%':
|
||||
buf = append(buf, '%')
|
||||
i++ // skip specifier
|
||||
default:
|
||||
buf = append(buf, '%')
|
||||
}
|
||||
}
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func EscapeString(s string) string {
|
||||
buf := make([]byte, 0, len(s))
|
||||
return string(escapeStringBackslash(buf, s))
|
||||
}
|
||||
|
||||
func appendSQLArgBool(buf []byte, v bool) []byte {
|
||||
if v {
|
||||
return append(buf, '1')
|
||||
}
|
||||
return append(buf, '0')
|
||||
}
|
||||
|
||||
func appendSQLArgString(buf []byte, s string) []byte {
|
||||
buf = append(buf, '\'')
|
||||
buf = escapeStringBackslash(buf, s)
|
||||
buf = append(buf, '\'')
|
||||
return buf
|
||||
}
|
||||
|
||||
func escapeStringBackslash(buf []byte, v string) []byte {
|
||||
return escapeBytesBackslash(buf, unsafe.Slice(unsafe.StringData(v), len(v)))
|
||||
}
|
||||
|
||||
// escapeBytesBackslash will escape []byte into the buffer, with backslash.
|
||||
func escapeBytesBackslash(buf []byte, v []byte) []byte {
|
||||
pos := len(buf)
|
||||
buf = reserveBuffer(buf, len(v)*2)
|
||||
|
||||
for _, c := range v {
|
||||
switch c {
|
||||
case '\x00':
|
||||
buf[pos] = '\\'
|
||||
buf[pos+1] = '0'
|
||||
pos += 2
|
||||
case '\n':
|
||||
buf[pos] = '\\'
|
||||
buf[pos+1] = 'n'
|
||||
pos += 2
|
||||
case '\r':
|
||||
buf[pos] = '\\'
|
||||
buf[pos+1] = 'r'
|
||||
pos += 2
|
||||
case '\x1a':
|
||||
buf[pos] = '\\'
|
||||
buf[pos+1] = 'Z'
|
||||
pos += 2
|
||||
case '\'':
|
||||
buf[pos] = '\\'
|
||||
buf[pos+1] = '\''
|
||||
pos += 2
|
||||
case '"':
|
||||
buf[pos] = '\\'
|
||||
buf[pos+1] = '"'
|
||||
pos += 2
|
||||
case '\\':
|
||||
buf[pos] = '\\'
|
||||
buf[pos+1] = '\\'
|
||||
pos += 2
|
||||
default:
|
||||
buf[pos] = c
|
||||
pos++
|
||||
}
|
||||
}
|
||||
|
||||
return buf[:pos]
|
||||
}
|
||||
|
||||
func reserveBuffer(buf []byte, appendSize int) []byte {
|
||||
newSize := len(buf) + appendSize
|
||||
if cap(buf) < newSize {
|
||||
newBuf := make([]byte, len(buf)*2+appendSize)
|
||||
copy(newBuf, buf)
|
||||
buf = newBuf
|
||||
}
|
||||
return buf[:newSize]
|
||||
}
|
Loading…
Reference in a new issue