From 476afeac50d39d7521a90b84e3e64a5a214c7971 Mon Sep 17 00:00:00 2001 From: Anton Nesterov Date: Tue, 13 Aug 2024 14:13:15 +0200 Subject: [PATCH] [wip] proto: validate against builder methods Signed-off-by: Anton Nesterov --- pkg/builder/Builder.go | 6 + pkg/proto/go.mod | 2 +- pkg/proto/request.go | 12 + pkg/proto/request_gen.go | 35 ++- pkg/proto/response.go | 14 ++ pkg/proto/response_gen.go | 412 +++++++++++++++++++++++++++++++++ pkg/proto/response_gen_test.go | 236 +++++++++++++++++++ 7 files changed, 711 insertions(+), 6 deletions(-) create mode 100644 pkg/proto/response.go create mode 100644 pkg/proto/response_gen.go create mode 100644 pkg/proto/response_gen_test.go diff --git a/pkg/builder/Builder.go b/pkg/builder/Builder.go index c8e0214..16056f7 100644 --- a/pkg/builder/Builder.go +++ b/pkg/builder/Builder.go @@ -5,6 +5,12 @@ import ( "strings" ) +const ( + BUILDER_VERSION = "0.0.1" + BUILDER_CLIENT_METHODS = "In|Find|Select|Fields|Join|Group|Sort|Limit|Offset|Delete|Insert|Set|Update|OnConflict|DoUpdate|DoNothing" + BUILDER_SERVER_METHODS = "Sql" +) + type Builder struct { TableName string TableAlias string diff --git a/pkg/proto/go.mod b/pkg/proto/go.mod index fb6793d..94de343 100644 --- a/pkg/proto/go.mod +++ b/pkg/proto/go.mod @@ -18,7 +18,7 @@ replace l12.xyz/dal/filters v0.0.0 => ../filters require ( github.com/pkg/errors v0.9.1 // indirect - l12.xyz/dal/adapter v0.0.0 // indirect + l12.xyz/dal/adapter v0.0.0 l12.xyz/dal/filters v0.0.0 // indirect l12.xyz/dal/utils v0.0.0 // indirect ) diff --git a/pkg/proto/request.go b/pkg/proto/request.go index 07d1b81..29d6cbd 100644 --- a/pkg/proto/request.go +++ b/pkg/proto/request.go @@ -3,6 +3,8 @@ package proto import ( "fmt" "reflect" + "slices" + "strings" "l12.xyz/dal/adapter" "l12.xyz/dal/builder" @@ -16,13 +18,23 @@ type BuildCmd struct { } type Request struct { + Id uint32 `msg:"id"` Db string `msg:"db"` Commands []BuildCmd `msg:"commands"` } +var allowedMethods = strings.Split(builder.BUILDER_CLIENT_METHODS, "|") + func (q *Request) Parse(dialect adapter.Dialect) (adapter.Query, error) { b := builder.New(dialect) for _, cmd := range q.Commands { + if !slices.Contains(allowedMethods, cmd.Method) { + return adapter.Query{}, fmt.Errorf( + "method %s is not allowed, awailable methods are %v", + cmd.Method, + allowedMethods, + ) + } method := reflect.ValueOf(b).MethodByName(cmd.Method) if !method.IsValid() { return adapter.Query{}, fmt.Errorf("method %s not found", cmd.Method) diff --git a/pkg/proto/request_gen.go b/pkg/proto/request_gen.go index e4bec82..668ee67 100644 --- a/pkg/proto/request_gen.go +++ b/pkg/proto/request_gen.go @@ -195,6 +195,12 @@ func (z *Request) DecodeMsg(dc *msgp.Reader) (err error) { return } switch msgp.UnsafeString(field) { + case "id": + z.Id, err = dc.ReadUint32() + if err != nil { + err = msgp.WrapError(err, "Id") + return + } case "db": z.Db, err = dc.ReadString() if err != nil { @@ -275,9 +281,19 @@ func (z *Request) DecodeMsg(dc *msgp.Reader) (err error) { // EncodeMsg implements msgp.Encodable func (z *Request) EncodeMsg(en *msgp.Writer) (err error) { - // map header, size 2 + // map header, size 3 + // write "id" + err = en.Append(0x83, 0xa2, 0x69, 0x64) + if err != nil { + return + } + err = en.WriteUint32(z.Id) + if err != nil { + err = msgp.WrapError(err, "Id") + return + } // write "db" - err = en.Append(0x82, 0xa2, 0x64, 0x62) + err = en.Append(0xa2, 0x64, 0x62) if err != nil { return } @@ -332,9 +348,12 @@ func (z *Request) EncodeMsg(en *msgp.Writer) (err error) { // MarshalMsg implements msgp.Marshaler func (z *Request) MarshalMsg(b []byte) (o []byte, err error) { o = msgp.Require(b, z.Msgsize()) - // map header, size 2 + // map header, size 3 + // string "id" + o = append(o, 0x83, 0xa2, 0x69, 0x64) + o = msgp.AppendUint32(o, z.Id) // string "db" - o = append(o, 0x82, 0xa2, 0x64, 0x62) + o = append(o, 0xa2, 0x64, 0x62) o = msgp.AppendString(o, z.Db) // string "commands" o = append(o, 0xa8, 0x63, 0x6f, 0x6d, 0x6d, 0x61, 0x6e, 0x64, 0x73) @@ -376,6 +395,12 @@ func (z *Request) UnmarshalMsg(bts []byte) (o []byte, err error) { return } switch msgp.UnsafeString(field) { + case "id": + z.Id, bts, err = msgp.ReadUint32Bytes(bts) + if err != nil { + err = msgp.WrapError(err, "Id") + return + } case "db": z.Db, bts, err = msgp.ReadStringBytes(bts) if err != nil { @@ -457,7 +482,7 @@ func (z *Request) UnmarshalMsg(bts []byte) (o []byte, err error) { // Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message func (z *Request) Msgsize() (s int) { - s = 1 + 3 + msgp.StringPrefixSize + len(z.Db) + 9 + msgp.ArrayHeaderSize + s = 1 + 3 + msgp.Uint32Size + 3 + msgp.StringPrefixSize + len(z.Db) + 9 + msgp.ArrayHeaderSize for za0001 := range z.Commands { s += 1 + 7 + msgp.StringPrefixSize + len(z.Commands[za0001].Method) + 5 + msgp.ArrayHeaderSize for za0002 := range z.Commands[za0001].Args { diff --git a/pkg/proto/response.go b/pkg/proto/response.go new file mode 100644 index 0000000..385a96f --- /dev/null +++ b/pkg/proto/response.go @@ -0,0 +1,14 @@ +package proto + +//go:generate msgp + +type RequestError struct { + Message string `msg:"msg"` + ErrorCode int `msg:"error_code"` +} + +type Response struct { + Id uint32 `msg:"id"` + Result []interface{} `msg:"result"` + Error RequestError `msg:"error"` +} diff --git a/pkg/proto/response_gen.go b/pkg/proto/response_gen.go new file mode 100644 index 0000000..3f8196b --- /dev/null +++ b/pkg/proto/response_gen.go @@ -0,0 +1,412 @@ +package proto + +// Code generated by github.com/tinylib/msgp DO NOT EDIT. + +import ( + "github.com/tinylib/msgp/msgp" +) + +// DecodeMsg implements msgp.Decodable +func (z *RequestError) DecodeMsg(dc *msgp.Reader) (err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, err = dc.ReadMapHeader() + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, err = dc.ReadMapKeyPtr() + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "msg": + z.Message, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "Message") + return + } + case "error_code": + z.ErrorCode, err = dc.ReadInt() + if err != nil { + err = msgp.WrapError(err, "ErrorCode") + return + } + default: + err = dc.Skip() + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + return +} + +// EncodeMsg implements msgp.Encodable +func (z RequestError) EncodeMsg(en *msgp.Writer) (err error) { + // map header, size 2 + // write "msg" + err = en.Append(0x82, 0xa3, 0x6d, 0x73, 0x67) + if err != nil { + return + } + err = en.WriteString(z.Message) + if err != nil { + err = msgp.WrapError(err, "Message") + return + } + // write "error_code" + err = en.Append(0xaa, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x5f, 0x63, 0x6f, 0x64, 0x65) + if err != nil { + return + } + err = en.WriteInt(z.ErrorCode) + if err != nil { + err = msgp.WrapError(err, "ErrorCode") + return + } + return +} + +// MarshalMsg implements msgp.Marshaler +func (z RequestError) MarshalMsg(b []byte) (o []byte, err error) { + o = msgp.Require(b, z.Msgsize()) + // map header, size 2 + // string "msg" + o = append(o, 0x82, 0xa3, 0x6d, 0x73, 0x67) + o = msgp.AppendString(o, z.Message) + // string "error_code" + o = append(o, 0xaa, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x5f, 0x63, 0x6f, 0x64, 0x65) + o = msgp.AppendInt(o, z.ErrorCode) + return +} + +// UnmarshalMsg implements msgp.Unmarshaler +func (z *RequestError) UnmarshalMsg(bts []byte) (o []byte, err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, bts, err = msgp.ReadMapHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, bts, err = msgp.ReadMapKeyZC(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "msg": + z.Message, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "Message") + return + } + case "error_code": + z.ErrorCode, bts, err = msgp.ReadIntBytes(bts) + if err != nil { + err = msgp.WrapError(err, "ErrorCode") + return + } + default: + bts, err = msgp.Skip(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + o = bts + return +} + +// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message +func (z RequestError) Msgsize() (s int) { + s = 1 + 4 + msgp.StringPrefixSize + len(z.Message) + 11 + msgp.IntSize + return +} + +// DecodeMsg implements msgp.Decodable +func (z *Response) DecodeMsg(dc *msgp.Reader) (err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, err = dc.ReadMapHeader() + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, err = dc.ReadMapKeyPtr() + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "id": + z.Id, err = dc.ReadUint32() + if err != nil { + err = msgp.WrapError(err, "Id") + return + } + case "result": + var zb0002 uint32 + zb0002, err = dc.ReadArrayHeader() + if err != nil { + err = msgp.WrapError(err, "Result") + return + } + if cap(z.Result) >= int(zb0002) { + z.Result = (z.Result)[:zb0002] + } else { + z.Result = make([]interface{}, zb0002) + } + for za0001 := range z.Result { + z.Result[za0001], err = dc.ReadIntf() + if err != nil { + err = msgp.WrapError(err, "Result", za0001) + return + } + } + case "error": + var zb0003 uint32 + zb0003, err = dc.ReadMapHeader() + if err != nil { + err = msgp.WrapError(err, "Error") + return + } + for zb0003 > 0 { + zb0003-- + field, err = dc.ReadMapKeyPtr() + if err != nil { + err = msgp.WrapError(err, "Error") + return + } + switch msgp.UnsafeString(field) { + case "msg": + z.Error.Message, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "Error", "Message") + return + } + case "error_code": + z.Error.ErrorCode, err = dc.ReadInt() + if err != nil { + err = msgp.WrapError(err, "Error", "ErrorCode") + return + } + default: + err = dc.Skip() + if err != nil { + err = msgp.WrapError(err, "Error") + return + } + } + } + default: + err = dc.Skip() + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + return +} + +// EncodeMsg implements msgp.Encodable +func (z *Response) EncodeMsg(en *msgp.Writer) (err error) { + // map header, size 3 + // write "id" + err = en.Append(0x83, 0xa2, 0x69, 0x64) + if err != nil { + return + } + err = en.WriteUint32(z.Id) + if err != nil { + err = msgp.WrapError(err, "Id") + return + } + // write "result" + err = en.Append(0xa6, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74) + if err != nil { + return + } + err = en.WriteArrayHeader(uint32(len(z.Result))) + if err != nil { + err = msgp.WrapError(err, "Result") + return + } + for za0001 := range z.Result { + err = en.WriteIntf(z.Result[za0001]) + if err != nil { + err = msgp.WrapError(err, "Result", za0001) + return + } + } + // write "error" + err = en.Append(0xa5, 0x65, 0x72, 0x72, 0x6f, 0x72) + if err != nil { + return + } + // map header, size 2 + // write "msg" + err = en.Append(0x82, 0xa3, 0x6d, 0x73, 0x67) + if err != nil { + return + } + err = en.WriteString(z.Error.Message) + if err != nil { + err = msgp.WrapError(err, "Error", "Message") + return + } + // write "error_code" + err = en.Append(0xaa, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x5f, 0x63, 0x6f, 0x64, 0x65) + if err != nil { + return + } + err = en.WriteInt(z.Error.ErrorCode) + if err != nil { + err = msgp.WrapError(err, "Error", "ErrorCode") + return + } + return +} + +// MarshalMsg implements msgp.Marshaler +func (z *Response) MarshalMsg(b []byte) (o []byte, err error) { + o = msgp.Require(b, z.Msgsize()) + // map header, size 3 + // string "id" + o = append(o, 0x83, 0xa2, 0x69, 0x64) + o = msgp.AppendUint32(o, z.Id) + // string "result" + o = append(o, 0xa6, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74) + o = msgp.AppendArrayHeader(o, uint32(len(z.Result))) + for za0001 := range z.Result { + o, err = msgp.AppendIntf(o, z.Result[za0001]) + if err != nil { + err = msgp.WrapError(err, "Result", za0001) + return + } + } + // string "error" + o = append(o, 0xa5, 0x65, 0x72, 0x72, 0x6f, 0x72) + // map header, size 2 + // string "msg" + o = append(o, 0x82, 0xa3, 0x6d, 0x73, 0x67) + o = msgp.AppendString(o, z.Error.Message) + // string "error_code" + o = append(o, 0xaa, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x5f, 0x63, 0x6f, 0x64, 0x65) + o = msgp.AppendInt(o, z.Error.ErrorCode) + return +} + +// UnmarshalMsg implements msgp.Unmarshaler +func (z *Response) UnmarshalMsg(bts []byte) (o []byte, err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, bts, err = msgp.ReadMapHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, bts, err = msgp.ReadMapKeyZC(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "id": + z.Id, bts, err = msgp.ReadUint32Bytes(bts) + if err != nil { + err = msgp.WrapError(err, "Id") + return + } + case "result": + var zb0002 uint32 + zb0002, bts, err = msgp.ReadArrayHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err, "Result") + return + } + if cap(z.Result) >= int(zb0002) { + z.Result = (z.Result)[:zb0002] + } else { + z.Result = make([]interface{}, zb0002) + } + for za0001 := range z.Result { + z.Result[za0001], bts, err = msgp.ReadIntfBytes(bts) + if err != nil { + err = msgp.WrapError(err, "Result", za0001) + return + } + } + case "error": + var zb0003 uint32 + zb0003, bts, err = msgp.ReadMapHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err, "Error") + return + } + for zb0003 > 0 { + zb0003-- + field, bts, err = msgp.ReadMapKeyZC(bts) + if err != nil { + err = msgp.WrapError(err, "Error") + return + } + switch msgp.UnsafeString(field) { + case "msg": + z.Error.Message, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "Error", "Message") + return + } + case "error_code": + z.Error.ErrorCode, bts, err = msgp.ReadIntBytes(bts) + if err != nil { + err = msgp.WrapError(err, "Error", "ErrorCode") + return + } + default: + bts, err = msgp.Skip(bts) + if err != nil { + err = msgp.WrapError(err, "Error") + return + } + } + } + default: + bts, err = msgp.Skip(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + o = bts + return +} + +// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message +func (z *Response) Msgsize() (s int) { + s = 1 + 3 + msgp.Uint32Size + 7 + msgp.ArrayHeaderSize + for za0001 := range z.Result { + s += msgp.GuessSize(z.Result[za0001]) + } + s += 6 + 1 + 4 + msgp.StringPrefixSize + len(z.Error.Message) + 11 + msgp.IntSize + return +} diff --git a/pkg/proto/response_gen_test.go b/pkg/proto/response_gen_test.go new file mode 100644 index 0000000..02380f7 --- /dev/null +++ b/pkg/proto/response_gen_test.go @@ -0,0 +1,236 @@ +package proto + +// Code generated by github.com/tinylib/msgp DO NOT EDIT. + +import ( + "bytes" + "testing" + + "github.com/tinylib/msgp/msgp" +) + +func TestMarshalUnmarshalRequestError(t *testing.T) { + v := RequestError{} + bts, err := v.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + left, err := v.UnmarshalMsg(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left) + } + + left, err = msgp.Skip(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after Skip(): %q", len(left), left) + } +} + +func BenchmarkMarshalMsgRequestError(b *testing.B) { + v := RequestError{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.MarshalMsg(nil) + } +} + +func BenchmarkAppendMsgRequestError(b *testing.B) { + v := RequestError{} + bts := make([]byte, 0, v.Msgsize()) + bts, _ = v.MarshalMsg(bts[0:0]) + b.SetBytes(int64(len(bts))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bts, _ = v.MarshalMsg(bts[0:0]) + } +} + +func BenchmarkUnmarshalRequestError(b *testing.B) { + v := RequestError{} + bts, _ := v.MarshalMsg(nil) + b.ReportAllocs() + b.SetBytes(int64(len(bts))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := v.UnmarshalMsg(bts) + if err != nil { + b.Fatal(err) + } + } +} + +func TestEncodeDecodeRequestError(t *testing.T) { + v := RequestError{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + + m := v.Msgsize() + if buf.Len() > m { + t.Log("WARNING: TestEncodeDecodeRequestError Msgsize() is inaccurate") + } + + vn := RequestError{} + err := msgp.Decode(&buf, &vn) + if err != nil { + t.Error(err) + } + + buf.Reset() + msgp.Encode(&buf, &v) + err = msgp.NewReader(&buf).Skip() + if err != nil { + t.Error(err) + } +} + +func BenchmarkEncodeRequestError(b *testing.B) { + v := RequestError{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + en := msgp.NewWriter(msgp.Nowhere) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.EncodeMsg(en) + } + en.Flush() +} + +func BenchmarkDecodeRequestError(b *testing.B) { + v := RequestError{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + rd := msgp.NewEndlessReader(buf.Bytes(), b) + dc := msgp.NewReader(rd) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := v.DecodeMsg(dc) + if err != nil { + b.Fatal(err) + } + } +} + +func TestMarshalUnmarshalResponse(t *testing.T) { + v := Response{} + bts, err := v.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + left, err := v.UnmarshalMsg(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left) + } + + left, err = msgp.Skip(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after Skip(): %q", len(left), left) + } +} + +func BenchmarkMarshalMsgResponse(b *testing.B) { + v := Response{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.MarshalMsg(nil) + } +} + +func BenchmarkAppendMsgResponse(b *testing.B) { + v := Response{} + bts := make([]byte, 0, v.Msgsize()) + bts, _ = v.MarshalMsg(bts[0:0]) + b.SetBytes(int64(len(bts))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bts, _ = v.MarshalMsg(bts[0:0]) + } +} + +func BenchmarkUnmarshalResponse(b *testing.B) { + v := Response{} + bts, _ := v.MarshalMsg(nil) + b.ReportAllocs() + b.SetBytes(int64(len(bts))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := v.UnmarshalMsg(bts) + if err != nil { + b.Fatal(err) + } + } +} + +func TestEncodeDecodeResponse(t *testing.T) { + v := Response{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + + m := v.Msgsize() + if buf.Len() > m { + t.Log("WARNING: TestEncodeDecodeResponse Msgsize() is inaccurate") + } + + vn := Response{} + err := msgp.Decode(&buf, &vn) + if err != nil { + t.Error(err) + } + + buf.Reset() + msgp.Encode(&buf, &v) + err = msgp.NewReader(&buf).Skip() + if err != nil { + t.Error(err) + } +} + +func BenchmarkEncodeResponse(b *testing.B) { + v := Response{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + en := msgp.NewWriter(msgp.Nowhere) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.EncodeMsg(en) + } + en.Flush() +} + +func BenchmarkDecodeResponse(b *testing.B) { + v := Response{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + rd := msgp.NewEndlessReader(buf.Bytes(), b) + dc := msgp.NewReader(rd) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := v.DecodeMsg(dc) + if err != nil { + b.Fatal(err) + } + } +}