From 7f5c2a32cceba3f6722fcd49917846fd083f3caf Mon Sep 17 00:00:00 2001 From: Anton Nesterov Date: Thu, 15 Aug 2024 09:06:11 +0200 Subject: [PATCH] [ref] refactor adapter, add dialect registry for future extensibility Signed-off-by: Anton Nesterov --- pkg/__test__/builder_test.go | 4 ++-- pkg/__test__/proto_test.go | 2 +- pkg/adapter/{SQLite.go => CommonDialect.go} | 18 ++++++++------ pkg/adapter/registry.go | 24 +++++++++++++++++++ pkg/adapter/types.go | 8 +++---- pkg/builder/Builder_test.go | 8 +++---- pkg/builder/convert_conflict_test.go | 2 +- pkg/builder/convert_find_test.go | 6 ++--- pkg/builder/convert_insert_test.go | 2 +- pkg/builder/convert_join_test.go | 10 +++----- pkg/builder/convert_sort_test.go | 2 +- pkg/builder/convert_update_test.go | 2 +- pkg/builder/types.go | 1 + pkg/server/{handler.go => query_handler.go} | 15 ++++++++---- ...{handler_test.go => query_handler_test.go} | 5 ++-- 15 files changed, 68 insertions(+), 41 deletions(-) rename pkg/adapter/{SQLite.go => CommonDialect.go} (69%) create mode 100644 pkg/adapter/registry.go rename pkg/server/{handler.go => query_handler.go} (68%) rename pkg/server/{handler_test.go => query_handler_test.go} (93%) diff --git a/pkg/__test__/builder_test.go b/pkg/__test__/builder_test.go index db84773..17414c9 100644 --- a/pkg/__test__/builder_test.go +++ b/pkg/__test__/builder_test.go @@ -21,7 +21,7 @@ func TestBuilderBasic(t *testing.T) { t.Fatalf("failed to create table: %v", err) } - insert, values := builder.New(adapter.SQLite{}).In("test t").Insert([]builder.Map{ + insert, values := builder.New(adapter.CommonDialect{}).In("test t").Insert([]builder.Map{ {"name": "a"}, {"name": 'b'}, }).Sql() @@ -31,7 +31,7 @@ func TestBuilderBasic(t *testing.T) { t.Fatalf("failed to insert data: %v", err) } - expr, values := builder.New(adapter.SQLite{}).In("test t").Find(builder.Find{"name": builder.Is{"$in": []interface{}{"a", 98}}}).Sql() + expr, values := builder.New(adapter.CommonDialect{}).In("test t").Find(builder.Find{"name": builder.Is{"$in": []interface{}{"a", 98}}}).Sql() fmt.Println(expr) rows, err := a.Query(adapter.Query{ Db: "file::memory:?cache=shared", diff --git a/pkg/__test__/proto_test.go b/pkg/__test__/proto_test.go index ea558f1..2a36ce8 100644 --- a/pkg/__test__/proto_test.go +++ b/pkg/__test__/proto_test.go @@ -16,7 +16,7 @@ func TestProtoMessagePack(t *testing.T) { } req := proto.Request{} req.UnmarshalMsg(message) - query, err := req.Parse(adapter.SQLite{}) + query, err := req.Parse(adapter.CommonDialect{}) if err != nil { t.Fatalf("failed to parse query: %v", err) } diff --git a/pkg/adapter/SQLite.go b/pkg/adapter/CommonDialect.go similarity index 69% rename from pkg/adapter/SQLite.go rename to pkg/adapter/CommonDialect.go index 4863c9f..6fbec26 100644 --- a/pkg/adapter/SQLite.go +++ b/pkg/adapter/CommonDialect.go @@ -7,13 +7,17 @@ import ( utils "l12.xyz/dal/utils" ) -type SQLite struct { +/** +* CommonDialect is a simple implementation of the Dialect interface. +* Should be usable for most SQL databases. +**/ +type CommonDialect struct { TableName string TableAlias string FieldName string } -func (c SQLite) New(opts DialectOpts) Dialect { +func (c CommonDialect) New(opts DialectOpts) Dialect { tn := opts["TableName"] if tn == "" { tn = c.TableName @@ -26,18 +30,18 @@ func (c SQLite) New(opts DialectOpts) Dialect { if fn == "" { fn = c.FieldName } - return SQLite{ + return CommonDialect{ TableName: tn, TableAlias: ta, FieldName: fn, } } -func (c SQLite) GetTableName() string { +func (c CommonDialect) GetTableName() string { return c.TableName } -func (c SQLite) GetFieldName() string { +func (c CommonDialect) GetFieldName() string { if strings.Contains(c.FieldName, ".") { return c.FieldName } @@ -47,7 +51,7 @@ func (c SQLite) GetFieldName() string { return c.FieldName } -func (c SQLite) GetColumnName(key string) string { +func (c CommonDialect) GetColumnName(key string) string { if strings.Contains(key, ".") { return key } @@ -57,7 +61,7 @@ func (c SQLite) GetColumnName(key string) string { return key } -func (c SQLite) NormalizeValue(value interface{}) interface{} { +func (c CommonDialect) NormalizeValue(value interface{}) interface{} { str, isStr := value.(string) if !isStr { return value diff --git a/pkg/adapter/registry.go b/pkg/adapter/registry.go new file mode 100644 index 0000000..bf38618 --- /dev/null +++ b/pkg/adapter/registry.go @@ -0,0 +1,24 @@ +package adapter + +import "fmt" + +var DIALECTS = map[string]Dialect{ + "sqlite3": CommonDialect{}, +} + +/** + * Register a new dialect for a given driver name. + * `driverName` is the valid name of the db driver (e.g. "sqlite3", "postgres"). + * `dialect` is an implementation of the Dialect interface. +**/ +func RegisterDialect(driverName string, dialect Dialect) { + DIALECTS[driverName] = dialect +} + +func GetDialect(driverName string) Dialect { + dialect, ok := DIALECTS[driverName] + if !ok { + panic(fmt.Errorf("db driver %s not found", driverName)) + } + return dialect +} diff --git a/pkg/adapter/types.go b/pkg/adapter/types.go index dd46e54..3fb8a58 100644 --- a/pkg/adapter/types.go +++ b/pkg/adapter/types.go @@ -8,6 +8,9 @@ type Query struct { type DialectOpts map[string]string +/** +* Dialect interface provides general utilities for normalizing values for particular DB. +**/ type Dialect interface { New(opts DialectOpts) Dialect GetTableName() string @@ -15,8 +18,3 @@ type Dialect interface { GetColumnName(key string) string NormalizeValue(interface{}) interface{} } - -var DIALECTS = map[string]Dialect{ - "sqlite3": SQLite{}, - "sqlite": SQLite{}, -} diff --git a/pkg/builder/Builder_test.go b/pkg/builder/Builder_test.go index 6b9bae8..0ef89db 100644 --- a/pkg/builder/Builder_test.go +++ b/pkg/builder/Builder_test.go @@ -6,7 +6,7 @@ import ( ) func TestBuilderFind(t *testing.T) { - db := New(SQLiteContext{}) + db := New(CommonDialect{}) db.In("table t").Find(Query{ "field": "value", "a": 1, @@ -19,7 +19,7 @@ func TestBuilderFind(t *testing.T) { } func TestBuilderFields(t *testing.T) { - db := New(SQLiteContext{}) + db := New(CommonDialect{}) db.In("table t") db.Find(Query{ "field": "value", @@ -37,7 +37,7 @@ func TestBuilderFields(t *testing.T) { } func TestBuilderGroup(t *testing.T) { - db := New(SQLiteContext{}) + db := New(CommonDialect{}) db.In("table t") db.Find(Query{ "field": Is{ @@ -57,7 +57,7 @@ func TestBuilderGroup(t *testing.T) { } func TestBuilderJoin(t *testing.T) { - db := New(SQLiteContext{}) + db := New(CommonDialect{}) db.In("table t") db.Find(Query{ "field": "value", diff --git a/pkg/builder/convert_conflict_test.go b/pkg/builder/convert_conflict_test.go index a57b54a..a9e685f 100644 --- a/pkg/builder/convert_conflict_test.go +++ b/pkg/builder/convert_conflict_test.go @@ -5,7 +5,7 @@ import ( ) func TestConvertConflict(t *testing.T) { - ctx := SQLiteContext{ + ctx := CommonDialect{ TableName: "test", TableAlias: "t", FieldName: "test", diff --git a/pkg/builder/convert_find_test.go b/pkg/builder/convert_find_test.go index 4a737ac..0ed38c3 100644 --- a/pkg/builder/convert_find_test.go +++ b/pkg/builder/convert_find_test.go @@ -12,7 +12,7 @@ func TestConvertFind(t *testing.T) { "$gt": 2, }, } - ctx := SQLiteContext{ + ctx := CommonDialect{ TableAlias: "t", } result, values := covertFind(ctx, find) @@ -39,7 +39,7 @@ func TestConvertFindAnd(t *testing.T) { }, }, } - ctx := SQLiteContext{ + ctx := CommonDialect{ TableAlias: "t", } result, values := covertFind(ctx, find) @@ -61,7 +61,7 @@ func TestConvertFindOr(t *testing.T) { }, }, } - ctx := SQLiteContext{ + ctx := CommonDialect{ TableAlias: "t", } result, values := covertFind(ctx, find) diff --git a/pkg/builder/convert_insert_test.go b/pkg/builder/convert_insert_test.go index 5dedbc8..6e82ea4 100644 --- a/pkg/builder/convert_insert_test.go +++ b/pkg/builder/convert_insert_test.go @@ -6,7 +6,7 @@ import ( ) func TestConvertInsert(t *testing.T) { - ctx := SQLiteContext{ + ctx := CommonDialect{ TableName: "test", TableAlias: "t", } diff --git a/pkg/builder/convert_join_test.go b/pkg/builder/convert_join_test.go index 11255a3..613444e 100644 --- a/pkg/builder/convert_join_test.go +++ b/pkg/builder/convert_join_test.go @@ -3,12 +3,8 @@ package builder import ( "fmt" "testing" - - adapter "l12.xyz/dal/adapter" ) -type SQLiteContext = adapter.SQLite - func TestJoin(t *testing.T) { j := Join{ For: "artist a", @@ -17,7 +13,7 @@ func TestJoin(t *testing.T) { }, As: "LEFT", } - ctx := SQLiteContext{ + ctx := CommonDialect{ TableAlias: "t", } result, vals := j.Convert(ctx) @@ -38,7 +34,7 @@ func TestConvertJoin(t *testing.T) { }, }, } - ctx := SQLiteContext{ + ctx := CommonDialect{ TableAlias: "t", } result, vals := convertJoin(ctx, joins...) @@ -56,7 +52,7 @@ func TestConvertMap(t *testing.T) { joins := []interface{}{ Map{"$for": "artist a", "$do": Map{"a.impl": "t.impl"}, "$as": "LEFT"}, } - ctx := SQLiteContext{ + ctx := CommonDialect{ TableAlias: "t", } result, vals := convertJoin(ctx, joins...) diff --git a/pkg/builder/convert_sort_test.go b/pkg/builder/convert_sort_test.go index 76fcff3..f150ba2 100644 --- a/pkg/builder/convert_sort_test.go +++ b/pkg/builder/convert_sort_test.go @@ -5,7 +5,7 @@ import ( ) func TestConvertSort(t *testing.T) { - ctx := SQLiteContext{ + ctx := CommonDialect{ TableAlias: "t", FieldName: "test", } diff --git a/pkg/builder/convert_update_test.go b/pkg/builder/convert_update_test.go index e167ce3..f1f8e22 100644 --- a/pkg/builder/convert_update_test.go +++ b/pkg/builder/convert_update_test.go @@ -5,7 +5,7 @@ import ( ) func TestConvertUpdate(t *testing.T) { - ctx := SQLiteContext{ + ctx := CommonDialect{ TableName: "test", TableAlias: "t", FieldName: "test", diff --git a/pkg/builder/types.go b/pkg/builder/types.go index fde0228..5f9ecbd 100644 --- a/pkg/builder/types.go +++ b/pkg/builder/types.go @@ -5,6 +5,7 @@ import ( filters "l12.xyz/dal/filters" ) +type CommonDialect = adapter.CommonDialect type Map = map[string]interface{} type Fields = Map type Find = filters.Find diff --git a/pkg/server/handler.go b/pkg/server/query_handler.go similarity index 68% rename from pkg/server/handler.go rename to pkg/server/query_handler.go index a66a5df..a5df9fb 100644 --- a/pkg/server/handler.go +++ b/pkg/server/query_handler.go @@ -1,7 +1,6 @@ package server import ( - "fmt" "io" "net/http" "reflect" @@ -10,11 +9,17 @@ import ( "l12.xyz/dal/proto" ) +/** +* QueryHandler is a http.Handler that reads a proto.Request from the request body, +* parses it into a query, executes the query on the provided db and writes the +* result to the response body. +* - The request body is expected to be in msgpack format (proto.Request). +* - The response body is written in msgpack format. +* - The respose is a stream of rows (proto.Row), where the first row is the column names. +* - The columns are sorted alphabetically, so it is client's responsibility to match them and sort as needed. +**/ func QueryHandler(db adapter.DBAdapter) http.Handler { - dialect, ok := adapter.DIALECTS[db.Type] - if !ok { - panic(fmt.Errorf("dialect %s not found", db.Type)) - } + dialect := adapter.GetDialect(db.Type) return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { bodyReader, err := r.GetBody() diff --git a/pkg/server/handler_test.go b/pkg/server/query_handler_test.go similarity index 93% rename from pkg/server/handler_test.go rename to pkg/server/query_handler_test.go index 3263d4e..8f6f624 100644 --- a/pkg/server/handler_test.go +++ b/pkg/server/query_handler_test.go @@ -15,6 +15,7 @@ import ( ) func TestQueryHandler(t *testing.T) { + adapter.RegisterDialect("sqlite3", adapter.CommonDialect{}) a := adapter.DBAdapter{Type: "sqlite3"} db, err := a.Open("file::memory:?cache=shared") if err != nil { @@ -52,9 +53,7 @@ func TestQueryHandler(t *testing.T) { t.Fatal(err) } rr := httptest.NewRecorder() - handler := QueryHandler(adapter.DBAdapter{ - Type: "sqlite3", - }) + handler := QueryHandler(a) handler.ServeHTTP(rr, req) res, _ := io.ReadAll(rr.Result().Body) result := proto.UnmarshalRows(res)