[wip] builder

Signed-off-by: Anton Nesterov <anton@demiurg.io>
This commit is contained in:
Anton Nesterov 2024-08-10 00:33:10 +02:00
parent 1db30b92c2
commit 05f155ecc0
No known key found for this signature in database
GPG key ID: 59121E8AE2851FB5
16 changed files with 185 additions and 88 deletions

View file

@ -1,6 +1,9 @@
package builder package builder
import "strings" import (
"fmt"
"strings"
)
type SQLParts struct { type SQLParts struct {
Operation string Operation string
@ -13,8 +16,9 @@ type SQLParts struct {
GroupExp string GroupExp string
OrderExp string OrderExp string
LimitExp string LimitExp string
updateExp string OffsetExp string
upsertExp string Insert InsertData
Update UpdateData
} }
type Builder struct { type Builder struct {
@ -61,12 +65,95 @@ func (b *Builder) Find(query Find) *Builder {
return b return b
} }
func (b *Builder) Select(fields ...Map) *Builder {
fieldsExp, err := convertFields(fields)
if err != nil {
return b
}
b.Parts.FieldsExp = fieldsExp
return b
}
func (b *Builder) Fields(fields ...Map) *Builder {
return b.Select(fields...)
}
func (b *Builder) Join(joins ...interface{}) *Builder { func (b *Builder) Join(joins ...interface{}) *Builder {
b.Parts.JoinExps = convertJoin(b.Ctx, joins...) b.Parts.JoinExps = convertJoin(b.Ctx, joins...)
return b return b
} }
func (b *Builder) Sql() string { func (b *Builder) Group(keys ...string) *Builder {
b.Parts.HavingExp = "HAVING"
b.Parts.GroupExp = convertGroup(b.Ctx, keys)
return b
}
func (b *Builder) Sort(sort Map) *Builder {
b.Parts.OrderExp, _ = convertSort(b.Ctx, sort)
return b
}
func (b *Builder) Limit(limit int) *Builder {
b.Parts.LimitExp = fmt.Sprintf("LIMIT %d", limit)
return b
}
func (b *Builder) Offset(offset int) *Builder {
b.Parts.OffsetExp = fmt.Sprintf("OFFSET %d", offset)
return b
}
func (b *Builder) Delete() *Builder {
b.Parts.Operation = "DELETE"
return b
}
func (b *Builder) Insert(inserts []Map) *Builder {
insertData, _ := convertInsert(b.Ctx, inserts)
b.Parts = SQLParts{
Operation: "INSERT INTO",
Insert: insertData,
}
return b
}
func (b *Builder) Set(updates Map) *Builder {
updateData := convertUpdate(b.Ctx, updates)
b.Parts = SQLParts{
Operation: "UPDATE",
Update: updateData,
}
return b
}
func (b *Builder) Update(updates Map) *Builder {
return b.Set(updates)
}
func (b *Builder) OnConflict(fields ...string) *Builder {
if b.Parts.Operation == "UPDATE" {
b.Parts.Update.Upsert = convertConflict(b.Ctx, fields...)
b.Parts.Update.UpsertExp = "DO NOTHING"
}
return b
}
func (b *Builder) DoUpdate(fields ...string) *Builder {
if b.Parts.Operation == "UPDATE" {
b.Parts.Update.UpsertExp = convertUpsert(fields)
}
return b
}
func (b *Builder) DoNothing() *Builder {
if b.Parts.Operation == "UPDATE" {
b.Parts.Update.UpsertExp = "DO NOTHING"
}
return b
}
func (b *Builder) Sql() (string, []interface{}) {
operation := b.Parts.Operation operation := b.Parts.Operation
switch { switch {
case operation == "SELECT" || operation == "SELECT DISTINCT": case operation == "SELECT" || operation == "SELECT DISTINCT":
@ -84,8 +171,28 @@ func (b *Builder) Sql() string {
b.Parts.FiterExp, b.Parts.FiterExp,
b.Parts.OrderExp, b.Parts.OrderExp,
b.Parts.LimitExp, b.Parts.LimitExp,
}, " ")) b.Parts.OffsetExp,
}, " ")), nil
case operation == "DELETE":
return unspace(strings.Join([]string{
b.Parts.Operation,
b.Parts.From,
b.Parts.FromExp,
b.Parts.HavingExp,
b.Parts.FiterExp,
b.Parts.OrderExp,
b.Parts.LimitExp,
b.Parts.OffsetExp,
}, " ")), nil
case operation == "INSERT INTO":
return b.Parts.Insert.Statement, b.Parts.Insert.Values
case operation == "UPDATE":
return unspace(strings.Join([]string{
b.Parts.Update.Statement,
b.Parts.Update.Upsert,
b.Parts.Update.UpsertExp,
}, " ")), b.Parts.Update.Values
default: default:
return "" return "", nil
} }
} }

