From 9d13824fdff7156fd8ee250dcb27d0bda415c153 Mon Sep 17 00:00:00 2001 From: Anton Nesterov Date: Thu, 15 Aug 2024 17:52:21 +0200 Subject: [PATCH] [wip] insert/update/delete operations Signed-off-by: Anton Nesterov --- dal/Builder.ts | 2 +- pkg/__test__/builder_test.go | 67 ++++++++++++- pkg/adapter/CommonDialect.go | 16 ++-- pkg/adapter/DBAdapter.go | 12 +++ pkg/adapter/types.go | 9 +- pkg/builder/Builder.go | 45 ++++++--- pkg/builder/Builder_test.go | 10 ++ pkg/builder/convert_update.go | 3 +- pkg/proto/request.go | 13 ++- pkg/proto/request_gen_test.go | 16 ++-- pkg/proto/response.go | 9 ++ pkg/proto/response_gen.go | 160 +++++++++++++++++++++++++++++++ pkg/proto/response_gen_test.go | 123 ++++++++++++++++++++++++ pkg/server/query_handler.go | 20 ++++ pkg/server/query_handler_test.go | 35 +++++++ 15 files changed, 502 insertions(+), 38 deletions(-) create mode 100644 pkg/proto/response.go create mode 100644 pkg/proto/response_gen.go create mode 100644 pkg/proto/response_gen_test.go diff --git a/dal/Builder.ts b/dal/Builder.ts index b7e0d7a..38eabca 100644 --- a/dal/Builder.ts +++ b/dal/Builder.ts @@ -61,7 +61,7 @@ export default class Builder any> { }); } private formatRow(data: unknown[]) { - if (!this.dtoTemplate) { + if (!this.dtoTemplate || this.dtoTemplate === Object) { return data; } const instance = new this.dtoTemplate(data); diff --git a/pkg/__test__/builder_test.go b/pkg/__test__/builder_test.go index 17414c9..7a71dc4 100644 --- a/pkg/__test__/builder_test.go +++ b/pkg/__test__/builder_test.go @@ -21,11 +21,12 @@ func TestBuilderBasic(t *testing.T) { t.Fatalf("failed to create table: %v", err) } - insert, values := builder.New(adapter.CommonDialect{}).In("test t").Insert([]builder.Map{ + inserts := []builder.Map{ {"name": "a"}, {"name": 'b'}, - }).Sql() - fmt.Println(insert, values) + } + insert, values := builder.New(adapter.CommonDialect{}).In("test t").Insert(inserts...).Sql() + _, err = db.Exec(insert, values...) if err != nil { t.Fatalf("failed to insert data: %v", err) @@ -52,3 +53,63 @@ func TestBuilderBasic(t *testing.T) { fmt.Printf("id: %d, name: %s\n", id, name) } } + +func TestBuilderSet(t *testing.T) { + a := adapter.DBAdapter{Type: "sqlite3"} + db, err := a.Open("file::memory:?cache=shared") + if err != nil { + t.Fatalf("failed to open db: %v", err) + } + defer db.Close() + _, err = db.Exec("CREATE TABLE test (id INTEGER PRIMARY KEY, name BLOB)") + if err != nil { + t.Fatalf("failed to create table: %v", err) + } + + inserts := []builder.Map{ + {"name": "a"}, + {"name": 'b'}, + } + insert, values := builder.New(adapter.CommonDialect{}).In("test t").Insert(inserts...).Sql() + + _, err = db.Exec(insert, values...) + if err != nil { + t.Fatalf("failed to insert data: %v", err) + } + + b := builder.New(adapter.CommonDialect{}).In("test t") + b.Find(builder.Find{"id": builder.Is{"$eq": 2}}) + b.Set(builder.Map{"name": "c"}) + b.Tx() + expr, values := b.Sql() + fmt.Println(expr, values) + _, err = a.Exec(adapter.Query{ + Db: "file::memory:?cache=shared", + Expression: expr, + Data: values, + Transaction: b.Transaction, + }) + if err != nil { + t.Fatalf("failed to query data: %v", err) + } + b = builder.New(adapter.CommonDialect{}).In("test t") + b.Find(builder.Find{}) + expr, values = b.Sql() + fmt.Println(expr, values) + rows, _ := a.Query(adapter.Query{ + Db: "file::memory:?cache=shared", + Expression: expr, + Data: values, + Transaction: b.Transaction, + }) + defer rows.Close() + for rows.Next() { + var id int + var name string + err = rows.Scan(&id, &name) + if err != nil { + t.Fatalf("failed to scan row: %v", err) + } + fmt.Printf("id: %d, name: %s\n", id, name) + } +} diff --git a/pkg/adapter/CommonDialect.go b/pkg/adapter/CommonDialect.go index 5228264..42e54c9 100644 --- a/pkg/adapter/CommonDialect.go +++ b/pkg/adapter/CommonDialect.go @@ -18,16 +18,16 @@ type CommonDialect struct { } func (c CommonDialect) New(opts DialectOpts) Dialect { - tn := opts["TableName"] - if tn == "" { + tn, ok := opts["TableName"] + if !ok { tn = c.TableName } - ta := opts["TableAlias"] - if ta == "" { + ta, ok := opts["TableAlias"] + if !ok { ta = c.TableAlias } - fn := opts["FieldName"] - if fn == "" { + fn, ok := opts["FieldName"] + if !ok { fn = c.FieldName } return CommonDialect{ @@ -41,6 +41,10 @@ func (c CommonDialect) GetTableName() string { return c.TableName } +func (c CommonDialect) GetTableAlias() string { + return c.TableAlias +} + func (c CommonDialect) GetFieldName() string { if strings.Contains(c.FieldName, ".") { return c.FieldName diff --git a/pkg/adapter/DBAdapter.go b/pkg/adapter/DBAdapter.go index 508bbe0..02e5525 100644 --- a/pkg/adapter/DBAdapter.go +++ b/pkg/adapter/DBAdapter.go @@ -110,6 +110,12 @@ func (a *DBAdapter) Query(req Query) (*sql.Rows, error) { if err != nil { return nil, err } + if req.Transaction { + tx, _ := db.Begin() + rows, err := tx.Query(req.Expression, req.Data...) + tx.Commit() + return rows, err + } sfmt, err := db.Prepare(req.Expression) if err != nil { return nil, err @@ -122,6 +128,12 @@ func (a *DBAdapter) Exec(req Query) (sql.Result, error) { if err != nil { return nil, err } + if req.Transaction { + tx, _ := db.Begin() + result, err := tx.Exec(req.Expression, req.Data...) + tx.Commit() + return result, err + } sfmt, err := db.Prepare(req.Expression) if err != nil { return nil, err diff --git a/pkg/adapter/types.go b/pkg/adapter/types.go index 631dfda..02eac73 100644 --- a/pkg/adapter/types.go +++ b/pkg/adapter/types.go @@ -1,9 +1,11 @@ package adapter type Query struct { - Db string `json:"db"` - Expression string `json:"expr"` - Data []interface{} `json:"data"` + Db string `json:"db"` + Expression string `json:"expr"` + Data []interface{} `json:"data"` + Transaction bool `json:"transaction"` + Exec bool `json:"exec"` } type DialectOpts map[string]string @@ -14,6 +16,7 @@ Dialect interface provides general utilities for normalizing values for particul type Dialect interface { New(opts DialectOpts) Dialect GetTableName() string + GetTableAlias() string GetFieldName() string GetColumnName(key string) string NormalizeValue(interface{}) interface{} diff --git a/pkg/builder/Builder.go b/pkg/builder/Builder.go index 16056f7..65da304 100644 --- a/pkg/builder/Builder.go +++ b/pkg/builder/Builder.go @@ -12,10 +12,12 @@ const ( ) type Builder struct { - TableName string - TableAlias string - Parts SQLParts - Dialect Dialect + TableName string + TableAlias string + Parts SQLParts + Dialect Dialect + LastQuery Find + Transaction bool } type SQLParts struct { @@ -56,19 +58,23 @@ func (b *Builder) In(table string) *Builder { } func (b *Builder) Find(query Find) *Builder { + b.LastQuery = query b.Parts.FiterExp, b.Parts.Values = covertFind( b.Dialect, query, ) - if b.Parts.Operation == "" { + if len(b.Parts.Operation) == 0 { b.Parts.Operation = "SELECT" } - if b.Parts.HavingExp == "" { + if len(b.Parts.HavingExp) == 0 { b.Parts.HavingExp = "WHERE" } - if b.Parts.FieldsExp == "" { + if len(b.Parts.FieldsExp) == 0 { b.Parts.FieldsExp = "*" } + if len(b.Parts.FiterExp) == 0 { + b.Parts.HavingExp = "" + } return b } @@ -118,7 +124,7 @@ func (b *Builder) Delete() *Builder { return b } -func (b *Builder) Insert(inserts []Map) *Builder { +func (b *Builder) Insert(inserts ...Map) *Builder { insertData, _ := convertInsert(b.Dialect, inserts) b.Parts = SQLParts{ Operation: "INSERT INTO", @@ -128,11 +134,13 @@ func (b *Builder) Insert(inserts []Map) *Builder { } func (b *Builder) Set(updates Map) *Builder { + b.Dialect = b.Dialect.New(DialectOpts{ + "TableAlias": "", + }) updateData := convertUpdate(b.Dialect, updates) - b.Parts = SQLParts{ - Operation: "UPDATE", - Update: updateData, - } + b.Find(b.LastQuery) + b.Parts.Operation = "UPDATE" + b.Parts.Update = updateData return b } @@ -168,6 +176,11 @@ func (b *Builder) DoNothing() *Builder { return b } +func (b *Builder) Tx() *Builder { + b.Transaction = true + return b +} + func (b *Builder) Sql() (string, []interface{}) { operation := b.Parts.Operation switch { @@ -202,11 +215,17 @@ func (b *Builder) Sql() (string, []interface{}) { case operation == "INSERT INTO": return b.Parts.Insert.Statement, b.Parts.Insert.Values case operation == "UPDATE": + values := append(b.Parts.Update.Values, b.Parts.Values...) return unspace(strings.Join([]string{ b.Parts.Update.Statement, + b.Parts.HavingExp, + b.Parts.FiterExp, + b.Parts.OrderExp, + b.Parts.LimitExp, + b.Parts.OffsetExp, b.Parts.Update.Upsert, b.Parts.Update.UpsertExp, - }, " ")), b.Parts.Update.Values + }, " ")), values default: return "", nil } diff --git a/pkg/builder/Builder_test.go b/pkg/builder/Builder_test.go index 0ef89db..b66cc5e 100644 --- a/pkg/builder/Builder_test.go +++ b/pkg/builder/Builder_test.go @@ -5,6 +5,16 @@ import ( "testing" ) +func TestBuilderFindEmpty(t *testing.T) { + db := New(CommonDialect{}) + db.In("table t").Find(Query{}).Limit(10) + expect := "SELECT * FROM table t LIMIT 10" + result, _ := db.Sql() + if result != expect { + t.Errorf(`Expected: "%s", Got: %s`, expect, result) + } +} + func TestBuilderFind(t *testing.T) { db := New(CommonDialect{}) db.In("table t").Find(Query{ diff --git a/pkg/builder/convert_update.go b/pkg/builder/convert_update.go index c20e809..d101dca 100644 --- a/pkg/builder/convert_update.go +++ b/pkg/builder/convert_update.go @@ -21,7 +21,8 @@ func convertUpdate(ctx Dialect, updates Map) UpdateData { values = append(values, updates[key]) } sfmt := fmt.Sprintf( - "UPDATE %s SET %s", ctx.GetTableName(), + "UPDATE %s SET %s", + ctx.GetTableName(), strings.Join(set, ","), ) return UpdateData{ diff --git a/pkg/proto/request.go b/pkg/proto/request.go index 2c497ad..9f0e238 100644 --- a/pkg/proto/request.go +++ b/pkg/proto/request.go @@ -33,6 +33,7 @@ func (q *Request) Parse(dialect adapter.Dialect) (adapter.Query, error) { return adapter.Query{}, fmt.Errorf("Request format: commands are required") } b := builder.New(dialect) + exec := false for _, cmd := range q.Commands { if !slices.Contains(allowedMethods, cmd.Method) { return adapter.Query{}, fmt.Errorf( @@ -45,16 +46,22 @@ func (q *Request) Parse(dialect adapter.Dialect) (adapter.Query, error) { if !method.IsValid() { return adapter.Query{}, fmt.Errorf("method %s not found", cmd.Method) } + if cmd.Method == "Insert" || cmd.Method == "Set" || cmd.Method == "Delete" { + exec = true + } args := make([]reflect.Value, len(cmd.Args)) for i, arg := range cmd.Args { args[i] = reflect.ValueOf(arg) } + fmt.Print(exec, cmd.Method, args) method.Call(args) } expr, data := b.Sql() return adapter.Query{ - Db: q.Db, - Expression: expr, - Data: data, + Db: q.Db, + Expression: expr, + Data: data, + Transaction: b.Transaction, + Exec: exec, }, nil } diff --git a/pkg/proto/request_gen_test.go b/pkg/proto/request_gen_test.go index 18037cf..7335413 100644 --- a/pkg/proto/request_gen_test.go +++ b/pkg/proto/request_gen_test.go @@ -9,7 +9,7 @@ import ( "github.com/tinylib/msgp/msgp" ) -func TestMarshalUnmarshalBuildCmd(t *testing.T) { +func TestMarshalUnmarshalBuilderMethod(t *testing.T) { v := BuilderMethod{} bts, err := v.MarshalMsg(nil) if err != nil { @@ -32,7 +32,7 @@ func TestMarshalUnmarshalBuildCmd(t *testing.T) { } } -func BenchmarkMarshalMsgBuildCmd(b *testing.B) { +func BenchmarkMarshalMsgBuilderMethod(b *testing.B) { v := BuilderMethod{} b.ReportAllocs() b.ResetTimer() @@ -41,7 +41,7 @@ func BenchmarkMarshalMsgBuildCmd(b *testing.B) { } } -func BenchmarkAppendMsgBuildCmd(b *testing.B) { +func BenchmarkAppendMsgBuilderMethod(b *testing.B) { v := BuilderMethod{} bts := make([]byte, 0, v.Msgsize()) bts, _ = v.MarshalMsg(bts[0:0]) @@ -53,7 +53,7 @@ func BenchmarkAppendMsgBuildCmd(b *testing.B) { } } -func BenchmarkUnmarshalBuildCmd(b *testing.B) { +func BenchmarkUnmarshalBuilderMethod(b *testing.B) { v := BuilderMethod{} bts, _ := v.MarshalMsg(nil) b.ReportAllocs() @@ -67,14 +67,14 @@ func BenchmarkUnmarshalBuildCmd(b *testing.B) { } } -func TestEncodeDecodeBuildCmd(t *testing.T) { +func TestEncodeDecodeBuilderMethod(t *testing.T) { v := BuilderMethod{} var buf bytes.Buffer msgp.Encode(&buf, &v) m := v.Msgsize() if buf.Len() > m { - t.Log("WARNING: TestEncodeDecodeBuildCmd Msgsize() is inaccurate") + t.Log("WARNING: TestEncodeDecodeBuilderMethod Msgsize() is inaccurate") } vn := BuilderMethod{} @@ -91,7 +91,7 @@ func TestEncodeDecodeBuildCmd(t *testing.T) { } } -func BenchmarkEncodeBuildCmd(b *testing.B) { +func BenchmarkEncodeBuilderMethod(b *testing.B) { v := BuilderMethod{} var buf bytes.Buffer msgp.Encode(&buf, &v) @@ -105,7 +105,7 @@ func BenchmarkEncodeBuildCmd(b *testing.B) { en.Flush() } -func BenchmarkDecodeBuildCmd(b *testing.B) { +func BenchmarkDecodeBuilderMethod(b *testing.B) { v := BuilderMethod{} var buf bytes.Buffer msgp.Encode(&buf, &v) diff --git a/pkg/proto/response.go b/pkg/proto/response.go new file mode 100644 index 0000000..70cf4dc --- /dev/null +++ b/pkg/proto/response.go @@ -0,0 +1,9 @@ +package proto + +//go:generate msgp + +type Response struct { + Id uint32 `msg:"i"` + RowsAffected int64 `msg:"ra"` + LastInsertId int64 `msg:"li"` +} diff --git a/pkg/proto/response_gen.go b/pkg/proto/response_gen.go new file mode 100644 index 0000000..2d150ee --- /dev/null +++ b/pkg/proto/response_gen.go @@ -0,0 +1,160 @@ +package proto + +// Code generated by github.com/tinylib/msgp DO NOT EDIT. + +import ( + "github.com/tinylib/msgp/msgp" +) + +// DecodeMsg implements msgp.Decodable +func (z *Response) DecodeMsg(dc *msgp.Reader) (err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, err = dc.ReadMapHeader() + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, err = dc.ReadMapKeyPtr() + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "i": + z.Id, err = dc.ReadUint32() + if err != nil { + err = msgp.WrapError(err, "Id") + return + } + case "ra": + z.RowsAffected, err = dc.ReadInt64() + if err != nil { + err = msgp.WrapError(err, "RowsAffected") + return + } + case "li": + z.LastInsertId, err = dc.ReadInt64() + if err != nil { + err = msgp.WrapError(err, "LastInsertId") + return + } + default: + err = dc.Skip() + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + return +} + +// EncodeMsg implements msgp.Encodable +func (z Response) EncodeMsg(en *msgp.Writer) (err error) { + // map header, size 3 + // write "i" + err = en.Append(0x83, 0xa1, 0x69) + if err != nil { + return + } + err = en.WriteUint32(z.Id) + if err != nil { + err = msgp.WrapError(err, "Id") + return + } + // write "ra" + err = en.Append(0xa2, 0x72, 0x61) + if err != nil { + return + } + err = en.WriteInt64(z.RowsAffected) + if err != nil { + err = msgp.WrapError(err, "RowsAffected") + return + } + // write "li" + err = en.Append(0xa2, 0x6c, 0x69) + if err != nil { + return + } + err = en.WriteInt64(z.LastInsertId) + if err != nil { + err = msgp.WrapError(err, "LastInsertId") + return + } + return +} + +// MarshalMsg implements msgp.Marshaler +func (z Response) MarshalMsg(b []byte) (o []byte, err error) { + o = msgp.Require(b, z.Msgsize()) + // map header, size 3 + // string "i" + o = append(o, 0x83, 0xa1, 0x69) + o = msgp.AppendUint32(o, z.Id) + // string "ra" + o = append(o, 0xa2, 0x72, 0x61) + o = msgp.AppendInt64(o, z.RowsAffected) + // string "li" + o = append(o, 0xa2, 0x6c, 0x69) + o = msgp.AppendInt64(o, z.LastInsertId) + return +} + +// UnmarshalMsg implements msgp.Unmarshaler +func (z *Response) UnmarshalMsg(bts []byte) (o []byte, err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, bts, err = msgp.ReadMapHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, bts, err = msgp.ReadMapKeyZC(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "i": + z.Id, bts, err = msgp.ReadUint32Bytes(bts) + if err != nil { + err = msgp.WrapError(err, "Id") + return + } + case "ra": + z.RowsAffected, bts, err = msgp.ReadInt64Bytes(bts) + if err != nil { + err = msgp.WrapError(err, "RowsAffected") + return + } + case "li": + z.LastInsertId, bts, err = msgp.ReadInt64Bytes(bts) + if err != nil { + err = msgp.WrapError(err, "LastInsertId") + return + } + default: + bts, err = msgp.Skip(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + o = bts + return +} + +// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message +func (z Response) Msgsize() (s int) { + s = 1 + 2 + msgp.Uint32Size + 3 + msgp.Int64Size + 3 + msgp.Int64Size + return +} diff --git a/pkg/proto/response_gen_test.go b/pkg/proto/response_gen_test.go new file mode 100644 index 0000000..ea186a0 --- /dev/null +++ b/pkg/proto/response_gen_test.go @@ -0,0 +1,123 @@ +package proto + +// Code generated by github.com/tinylib/msgp DO NOT EDIT. + +import ( + "bytes" + "testing" + + "github.com/tinylib/msgp/msgp" +) + +func TestMarshalUnmarshalResponse(t *testing.T) { + v := Response{} + bts, err := v.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + left, err := v.UnmarshalMsg(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left) + } + + left, err = msgp.Skip(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after Skip(): %q", len(left), left) + } +} + +func BenchmarkMarshalMsgResponse(b *testing.B) { + v := Response{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.MarshalMsg(nil) + } +} + +func BenchmarkAppendMsgResponse(b *testing.B) { + v := Response{} + bts := make([]byte, 0, v.Msgsize()) + bts, _ = v.MarshalMsg(bts[0:0]) + b.SetBytes(int64(len(bts))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bts, _ = v.MarshalMsg(bts[0:0]) + } +} + +func BenchmarkUnmarshalResponse(b *testing.B) { + v := Response{} + bts, _ := v.MarshalMsg(nil) + b.ReportAllocs() + b.SetBytes(int64(len(bts))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := v.UnmarshalMsg(bts) + if err != nil { + b.Fatal(err) + } + } +} + +func TestEncodeDecodeResponse(t *testing.T) { + v := Response{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + + m := v.Msgsize() + if buf.Len() > m { + t.Log("WARNING: TestEncodeDecodeResponse Msgsize() is inaccurate") + } + + vn := Response{} + err := msgp.Decode(&buf, &vn) + if err != nil { + t.Error(err) + } + + buf.Reset() + msgp.Encode(&buf, &v) + err = msgp.NewReader(&buf).Skip() + if err != nil { + t.Error(err) + } +} + +func BenchmarkEncodeResponse(b *testing.B) { + v := Response{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + en := msgp.NewWriter(msgp.Nowhere) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.EncodeMsg(en) + } + en.Flush() +} + +func BenchmarkDecodeResponse(b *testing.B) { + v := Response{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + rd := msgp.NewEndlessReader(buf.Bytes(), b) + dc := msgp.NewReader(rd) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := v.DecodeMsg(dc) + if err != nil { + b.Fatal(err) + } + } +} diff --git a/pkg/server/query_handler.go b/pkg/server/query_handler.go index a1738ff..11b2c4d 100644 --- a/pkg/server/query_handler.go +++ b/pkg/server/query_handler.go @@ -37,6 +37,25 @@ func QueryHandler(db adapter.DBAdapter) http.Handler { return } + if query.Exec { + result, err := db.Exec(query) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/x-msgpack") + ra, _ := result.RowsAffected() + la, _ := result.LastInsertId() + res := proto.Response{ + Id: 0, + RowsAffected: ra, + LastInsertId: la, + } + out, _ := res.MarshalMsg(nil) + w.Write(out) + return + } + rows, err := db.Query(query) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -46,6 +65,7 @@ func QueryHandler(db adapter.DBAdapter) http.Handler { w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Set("Content-Type", "application/x-msgpack") + flusher, ok := w.(http.Flusher) if !ok { http.Error(w, "expected http.ResponseWriter to be an http.Flusher", http.StatusInternalServerError) diff --git a/pkg/server/query_handler_test.go b/pkg/server/query_handler_test.go index ee34ef2..e4241be 100644 --- a/pkg/server/query_handler_test.go +++ b/pkg/server/query_handler_test.go @@ -59,3 +59,38 @@ func TestQueryHandler(t *testing.T) { result := proto.UnmarshalRows(res) fmt.Println(result) } + +func TestQueryHandlerInsert(t *testing.T) { + adapter.RegisterDialect("sqlite3", adapter.CommonDialect{}) + a := adapter.DBAdapter{Type: "sqlite3"} + db, err := a.Open("file::memory:?cache=shared") + if err != nil { + t.Fatalf("failed to open db: %v", err) + } + defer db.Close() + _, err = db.Exec("CREATE TABLE test (id INTEGER PRIMARY KEY, name BLOB, data TEXT)") + if err != nil { + t.Fatalf("failed to create table: %v", err) + } + data := proto.Request{ + Id: 0, + Db: "file::memory:?cache=shared", + Commands: []proto.BuilderMethod{ + {Method: "In", Args: []interface{}{"test t"}}, + {Method: "Insert", Args: []interface{}{ + map[string]interface{}{"name": "test", "data": "y"}, + }}, + }, + } + body, _ := data.MarshalMsg(nil) + req, err := http.NewRequest("POST", "/", bytes.NewBuffer(body)) + if err != nil { + t.Fatal(err) + } + rr := httptest.NewRecorder() + handler := QueryHandler(a) + handler.ServeHTTP(rr, req) + res, _ := io.ReadAll(rr.Result().Body) + result := proto.UnmarshalRows(res) + fmt.Println(result) +}