diff --git a/pkg/adapter/SQLiteCtx.go b/pkg/adapter/SQLiteCtx.go index 734fcbc..a8a1914 100644 --- a/pkg/adapter/SQLiteCtx.go +++ b/pkg/adapter/SQLiteCtx.go @@ -14,6 +14,10 @@ type SQLiteContext struct { } func (c SQLiteContext) New(opts CtxOpts) Context { + tn := opts["TableName"] + if tn == "" { + tn = c.TableName + } ta := opts["TableAlias"] if ta == "" { ta = c.TableAlias @@ -23,6 +27,7 @@ func (c SQLiteContext) New(opts CtxOpts) Context { fn = c.FieldName } return SQLiteContext{ + TableName: tn, TableAlias: ta, FieldName: fn, } @@ -42,6 +47,16 @@ func (c SQLiteContext) GetFieldName() string { return c.FieldName } +func (c SQLiteContext) GetColumnName(key string) string { + if strings.Contains(key, ".") { + return key + } + if c.TableAlias != "" { + return c.TableAlias + "." + key + } + return key +} + func (c SQLiteContext) NormalizeValue(value interface{}) interface{} { str, ok := value.(string) if utils.IsSQLFunction(str) { diff --git a/pkg/adapter/types.go b/pkg/adapter/types.go index 8ac8019..a0a4b71 100644 --- a/pkg/adapter/types.go +++ b/pkg/adapter/types.go @@ -6,5 +6,6 @@ type Context interface { New(opts CtxOpts) Context GetTableName() string GetFieldName() string + GetColumnName(key string) string NormalizeValue(interface{}) interface{} } diff --git a/pkg/builder/Builder.go b/pkg/builder/Builder.go index 289c329..8f05f32 100644 --- a/pkg/builder/Builder.go +++ b/pkg/builder/Builder.go @@ -1,37 +1,91 @@ package builder +import "strings" + type SQLParts struct { - operation string - selectExp string - fromExp string - fiterExp string - joinExp []string - groupExp string - orderExp string - limitExp string + Operation string + From string + FieldsExp string + FromExp string + HavingExp string + FiterExp string + JoinExps []string + GroupExp string + OrderExp string + LimitExp string updateExp string upsertExp string } type Builder struct { - parts SQLParts + Parts SQLParts + TableName string + TableAlias string + Ctx Context } -func New() *Builder { - return &Builder{} +func New(ctx Context) *Builder { + return &Builder{ + Parts: SQLParts{ + Operation: "SELECT", + From: "FROM", + }, + Ctx: ctx, + } } -func (b *Builder) In(selectExp string) *Builder { - b.parts.selectExp = selectExp +func (b *Builder) In(table string) *Builder { + b.TableName, b.TableAlias = getTableAlias(table) + b.Parts.FromExp = table + b.Ctx = b.Ctx.New(CtxOpts{ + "TableName": b.TableName, + "TableAlias": b.TableAlias, + }) return b } -func (b *Builder) Find(fromExp string) *Builder { - b.parts.fromExp = fromExp +func (b *Builder) Find(query Find) *Builder { + b.Parts.FiterExp = covertFind( + b.Ctx, + query, + ) + if b.Parts.Operation == "" { + b.Parts.Operation = "SELECT" + } + if b.Parts.HavingExp == "" { + b.Parts.HavingExp = "WHERE" + } + if b.Parts.FieldsExp == "" { + b.Parts.FieldsExp = "*" + } return b } -func (b *Builder) Join(fiterExp string) *Builder { - b.parts.fiterExp = fiterExp +func (b *Builder) Join(joins ...interface{}) *Builder { + b.Parts.JoinExps = convertJoin(b.Ctx, joins...) return b } + +func (b *Builder) Sql() string { + operation := b.Parts.Operation + switch { + case operation == "SELECT" || operation == "SELECT DISTINCT": + return unspace(strings.Join([]string{ + b.Parts.Operation, + b.Parts.FieldsExp, + b.Parts.From, + b.Parts.FromExp, + strings.Join( + b.Parts.JoinExps, + " ", + ), + b.Parts.GroupExp, + b.Parts.HavingExp, + b.Parts.FiterExp, + b.Parts.OrderExp, + b.Parts.LimitExp, + }, " ")) + default: + return "" + } +} diff --git a/pkg/builder/Builder_test.go b/pkg/builder/Builder_test.go new file mode 100644 index 0000000..63c83ab --- /dev/null +++ b/pkg/builder/Builder_test.go @@ -0,0 +1,36 @@ +package builder + +import ( + "testing" +) + +func TestBuilderFind(t *testing.T) { + db := New(SQLiteContext{}) + db.In("table t").Find(Query{ + "field": "value", + "a": 1, + }) + expect := "SELECT * FROM table t WHERE t.a = 1 AND t.field = 'value'" + if db.Sql() != expect { + t.Errorf(`Expected: "%s", Got: %s`, expect, db.Sql()) + } +} + +func TestBuilderJoin(t *testing.T) { + db := New(SQLiteContext{}) + db.In("table t") + db.Find(Query{ + "field": "value", + "a": 1, + }) + db.Join(Join{ + For: "table2 t2", + Do: Query{ + "t2.field": "t.field", + }, + }) + expect := "SELECT * FROM table t JOIN table2 t2 ON t2.field = t.field WHERE t.a = 1 AND t.field = 'value'" + if db.Sql() != expect { + t.Errorf(`Expected: "%s", Got: %s`, expect, db.Sql()) + } +} diff --git a/pkg/builder/convert_conflict.go b/pkg/builder/convert_conflict.go new file mode 100644 index 0000000..db034e6 --- /dev/null +++ b/pkg/builder/convert_conflict.go @@ -0,0 +1,13 @@ +package builder + +import ( + "fmt" + "strings" + + utils "l12.xyz/dal/utils" +) + +func ConvertConflict(ctx Context, fields ...string) string { + keys := utils.Map(fields, ctx.GetColumnName) + return fmt.Sprintf("ON CONFLICT (%s) DO", strings.Join(keys, ",")) +} diff --git a/pkg/builder/convert_conflict_test.go b/pkg/builder/convert_conflict_test.go new file mode 100644 index 0000000..3f11740 --- /dev/null +++ b/pkg/builder/convert_conflict_test.go @@ -0,0 +1,18 @@ +package builder + +import ( + "testing" +) + +func TestConvertConflict(t *testing.T) { + ctx := SQLiteContext{ + TableName: "test", + TableAlias: "t", + FieldName: "test", + } + result := ConvertConflict(ctx, "a", "b", "tb.c") + + if result != `ON CONFLICT (t.a,t.b,tb.c) DO` { + t.Errorf(`Expected "ON CONFLICT (t.a,t.b,tb.c) DO", got %s`, result) + } +} diff --git a/pkg/builder/convert_find.go b/pkg/builder/convert_find.go index 791ba17..cd3b6e0 100644 --- a/pkg/builder/convert_find.go +++ b/pkg/builder/convert_find.go @@ -7,7 +7,7 @@ import ( filters "l12.xyz/dal/filters" ) -func CovertFind(ctx Context, find Find) string { +func covertFind(ctx Context, find Find) string { return covert_find(ctx, find, "") } @@ -15,7 +15,7 @@ func covert_find(ctx Context, find Find, join string) string { if join == "" { join = " AND " } - keys := AggregateSortedKeys([]Map{find}) + keys := aggregateSortedKeys([]Map{find}) expressions := []string{} for _, key := range keys { value := find[key] diff --git a/pkg/builder/convert_find_test.go b/pkg/builder/convert_find_test.go index a212044..80db309 100644 --- a/pkg/builder/convert_find_test.go +++ b/pkg/builder/convert_find_test.go @@ -14,7 +14,7 @@ func TestConvertFind(t *testing.T) { ctx := SQLiteContext{ TableAlias: "t", } - result := CovertFind(ctx, find) + result := covertFind(ctx, find) if result == `t.exp > 1 AND t.impl = '1'` { return } @@ -38,7 +38,7 @@ func TestConvertFindAnd(t *testing.T) { ctx := SQLiteContext{ TableAlias: "t", } - result := CovertFind(ctx, find) + result := covertFind(ctx, find) if result == `(t.a > 1 AND t.b < 10)` { return } @@ -62,7 +62,7 @@ func TestConvertFindOr(t *testing.T) { ctx := SQLiteContext{ TableAlias: "t", } - result := CovertFind(ctx, find) + result := covertFind(ctx, find) if result == `(t.a > 1 OR t.b < 10)` { return } diff --git a/pkg/builder/convert_group.go b/pkg/builder/convert_group.go new file mode 100644 index 0000000..ceee2bf --- /dev/null +++ b/pkg/builder/convert_group.go @@ -0,0 +1,12 @@ +package builder + +import ( + "strings" + + "l12.xyz/dal/utils" +) + +func ConvertGroup(ctx Context, keys []string) string { + set := utils.Map(keys, ctx.GetColumnName) + return strings.Join(set, ", ") +} diff --git a/pkg/builder/convert_insert.go b/pkg/builder/convert_insert.go index 6bbd36b..6cd0d0d 100644 --- a/pkg/builder/convert_insert.go +++ b/pkg/builder/convert_insert.go @@ -11,7 +11,7 @@ type InsertData struct { } func ConvertInsert(ctx Context, inserts []Map) (InsertData, error) { - keys := AggregateSortedKeys(inserts) + keys := aggregateSortedKeys(inserts) placeholder := make([]string, 0) for range keys { placeholder = append(placeholder, "?") diff --git a/pkg/builder/convert_insert_test.go b/pkg/builder/convert_insert_test.go index 9cf7881..6694bff 100644 --- a/pkg/builder/convert_insert_test.go +++ b/pkg/builder/convert_insert_test.go @@ -1,7 +1,6 @@ package builder import ( - "fmt" "testing" ) @@ -20,8 +19,8 @@ func TestConvertInsert(t *testing.T) { t.Errorf(`Expected "INSERT INTO test (a,b,c) VALUES (?,?,?)", got %s`, result.Statement) } - for _, r := range result.Values { - fmt.Println(r) - } + // for _, r := range result.Values { + // fmt.Println(r) + // } } diff --git a/pkg/builder/convert_join.go b/pkg/builder/convert_join.go index e6dc040..ea17ff1 100644 --- a/pkg/builder/convert_join.go +++ b/pkg/builder/convert_join.go @@ -15,7 +15,7 @@ func (j Join) Convert(ctx Context) string { if j.For == "" { return "" } - filter := CovertFind(ctx, j.Do) + filter := covertFind(ctx, j.Do) var as string = "" if j.As != "" { as = fmt.Sprintf("%s ", j.As) @@ -23,7 +23,7 @@ func (j Join) Convert(ctx Context) string { return as + fmt.Sprintf("JOIN %s ON %s", j.For, filter) } -func ConvertJoin(ctx Context, joins ...interface{}) []string { +func convertJoin(ctx Context, joins ...interface{}) []string { var result []string for _, join := range joins { jstr, ok := join.(string) diff --git a/pkg/builder/convert_join_test.go b/pkg/builder/convert_join_test.go index 357a01f..6b375ec 100644 --- a/pkg/builder/convert_join_test.go +++ b/pkg/builder/convert_join_test.go @@ -39,7 +39,7 @@ func TestConvertJoin(t *testing.T) { ctx := SQLiteContext{ TableAlias: "t", } - result := ConvertJoin(ctx, joins...) + result := convertJoin(ctx, joins...) if result[1] != `JOIN artist a ON a.impl = t.impl` { t.Errorf(`Expected "JOIN artist a ON a.impl = t.impl", got %s`, result[1]) } @@ -56,7 +56,7 @@ func TestConvertMap(t *testing.T) { ctx := SQLiteContext{ TableAlias: "t", } - result := ConvertJoin(ctx, joins...) + result := convertJoin(ctx, joins...) if result[0] != `LEFT JOIN artist a ON a.impl = t.impl` { t.Errorf(`Expected "LEFT JOIN artist a ON a.impl = t.impl", got %s`, result[0]) } diff --git a/pkg/builder/convert_limit.go b/pkg/builder/convert_limit.go new file mode 100644 index 0000000..5788a98 --- /dev/null +++ b/pkg/builder/convert_limit.go @@ -0,0 +1,41 @@ +package builder + +import "fmt" + +type Pagination struct { + Limit interface{} + Offset interface{} +} + +func ConvertLimit(limit int) string { + if limit == 0 { + return "" + } + return fmt.Sprintf("LIMIT %d", limit) +} + +func ConvertOffset(offset int) string { + if offset == 0 { + return "" + } + return fmt.Sprintf("OFFSET %d", offset) +} + +func ConvertLimitOffset(limit, offset int) string { + if limit == 0 && offset == 0 { + return "" + } + return fmt.Sprintf("LIMIT %d OFFSET %d", limit, offset) +} + +func ConvertPagination(p Pagination) string { + limit := "" + if p.Limit != nil { + limit = fmt.Sprintf("LIMIT %d", p.Limit) + } + offset := "" + if p.Offset != nil { + offset = fmt.Sprintf("OFFSET %d", p.Offset) + } + return fmt.Sprintf("%s %s", limit, offset) +} diff --git a/pkg/builder/convert_sort.go b/pkg/builder/convert_sort.go new file mode 100644 index 0000000..94977e7 --- /dev/null +++ b/pkg/builder/convert_sort.go @@ -0,0 +1,58 @@ +package builder + +import ( + "fmt" + "strings" +) + +func ConvertSort(ctx Context, sort Map) (string, error) { + if sort == nil { + return "", nil + } + keys := aggregateSortedKeys([]Map{sort}) + expressions := make([]string, 0) + for _, key := range keys { + name := ctx.GetColumnName(key) + order := normalize_order(sort[key]) + if order != "" { + order = " " + order + } + expressions = append(expressions, name+order) + } + return fmt.Sprintf("ORDER BY %s", strings.Join(expressions, ", ")), nil +} + +func normalize_order(order interface{}) string { + if order == nil { + return "" + } + orderInt, ok := order.(int) + if ok { + if orderInt == 1 { + return "ASC" + } + if orderInt == -1 { + return "DESC" + } + } + orderStr, ok := order.(string) + if !ok { + return "" + } + if orderStr == "" { + return "" + } + if orderStr == "1" { + return "ASC" + } + if orderStr == "-1" { + return "DESC" + } + if strings.ToUpper(orderStr) == "ASC" { + return "ASC" + } + if strings.ToUpper(orderStr) == "DESC" { + return "DESC" + } + return "" +} diff --git a/pkg/builder/convert_sort_test.go b/pkg/builder/convert_sort_test.go new file mode 100644 index 0000000..1ef6762 --- /dev/null +++ b/pkg/builder/convert_sort_test.go @@ -0,0 +1,24 @@ +package builder + +import ( + "testing" +) + +func TestConvertSort(t *testing.T) { + ctx := SQLiteContext{ + TableAlias: "t", + FieldName: "test", + } + result, err := ConvertSort(ctx, Map{ + "a": -1, + "c": "desc", + "b": 1, + "d": nil, + }) + if err != nil { + t.Error(err) + } + if result != `ORDER BY t.a DESC, t.b ASC, t.c DESC, t.d` { + t.Errorf("Expected ORDER BY t.a DESC, t.b ASC, t.c DESC, t.d, got %s", result) + } +} diff --git a/pkg/builder/convert_update.go b/pkg/builder/convert_update.go new file mode 100644 index 0000000..9079c5a --- /dev/null +++ b/pkg/builder/convert_update.go @@ -0,0 +1,29 @@ +package builder + +import ( + "fmt" + "strings" +) + +type UpdateData struct { + Statement string + Values []interface{} +} + +func ConvertUpdate(ctx Context, updates Map) (UpdateData, error) { + keys := aggregateSortedKeys([]Map{updates}) + set := make([]string, 0) + values := make([]interface{}, 0) + for _, key := range keys { + set = append(set, fmt.Sprintf("%s = ?", key)) + values = append(values, updates[key]) + } + sfmt := fmt.Sprintf( + "UPDATE %s SET %s", ctx.GetTableName(), + strings.Join(set, ","), + ) + return UpdateData{ + Statement: sfmt, + Values: values, + }, nil +} diff --git a/pkg/builder/convert_update_test.go b/pkg/builder/convert_update_test.go new file mode 100644 index 0000000..35293af --- /dev/null +++ b/pkg/builder/convert_update_test.go @@ -0,0 +1,24 @@ +package builder + +import ( + "testing" +) + +func TestConvertUpdate(t *testing.T) { + ctx := SQLiteContext{ + TableName: "test", + TableAlias: "t", + FieldName: "test", + } + result, err := ConvertUpdate(ctx, Map{ + "c": nil, + "a": 1, + "b": 2, + }) + if err != nil { + t.Error(err) + } + if result.Statement != `UPDATE test SET a = ?,b = ?,c = ?` { + t.Errorf(`Expected "UPDATE test SET a = ?,b = ?,c = ?", got %s`, result.Statement) + } +} diff --git a/pkg/builder/convert_upsert.go b/pkg/builder/convert_upsert.go new file mode 100644 index 0000000..cf9200e --- /dev/null +++ b/pkg/builder/convert_upsert.go @@ -0,0 +1,17 @@ +package builder + +import ( + "fmt" + "strings" +) + +func ConvertUpsert(keys []string) string { + set := make([]string, 0) + for _, key := range keys { + set = append(set, fmt.Sprintf("%s = EXCLUDED.%s", key, key)) + } + return fmt.Sprintf( + "UPDATE SET %s", + strings.Join(set, ", "), + ) +} diff --git a/pkg/builder/utils.go b/pkg/builder/utils.go index cedea80..db205c1 100644 --- a/pkg/builder/utils.go +++ b/pkg/builder/utils.go @@ -1,8 +1,11 @@ package builder -import "sort" +import ( + "sort" + "strings" +) -func AggregateSortedKeys(maps []Map) []string { +func aggregateSortedKeys(maps []Map) []string { set := make(map[string]int) keys := make([]string, 0) for _, item := range maps { @@ -18,3 +21,19 @@ func AggregateSortedKeys(maps []Map) []string { sort.Strings(keys) return keys } + +func getTableAlias(tableName string) (string, string) { + if !strings.Contains(tableName, " ") { + return tableName, "" + } + if strings.Contains(strings.ToLower(tableName), " as ") { + data := strings.Split(strings.ToLower(tableName), " as ") + return data[0], data[1] + } + data := strings.Split(tableName, " ") + return data[0], data[1] +} + +func unspace(s string) string { + return strings.Join(strings.Fields(s), " ") +}