View file

@ -11,8 +11,46 @@ func TestBuilderFind(t *testing.T) {
"a": 1, "a": 1,
}) })
expect := "SELECT * FROM table t WHERE t.a = 1 AND t.field = 'value'" expect := "SELECT * FROM table t WHERE t.a = 1 AND t.field = 'value'"
if db.Sql() != expect { result, _ := db.Sql()
t.Errorf(`Expected: "%s", Got: %s`, expect, db.Sql()) if result != expect {
t.Errorf(`Expected: "%s", Got: %s`, expect, result)
}
}
func TestBuilderFields(t *testing.T) {
db := New(SQLiteContext{})
db.In("table t")
db.Find(Query{
"field": "value",
"a": 1,
})
db.Fields(Map{
"t.field": "field",
"t.a": 1,
})
expect := "SELECT t.a, t.field AS field FROM table t WHERE t.a = 1 AND t.field = 'value'"
result, _ := db.Sql()
if result != expect {
t.Errorf(`Expected: "%s", Got: %s`, expect, result)
}
}
func TestBuilderGroup(t *testing.T) {
db := New(SQLiteContext{})
db.In("table t")
db.Find(Query{
"field": Is{
"$gt": 1,
},
})
db.Fields(Map{
"SUM(t.field)": "field",
})
db.Group("field")
expect := "SELECT SUM(t.field) AS field FROM table t GROUP BY t.field HAVING t.field > 1"
result, _ := db.Sql()
if result != expect {
t.Errorf(`Expected: "%s", Got: %s`, expect, result)
} }
} }
@ -30,7 +68,8 @@ func TestBuilderJoin(t *testing.T) {
}, },
}) })
expect := "SELECT * FROM table t JOIN table2 t2 ON t2.field = t.field WHERE t.a = 1 AND t.field = 'value'" expect := "SELECT * FROM table t JOIN table2 t2 ON t2.field = t.field WHERE t.a = 1 AND t.field = 'value'"
if db.Sql() != expect { result, _ := db.Sql()
t.Errorf(`Expected: "%s", Got: %s`, expect, db.Sql()) if result != expect {
t.Errorf(`Expected: "%s", Got: %s`, expect, result)
} }
} }

View file

@ -7,7 +7,7 @@ import (
utils "l12.xyz/dal/utils" utils "l12.xyz/dal/utils"
) )
func ConvertConflict(ctx Context, fields ...string) string { func convertConflict(ctx Context, fields ...string) string {
keys := utils.Map(fields, ctx.GetColumnName) keys := utils.Map(fields, ctx.GetColumnName)
return fmt.Sprintf("ON CONFLICT (%s) DO", strings.Join(keys, ",")) return fmt.Sprintf("ON CONFLICT (%s)", strings.Join(keys, ","))
} }

View file

@ -10,9 +10,9 @@ func TestConvertConflict(t *testing.T) {
TableAlias: "t", TableAlias: "t",
FieldName: "test", FieldName: "test",
} }
result := ConvertConflict(ctx, "a", "b", "tb.c") result := convertConflict(ctx, "a", "b", "tb.c")
if result != `ON CONFLICT (t.a,t.b,tb.c) DO` { if result != `ON CONFLICT (t.a,t.b,tb.c)` {
t.Errorf(`Expected "ON CONFLICT (t.a,t.b,tb.c) DO", got %s`, result) t.Errorf(`Expected "ON CONFLICT (t.a,t.b,tb.c)", got %s`, result)
} }
} }

View file

