From 05f155ecc0f84714fd4f6e75e31aeb82388bc144 Mon Sep 17 00:00:00 2001 From: Anton Nesterov Date: Sat, 10 Aug 2024 00:33:10 +0200 Subject: [PATCH] [wip] builder Signed-off-by: Anton Nesterov --- .../{SQLiteCtx.go => SQLiteContext.go} | 0 pkg/builder/Builder.go | 119 +++++++++++++++++- pkg/builder/Builder_test.go | 47 ++++++- pkg/builder/convert_conflict.go | 4 +- pkg/builder/convert_conflict_test.go | 6 +- pkg/builder/convert_fields.go | 7 +- pkg/builder/convert_fields_test.go | 18 +-- pkg/builder/convert_group.go | 5 +- pkg/builder/convert_insert.go | 5 +- pkg/builder/convert_insert_test.go | 2 +- pkg/builder/convert_limit.go | 41 ------ pkg/builder/convert_sort.go | 2 +- pkg/builder/convert_sort_test.go | 2 +- pkg/builder/convert_update.go | 6 +- pkg/builder/convert_update_test.go | 5 +- pkg/builder/convert_upsert.go | 4 +- 16 files changed, 185 insertions(+), 88 deletions(-) rename pkg/adapter/{SQLiteCtx.go => SQLiteContext.go} (100%) delete mode 100644 pkg/builder/convert_limit.go diff --git a/pkg/adapter/SQLiteCtx.go b/pkg/adapter/SQLiteContext.go similarity index 100% rename from pkg/adapter/SQLiteCtx.go rename to pkg/adapter/SQLiteContext.go diff --git a/pkg/builder/Builder.go b/pkg/builder/Builder.go index 8f05f32..e0223ea 100644 --- a/pkg/builder/Builder.go +++ b/pkg/builder/Builder.go @@ -1,6 +1,9 @@ package builder -import "strings" +import ( + "fmt" + "strings" +) type SQLParts struct { Operation string @@ -13,8 +16,9 @@ type SQLParts struct { GroupExp string OrderExp string LimitExp string - updateExp string - upsertExp string + OffsetExp string + Insert InsertData + Update UpdateData } type Builder struct { @@ -61,12 +65,95 @@ func (b *Builder) Find(query Find) *Builder { return b } +func (b *Builder) Select(fields ...Map) *Builder { + fieldsExp, err := convertFields(fields) + if err != nil { + return b + } + b.Parts.FieldsExp = fieldsExp + return b +} + +func (b *Builder) Fields(fields ...Map) *Builder { + return b.Select(fields...) +} + func (b *Builder) Join(joins ...interface{}) *Builder { b.Parts.JoinExps = convertJoin(b.Ctx, joins...) return b } -func (b *Builder) Sql() string { +func (b *Builder) Group(keys ...string) *Builder { + b.Parts.HavingExp = "HAVING" + b.Parts.GroupExp = convertGroup(b.Ctx, keys) + return b +} + +func (b *Builder) Sort(sort Map) *Builder { + b.Parts.OrderExp, _ = convertSort(b.Ctx, sort) + return b +} + +func (b *Builder) Limit(limit int) *Builder { + b.Parts.LimitExp = fmt.Sprintf("LIMIT %d", limit) + return b +} + +func (b *Builder) Offset(offset int) *Builder { + b.Parts.OffsetExp = fmt.Sprintf("OFFSET %d", offset) + return b +} + +func (b *Builder) Delete() *Builder { + b.Parts.Operation = "DELETE" + return b +} + +func (b *Builder) Insert(inserts []Map) *Builder { + insertData, _ := convertInsert(b.Ctx, inserts) + b.Parts = SQLParts{ + Operation: "INSERT INTO", + Insert: insertData, + } + return b +} + +func (b *Builder) Set(updates Map) *Builder { + updateData := convertUpdate(b.Ctx, updates) + b.Parts = SQLParts{ + Operation: "UPDATE", + Update: updateData, + } + return b +} + +func (b *Builder) Update(updates Map) *Builder { + return b.Set(updates) +} + +func (b *Builder) OnConflict(fields ...string) *Builder { + if b.Parts.Operation == "UPDATE" { + b.Parts.Update.Upsert = convertConflict(b.Ctx, fields...) + b.Parts.Update.UpsertExp = "DO NOTHING" + } + return b +} + +func (b *Builder) DoUpdate(fields ...string) *Builder { + if b.Parts.Operation == "UPDATE" { + b.Parts.Update.UpsertExp = convertUpsert(fields) + } + return b +} + +func (b *Builder) DoNothing() *Builder { + if b.Parts.Operation == "UPDATE" { + b.Parts.Update.UpsertExp = "DO NOTHING" + } + return b +} + +func (b *Builder) Sql() (string, []interface{}) { operation := b.Parts.Operation switch { case operation == "SELECT" || operation == "SELECT DISTINCT": @@ -84,8 +171,28 @@ func (b *Builder) Sql() string { b.Parts.FiterExp, b.Parts.OrderExp, b.Parts.LimitExp, - }, " ")) + b.Parts.OffsetExp, + }, " ")), nil + case operation == "DELETE": + return unspace(strings.Join([]string{ + b.Parts.Operation, + b.Parts.From, + b.Parts.FromExp, + b.Parts.HavingExp, + b.Parts.FiterExp, + b.Parts.OrderExp, + b.Parts.LimitExp, + b.Parts.OffsetExp, + }, " ")), nil + case operation == "INSERT INTO": + return b.Parts.Insert.Statement, b.Parts.Insert.Values + case operation == "UPDATE": + return unspace(strings.Join([]string{ + b.Parts.Update.Statement, + b.Parts.Update.Upsert, + b.Parts.Update.UpsertExp, + }, " ")), b.Parts.Update.Values default: - return "" + return "", nil } } diff --git a/pkg/builder/Builder_test.go b/pkg/builder/Builder_test.go index 63c83ab..f7da7b1 100644 --- a/pkg/builder/Builder_test.go +++ b/pkg/builder/Builder_test.go @@ -11,8 +11,46 @@ func TestBuilderFind(t *testing.T) { "a": 1, }) expect := "SELECT * FROM table t WHERE t.a = 1 AND t.field = 'value'" - if db.Sql() != expect { - t.Errorf(`Expected: "%s", Got: %s`, expect, db.Sql()) + result, _ := db.Sql() + if result != expect { + t.Errorf(`Expected: "%s", Got: %s`, expect, result) + } +} + +func TestBuilderFields(t *testing.T) { + db := New(SQLiteContext{}) + db.In("table t") + db.Find(Query{ + "field": "value", + "a": 1, + }) + db.Fields(Map{ + "t.field": "field", + "t.a": 1, + }) + expect := "SELECT t.a, t.field AS field FROM table t WHERE t.a = 1 AND t.field = 'value'" + result, _ := db.Sql() + if result != expect { + t.Errorf(`Expected: "%s", Got: %s`, expect, result) + } +} + +func TestBuilderGroup(t *testing.T) { + db := New(SQLiteContext{}) + db.In("table t") + db.Find(Query{ + "field": Is{ + "$gt": 1, + }, + }) + db.Fields(Map{ + "SUM(t.field)": "field", + }) + db.Group("field") + expect := "SELECT SUM(t.field) AS field FROM table t GROUP BY t.field HAVING t.field > 1" + result, _ := db.Sql() + if result != expect { + t.Errorf(`Expected: "%s", Got: %s`, expect, result) } } @@ -30,7 +68,8 @@ func TestBuilderJoin(t *testing.T) { }, }) expect := "SELECT * FROM table t JOIN table2 t2 ON t2.field = t.field WHERE t.a = 1 AND t.field = 'value'" - if db.Sql() != expect { - t.Errorf(`Expected: "%s", Got: %s`, expect, db.Sql()) + result, _ := db.Sql() + if result != expect { + t.Errorf(`Expected: "%s", Got: %s`, expect, result) } } diff --git a/pkg/builder/convert_conflict.go b/pkg/builder/convert_conflict.go index db034e6..d97f439 100644 --- a/pkg/builder/convert_conflict.go +++ b/pkg/builder/convert_conflict.go @@ -7,7 +7,7 @@ import ( utils "l12.xyz/dal/utils" ) -func ConvertConflict(ctx Context, fields ...string) string { +func convertConflict(ctx Context, fields ...string) string { keys := utils.Map(fields, ctx.GetColumnName) - return fmt.Sprintf("ON CONFLICT (%s) DO", strings.Join(keys, ",")) + return fmt.Sprintf("ON CONFLICT (%s)", strings.Join(keys, ",")) } diff --git a/pkg/builder/convert_conflict_test.go b/pkg/builder/convert_conflict_test.go index 3f11740..a57b54a 100644 --- a/pkg/builder/convert_conflict_test.go +++ b/pkg/builder/convert_conflict_test.go @@ -10,9 +10,9 @@ func TestConvertConflict(t *testing.T) { TableAlias: "t", FieldName: "test", } - result := ConvertConflict(ctx, "a", "b", "tb.c") + result := convertConflict(ctx, "a", "b", "tb.c") - if result != `ON CONFLICT (t.a,t.b,tb.c) DO` { - t.Errorf(`Expected "ON CONFLICT (t.a,t.b,tb.c) DO", got %s`, result) + if result != `ON CONFLICT (t.a,t.b,tb.c)` { + t.Errorf(`Expected "ON CONFLICT (t.a,t.b,tb.c)", got %s`, result) } } diff --git a/pkg/builder/convert_fields.go b/pkg/builder/convert_fields.go index e9366b3..4c5eda9 100644 --- a/pkg/builder/convert_fields.go +++ b/pkg/builder/convert_fields.go @@ -5,10 +5,13 @@ import ( "strings" ) -func ConvertFields(ctx Context, fields []Map) (string, error) { +func convertFields(fields []Map) (string, error) { var expressions []string for _, fieldAssoc := range fields { - for field, as := range fieldAssoc { + keys := aggregateSortedKeys([]Map{fieldAssoc}) + for _, key := range keys { + field := key + as := fieldAssoc[key] asBool, ok := as.(bool) if ok { if asBool { diff --git a/pkg/builder/convert_fields_test.go b/pkg/builder/convert_fields_test.go index 0e82be5..b809327 100644 --- a/pkg/builder/convert_fields_test.go +++ b/pkg/builder/convert_fields_test.go @@ -5,11 +5,7 @@ import ( ) func TestConvertFieldsBool(t *testing.T) { - ctx := SQLiteContext{ - TableAlias: "t", - FieldName: "test", - } - result, err := ConvertFields(ctx, []Map{ + result, err := convertFields([]Map{ {"test": true}, {"test2": false}, }) @@ -22,11 +18,7 @@ func TestConvertFieldsBool(t *testing.T) { } func TestConvertFieldsInt(t *testing.T) { - ctx := SQLiteContext{ - TableAlias: "t", - FieldName: "test", - } - result, err := ConvertFields(ctx, []Map{ + result, err := convertFields([]Map{ {"test": 0}, {"test2": 1}, }) @@ -39,11 +31,7 @@ func TestConvertFieldsInt(t *testing.T) { } func TestConvertFieldsStr(t *testing.T) { - ctx := SQLiteContext{ - TableAlias: "t", - FieldName: "test", - } - result, err := ConvertFields(ctx, []Map{ + result, err := convertFields([]Map{ {"t.test": "Test"}, {"SUM(t.test, t.int)": "Sum"}, }) diff --git a/pkg/builder/convert_group.go b/pkg/builder/convert_group.go index ceee2bf..933df50 100644 --- a/pkg/builder/convert_group.go +++ b/pkg/builder/convert_group.go @@ -1,12 +1,13 @@ package builder import ( + "fmt" "strings" "l12.xyz/dal/utils" ) -func ConvertGroup(ctx Context, keys []string) string { +func convertGroup(ctx Context, keys []string) string { set := utils.Map(keys, ctx.GetColumnName) - return strings.Join(set, ", ") + return fmt.Sprintf("GROUP BY %s", strings.Join(set, ", ")) } diff --git a/pkg/builder/convert_insert.go b/pkg/builder/convert_insert.go index 6cd0d0d..2ec7902 100644 --- a/pkg/builder/convert_insert.go +++ b/pkg/builder/convert_insert.go @@ -10,7 +10,7 @@ type InsertData struct { Values []interface{} } -func ConvertInsert(ctx Context, inserts []Map) (InsertData, error) { +func convertInsert(ctx Context, inserts []Map) (InsertData, error) { keys := aggregateSortedKeys(inserts) placeholder := make([]string, 0) for range keys { @@ -27,7 +27,8 @@ func ConvertInsert(ctx Context, inserts []Map) (InsertData, error) { } sfmt := fmt.Sprintf( - "INSERT INTO %s (%s) VALUES (%s)", ctx.GetTableName(), + "INSERT INTO %s (%s) VALUES (%s)", + ctx.GetTableName(), strings.Join(keys, ","), strings.Join(placeholder, ","), ) diff --git a/pkg/builder/convert_insert_test.go b/pkg/builder/convert_insert_test.go index 6694bff..c33b261 100644 --- a/pkg/builder/convert_insert_test.go +++ b/pkg/builder/convert_insert_test.go @@ -13,7 +13,7 @@ func TestConvertInsert(t *testing.T) { {"a": "1", "b": 2}, {"b": 2, "a": "1", "c": 3}, } - result, _ := ConvertInsert(ctx, insert) + result, _ := convertInsert(ctx, insert) if result.Statement != `INSERT INTO test (a,b,c) VALUES (?,?,?)` { t.Errorf(`Expected "INSERT INTO test (a,b,c) VALUES (?,?,?)", got %s`, result.Statement) diff --git a/pkg/builder/convert_limit.go b/pkg/builder/convert_limit.go deleted file mode 100644 index 5788a98..0000000 --- a/pkg/builder/convert_limit.go +++ /dev/null @@ -1,41 +0,0 @@ -package builder - -import "fmt" - -type Pagination struct { - Limit interface{} - Offset interface{} -} - -func ConvertLimit(limit int) string { - if limit == 0 { - return "" - } - return fmt.Sprintf("LIMIT %d", limit) -} - -func ConvertOffset(offset int) string { - if offset == 0 { - return "" - } - return fmt.Sprintf("OFFSET %d", offset) -} - -func ConvertLimitOffset(limit, offset int) string { - if limit == 0 && offset == 0 { - return "" - } - return fmt.Sprintf("LIMIT %d OFFSET %d", limit, offset) -} - -func ConvertPagination(p Pagination) string { - limit := "" - if p.Limit != nil { - limit = fmt.Sprintf("LIMIT %d", p.Limit) - } - offset := "" - if p.Offset != nil { - offset = fmt.Sprintf("OFFSET %d", p.Offset) - } - return fmt.Sprintf("%s %s", limit, offset) -} diff --git a/pkg/builder/convert_sort.go b/pkg/builder/convert_sort.go index 94977e7..3a7c9b1 100644 --- a/pkg/builder/convert_sort.go +++ b/pkg/builder/convert_sort.go @@ -5,7 +5,7 @@ import ( "strings" ) -func ConvertSort(ctx Context, sort Map) (string, error) { +func convertSort(ctx Context, sort Map) (string, error) { if sort == nil { return "", nil } diff --git a/pkg/builder/convert_sort_test.go b/pkg/builder/convert_sort_test.go index 1ef6762..76fcff3 100644 --- a/pkg/builder/convert_sort_test.go +++ b/pkg/builder/convert_sort_test.go @@ -9,7 +9,7 @@ func TestConvertSort(t *testing.T) { TableAlias: "t", FieldName: "test", } - result, err := ConvertSort(ctx, Map{ + result, err := convertSort(ctx, Map{ "a": -1, "c": "desc", "b": 1, diff --git a/pkg/builder/convert_update.go b/pkg/builder/convert_update.go index 9079c5a..1a8f706 100644 --- a/pkg/builder/convert_update.go +++ b/pkg/builder/convert_update.go @@ -7,10 +7,12 @@ import ( type UpdateData struct { Statement string + Upsert string + UpsertExp string Values []interface{} } -func ConvertUpdate(ctx Context, updates Map) (UpdateData, error) { +func convertUpdate(ctx Context, updates Map) UpdateData { keys := aggregateSortedKeys([]Map{updates}) set := make([]string, 0) values := make([]interface{}, 0) @@ -25,5 +27,5 @@ func ConvertUpdate(ctx Context, updates Map) (UpdateData, error) { return UpdateData{ Statement: sfmt, Values: values, - }, nil + } } diff --git a/pkg/builder/convert_update_test.go b/pkg/builder/convert_update_test.go index 35293af..e167ce3 100644 --- a/pkg/builder/convert_update_test.go +++ b/pkg/builder/convert_update_test.go @@ -10,14 +10,11 @@ func TestConvertUpdate(t *testing.T) { TableAlias: "t", FieldName: "test", } - result, err := ConvertUpdate(ctx, Map{ + result := convertUpdate(ctx, Map{ "c": nil, "a": 1, "b": 2, }) - if err != nil { - t.Error(err) - } if result.Statement != `UPDATE test SET a = ?,b = ?,c = ?` { t.Errorf(`Expected "UPDATE test SET a = ?,b = ?,c = ?", got %s`, result.Statement) } diff --git a/pkg/builder/convert_upsert.go b/pkg/builder/convert_upsert.go index cf9200e..7cd24bf 100644 --- a/pkg/builder/convert_upsert.go +++ b/pkg/builder/convert_upsert.go @@ -5,13 +5,13 @@ import ( "strings" ) -func ConvertUpsert(keys []string) string { +func convertUpsert(keys []string) string { set := make([]string, 0) for _, key := range keys { set = append(set, fmt.Sprintf("%s = EXCLUDED.%s", key, key)) } return fmt.Sprintf( - "UPDATE SET %s", + "DO UPDATE SET %s", strings.Join(set, ", "), ) }