[wip] proto: validate against builder methods

Signed-off-by: Anton Nesterov <anton@demiurg.io>
This commit is contained in:
Anton Nesterov 2024-08-13 14:13:15 +02:00
parent c08aa3e104
commit 476afeac50
No known key found for this signature in database
GPG key ID: 59121E8AE2851FB5
7 changed files with 711 additions and 6 deletions

View file

@ -5,6 +5,12 @@ import (
"strings" "strings"
) )
const (
BUILDER_VERSION = "0.0.1"
BUILDER_CLIENT_METHODS = "In|Find|Select|Fields|Join|Group|Sort|Limit|Offset|Delete|Insert|Set|Update|OnConflict|DoUpdate|DoNothing"
BUILDER_SERVER_METHODS = "Sql"
)
type Builder struct { type Builder struct {
TableName string TableName string
TableAlias string TableAlias string

View file

@ -18,7 +18,7 @@ replace l12.xyz/dal/filters v0.0.0 => ../filters
require ( require (
github.com/pkg/errors v0.9.1 // indirect github.com/pkg/errors v0.9.1 // indirect
l12.xyz/dal/adapter v0.0.0 // indirect l12.xyz/dal/adapter v0.0.0
l12.xyz/dal/filters v0.0.0 // indirect l12.xyz/dal/filters v0.0.0 // indirect
l12.xyz/dal/utils v0.0.0 // indirect l12.xyz/dal/utils v0.0.0 // indirect
) )

View file

@ -3,6 +3,8 @@ package proto
import ( import (
"fmt" "fmt"
"reflect" "reflect"
"slices"
"strings"
"l12.xyz/dal/adapter" "l12.xyz/dal/adapter"
"l12.xyz/dal/builder" "l12.xyz/dal/builder"
@ -16,13 +18,23 @@ type BuildCmd struct {
} }
type Request struct { type Request struct {
Id uint32 `msg:"id"`
Db string `msg:"db"` Db string `msg:"db"`
Commands []BuildCmd `msg:"commands"` Commands []BuildCmd `msg:"commands"`
} }
var allowedMethods = strings.Split(builder.BUILDER_CLIENT_METHODS, "|")
func (q *Request) Parse(dialect adapter.Dialect) (adapter.Query, error) { func (q *Request) Parse(dialect adapter.Dialect) (adapter.Query, error) {
b := builder.New(dialect) b := builder.New(dialect)
for _, cmd := range q.Commands { for _, cmd := range q.Commands {
if !slices.Contains(allowedMethods, cmd.Method) {
return adapter.Query{}, fmt.Errorf(
"method %s is not allowed, awailable methods are %v",
cmd.Method,
allowedMethods,
)
}
method := reflect.ValueOf(b).MethodByName(cmd.Method) method := reflect.ValueOf(b).MethodByName(cmd.Method)
if !method.IsValid() { if !method.IsValid() {
return adapter.Query{}, fmt.Errorf("method %s not found", cmd.Method) return adapter.Query{}, fmt.Errorf("method %s not found", cmd.Method)

View file

@ -195,6 +195,12 @@ func (z *Request) DecodeMsg(dc *msgp.Reader) (err error) {
return return
} }
switch msgp.UnsafeString(field) { switch msgp.UnsafeString(field) {
case "id":
z.Id, err = dc.ReadUint32()
if err != nil {
err = msgp.WrapError(err, "Id")
return
}
case "db": case "db":
z.Db, err = dc.ReadString() z.Db, err = dc.ReadString()
if err != nil { if err != nil {
@ -275,9 +281,19 @@ func (z *Request) DecodeMsg(dc *msgp.Reader) (err error) {
// EncodeMsg implements msgp.Encodable // EncodeMsg implements msgp.Encodable
func (z *Request) EncodeMsg(en *msgp.Writer) (err error) { func (z *Request) EncodeMsg(en *msgp.Writer) (err error) {
// map header, size 2 // map header, size 3
// write "id"
err = en.Append(0x83, 0xa2, 0x69, 0x64)
if err != nil {
return
}
err = en.WriteUint32(z.Id)
if err != nil {
err = msgp.WrapError(err, "Id")
return
}
// write "db" // write "db"
err = en.Append(0x82, 0xa2, 0x64, 0x62) err = en.Append(0xa2, 0x64, 0x62)
if err != nil { if err != nil {
return return
} }
@ -332,9 +348,12 @@ func (z *Request) EncodeMsg(en *msgp.Writer) (err error) {
// MarshalMsg implements msgp.Marshaler // MarshalMsg implements msgp.Marshaler
func (z *Request) MarshalMsg(b []byte) (o []byte, err error) { func (z *Request) MarshalMsg(b []byte) (o []byte, err error) {
o = msgp.Require(b, z.Msgsize()) o = msgp.Require(b, z.Msgsize())
// map header, size 2 // map header, size 3
// string "id"
o = append(o, 0x83, 0xa2, 0x69, 0x64)
o = msgp.AppendUint32(o, z.Id)
// string "db" // string "db"
o = append(o, 0x82, 0xa2, 0x64, 0x62) o = append(o, 0xa2, 0x64, 0x62)
o = msgp.AppendString(o, z.Db) o = msgp.AppendString(o, z.Db)
// string "commands" // string "commands"
o = append(o, 0xa8, 0x63, 0x6f, 0x6d, 0x6d, 0x61, 0x6e, 0x64, 0x73) o = append(o, 0xa8, 0x63, 0x6f, 0x6d, 0x6d, 0x61, 0x6e, 0x64, 0x73)
@ -376,6 +395,12 @@ func (z *Request) UnmarshalMsg(bts []byte) (o []byte, err error) {
return return
} }
switch msgp.UnsafeString(field) { switch msgp.UnsafeString(field) {
case "id":
z.Id, bts, err = msgp.ReadUint32Bytes(bts)
if err != nil {
err = msgp.WrapError(err, "Id")
return
}
case "db": case "db":
z.Db, bts, err = msgp.ReadStringBytes(bts) z.Db, bts, err = msgp.ReadStringBytes(bts)
if err != nil { if err != nil {
@ -457,7 +482,7 @@ func (z *Request) UnmarshalMsg(bts []byte) (o []byte, err error) {
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message // Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
func (z *Request) Msgsize() (s int) { func (z *Request) Msgsize() (s int) {
s = 1 + 3 + msgp.StringPrefixSize + len(z.Db) + 9 + msgp.ArrayHeaderSize s = 1 + 3 + msgp.Uint32Size + 3 + msgp.StringPrefixSize + len(z.Db) + 9 + msgp.ArrayHeaderSize
for za0001 := range z.Commands { for za0001 := range z.Commands {
s += 1 + 7 + msgp.StringPrefixSize + len(z.Commands[za0001].Method) + 5 + msgp.ArrayHeaderSize s += 1 + 7 + msgp.StringPrefixSize + len(z.Commands[za0001].Method) + 5 + msgp.ArrayHeaderSize
for za0002 := range z.Commands[za0001].Args { for za0002 := range z.Commands[za0001].Args {

14
pkg/proto/response.go Normal file
View file

@ -0,0 +1,14 @@
package proto
//go:generate msgp
type RequestError struct {
Message string `msg:"msg"`
ErrorCode int `msg:"error_code"`
}
type Response struct {
Id uint32 `msg:"id"`
Result []interface{} `msg:"result"`
Error RequestError `msg:"error"`
}

412
pkg/proto/response_gen.go Normal file
View file

@ -0,0 +1,412 @@
package proto
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
import (
"github.com/tinylib/msgp/msgp"
)
// DecodeMsg implements msgp.Decodable
func (z *RequestError) DecodeMsg(dc *msgp.Reader) (err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, err = dc.ReadMapHeader()
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, err = dc.ReadMapKeyPtr()
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
case "msg":
z.Message, err = dc.ReadString()
if err != nil {
err = msgp.WrapError(err, "Message")
return
}
case "error_code":
z.ErrorCode, err = dc.ReadInt()
if err != nil {
err = msgp.WrapError(err, "ErrorCode")
return
}
default:
err = dc.Skip()
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
return
}
// EncodeMsg implements msgp.Encodable
func (z RequestError) EncodeMsg(en *msgp.Writer) (err error) {
// map header, size 2
// write "msg"
err = en.Append(0x82, 0xa3, 0x6d, 0x73, 0x67)
if err != nil {
return
}
err = en.WriteString(z.Message)
if err != nil {
err = msgp.WrapError(err, "Message")
return
}
// write "error_code"
err = en.Append(0xaa, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x5f, 0x63, 0x6f, 0x64, 0x65)
if err != nil {
return
}
err = en.WriteInt(z.ErrorCode)
if err != nil {
err = msgp.WrapError(err, "ErrorCode")
return
}
return
}
// MarshalMsg implements msgp.Marshaler
func (z RequestError) MarshalMsg(b []byte) (o []byte, err error) {
o = msgp.Require(b, z.Msgsize())
// map header, size 2
// string "msg"
o = append(o, 0x82, 0xa3, 0x6d, 0x73, 0x67)
o = msgp.AppendString(o, z.Message)
// string "error_code"
o = append(o, 0xaa, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x5f, 0x63, 0x6f, 0x64, 0x65)
o = msgp.AppendInt(o, z.ErrorCode)
return
}
// UnmarshalMsg implements msgp.Unmarshaler
func (z *RequestError) UnmarshalMsg(bts []byte) (o []byte, err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, bts, err = msgp.ReadMapHeaderBytes(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, bts, err = msgp.ReadMapKeyZC(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
case "msg":
z.Message, bts, err = msgp.ReadStringBytes(bts)
if err != nil {
err = msgp.WrapError(err, "Message")
return
}
case "error_code":
z.ErrorCode, bts, err = msgp.ReadIntBytes(bts)
if err != nil {
err = msgp.WrapError(err, "ErrorCode")
return
}
default:
bts, err = msgp.Skip(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
o = bts
return
}
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
func (z RequestError) Msgsize() (s int) {
s = 1 + 4 + msgp.StringPrefixSize + len(z.Message) + 11 + msgp.IntSize
return
}
// DecodeMsg implements msgp.Decodable
func (z *Response) DecodeMsg(dc *msgp.Reader) (err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, err = dc.ReadMapHeader()
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, err = dc.ReadMapKeyPtr()
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
case "id":
z.Id, err = dc.ReadUint32()
if err != nil {
err = msgp.WrapError(err, "Id")
return
}
case "result":
var zb0002 uint32
zb0002, err = dc.ReadArrayHeader()
if err != nil {
err = msgp.WrapError(err, "Result")
return
}
if cap(z.Result) >= int(zb0002) {
z.Result = (z.Result)[:zb0002]
} else {
z.Result = make([]interface{}, zb0002)
}
for za0001 := range z.Result {
z.Result[za0001], err = dc.ReadIntf()
if err != nil {
err = msgp.WrapError(err, "Result", za0001)
return
}
}
case "error":
var zb0003 uint32
zb0003, err = dc.ReadMapHeader()
if err != nil {
err = msgp.WrapError(err, "Error")
return
}
for zb0003 > 0 {
zb0003--
field, err = dc.ReadMapKeyPtr()
if err != nil {
err = msgp.WrapError(err, "Error")
return
}
switch msgp.UnsafeString(field) {
case "msg":
z.Error.Message, err = dc.ReadString()
if err != nil {
err = msgp.WrapError(err, "Error", "Message")
return
}
case "error_code":
z.Error.ErrorCode, err = dc.ReadInt()
if err != nil {
err = msgp.WrapError(err, "Error", "ErrorCode")
return
}
default:
err = dc.Skip()
if err != nil {
err = msgp.WrapError(err, "Error")
return
}
}
}
default:
err = dc.Skip()
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
return
}
// EncodeMsg implements msgp.Encodable
func (z *Response) EncodeMsg(en *msgp.Writer) (err error) {
// map header, size 3
// write "id"
err = en.Append(0x83, 0xa2, 0x69, 0x64)
if err != nil {
return
}
err = en.WriteUint32(z.Id)
if err != nil {
err = msgp.WrapError(err, "Id")
return
}
// write "result"
err = en.Append(0xa6, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74)
if err != nil {
return
}
err = en.WriteArrayHeader(uint32(len(z.Result)))
if err != nil {
err = msgp.WrapError(err, "Result")
return
}
for za0001 := range z.Result {
err = en.WriteIntf(z.Result[za0001])
if err != nil {
err = msgp.WrapError(err, "Result", za0001)
return
}
}
// write "error"
err = en.Append(0xa5, 0x65, 0x72, 0x72, 0x6f, 0x72)
if err != nil {
return
}
// map header, size 2
// write "msg"
err = en.Append(0x82, 0xa3, 0x6d, 0x73, 0x67)
if err != nil {
return
}
err = en.WriteString(z.Error.Message)
if err != nil {
err = msgp.WrapError(err, "Error", "Message")
return
}
// write "error_code"
err = en.Append(0xaa, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x5f, 0x63, 0x6f, 0x64, 0x65)
if err != nil {
return
}
err = en.WriteInt(z.Error.ErrorCode)
if err != nil {
err = msgp.WrapError(err, "Error", "ErrorCode")
return
}
return
}
// MarshalMsg implements msgp.Marshaler
func (z *Response) MarshalMsg(b []byte) (o []byte, err error) {
o = msgp.Require(b, z.Msgsize())
// map header, size 3
// string "id"
o = append(o, 0x83, 0xa2, 0x69, 0x64)
o = msgp.AppendUint32(o, z.Id)
// string "result"
o = append(o, 0xa6, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74)
o = msgp.AppendArrayHeader(o, uint32(len(z.Result)))
for za0001 := range z.Result {
o, err = msgp.AppendIntf(o, z.Result[za0001])
if err != nil {
err = msgp.WrapError(err, "Result", za0001)
return
}
}
// string "error"
o = append(o, 0xa5, 0x65, 0x72, 0x72, 0x6f, 0x72)
// map header, size 2
// string "msg"
o = append(o, 0x82, 0xa3, 0x6d, 0x73, 0x67)
o = msgp.AppendString(o, z.Error.Message)
// string "error_code"
o = append(o, 0xaa, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x5f, 0x63, 0x6f, 0x64, 0x65)
o = msgp.AppendInt(o, z.Error.ErrorCode)
return
}
// UnmarshalMsg implements msgp.Unmarshaler
func (z *Response) UnmarshalMsg(bts []byte) (o []byte, err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, bts, err = msgp.ReadMapHeaderBytes(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, bts, err = msgp.ReadMapKeyZC(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
case "id":
z.Id, bts, err = msgp.ReadUint32Bytes(bts)
if err != nil {
err = msgp.WrapError(err, "Id")
return
}
case "result":
var zb0002 uint32
zb0002, bts, err = msgp.ReadArrayHeaderBytes(bts)
if err != nil {
err = msgp.WrapError(err, "Result")
return
}
if cap(z.Result) >= int(zb0002) {
z.Result = (z.Result)[:zb0002]
} else {
z.Result = make([]interface{}, zb0002)
}
for za0001 := range z.Result {
z.Result[za0001], bts, err = msgp.ReadIntfBytes(bts)
if err != nil {
err = msgp.WrapError(err, "Result", za0001)
return
}
}
case "error":
var zb0003 uint32
zb0003, bts, err = msgp.ReadMapHeaderBytes(bts)
if err != nil {
err = msgp.WrapError(err, "Error")
return
}
for zb0003 > 0 {
zb0003--
field, bts, err = msgp.ReadMapKeyZC(bts)
if err != nil {
err = msgp.WrapError(err, "Error")
return
}
switch msgp.UnsafeString(field) {
case "msg":
z.Error.Message, bts, err = msgp.ReadStringBytes(bts)
if err != nil {
err = msgp.WrapError(err, "Error", "Message")
return
}
case "error_code":
z.Error.ErrorCode, bts, err = msgp.ReadIntBytes(bts)
if err != nil {
err = msgp.WrapError(err, "Error", "ErrorCode")
return
}
default:
bts, err = msgp.Skip(bts)
if err != nil {
err = msgp.WrapError(err, "Error")
return
}
}
}
default:
bts, err = msgp.Skip(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
o = bts
return
}
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
func (z *Response) Msgsize() (s int) {
s = 1 + 3 + msgp.Uint32Size + 7 + msgp.ArrayHeaderSize
for za0001 := range z.Result {
s += msgp.GuessSize(z.Result[za0001])
}
s += 6 + 1 + 4 + msgp.StringPrefixSize + len(z.Error.Message) + 11 + msgp.IntSize
return
}

View file

@ -0,0 +1,236 @@
package proto
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
import (
"bytes"
"testing"
"github.com/tinylib/msgp/msgp"
)
func TestMarshalUnmarshalRequestError(t *testing.T) {
v := RequestError{}
bts, err := v.MarshalMsg(nil)
if err != nil {
t.Fatal(err)
}
left, err := v.UnmarshalMsg(bts)
if err != nil {
t.Fatal(err)
}
if len(left) > 0 {
t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left)
}
left, err = msgp.Skip(bts)
if err != nil {
t.Fatal(err)
}
if len(left) > 0 {
t.Errorf("%d bytes left over after Skip(): %q", len(left), left)
}
}
func BenchmarkMarshalMsgRequestError(b *testing.B) {
v := RequestError{}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
v.MarshalMsg(nil)
}
}
func BenchmarkAppendMsgRequestError(b *testing.B) {
v := RequestError{}
bts := make([]byte, 0, v.Msgsize())
bts, _ = v.MarshalMsg(bts[0:0])
b.SetBytes(int64(len(bts)))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
bts, _ = v.MarshalMsg(bts[0:0])
}
}
func BenchmarkUnmarshalRequestError(b *testing.B) {
v := RequestError{}
bts, _ := v.MarshalMsg(nil)
b.ReportAllocs()
b.SetBytes(int64(len(bts)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := v.UnmarshalMsg(bts)
if err != nil {
b.Fatal(err)
}
}
}
func TestEncodeDecodeRequestError(t *testing.T) {
v := RequestError{}
var buf bytes.Buffer
msgp.Encode(&buf, &v)
m := v.Msgsize()
if buf.Len() > m {
t.Log("WARNING: TestEncodeDecodeRequestError Msgsize() is inaccurate")
}
vn := RequestError{}
err := msgp.Decode(&buf, &vn)
if err != nil {
t.Error(err)
}
buf.Reset()
msgp.Encode(&buf, &v)
err = msgp.NewReader(&buf).Skip()
if err != nil {
t.Error(err)
}
}
func BenchmarkEncodeRequestError(b *testing.B) {
v := RequestError{}
var buf bytes.Buffer
msgp.Encode(&buf, &v)
b.SetBytes(int64(buf.Len()))
en := msgp.NewWriter(msgp.Nowhere)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
v.EncodeMsg(en)
}
en.Flush()
}
func BenchmarkDecodeRequestError(b *testing.B) {
v := RequestError{}
var buf bytes.Buffer
msgp.Encode(&buf, &v)
b.SetBytes(int64(buf.Len()))
rd := msgp.NewEndlessReader(buf.Bytes(), b)
dc := msgp.NewReader(rd)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
err := v.DecodeMsg(dc)
if err != nil {
b.Fatal(err)
}
}
}
func TestMarshalUnmarshalResponse(t *testing.T) {
v := Response{}
bts, err := v.MarshalMsg(nil)
if err != nil {
t.Fatal(err)
}
left, err := v.UnmarshalMsg(bts)
if err != nil {
t.Fatal(err)
}
if len(left) > 0 {
t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left)
}
left, err = msgp.Skip(bts)
if err != nil {
t.Fatal(err)
}
if len(left) > 0 {
t.Errorf("%d bytes left over after Skip(): %q", len(left), left)
}
}
func BenchmarkMarshalMsgResponse(b *testing.B) {
v := Response{}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
v.MarshalMsg(nil)
}
}
func BenchmarkAppendMsgResponse(b *testing.B) {
v := Response{}
bts := make([]byte, 0, v.Msgsize())
bts, _ = v.MarshalMsg(bts[0:0])
b.SetBytes(int64(len(bts)))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
bts, _ = v.MarshalMsg(bts[0:0])
}
}
func BenchmarkUnmarshalResponse(b *testing.B) {
v := Response{}
bts, _ := v.MarshalMsg(nil)
b.ReportAllocs()
b.SetBytes(int64(len(bts)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := v.UnmarshalMsg(bts)
if err != nil {
b.Fatal(err)
}
}
}
func TestEncodeDecodeResponse(t *testing.T) {
v := Response{}
var buf bytes.Buffer
msgp.Encode(&buf, &v)
m := v.Msgsize()
if buf.Len() > m {
t.Log("WARNING: TestEncodeDecodeResponse Msgsize() is inaccurate")
}
vn := Response{}
err := msgp.Decode(&buf, &vn)
if err != nil {
t.Error(err)
}
buf.Reset()
msgp.Encode(&buf, &v)
err = msgp.NewReader(&buf).Skip()
if err != nil {
t.Error(err)
}
}
func BenchmarkEncodeResponse(b *testing.B) {
v := Response{}
var buf bytes.Buffer
msgp.Encode(&buf, &v)
b.SetBytes(int64(buf.Len()))
en := msgp.NewWriter(msgp.Nowhere)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
v.EncodeMsg(en)
}
en.Flush()
}
func BenchmarkDecodeResponse(b *testing.B) {
v := Response{}
var buf bytes.Buffer
msgp.Encode(&buf, &v)
b.SetBytes(int64(buf.Len()))
rd := msgp.NewEndlessReader(buf.Bytes(), b)
dc := msgp.NewReader(rd)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
err := v.DecodeMsg(dc)
if err != nil {
b.Fatal(err)
}
}
}