@ -5,10 +5,13 @@ import (
"strings" "strings"
) )
func ConvertFields(ctx Context, fields []Map) (string, error) { func convertFields(fields []Map) (string, error) {
var expressions []string var expressions []string
for _, fieldAssoc := range fields { for _, fieldAssoc := range fields {
for field, as := range fieldAssoc { keys := aggregateSortedKeys([]Map{fieldAssoc})
for _, key := range keys {
field := key
as := fieldAssoc[key]
asBool, ok := as.(bool) asBool, ok := as.(bool)
if ok { if ok {
if asBool { if asBool {

View file

@ -5,11 +5,7 @@ import (
) )
func TestConvertFieldsBool(t *testing.T) { func TestConvertFieldsBool(t *testing.T) {
ctx := SQLiteContext{ result, err := convertFields([]Map{
TableAlias: "t",
FieldName: "test",
}
result, err := ConvertFields(ctx, []Map{
{"test": true}, {"test": true},
{"test2": false}, {"test2": false},
}) })
@ -22,11 +18,7 @@ func TestConvertFieldsBool(t *testing.T) {
} }
func TestConvertFieldsInt(t *testing.T) { func TestConvertFieldsInt(t *testing.T) {
ctx := SQLiteContext{ result, err := convertFields([]Map{
TableAlias: "t",
FieldName: "test",
}
result, err := ConvertFields(ctx, []Map{
{"test": 0}, {"test": 0},
{"test2": 1}, {"test2": 1},
}) })
@ -39,11 +31,7 @@ func TestConvertFieldsInt(t *testing.T) {
} }
func TestConvertFieldsStr(t *testing.T) { func TestConvertFieldsStr(t *testing.T) {
ctx := SQLiteContext{ result, err := convertFields([]Map{
TableAlias: "t",
FieldName: "test",
}
result, err := ConvertFields(ctx, []Map{
{"t.test": "Test"}, {"t.test": "Test"},
{"SUM(t.test, t.int)": "Sum"}, {"SUM(t.test, t.int)": "Sum"},
}) })

View file

@ -1,12 +1,13 @@
package builder package builder
import ( import (
"fmt"
"strings" "strings"
"l12.xyz/dal/utils" "l12.xyz/dal/utils"
) )
func ConvertGroup(ctx Context, keys []string) string { func convertGroup(ctx Context, keys []string) string {
set := utils.Map(keys, ctx.GetColumnName) set := utils.Map(keys, ctx.GetColumnName)
return strings.Join(set, ", ") return fmt.Sprintf("GROUP BY %s", strings.Join(set, ", "))
} }

View file

@ -10,7 +10,7 @@ type InsertData struct {
Values []interface{} Values []interface{}
} }
func ConvertInsert(ctx Context, inserts []Map) (InsertData, error) { func convertInsert(ctx Context, inserts []Map) (InsertData, error) {
keys := aggregateSortedKeys(inserts) keys := aggregateSortedKeys(inserts)
placeholder := make([]string, 0) placeholder := make([]string, 0)
for range keys { for range keys {
@ -27,7 +27,8 @@ func ConvertInsert(ctx Context, inserts []Map) (InsertData, error) {
} }
sfmt := fmt.Sprintf( sfmt := fmt.Sprintf(
"INSERT INTO %s (%s) VALUES (%s)", ctx.GetTableName(), "INSERT INTO %s (%s) VALUES (%s)",
ctx.GetTableName(),
strings.Join(keys, ","), strings.Join(keys, ","),
strings.Join(placeholder, ","), strings.Join(placeholder, ","),
) )

View file

@ -13,7 +13,7 @@ func TestConvertInsert(t *testing.T) {
{"a": "1", "b": 2}, {"a": "1", "b": 2},
{"b": 2, "a": "1", "c": 3}, {"b": 2, "a": "1", "c": 3},
} }
result, _ := ConvertInsert(ctx, insert) result, _ := convertInsert(ctx, insert)
if result.Statement != `INSERT INTO test (a,b,c) VALUES (?,?,?)` { if result.Statement != `INSERT INTO test (a,b,c) VALUES (?,?,?)` {
t.Errorf(`Expected "INSERT INTO test (a,b,c) VALUES (?,?,?)", got %s`, result.Statement) t.Errorf(`Expected "INSERT INTO test (a,b,c) VALUES (?,?,?)", got %s`, result.Statement)

View file

@ -1,41 +0,0 @@
package builder
import "fmt"
type Pagination struct {
Limit interface{}
Offset interface{}
}
func ConvertLimit(limit int) string {
if limit == 0 {
return ""
}
return fmt.Sprintf("LIMIT %d", limit)
}
func ConvertOffset(offset int) string {
if offset == 0 {
return ""
}
return fmt.Sprintf("OFFSET %d", offset)
}
func ConvertLimitOffset(limit, offset int) string {
if limit == 0 && offset == 0 {
return ""
}
return fmt.Sprintf("LIMIT %d OFFSET %d", limit, offset)
}
func ConvertPagination(p Pagination) string {
limit := ""
if p.Limit != nil {
limit = fmt.Sprintf("LIMIT %d", p.Limit)
}
offset := ""
if p.Offset != nil {
offset = fmt.Sprintf("OFFSET %d", p.Offset)
}
return fmt.Sprintf("%s %s", limit, offset)
}

View file

@ -5,7 +5,7 @@ import (
"strings" "strings"
) )
func ConvertSort(ctx Context, sort Map) (string, error) { func convertSort(ctx Context, sort Map) (string, error) {
if sort == nil { if sort == nil {
return "", nil return "", nil
} }

View file

@ -9,7 +9,7 @@ func TestConvertSort(t *testing.T) {
TableAlias: "t", TableAlias: "t",
FieldName: "test", FieldName: "test",
} }
result, err := ConvertSort(ctx, Map{ result, err := convertSort(ctx, Map{
"a": -1, "a": -1,
"c": "desc", "c": "desc",
"b": 1, "b": 1,

View file

@ -7,10 +7,12 @@ import (
type UpdateData struct { type UpdateData struct {
Statement string Statement string
Upsert string
UpsertExp string
Values []interface{} Values []interface{}
} }
func ConvertUpdate(ctx Context, updates Map) (UpdateData, error) { func convertUpdate(ctx Context, updates Map) UpdateData {
keys := aggregateSortedKeys([]Map{updates}) keys := aggregateSortedKeys([]Map{updates})
set := make([]string, 0) set := make([]string, 0)
values := make([]interface{}, 0) values := make([]interface{}, 0)
@ -25,5 +27,5 @@ func ConvertUpdate(ctx Context, updates Map) (UpdateData, error) {
return UpdateData{ return UpdateData{
Statement: sfmt, Statement: sfmt,
Values: values, Values: values,
}, nil }
} }

View file

@ -10,14 +10,11 @@ func TestConvertUpdate(t *testing.T) {
TableAlias: "t", TableAlias: "t",
FieldName: "test", FieldName: "test",
} }
result, err := ConvertUpdate(ctx, Map{ result := convertUpdate(ctx, Map{
"c": nil, "c": nil,
"a": 1, "a": 1,
"b": 2, "b": 2,
}) })
if err != nil {
t.Error(err)
}
if result.Statement != `UPDATE test SET a = ?,b = ?,c = ?` { if result.Statement != `UPDATE test SET a = ?,b = ?,c = ?` {
t.Errorf(`Expected "UPDATE test SET a = ?,b = ?,c = ?", got %s`, result.Statement) t.Errorf(`Expected "UPDATE test SET a = ?,b = ?,c = ?", got %s`, result.Statement)
} }

View file

@ -5,13 +5,13 @@ import (
"strings" "strings"
) )
func ConvertUpsert(keys []string) string { func convertUpsert(keys []string) string {
set := make([]string, 0) set := make([]string, 0)
for _, key := range keys { for _, key := range keys {
set = append(set, fmt.Sprintf("%s = EXCLUDED.%s", key, key)) set = append(set, fmt.Sprintf("%s = EXCLUDED.%s", key, key))
} }
return fmt.Sprintf( return fmt.Sprintf(
"UPDATE SET %s", "DO UPDATE SET %s",
strings.Join(set, ", "), strings.Join(set, ", "),
) )
} }