[ref] refactor adapter, add dialect registry for future extensibility

Signed-off-by: Anton Nesterov <anton@demiurg.io>
This commit is contained in:
Anton Nesterov 2024-08-15 09:06:11 +02:00
parent 920098e80c
commit 7f5c2a32cc
No known key found for this signature in database
GPG key ID: 59121E8AE2851FB5
15 changed files with 68 additions and 41 deletions

View file

@ -21,7 +21,7 @@ func TestBuilderBasic(t *testing.T) {
t.Fatalf("failed to create table: %v", err) t.Fatalf("failed to create table: %v", err)
} }
insert, values := builder.New(adapter.SQLite{}).In("test t").Insert([]builder.Map{ insert, values := builder.New(adapter.CommonDialect{}).In("test t").Insert([]builder.Map{
{"name": "a"}, {"name": "a"},
{"name": 'b'}, {"name": 'b'},
}).Sql() }).Sql()
@ -31,7 +31,7 @@ func TestBuilderBasic(t *testing.T) {
t.Fatalf("failed to insert data: %v", err) t.Fatalf("failed to insert data: %v", err)
} }
expr, values := builder.New(adapter.SQLite{}).In("test t").Find(builder.Find{"name": builder.Is{"$in": []interface{}{"a", 98}}}).Sql() expr, values := builder.New(adapter.CommonDialect{}).In("test t").Find(builder.Find{"name": builder.Is{"$in": []interface{}{"a", 98}}}).Sql()
fmt.Println(expr) fmt.Println(expr)
rows, err := a.Query(adapter.Query{ rows, err := a.Query(adapter.Query{
Db: "file::memory:?cache=shared", Db: "file::memory:?cache=shared",

View file

@ -16,7 +16,7 @@ func TestProtoMessagePack(t *testing.T) {
} }
req := proto.Request{} req := proto.Request{}
req.UnmarshalMsg(message) req.UnmarshalMsg(message)
query, err := req.Parse(adapter.SQLite{}) query, err := req.Parse(adapter.CommonDialect{})
if err != nil { if err != nil {
t.Fatalf("failed to parse query: %v", err) t.Fatalf("failed to parse query: %v", err)
} }

View file

@ -7,13 +7,17 @@ import (
utils "l12.xyz/dal/utils" utils "l12.xyz/dal/utils"
) )
type SQLite struct { /**
* CommonDialect is a simple implementation of the Dialect interface.
* Should be usable for most SQL databases.
**/
type CommonDialect struct {
TableName string TableName string
TableAlias string TableAlias string
FieldName string FieldName string
} }
func (c SQLite) New(opts DialectOpts) Dialect { func (c CommonDialect) New(opts DialectOpts) Dialect {
tn := opts["TableName"] tn := opts["TableName"]
if tn == "" { if tn == "" {
tn = c.TableName tn = c.TableName
@ -26,18 +30,18 @@ func (c SQLite) New(opts DialectOpts) Dialect {
if fn == "" { if fn == "" {
fn = c.FieldName fn = c.FieldName
} }
return SQLite{ return CommonDialect{
TableName: tn, TableName: tn,
TableAlias: ta, TableAlias: ta,
FieldName: fn, FieldName: fn,
} }
} }
func (c SQLite) GetTableName() string { func (c CommonDialect) GetTableName() string {
return c.TableName return c.TableName
} }
func (c SQLite) GetFieldName() string { func (c CommonDialect) GetFieldName() string {
if strings.Contains(c.FieldName, ".") { if strings.Contains(c.FieldName, ".") {
return c.FieldName return c.FieldName
} }
@ -47,7 +51,7 @@ func (c SQLite) GetFieldName() string {
return c.FieldName return c.FieldName
} }
func (c SQLite) GetColumnName(key string) string { func (c CommonDialect) GetColumnName(key string) string {
if strings.Contains(key, ".") { if strings.Contains(key, ".") {
return key return key
} }
@ -57,7 +61,7 @@ func (c SQLite) GetColumnName(key string) string {
return key return key
} }
func (c SQLite) NormalizeValue(value interface{}) interface{} { func (c CommonDialect) NormalizeValue(value interface{}) interface{} {
str, isStr := value.(string) str, isStr := value.(string)
if !isStr { if !isStr {
return value return value

24
pkg/adapter/registry.go Normal file
View file

@ -0,0 +1,24 @@
package adapter
import "fmt"
var DIALECTS = map[string]Dialect{
"sqlite3": CommonDialect{},
}
/**
* Register a new dialect for a given driver name.
* `driverName` is the valid name of the db driver (e.g. "sqlite3", "postgres").
* `dialect` is an implementation of the Dialect interface.
**/
func RegisterDialect(driverName string, dialect Dialect) {
DIALECTS[driverName] = dialect
}
func GetDialect(driverName string) Dialect {
dialect, ok := DIALECTS[driverName]
if !ok {
panic(fmt.Errorf("db driver %s not found", driverName))
}
return dialect
}

View file

@ -8,6 +8,9 @@ type Query struct {
type DialectOpts map[string]string type DialectOpts map[string]string
/**
* Dialect interface provides general utilities for normalizing values for particular DB.
**/
type Dialect interface { type Dialect interface {
New(opts DialectOpts) Dialect New(opts DialectOpts) Dialect
GetTableName() string GetTableName() string
@ -15,8 +18,3 @@ type Dialect interface {
GetColumnName(key string) string GetColumnName(key string) string
NormalizeValue(interface{}) interface{} NormalizeValue(interface{}) interface{}
} }
var DIALECTS = map[string]Dialect{
"sqlite3": SQLite{},
"sqlite": SQLite{},
}

View file

@ -6,7 +6,7 @@ import (
) )
func TestBuilderFind(t *testing.T) { func TestBuilderFind(t *testing.T) {
db := New(SQLiteContext{}) db := New(CommonDialect{})
db.In("table t").Find(Query{ db.In("table t").Find(Query{
"field": "value", "field": "value",
"a": 1, "a": 1,
@ -19,7 +19,7 @@ func TestBuilderFind(t *testing.T) {
} }
func TestBuilderFields(t *testing.T) { func TestBuilderFields(t *testing.T) {
db := New(SQLiteContext{}) db := New(CommonDialect{})
db.In("table t") db.In("table t")
db.Find(Query{ db.Find(Query{
"field": "value", "field": "value",
@ -37,7 +37,7 @@ func TestBuilderFields(t *testing.T) {
} }
func TestBuilderGroup(t *testing.T) { func TestBuilderGroup(t *testing.T) {
db := New(SQLiteContext{}) db := New(CommonDialect{})
db.In("table t") db.In("table t")
db.Find(Query{ db.Find(Query{
"field": Is{ "field": Is{
@ -57,7 +57,7 @@ func TestBuilderGroup(t *testing.T) {
} }
func TestBuilderJoin(t *testing.T) { func TestBuilderJoin(t *testing.T) {
db := New(SQLiteContext{}) db := New(CommonDialect{})
db.In("table t") db.In("table t")
db.Find(Query{ db.Find(Query{
"field": "value", "field": "value",

View file

@ -5,7 +5,7 @@ import (
) )
func TestConvertConflict(t *testing.T) { func TestConvertConflict(t *testing.T) {
ctx := SQLiteContext{ ctx := CommonDialect{
TableName: "test", TableName: "test",
TableAlias: "t", TableAlias: "t",
FieldName: "test", FieldName: "test",

View file

@ -12,7 +12,7 @@ func TestConvertFind(t *testing.T) {
"$gt": 2, "$gt": 2,
}, },
} }
ctx := SQLiteContext{ ctx := CommonDialect{
TableAlias: "t", TableAlias: "t",
} }
result, values := covertFind(ctx, find) result, values := covertFind(ctx, find)
@ -39,7 +39,7 @@ func TestConvertFindAnd(t *testing.T) {
}, },
}, },
} }
ctx := SQLiteContext{ ctx := CommonDialect{
TableAlias: "t", TableAlias: "t",
} }
result, values := covertFind(ctx, find) result, values := covertFind(ctx, find)
@ -61,7 +61,7 @@ func TestConvertFindOr(t *testing.T) {
}, },
}, },
} }
ctx := SQLiteContext{ ctx := CommonDialect{
TableAlias: "t", TableAlias: "t",
} }
result, values := covertFind(ctx, find) result, values := covertFind(ctx, find)

View file

@ -6,7 +6,7 @@ import (
) )
func TestConvertInsert(t *testing.T) { func TestConvertInsert(t *testing.T) {
ctx := SQLiteContext{ ctx := CommonDialect{
TableName: "test", TableName: "test",
TableAlias: "t", TableAlias: "t",
} }

View file

@ -3,12 +3,8 @@ package builder
import ( import (
"fmt" "fmt"
"testing" "testing"
adapter "l12.xyz/dal/adapter"
) )
type SQLiteContext = adapter.SQLite
func TestJoin(t *testing.T) { func TestJoin(t *testing.T) {
j := Join{ j := Join{
For: "artist a", For: "artist a",
@ -17,7 +13,7 @@ func TestJoin(t *testing.T) {
}, },
As: "LEFT", As: "LEFT",
} }
ctx := SQLiteContext{ ctx := CommonDialect{
TableAlias: "t", TableAlias: "t",
} }
result, vals := j.Convert(ctx) result, vals := j.Convert(ctx)
@ -38,7 +34,7 @@ func TestConvertJoin(t *testing.T) {
}, },
}, },
} }
ctx := SQLiteContext{ ctx := CommonDialect{
TableAlias: "t", TableAlias: "t",
} }
result, vals := convertJoin(ctx, joins...) result, vals := convertJoin(ctx, joins...)
@ -56,7 +52,7 @@ func TestConvertMap(t *testing.T) {
joins := []interface{}{ joins := []interface{}{
Map{"$for": "artist a", "$do": Map{"a.impl": "t.impl"}, "$as": "LEFT"}, Map{"$for": "artist a", "$do": Map{"a.impl": "t.impl"}, "$as": "LEFT"},
} }
ctx := SQLiteContext{ ctx := CommonDialect{
TableAlias: "t", TableAlias: "t",
} }
result, vals := convertJoin(ctx, joins...) result, vals := convertJoin(ctx, joins...)

View file

@ -5,7 +5,7 @@ import (
) )
func TestConvertSort(t *testing.T) { func TestConvertSort(t *testing.T) {
ctx := SQLiteContext{ ctx := CommonDialect{
TableAlias: "t", TableAlias: "t",
FieldName: "test", FieldName: "test",
} }

View file

@ -5,7 +5,7 @@ import (
) )
func TestConvertUpdate(t *testing.T) { func TestConvertUpdate(t *testing.T) {
ctx := SQLiteContext{ ctx := CommonDialect{
TableName: "test", TableName: "test",
TableAlias: "t", TableAlias: "t",
FieldName: "test", FieldName: "test",

View file

@ -5,6 +5,7 @@ import (
filters "l12.xyz/dal/filters" filters "l12.xyz/dal/filters"
) )
type CommonDialect = adapter.CommonDialect
type Map = map[string]interface{} type Map = map[string]interface{}
type Fields = Map type Fields = Map
type Find = filters.Find type Find = filters.Find

View file

@ -1,7 +1,6 @@
package server package server
import ( import (
"fmt"
"io" "io"
"net/http" "net/http"
"reflect" "reflect"
@ -10,11 +9,17 @@ import (
"l12.xyz/dal/proto" "l12.xyz/dal/proto"
) )
/**
* QueryHandler is a http.Handler that reads a proto.Request from the request body,
* parses it into a query, executes the query on the provided db and writes the
* result to the response body.
* - The request body is expected to be in msgpack format (proto.Request).
* - The response body is written in msgpack format.
* - The respose is a stream of rows (proto.Row), where the first row is the column names.
* - The columns are sorted alphabetically, so it is client's responsibility to match them and sort as needed.
**/
func QueryHandler(db adapter.DBAdapter) http.Handler { func QueryHandler(db adapter.DBAdapter) http.Handler {
dialect, ok := adapter.DIALECTS[db.Type] dialect := adapter.GetDialect(db.Type)
if !ok {
panic(fmt.Errorf("dialect %s not found", db.Type))
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
bodyReader, err := r.GetBody() bodyReader, err := r.GetBody()

View file

@ -15,6 +15,7 @@ import (
) )
func TestQueryHandler(t *testing.T) { func TestQueryHandler(t *testing.T) {
adapter.RegisterDialect("sqlite3", adapter.CommonDialect{})
a := adapter.DBAdapter{Type: "sqlite3"} a := adapter.DBAdapter{Type: "sqlite3"}
db, err := a.Open("file::memory:?cache=shared") db, err := a.Open("file::memory:?cache=shared")
if err != nil { if err != nil {
@ -52,9 +53,7 @@ func TestQueryHandler(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
handler := QueryHandler(adapter.DBAdapter{ handler := QueryHandler(a)
Type: "sqlite3",
})
handler.ServeHTTP(rr, req) handler.ServeHTTP(rr, req)
res, _ := io.ReadAll(rr.Result().Body) res, _ := io.ReadAll(rr.Result().Body)
result := proto.UnmarshalRows(res) result := proto.UnmarshalRows(res)