[ref] refactor adapter, add dialect registry for future extensibility

Signed-off-by: Anton Nesterov <anton@demiurg.io>
This commit is contained in:
Anton Nesterov 2024-08-15 09:06:11 +02:00
parent 920098e80c
commit 7f5c2a32cc
No known key found for this signature in database
GPG key ID: 59121E8AE2851FB5
15 changed files with 68 additions and 41 deletions

View file

@ -21,7 +21,7 @@ func TestBuilderBasic(t *testing.T) {
t.Fatalf("failed to create table: %v", err)
}
insert, values := builder.New(adapter.SQLite{}).In("test t").Insert([]builder.Map{
insert, values := builder.New(adapter.CommonDialect{}).In("test t").Insert([]builder.Map{
{"name": "a"},
{"name": 'b'},
}).Sql()
@ -31,7 +31,7 @@ func TestBuilderBasic(t *testing.T) {
t.Fatalf("failed to insert data: %v", err)
}
expr, values := builder.New(adapter.SQLite{}).In("test t").Find(builder.Find{"name": builder.Is{"$in": []interface{}{"a", 98}}}).Sql()
expr, values := builder.New(adapter.CommonDialect{}).In("test t").Find(builder.Find{"name": builder.Is{"$in": []interface{}{"a", 98}}}).Sql()
fmt.Println(expr)
rows, err := a.Query(adapter.Query{
Db: "file::memory:?cache=shared",

View file

@ -16,7 +16,7 @@ func TestProtoMessagePack(t *testing.T) {
}
req := proto.Request{}
req.UnmarshalMsg(message)
query, err := req.Parse(adapter.SQLite{})
query, err := req.Parse(adapter.CommonDialect{})
if err != nil {
t.Fatalf("failed to parse query: %v", err)
}

View file

@ -7,13 +7,17 @@ import (
utils "l12.xyz/dal/utils"
)
type SQLite struct {
/**
* CommonDialect is a simple implementation of the Dialect interface.
* Should be usable for most SQL databases.
**/
type CommonDialect struct {
TableName string
TableAlias string
FieldName string
}
func (c SQLite) New(opts DialectOpts) Dialect {
func (c CommonDialect) New(opts DialectOpts) Dialect {
tn := opts["TableName"]
if tn == "" {
tn = c.TableName
@ -26,18 +30,18 @@ func (c SQLite) New(opts DialectOpts) Dialect {
if fn == "" {
fn = c.FieldName
}
return SQLite{
return CommonDialect{
TableName: tn,
TableAlias: ta,
FieldName: fn,
}
}
func (c SQLite) GetTableName() string {
func (c CommonDialect) GetTableName() string {
return c.TableName
}
func (c SQLite) GetFieldName() string {
func (c CommonDialect) GetFieldName() string {
if strings.Contains(c.FieldName, ".") {
return c.FieldName
}
@ -47,7 +51,7 @@ func (c SQLite) GetFieldName() string {
return c.FieldName
}
func (c SQLite) GetColumnName(key string) string {
func (c CommonDialect) GetColumnName(key string) string {
if strings.Contains(key, ".") {
return key
}
@ -57,7 +61,7 @@ func (c SQLite) GetColumnName(key string) string {
return key
}
func (c SQLite) NormalizeValue(value interface{}) interface{} {
func (c CommonDialect) NormalizeValue(value interface{}) interface{} {
str, isStr := value.(string)
if !isStr {
return value

24
pkg/adapter/registry.go Normal file
View file

@ -0,0 +1,24 @@
package adapter
import "fmt"
var DIALECTS = map[string]Dialect{
"sqlite3": CommonDialect{},
}
/**
* Register a new dialect for a given driver name.
* `driverName` is the valid name of the db driver (e.g. "sqlite3", "postgres").
* `dialect` is an implementation of the Dialect interface.
**/
func RegisterDialect(driverName string, dialect Dialect) {
DIALECTS[driverName] = dialect
}
func GetDialect(driverName string) Dialect {
dialect, ok := DIALECTS[driverName]
if !ok {
panic(fmt.Errorf("db driver %s not found", driverName))
}
return dialect
}

View file

@ -8,6 +8,9 @@ type Query struct {
type DialectOpts map[string]string
/**
* Dialect interface provides general utilities for normalizing values for particular DB.
**/
type Dialect interface {
New(opts DialectOpts) Dialect
GetTableName() string
@ -15,8 +18,3 @@ type Dialect interface {
GetColumnName(key string) string
NormalizeValue(interface{}) interface{}
}
var DIALECTS = map[string]Dialect{
"sqlite3": SQLite{},
"sqlite": SQLite{},
}

View file

@ -6,7 +6,7 @@ import (
)
func TestBuilderFind(t *testing.T) {
db := New(SQLiteContext{})
db := New(CommonDialect{})
db.In("table t").Find(Query{
"field": "value",
"a": 1,
@ -19,7 +19,7 @@ func TestBuilderFind(t *testing.T) {
}
func TestBuilderFields(t *testing.T) {
db := New(SQLiteContext{})
db := New(CommonDialect{})
db.In("table t")
db.Find(Query{
"field": "value",
@ -37,7 +37,7 @@ func TestBuilderFields(t *testing.T) {
}
func TestBuilderGroup(t *testing.T) {
db := New(SQLiteContext{})
db := New(CommonDialect{})
db.In("table t")
db.Find(Query{
"field": Is{
@ -57,7 +57,7 @@ func TestBuilderGroup(t *testing.T) {
}
func TestBuilderJoin(t *testing.T) {
db := New(SQLiteContext{})
db := New(CommonDialect{})
db.In("table t")
db.Find(Query{
"field": "value",

View file

@ -5,7 +5,7 @@ import (
)
func TestConvertConflict(t *testing.T) {
ctx := SQLiteContext{
ctx := CommonDialect{
TableName: "test",
TableAlias: "t",
FieldName: "test",

View file

@ -12,7 +12,7 @@ func TestConvertFind(t *testing.T) {
"$gt": 2,
},
}
ctx := SQLiteContext{
ctx := CommonDialect{
TableAlias: "t",
}
result, values := covertFind(ctx, find)
@ -39,7 +39,7 @@ func TestConvertFindAnd(t *testing.T) {
},
},
}
ctx := SQLiteContext{
ctx := CommonDialect{
TableAlias: "t",
}
result, values := covertFind(ctx, find)
@ -61,7 +61,7 @@ func TestConvertFindOr(t *testing.T) {
},
},
}
ctx := SQLiteContext{
ctx := CommonDialect{
TableAlias: "t",
}
result, values := covertFind(ctx, find)

View file

@ -6,7 +6,7 @@ import (
)
func TestConvertInsert(t *testing.T) {
ctx := SQLiteContext{
ctx := CommonDialect{
TableName: "test",
TableAlias: "t",
}

View file

@ -3,12 +3,8 @@ package builder
import (
"fmt"
"testing"
adapter "l12.xyz/dal/adapter"
)
type SQLiteContext = adapter.SQLite
func TestJoin(t *testing.T) {
j := Join{
For: "artist a",
@ -17,7 +13,7 @@ func TestJoin(t *testing.T) {
},
As: "LEFT",
}
ctx := SQLiteContext{
ctx := CommonDialect{
TableAlias: "t",
}
result, vals := j.Convert(ctx)
@ -38,7 +34,7 @@ func TestConvertJoin(t *testing.T) {
},
},
}
ctx := SQLiteContext{
ctx := CommonDialect{
TableAlias: "t",
}
result, vals := convertJoin(ctx, joins...)
@ -56,7 +52,7 @@ func TestConvertMap(t *testing.T) {
joins := []interface{}{
Map{"$for": "artist a", "$do": Map{"a.impl": "t.impl"}, "$as": "LEFT"},
}
ctx := SQLiteContext{
ctx := CommonDialect{
TableAlias: "t",
}
result, vals := convertJoin(ctx, joins...)

View file

@ -5,7 +5,7 @@ import (
)
func TestConvertSort(t *testing.T) {
ctx := SQLiteContext{
ctx := CommonDialect{
TableAlias: "t",
FieldName: "test",
}

View file

@ -5,7 +5,7 @@ import (
)
func TestConvertUpdate(t *testing.T) {
ctx := SQLiteContext{
ctx := CommonDialect{
TableName: "test",
TableAlias: "t",
FieldName: "test",

View file

@ -5,6 +5,7 @@ import (
filters "l12.xyz/dal/filters"
)
type CommonDialect = adapter.CommonDialect
type Map = map[string]interface{}
type Fields = Map
type Find = filters.Find

View file

@ -1,7 +1,6 @@
package server
import (
"fmt"
"io"
"net/http"
"reflect"
@ -10,11 +9,17 @@ import (
"l12.xyz/dal/proto"
)
/**
* QueryHandler is a http.Handler that reads a proto.Request from the request body,
* parses it into a query, executes the query on the provided db and writes the
* result to the response body.
* - The request body is expected to be in msgpack format (proto.Request).
* - The response body is written in msgpack format.
* - The respose is a stream of rows (proto.Row), where the first row is the column names.
* - The columns are sorted alphabetically, so it is client's responsibility to match them and sort as needed.
**/
func QueryHandler(db adapter.DBAdapter) http.Handler {
dialect, ok := adapter.DIALECTS[db.Type]
if !ok {
panic(fmt.Errorf("dialect %s not found", db.Type))
}
dialect := adapter.GetDialect(db.Type)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
bodyReader, err := r.GetBody()

View file

@ -15,6 +15,7 @@ import (
)
func TestQueryHandler(t *testing.T) {
adapter.RegisterDialect("sqlite3", adapter.CommonDialect{})
a := adapter.DBAdapter{Type: "sqlite3"}
db, err := a.Open("file::memory:?cache=shared")
if err != nil {
@ -52,9 +53,7 @@ func TestQueryHandler(t *testing.T) {
t.Fatal(err)
}
rr := httptest.NewRecorder()
handler := QueryHandler(adapter.DBAdapter{
Type: "sqlite3",
})
handler := QueryHandler(a)
handler.ServeHTTP(rr, req)
res, _ := io.ReadAll(rr.Result().Body)
result := proto.UnmarshalRows(res)