[wip] insert/update/delete operations

Signed-off-by: Anton Nesterov <anton@demiurg.io>
This commit is contained in:
Anton Nesterov 2024-08-15 17:52:21 +02:00
parent 60d213f21c
commit 9d13824fdf
No known key found for this signature in database
GPG key ID: 59121E8AE2851FB5
15 changed files with 502 additions and 38 deletions

View file

@ -61,7 +61,7 @@ export default class Builder<I extends abstract new (...args: any) => any> {
}); });
} }
private formatRow(data: unknown[]) { private formatRow(data: unknown[]) {
if (!this.dtoTemplate) { if (!this.dtoTemplate || this.dtoTemplate === Object) {
return data; return data;
} }
const instance = new this.dtoTemplate(data); const instance = new this.dtoTemplate(data);

View file

@ -21,11 +21,12 @@ func TestBuilderBasic(t *testing.T) {
t.Fatalf("failed to create table: %v", err) t.Fatalf("failed to create table: %v", err)
} }
insert, values := builder.New(adapter.CommonDialect{}).In("test t").Insert([]builder.Map{ inserts := []builder.Map{
{"name": "a"}, {"name": "a"},
{"name": 'b'}, {"name": 'b'},
}).Sql() }
fmt.Println(insert, values) insert, values := builder.New(adapter.CommonDialect{}).In("test t").Insert(inserts...).Sql()
_, err = db.Exec(insert, values...) _, err = db.Exec(insert, values...)
if err != nil { if err != nil {
t.Fatalf("failed to insert data: %v", err) t.Fatalf("failed to insert data: %v", err)
@ -52,3 +53,63 @@ func TestBuilderBasic(t *testing.T) {
fmt.Printf("id: %d, name: %s\n", id, name) fmt.Printf("id: %d, name: %s\n", id, name)
} }
} }
func TestBuilderSet(t *testing.T) {
a := adapter.DBAdapter{Type: "sqlite3"}
db, err := a.Open("file::memory:?cache=shared")
if err != nil {
t.Fatalf("failed to open db: %v", err)
}
defer db.Close()
_, err = db.Exec("CREATE TABLE test (id INTEGER PRIMARY KEY, name BLOB)")
if err != nil {
t.Fatalf("failed to create table: %v", err)
}
inserts := []builder.Map{
{"name": "a"},
{"name": 'b'},
}
insert, values := builder.New(adapter.CommonDialect{}).In("test t").Insert(inserts...).Sql()
_, err = db.Exec(insert, values...)
if err != nil {
t.Fatalf("failed to insert data: %v", err)
}
b := builder.New(adapter.CommonDialect{}).In("test t")
b.Find(builder.Find{"id": builder.Is{"$eq": 2}})
b.Set(builder.Map{"name": "c"})
b.Tx()
expr, values := b.Sql()
fmt.Println(expr, values)
_, err = a.Exec(adapter.Query{
Db: "file::memory:?cache=shared",
Expression: expr,
Data: values,
Transaction: b.Transaction,
})
if err != nil {
t.Fatalf("failed to query data: %v", err)
}
b = builder.New(adapter.CommonDialect{}).In("test t")
b.Find(builder.Find{})
expr, values = b.Sql()
fmt.Println(expr, values)
rows, _ := a.Query(adapter.Query{
Db: "file::memory:?cache=shared",
Expression: expr,
Data: values,
Transaction: b.Transaction,
})
defer rows.Close()
for rows.Next() {
var id int
var name string
err = rows.Scan(&id, &name)
if err != nil {
t.Fatalf("failed to scan row: %v", err)
}
fmt.Printf("id: %d, name: %s\n", id, name)
}
}

View file

