[wip] builder

Signed-off-by: Anton Nesterov <anton@demiurg.io>
This commit is contained in:
Anton Nesterov 2024-08-09 21:14:28 +02:00
parent fb13fcbece
commit 1db30b92c2
No known key found for this signature in database
GPG key ID: 59121E8AE2851FB5
20 changed files with 393 additions and 33 deletions

View file

@ -14,6 +14,10 @@ type SQLiteContext struct {
} }
func (c SQLiteContext) New(opts CtxOpts) Context { func (c SQLiteContext) New(opts CtxOpts) Context {
tn := opts["TableName"]
if tn == "" {
tn = c.TableName
}
ta := opts["TableAlias"] ta := opts["TableAlias"]
if ta == "" { if ta == "" {
ta = c.TableAlias ta = c.TableAlias
@ -23,6 +27,7 @@ func (c SQLiteContext) New(opts CtxOpts) Context {
fn = c.FieldName fn = c.FieldName
} }
return SQLiteContext{ return SQLiteContext{
TableName: tn,
TableAlias: ta, TableAlias: ta,
FieldName: fn, FieldName: fn,
} }
@ -42,6 +47,16 @@ func (c SQLiteContext) GetFieldName() string {
return c.FieldName return c.FieldName
} }
func (c SQLiteContext) GetColumnName(key string) string {
if strings.Contains(key, ".") {
return key
}
if c.TableAlias != "" {
return c.TableAlias + "." + key
}
return key
}
func (c SQLiteContext) NormalizeValue(value interface{}) interface{} { func (c SQLiteContext) NormalizeValue(value interface{}) interface{} {
str, ok := value.(string) str, ok := value.(string)
if utils.IsSQLFunction(str) { if utils.IsSQLFunction(str) {

View file

@ -6,5 +6,6 @@ type Context interface {
New(opts CtxOpts) Context New(opts CtxOpts) Context
GetTableName() string GetTableName() string
GetFieldName() string GetFieldName() string
GetColumnName(key string) string
NormalizeValue(interface{}) interface{} NormalizeValue(interface{}) interface{}
} }

View file

@ -1,37 +1,91 @@
package builder package builder
import "strings"
type SQLParts struct { type SQLParts struct {
operation string Operation string
selectExp string From string
fromExp string FieldsExp string
fiterExp string FromExp string
joinExp []string HavingExp string
groupExp string FiterExp string
orderExp string JoinExps []string
limitExp string GroupExp string
OrderExp string
LimitExp string
updateExp string updateExp string
upsertExp string upsertExp string
} }
type Builder struct { type Builder struct {
parts SQLParts Parts SQLParts
TableName string
TableAlias string
Ctx Context
} }
func New() *Builder { func New(ctx Context) *Builder {
return &Builder{} return &Builder{
Parts: SQLParts{
Operation: "SELECT",
From: "FROM",
},
Ctx: ctx,
}
} }
func (b *Builder) In(selectExp string) *Builder { func (b *Builder) In(table string) *Builder {
b.parts.selectExp = selectExp b.TableName, b.TableAlias = getTableAlias(table)
b.Parts.FromExp = table
b.Ctx = b.Ctx.New(CtxOpts{
"TableName": b.TableName,
"TableAlias": b.TableAlias,
})
return b return b
} }
func (b *Builder) Find(fromExp string) *Builder { func (b *Builder) Find(query Find) *Builder {
b.parts.fromExp = fromExp b.Parts.FiterExp = covertFind(
b.Ctx,
query,
)
if b.Parts.Operation == "" {
b.Parts.Operation = "SELECT"
}
if b.Parts.HavingExp == "" {
b.Parts.HavingExp = "WHERE"
}
if b.Parts.FieldsExp == "" {
b.Parts.FieldsExp = "*"
}
return b return b
} }
func (b *Builder) Join(fiterExp string) *Builder { func (b *Builder) Join(joins ...interface{}) *Builder {
b.parts.fiterExp = fiterExp b.Parts.JoinExps = convertJoin(b.Ctx, joins...)
return b return b
} }
func (b *Builder) Sql() string {
operation := b.Parts.Operation
switch {
case operation == "SELECT" || operation == "SELECT DISTINCT":
return unspace(strings.Join([]string{
b.Parts.Operation,
b.Parts.FieldsExp,
b.Parts.From,
b.Parts.FromExp,
strings.Join(
b.Parts.JoinExps,
" ",
),
b.Parts.GroupExp,
b.Parts.HavingExp,
b.Parts.FiterExp,
b.Parts.OrderExp,
b.Parts.LimitExp,
}, " "))
default:
return ""
}
}

View file

@ -0,0 +1,36 @@
package builder
import (
"testing"
)
func TestBuilderFind(t *testing.T) {
db := New(SQLiteContext{})
db.In("table t").Find(Query{
"field": "value",
"a": 1,
})
expect := "SELECT * FROM table t WHERE t.a = 1 AND t.field = 'value'"
if db.Sql() != expect {
t.Errorf(`Expected: "%s", Got: %s`, expect, db.Sql())
}
}
func TestBuilderJoin(t *testing.T) {
db := New(SQLiteContext{})
db.In("table t")
db.Find(Query{
"field": "value",
"a": 1,
})
db.Join(Join{
For: "table2 t2",
Do: Query{
"t2.field": "t.field",
},
})
expect := "SELECT * FROM table t JOIN table2 t2 ON t2.field = t.field WHERE t.a = 1 AND t.field = 'value'"
if db.Sql() != expect {
t.Errorf(`Expected: "%s", Got: %s`, expect, db.Sql())
}
}

View file

@ -0,0 +1,13 @@
package builder
import (
"fmt"
"strings"
utils "l12.xyz/dal/utils"
)
func ConvertConflict(ctx Context, fields ...string) string {
keys := utils.Map(fields, ctx.GetColumnName)
return fmt.Sprintf("ON CONFLICT (%s) DO", strings.Join(keys, ","))
}

View file

@ -0,0 +1,18 @@
package builder
import (
"testing"
)
func TestConvertConflict(t *testing.T) {
ctx := SQLiteContext{
TableName: "test",
TableAlias: "t",
FieldName: "test",
}
result := ConvertConflict(ctx, "a", "b", "tb.c")
if result != `ON CONFLICT (t.a,t.b,tb.c) DO` {
t.Errorf(`Expected "ON CONFLICT (t.a,t.b,tb.c) DO", got %s`, result)
}
}

View file

@ -7,7 +7,7 @@ import (
filters "l12.xyz/dal/filters" filters "l12.xyz/dal/filters"
) )
func CovertFind(ctx Context, find Find) string { func covertFind(ctx Context, find Find) string {
return covert_find(ctx, find, "") return covert_find(ctx, find, "")
} }
@ -15,7 +15,7 @@ func covert_find(ctx Context, find Find, join string) string {
if join == "" { if join == "" {
join = " AND " join = " AND "
} }
keys := AggregateSortedKeys([]Map{find}) keys := aggregateSortedKeys([]Map{find})
expressions := []string{} expressions := []string{}
for _, key := range keys { for _, key := range keys {
value := find[key] value := find[key]

View file

@ -14,7 +14,7 @@ func TestConvertFind(t *testing.T) {
ctx := SQLiteContext{ ctx := SQLiteContext{
TableAlias: "t", TableAlias: "t",
} }
result := CovertFind(ctx, find) result := covertFind(ctx, find)
if result == `t.exp > 1 AND t.impl = '1'` { if result == `t.exp > 1 AND t.impl = '1'` {
return return
} }
@ -38,7 +38,7 @@ func TestConvertFindAnd(t *testing.T) {
ctx := SQLiteContext{ ctx := SQLiteContext{
TableAlias: "t", TableAlias: "t",
} }
result := CovertFind(ctx, find) result := covertFind(ctx, find)
if result == `(t.a > 1 AND t.b < 10)` { if result == `(t.a > 1 AND t.b < 10)` {
return return
} }
@ -62,7 +62,7 @@ func TestConvertFindOr(t *testing.T) {
ctx := SQLiteContext{ ctx := SQLiteContext{
TableAlias: "t", TableAlias: "t",
} }
result := CovertFind(ctx, find) result := covertFind(ctx, find)
if result == `(t.a > 1 OR t.b < 10)` { if result == `(t.a > 1 OR t.b < 10)` {
return return
} }

View file

@ -0,0 +1,12 @@
package builder
import (
"strings"
"l12.xyz/dal/utils"
)
func ConvertGroup(ctx Context, keys []string) string {
set := utils.Map(keys, ctx.GetColumnName)
return strings.Join(set, ", ")
}

View file

@ -11,7 +11,7 @@ type InsertData struct {
} }
func ConvertInsert(ctx Context, inserts []Map) (InsertData, error) { func ConvertInsert(ctx Context, inserts []Map) (InsertData, error) {
keys := AggregateSortedKeys(inserts) keys := aggregateSortedKeys(inserts)
placeholder := make([]string, 0) placeholder := make([]string, 0)
for range keys { for range keys {
placeholder = append(placeholder, "?") placeholder = append(placeholder, "?")

View file

@ -1,7 +1,6 @@
package builder package builder
import ( import (
"fmt"
"testing" "testing"
) )
@ -20,8 +19,8 @@ func TestConvertInsert(t *testing.T) {
t.Errorf(`Expected "INSERT INTO test (a,b,c) VALUES (?,?,?)", got %s`, result.Statement) t.Errorf(`Expected "INSERT INTO test (a,b,c) VALUES (?,?,?)", got %s`, result.Statement)
} }
for _, r := range result.Values { // for _, r := range result.Values {
fmt.Println(r) // fmt.Println(r)
} // }
} }

View file

@ -15,7 +15,7 @@ func (j Join) Convert(ctx Context) string {
if j.For == "" { if j.For == "" {
return "" return ""
} }
filter := CovertFind(ctx, j.Do) filter := covertFind(ctx, j.Do)
var as string = "" var as string = ""
if j.As != "" { if j.As != "" {
as = fmt.Sprintf("%s ", j.As) as = fmt.Sprintf("%s ", j.As)
@ -23,7 +23,7 @@ func (j Join) Convert(ctx Context) string {
return as + fmt.Sprintf("JOIN %s ON %s", j.For, filter) return as + fmt.Sprintf("JOIN %s ON %s", j.For, filter)
} }
func ConvertJoin(ctx Context, joins ...interface{}) []string { func convertJoin(ctx Context, joins ...interface{}) []string {
var result []string var result []string
for _, join := range joins { for _, join := range joins {
jstr, ok := join.(string) jstr, ok := join.(string)

View file

@ -39,7 +39,7 @@ func TestConvertJoin(t *testing.T) {
ctx := SQLiteContext{ ctx := SQLiteContext{
TableAlias: "t", TableAlias: "t",
} }
result := ConvertJoin(ctx, joins...) result := convertJoin(ctx, joins...)
if result[1] != `JOIN artist a ON a.impl = t.impl` { if result[1] != `JOIN artist a ON a.impl = t.impl` {
t.Errorf(`Expected "JOIN artist a ON a.impl = t.impl", got %s`, result[1]) t.Errorf(`Expected "JOIN artist a ON a.impl = t.impl", got %s`, result[1])
} }
@ -56,7 +56,7 @@ func TestConvertMap(t *testing.T) {
ctx := SQLiteContext{ ctx := SQLiteContext{
TableAlias: "t", TableAlias: "t",
} }
result := ConvertJoin(ctx, joins...) result := convertJoin(ctx, joins...)
if result[0] != `LEFT JOIN artist a ON a.impl = t.impl` { if result[0] != `LEFT JOIN artist a ON a.impl = t.impl` {
t.Errorf(`Expected "LEFT JOIN artist a ON a.impl = t.impl", got %s`, result[0]) t.Errorf(`Expected "LEFT JOIN artist a ON a.impl = t.impl", got %s`, result[0])
} }

View file

@ -0,0 +1,41 @@
package builder
import "fmt"
type Pagination struct {
Limit interface{}
Offset interface{}
}
func ConvertLimit(limit int) string {
if limit == 0 {
return ""
}
return fmt.Sprintf("LIMIT %d", limit)
}
func ConvertOffset(offset int) string {
if offset == 0 {
return ""
}
return fmt.Sprintf("OFFSET %d", offset)
}
func ConvertLimitOffset(limit, offset int) string {
if limit == 0 && offset == 0 {
return ""
}
return fmt.Sprintf("LIMIT %d OFFSET %d", limit, offset)
}
func ConvertPagination(p Pagination) string {
limit := ""
if p.Limit != nil {
limit = fmt.Sprintf("LIMIT %d", p.Limit)
}
offset := ""
if p.Offset != nil {
offset = fmt.Sprintf("OFFSET %d", p.Offset)
}
return fmt.Sprintf("%s %s", limit, offset)
}

View file

@ -0,0 +1,58 @@
package builder
import (
"fmt"
"strings"
)
func ConvertSort(ctx Context, sort Map) (string, error) {
if sort == nil {
return "", nil
}
keys := aggregateSortedKeys([]Map{sort})
expressions := make([]string, 0)
for _, key := range keys {
name := ctx.GetColumnName(key)
order := normalize_order(sort[key])
if order != "" {
order = " " + order
}
expressions = append(expressions, name+order)
}
return fmt.Sprintf("ORDER BY %s", strings.Join(expressions, ", ")), nil
}
func normalize_order(order interface{}) string {
if order == nil {
return ""
}
orderInt, ok := order.(int)
if ok {
if orderInt == 1 {
return "ASC"
}
if orderInt == -1 {
return "DESC"
}
}
orderStr, ok := order.(string)
if !ok {
return ""
}
if orderStr == "" {
return ""
}
if orderStr == "1" {
return "ASC"
}
if orderStr == "-1" {
return "DESC"
}
if strings.ToUpper(orderStr) == "ASC" {
return "ASC"
}
if strings.ToUpper(orderStr) == "DESC" {
return "DESC"
}
return ""
}

View file

@ -0,0 +1,24 @@
package builder
import (
"testing"
)
func TestConvertSort(t *testing.T) {
ctx := SQLiteContext{
TableAlias: "t",
FieldName: "test",
}
result, err := ConvertSort(ctx, Map{
"a": -1,
"c": "desc",
"b": 1,
"d": nil,
})
if err != nil {
t.Error(err)
}
if result != `ORDER BY t.a DESC, t.b ASC, t.c DESC, t.d` {
t.Errorf("Expected ORDER BY t.a DESC, t.b ASC, t.c DESC, t.d, got %s", result)
}
}

View file

@ -0,0 +1,29 @@
package builder
import (
"fmt"
"strings"
)
type UpdateData struct {
Statement string
Values []interface{}
}
func ConvertUpdate(ctx Context, updates Map) (UpdateData, error) {
keys := aggregateSortedKeys([]Map{updates})
set := make([]string, 0)
values := make([]interface{}, 0)
for _, key := range keys {
set = append(set, fmt.Sprintf("%s = ?", key))
values = append(values, updates[key])
}
sfmt := fmt.Sprintf(
"UPDATE %s SET %s", ctx.GetTableName(),
strings.Join(set, ","),
)
return UpdateData{
Statement: sfmt,
Values: values,
}, nil
}

View file

@ -0,0 +1,24 @@
package builder
import (
"testing"
)
func TestConvertUpdate(t *testing.T) {
ctx := SQLiteContext{
TableName: "test",
TableAlias: "t",
FieldName: "test",
}
result, err := ConvertUpdate(ctx, Map{
"c": nil,
"a": 1,
"b": 2,
})
if err != nil {
t.Error(err)
}
if result.Statement != `UPDATE test SET a = ?,b = ?,c = ?` {
t.Errorf(`Expected "UPDATE test SET a = ?,b = ?,c = ?", got %s`, result.Statement)
}
}

View file

@ -0,0 +1,17 @@
package builder
import (
"fmt"
"strings"
)
func ConvertUpsert(keys []string) string {
set := make([]string, 0)
for _, key := range keys {
set = append(set, fmt.Sprintf("%s = EXCLUDED.%s", key, key))
}
return fmt.Sprintf(
"UPDATE SET %s",
strings.Join(set, ", "),
)
}

View file

@ -1,8 +1,11 @@
package builder package builder
import "sort" import (
"sort"
"strings"
)
func AggregateSortedKeys(maps []Map) []string { func aggregateSortedKeys(maps []Map) []string {
set := make(map[string]int) set := make(map[string]int)
keys := make([]string, 0) keys := make([]string, 0)
for _, item := range maps { for _, item := range maps {
@ -18,3 +21,19 @@ func AggregateSortedKeys(maps []Map) []string {
sort.Strings(keys) sort.Strings(keys)
return keys return keys
} }
func getTableAlias(tableName string) (string, string) {
if !strings.Contains(tableName, " ") {
return tableName, ""
}
if strings.Contains(strings.ToLower(tableName), " as ") {
data := strings.Split(strings.ToLower(tableName), " as ")
return data[0], data[1]
}
data := strings.Split(tableName, " ")
return data[0], data[1]
}
func unspace(s string) string {
return strings.Join(strings.Fields(s), " ")
}