@ -18,16 +18,16 @@ type CommonDialect struct {
} }
func (c CommonDialect) New(opts DialectOpts) Dialect { func (c CommonDialect) New(opts DialectOpts) Dialect {
tn := opts["TableName"] tn, ok := opts["TableName"]
if tn == "" { if !ok {
tn = c.TableName tn = c.TableName
} }
ta := opts["TableAlias"] ta, ok := opts["TableAlias"]
if ta == "" { if !ok {
ta = c.TableAlias ta = c.TableAlias
} }
fn := opts["FieldName"] fn, ok := opts["FieldName"]
if fn == "" { if !ok {
fn = c.FieldName fn = c.FieldName
} }
return CommonDialect{ return CommonDialect{
@ -41,6 +41,10 @@ func (c CommonDialect) GetTableName() string {
return c.TableName return c.TableName
} }
func (c CommonDialect) GetTableAlias() string {
return c.TableAlias
}
func (c CommonDialect) GetFieldName() string { func (c CommonDialect) GetFieldName() string {
if strings.Contains(c.FieldName, ".") { if strings.Contains(c.FieldName, ".") {
return c.FieldName return c.FieldName

View file

@ -110,6 +110,12 @@ func (a *DBAdapter) Query(req Query) (*sql.Rows, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if req.Transaction {
tx, _ := db.Begin()
rows, err := tx.Query(req.Expression, req.Data...)
tx.Commit()
return rows, err
}
sfmt, err := db.Prepare(req.Expression) sfmt, err := db.Prepare(req.Expression)
if err != nil { if err != nil {
return nil, err return nil, err
@ -122,6 +128,12 @@ func (a *DBAdapter) Exec(req Query) (sql.Result, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if req.Transaction {
tx, _ := db.Begin()
result, err := tx.Exec(req.Expression, req.Data...)
tx.Commit()
return result, err
}
sfmt, err := db.Prepare(req.Expression) sfmt, err := db.Prepare(req.Expression)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -4,6 +4,8 @@ type Query struct {
Db string `json:"db"` Db string `json:"db"`
Expression string `json:"expr"` Expression string `json:"expr"`
Data []interface{} `json:"data"` Data []interface{} `json:"data"`
Transaction bool `json:"transaction"`
Exec bool `json:"exec"`
} }
type DialectOpts map[string]string type DialectOpts map[string]string
@ -14,6 +16,7 @@ Dialect interface provides general utilities for normalizing values for particul
type Dialect interface { type Dialect interface {
New(opts DialectOpts) Dialect New(opts DialectOpts) Dialect
GetTableName() string GetTableName() string
GetTableAlias() string
GetFieldName() string GetFieldName() string
GetColumnName(key string) string GetColumnName(key string) string
NormalizeValue(interface{}) interface{} NormalizeValue(interface{}) interface{}

View file

@ -16,6 +16,8 @@ type Builder struct {
TableAlias string TableAlias string
Parts SQLParts Parts SQLParts
Dialect Dialect Dialect Dialect
LastQuery Find
Transaction bool
} }
type SQLParts struct { type SQLParts struct {
@ -56,19 +58,23 @@ func (b *Builder) In(table string) *Builder {
} }
func (b *Builder) Find(query Find) *Builder { func (b *Builder) Find(query Find) *Builder {
b.LastQuery = query
b.Parts.FiterExp, b.Parts.Values = covertFind( b.Parts.FiterExp, b.Parts.Values = covertFind(
b.Dialect, b.Dialect,
query, query,
) )
if b.Parts.Operation == "" { if len(b.Parts.Operation) == 0 {
b.Parts.Operation = "SELECT" b.Parts.Operation = "SELECT"
} }
if b.Parts.HavingExp == "" { if len(b.Parts.HavingExp) == 0 {
b.Parts.HavingExp = "WHERE" b.Parts.HavingExp = "WHERE"
} }
if b.Parts.FieldsExp == "" { if len(b.Parts.FieldsExp) == 0 {
b.Parts.FieldsExp = "*" b.Parts.FieldsExp = "*"
} }
if len(b.Parts.FiterExp) == 0 {
b.Parts.HavingExp = ""
}
return b return b
} }
@ -118,7 +124,7 @@ func (b *Builder) Delete() *Builder {
return b return b
} }
func (b *Builder) Insert(inserts []Map) *Builder { func (b *Builder) Insert(inserts ...Map) *Builder {
insertData, _ := convertInsert(b.Dialect, inserts) insertData, _ := convertInsert(b.Dialect, inserts)
b.Parts = SQLParts{ b.Parts = SQLParts{
Operation: "INSERT INTO", Operation: "INSERT INTO",
@ -128,11 +134,13 @@ func (b *Builder) Insert(inserts []Map) *Builder {
} }
func (b *Builder) Set(updates Map) *Builder { func (b *Builder) Set(updates Map) *Builder {
b.Dialect = b.Dialect.New(DialectOpts{
"TableAlias": "",
})
updateData := convertUpdate(b.Dialect, updates) updateData := convertUpdate(b.Dialect, updates)
b.Parts = SQLParts{ b.Find(b.LastQuery)
Operation: "UPDATE", b.Parts.Operation = "UPDATE"
Update: updateData, b.Parts.Update = updateData
}
return b return b
} }
@ -168,6 +176,11 @@ func (b *Builder) DoNothing() *Builder {
return b return b
} }
func (b *Builder) Tx() *Builder {
b.Transaction = true
return b
}
func (b *Builder) Sql() (string, []interface{}) { func (b *Builder) Sql() (string, []interface{}) {
operation := b.Parts.Operation operation := b.Parts.Operation
switch { switch {
@ -202,11 +215,17 @@ func (b *Builder) Sql() (string, []interface{}) {
case operation == "INSERT INTO": case operation == "INSERT INTO":
return b.Parts.Insert.Statement, b.Parts.Insert.Values return b.Parts.Insert.Statement, b.Parts.Insert.Values
case operation == "UPDATE": case operation == "UPDATE":
values := append(b.Parts.Update.Values, b.Parts.Values...)
return unspace(strings.Join([]string{ return unspace(strings.Join([]string{
b.Parts.Update.Statement, b.Parts.Update.Statement,
b.Parts.HavingExp,
b.Parts.FiterExp,
b.Parts.OrderExp,
b.Parts.LimitExp,
b.Parts.OffsetExp,
b.Parts.Update.Upsert, b.Parts.Update.Upsert,
b.Parts.Update.UpsertExp, b.Parts.Update.UpsertExp,
}, " ")), b.Parts.Update.Values }, " ")), values
default: default:
return "", nil return "", nil
} }

View file

@ -5,6 +5,16 @@ import (
"testing" "testing"
) )
func TestBuilderFindEmpty(t *testing.T) {
db := New(CommonDialect{})
db.In("table t").Find(Query{}).Limit(10)
expect := "SELECT * FROM table t LIMIT 10"
result, _ := db.Sql()
if result != expect {
t.Errorf(`Expected: "%s", Got: %s`, expect, result)
}
}
func TestBuilderFind(t *testing.T) { func TestBuilderFind(t *testing.T) {
db := New(CommonDialect{}) db := New(CommonDialect{})
db.In("table t").Find(Query{ db.In("table t").Find(Query{

View file

@ -21,7 +21,8 @@ func convertUpdate(ctx Dialect, updates Map) UpdateData {
values = append(values, updates[key]) values = append(values, updates[key])
} }
sfmt := fmt.Sprintf( sfmt := fmt.Sprintf(
"UPDATE %s SET %s", ctx.GetTableName(), "UPDATE %s SET %s",
ctx.GetTableName(),
strings.Join(set, ","), strings.Join(set, ","),
) )
return UpdateData{ return UpdateData{

View file

@ -33,6 +33,7 @@ func (q *Request) Parse(dialect adapter.Dialect) (adapter.Query, error) {
return adapter.Query{}, fmt.Errorf("Request format: commands are required") return adapter.Query{}, fmt.Errorf("Request format: commands are required")
} }
b := builder.New(dialect) b := builder.New(dialect)
exec := false
for _, cmd := range q.Commands { for _, cmd := range q.Commands {
if !slices.Contains(allowedMethods, cmd.Method) { if !slices.Contains(allowedMethods, cmd.Method) {
return adapter.Query{}, fmt.Errorf( return adapter.Query{}, fmt.Errorf(
@ -45,10 +46,14 @@ func (q *Request) Parse(dialect adapter.Dialect) (adapter.Query, error) {
if !method.IsValid() { if !method.IsValid() {
return adapter.Query{}, fmt.Errorf("method %s not found", cmd.Method) return adapter.Query{}, fmt.Errorf("method %s not found", cmd.Method)
} }
if cmd.Method == "Insert" || cmd.Method == "Set" || cmd.Method == "Delete" {
exec = true
}
args := make([]reflect.Value, len(cmd.Args)) args := make([]reflect.Value, len(cmd.Args))
for i, arg := range cmd.Args { for i, arg := range cmd.Args {
args[i] = reflect.ValueOf(arg) args[i] = reflect.ValueOf(arg)
} }
fmt.Print(exec, cmd.Method, args)
method.Call(args) method.Call(args)
} }
expr, data := b.Sql() expr, data := b.Sql()
@ -56,5 +61,7 @@ func (q *Request) Parse(dialect adapter.Dialect) (adapter.Query, error) {
Db: q.Db, Db: q.Db,
Expression: expr, Expression: expr,
Data: data, Data: data,
Transaction: b.Transaction,
Exec: exec,
}, nil }, nil
} }

View file

@ -9,7 +9,7 @@ import (
"github.com/tinylib/msgp/msgp" "github.com/tinylib/msgp/msgp"
) )
func TestMarshalUnmarshalBuildCmd(t *testing.T) { func TestMarshalUnmarshalBuilderMethod(t *testing.T) {
v := BuilderMethod{} v := BuilderMethod{}
bts, err := v.MarshalMsg(nil) bts, err := v.MarshalMsg(nil)
if err != nil { if err != nil {
@ -32,7 +32,7 @@ func TestMarshalUnmarshalBuildCmd(t *testing.T) {
} }
} }
func BenchmarkMarshalMsgBuildCmd(b *testing.B) { func BenchmarkMarshalMsgBuilderMethod(b *testing.B) {
v := BuilderMethod{} v := BuilderMethod{}
b.ReportAllocs() b.ReportAllocs()
b.ResetTimer() b.ResetTimer()
@ -41,7 +41,7 @@ func BenchmarkMarshalMsgBuildCmd(b *testing.B) {
} }
} }
func BenchmarkAppendMsgBuildCmd(b *testing.B) { func BenchmarkAppendMsgBuilderMethod(b *testing.B) {
v := BuilderMethod{} v := BuilderMethod{}
bts := make([]byte, 0, v.Msgsize()) bts := make([]byte, 0, v.Msgsize())
bts, _ = v.MarshalMsg(bts[0:0]) bts, _ = v.MarshalMsg(bts[0:0])
@ -53,7 +53,7 @@ func BenchmarkAppendMsgBuildCmd(b *testing.B) {
} }
} }
func BenchmarkUnmarshalBuildCmd(b *testing.B) { func BenchmarkUnmarshalBuilderMethod(b *testing.B) {
v := BuilderMethod{} v := BuilderMethod{}
bts, _ := v.MarshalMsg(nil) bts, _ := v.MarshalMsg(nil)
b.ReportAllocs() b.ReportAllocs()
@ -67,14 +67,14 @@ func BenchmarkUnmarshalBuildCmd(b *testing.B) {
} }
} }
func TestEncodeDecodeBuildCmd(t *testing.T) { func TestEncodeDecodeBuilderMethod(t *testing.T) {
v := BuilderMethod{} v := BuilderMethod{}
var buf bytes.Buffer var buf bytes.Buffer
msgp.Encode(&buf, &v) msgp.Encode(&buf, &v)
m := v.Msgsize() m := v.Msgsize()
if buf.Len() > m { if buf.Len() > m {
t.Log("WARNING: TestEncodeDecodeBuildCmd Msgsize() is inaccurate") t.Log("WARNING: TestEncodeDecodeBuilderMethod Msgsize() is inaccurate")
} }
vn := BuilderMethod{} vn := BuilderMethod{}
@ -91,7 +91,7 @@ func TestEncodeDecodeBuildCmd(t *testing.T) {
} }
} }
func BenchmarkEncodeBuildCmd(b *testing.B) { func BenchmarkEncodeBuilderMethod(b *testing.B) {
v := BuilderMethod{} v := BuilderMethod{}
var buf bytes.Buffer var buf bytes.Buffer
msgp.Encode(&buf, &v) msgp.Encode(&buf, &v)
@ -105,7 +105,7 @@ func BenchmarkEncodeBuildCmd(b *testing.B) {
en.Flush() en.Flush()
} }
func BenchmarkDecodeBuildCmd(b *testing.B) { func BenchmarkDecodeBuilderMethod(b *testing.B) {
v := BuilderMethod{} v := BuilderMethod{}
var buf bytes.Buffer var buf bytes.Buffer
msgp.Encode(&buf, &v) msgp.Encode(&buf, &v)

9
pkg/proto/response.go Normal file
View file

@ -0,0 +1,9 @@
package proto
//go:generate msgp
type Response struct {
Id uint32 `msg:"i"`
RowsAffected int64 `msg:"ra"`
LastInsertId int64 `msg:"li"`
}

160
pkg/proto/response_gen.go Normal file
View file

@ -0,0 +1,160 @@
package proto
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
import (
"github.com/tinylib/msgp/msgp"
)
// DecodeMsg implements msgp.Decodable
func (z *Response) DecodeMsg(dc *msgp.Reader) (err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, err = dc.ReadMapHeader()
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, err = dc.ReadMapKeyPtr()
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
case "i":
z.Id, err = dc.ReadUint32()
if err != nil {
err = msgp.WrapError(err, "Id")
return
}
case "ra":
z.RowsAffected, err = dc.ReadInt64()
if err != nil {
err = msgp.WrapError(err, "RowsAffected")
return
}
case "li":
z.LastInsertId, err = dc.ReadInt64()
if err != nil {
err = msgp.WrapError(err, "LastInsertId")
return
}
default:
err = dc.Skip()
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
return
}
// EncodeMsg implements msgp.Encodable
func (z Response) EncodeMsg(en *msgp.Writer) (err error) {
// map header, size 3
// write "i"
err = en.Append(0x83, 0xa1, 0x69)
if err != nil {
return
}
err = en.WriteUint32(z.Id)
if err != nil {
err = msgp.WrapError(err, "Id")
return
}
// write "ra"
err = en.Append(0xa2, 0x72, 0x61)
if err != nil {
return
}
err = en.WriteInt64(z.RowsAffected)
if err != nil {
err = msgp.WrapError(err, "RowsAffected")
return
}
// write "li"
err = en.Append(0xa2, 0x6c, 0x69)
if err != nil {
return
}
err = en.WriteInt64(z.LastInsertId)
if err != nil {
err = msgp.WrapError(err, "LastInsertId")
return
}
return
}
// MarshalMsg implements msgp.Marshaler
func (z Response) MarshalMsg(b []byte) (o []byte, err error) {
o = msgp.Require(b, z.Msgsize())
// map header, size 3
// string "i"
o = append(o, 0x83, 0xa1, 0x69)
o = msgp.AppendUint32(o, z.Id)
// string "ra"
o = append(o, 0xa2, 0x72, 0x61)
o = msgp.AppendInt64(o, z.RowsAffected)
// string "li"
o = append(o, 0xa2, 0x6c, 0x69)
o = msgp.AppendInt64(o, z.LastInsertId)
return
}
// UnmarshalMsg implements msgp.Unmarshaler
func (z *Response) UnmarshalMsg(bts []byte) (o []byte, err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, bts, err = msgp.ReadMapHeaderBytes(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, bts, err = msgp.ReadMapKeyZC(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
case "i":
z.Id, bts, err = msgp.ReadUint32Bytes(bts)
if err != nil {
err = msgp.WrapError(err, "Id")
return
}
case "ra":
z.RowsAffected, bts, err = msgp.ReadInt64Bytes(bts)
if err != nil {
err = msgp.WrapError(err, "RowsAffected")
return
}
case "li":
z.LastInsertId, bts, err = msgp.ReadInt64Bytes(bts)
if err != nil {
err = msgp.WrapError(err, "LastInsertId")
return
}
default:
bts, err = msgp.Skip(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
o = bts
return
}
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
func (z Response) Msgsize() (s int) {
s = 1 + 2 + msgp.Uint32Size + 3 + msgp.Int64Size + 3 + msgp.Int64Size
return
}

View file

@ -0,0 +1,123 @@
package proto
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
import (
"bytes"
"testing"
"github.com/tinylib/msgp/msgp"
)
func TestMarshalUnmarshalResponse(t *testing.T) {
v := Response{}
bts, err := v.MarshalMsg(nil)
if err != nil {
t.Fatal(err)
}
left, err := v.UnmarshalMsg(bts)
if err != nil {
t.Fatal(err)
}
if len(left) > 0 {
t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left)
}
left, err = msgp.Skip(bts)
if err != nil {
t.Fatal(err)
}
if len(left) > 0 {
t.Errorf("%d bytes left over after Skip(): %q", len(left), left)
}
}
func BenchmarkMarshalMsgResponse(b *testing.B) {
v := Response{}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
v.MarshalMsg(nil)
}
}
func BenchmarkAppendMsgResponse(b *testing.B) {
v := Response{}
bts := make([]byte, 0, v.Msgsize())
bts, _ = v.MarshalMsg(bts[0:0])
b.SetBytes(int64(len(bts)))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
bts, _ = v.MarshalMsg(bts[0:0])
}
}
func BenchmarkUnmarshalResponse(b *testing.B) {
v := Response{}
bts, _ := v.MarshalMsg(nil)
b.ReportAllocs()
b.SetBytes(int64(len(bts)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := v.UnmarshalMsg(bts)
if err != nil {
b.Fatal(err)
}
}
}
func TestEncodeDecodeResponse(t *testing.T) {
v := Response{}
var buf bytes.Buffer
msgp.Encode(&buf, &v)
m := v.Msgsize()
if buf.Len() > m {
t.Log("WARNING: TestEncodeDecodeResponse Msgsize() is inaccurate")
}
vn := Response{}
err := msgp.Decode(&buf, &vn)
if err != nil {
t.Error(err)
}
buf.Reset()
msgp.Encode(&buf, &v)
err = msgp.NewReader(&buf).Skip()
if err != nil {
t.Error(err)
}
}
func BenchmarkEncodeResponse(b *testing.B) {
v := Response{}
var buf bytes.Buffer
msgp.Encode(&buf, &v)
b.SetBytes(int64(buf.Len()))
en := msgp.NewWriter(msgp.Nowhere)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
v.EncodeMsg(en)
}
en.Flush()
}
func BenchmarkDecodeResponse(b *testing.B) {
v := Response{}
var buf bytes.Buffer
msgp.Encode(&buf, &v)
b.SetBytes(int64(buf.Len()))
rd := msgp.NewEndlessReader(buf.Bytes(), b)
dc := msgp.NewReader(rd)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
err := v.DecodeMsg(dc)
if err != nil {
b.Fatal(err)
}
}
}

View file

@ -37,6 +37,25 @@ func QueryHandler(db adapter.DBAdapter) http.Handler {
return return
} }
if query.Exec {
result, err := db.Exec(query)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/x-msgpack")
ra, _ := result.RowsAffected()
la, _ := result.LastInsertId()
res := proto.Response{
Id: 0,
RowsAffected: ra,
LastInsertId: la,
}
out, _ := res.MarshalMsg(nil)
w.Write(out)
return
}
rows, err := db.Query(query) rows, err := db.Query(query)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
@ -46,6 +65,7 @@ func QueryHandler(db adapter.DBAdapter) http.Handler {
w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("Content-Type", "application/x-msgpack") w.Header().Set("Content-Type", "application/x-msgpack")
flusher, ok := w.(http.Flusher) flusher, ok := w.(http.Flusher)
if !ok { if !ok {
http.Error(w, "expected http.ResponseWriter to be an http.Flusher", http.StatusInternalServerError) http.Error(w, "expected http.ResponseWriter to be an http.Flusher", http.StatusInternalServerError)

View file

@ -59,3 +59,38 @@ func TestQueryHandler(t *testing.T) {
result := proto.UnmarshalRows(res) result := proto.UnmarshalRows(res)
fmt.Println(result) fmt.Println(result)
} }
func TestQueryHandlerInsert(t *testing.T) {
adapter.RegisterDialect("sqlite3", adapter.CommonDialect{})
a := adapter.DBAdapter{Type: "sqlite3"}
db, err := a.Open("file::memory:?cache=shared")
if err != nil {
t.Fatalf("failed to open db: %v", err)
}
defer db.Close()
_, err = db.Exec("CREATE TABLE test (id INTEGER PRIMARY KEY, name BLOB, data TEXT)")
if err != nil {
t.Fatalf("failed to create table: %v", err)
}
data := proto.Request{
Id: 0,
Db: "file::memory:?cache=shared",
Commands: []proto.BuilderMethod{
{Method: "In", Args: []interface{}{"test t"}},
{Method: "Insert", Args: []interface{}{
map[string]interface{}{"name": "test", "data": "y"},
}},
},
}
body, _ := data.MarshalMsg(nil)
req, err := http.NewRequest("POST", "/", bytes.NewBuffer(body))
if err != nil {
t.Fatal(err)
}
rr := httptest.NewRecorder()
handler := QueryHandler(a)
handler.ServeHTTP(rr, req)
res, _ := io.ReadAll(rr.Result().Body)
result := proto.UnmarshalRows(res)
fmt.Println(result)
}