Skip to content

Commit

Permalink
fix: bug when serde with nil function
Browse files Browse the repository at this point in the history
  • Loading branch information
liuq19 committed Feb 27, 2025
1 parent 86fae91 commit 09960d0
Show file tree
Hide file tree
Showing 12 changed files with 161 additions and 31 deletions.
8 changes: 7 additions & 1 deletion internal/decoder/jitdec/assembler_regabi_amd64.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,9 +313,10 @@ var _OpFuncTab = [256]func(*_Assembler, *_Instr) {
_OP_check_char_0 : (*_Assembler)._asm_OP_check_char_0,
_OP_dismatch_err : (*_Assembler)._asm_OP_dismatch_err,
_OP_go_skip : (*_Assembler)._asm_OP_go_skip,
_OP_skip_emtpy : (*_Assembler)._asm_OP_skip_empty,
_OP_skip_emtpy : (*_Assembler)._asm_OP_skip_empty,
_OP_add : (*_Assembler)._asm_OP_add,
_OP_check_empty : (*_Assembler)._asm_OP_check_empty,
_OP_func : (*_Assembler)._asm_OP_func,
_OP_debug : (*_Assembler)._asm_OP_debug,
}

Expand Down Expand Up @@ -1265,6 +1266,11 @@ func (self *_Assembler) _asm_OP_dyn(p *_Instr) {
self.Link("_decode_end_{n}") // _decode_end_{n}:
}

func (self *_Assembler) _asm_OP_func(p *_Instr) {
self.Emit("MOVQ", jit.Type(p.vt()), _ET) // MOVQ ${p.vt()}, ET
self.Sjmp("JMP" , _LB_type_error) // JMP _LB_type_error
}

func (self *_Assembler) _asm_OP_str(_ *_Instr) {
self.parse_string() // PARSE STRING
self.unquote_once(jit.Ptr(_VP, 0), jit.Ptr(_VP, 8), false, true) // UNQUOTE once, (VP), 8(VP)
Expand Down
15 changes: 12 additions & 3 deletions internal/decoder/jitdec/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ const (
_OP_skip_emtpy
_OP_add
_OP_check_empty
_OP_func
_OP_debug
)

Expand Down Expand Up @@ -176,6 +177,7 @@ var _OpNames = [256]string {
_OP_add : "add",
_OP_go_skip : "go_skip",
_OP_check_empty : "check_empty",
_OP_func : "func_ptr",
_OP_debug : "debug",
}

Expand Down Expand Up @@ -630,10 +632,19 @@ func (self *_Compiler) compileOps(p *_Program, sp int, vt reflect.Type) {
case reflect.Ptr : self.compilePtr (p, sp, vt)
case reflect.Slice : self.compileSlice (p, sp, vt)
case reflect.Struct : self.compileStruct (p, sp, vt)
case reflect.Func : self.compileFunc (p, vt)
default : panic (&json.UnmarshalTypeError{Type: vt})
}
}

func (self *_Compiler) compileFunc(p *_Program, vt reflect.Type) {
i := p.pc()
p.add(_OP_is_null)
p.rtt(_OP_func, vt)
p.pin(i)
p.add(_OP_nil_1)
}

func (self *_Compiler) compileMap(p *_Program, sp int, vt reflect.Type) {
if reflect.PtrTo(vt.Key()).Implements(encodingTextUnmarshalerType) {
self.compileMapOp(p, sp, vt, _OP_map_key_utext_p)
Expand Down Expand Up @@ -1135,13 +1146,11 @@ func (self *_Compiler) compileInterface(p *_Program, vt reflect.Type) {
p.pin(j)
}

func (self *_Compiler) compilePrimitive(vt reflect.Type, p *_Program, op _Op) {
func (self *_Compiler) compilePrimitive(_ reflect.Type, p *_Program, op _Op) {
i := p.pc()
p.add(_OP_is_null)
// skip := self.checkPrimitive(p, vt)
p.add(op)
p.pin(i)
// p.pin(skip)
}

func (self *_Compiler) compileUnmarshalEnd(p *_Program, vt reflect.Type, i int) {
Expand Down
4 changes: 4 additions & 0 deletions internal/decoder/optdec/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,10 @@ func (c *compiler) compileBasic(vt reflect.Type) decFunc {
return c.compileSlice(vt)
case reflect.Struct:
return c.compileStruct(vt)
case reflect.Func:
return &funcDecoder{
typ: rt.UnpackType(vt),
}
default:
panic(&json.UnmarshalTypeError{Type: vt})
}
Expand Down
6 changes: 6 additions & 0 deletions internal/decoder/optdec/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,10 @@
Msg: msg,
}
}

func error_unsuppoted(typ *rt.GoType) error {
return &json.UnsupportedTypeError{
Type: typ.Pack(),
}
}

14 changes: 14 additions & 0 deletions internal/decoder/optdec/functor.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,3 +279,17 @@ func (d *recuriveDecoder) FromDom(vp unsafe.Pointer, node Node, ctx *context) er
}
return dec.FromDom(vp, node, ctx)
}

type funcDecoder struct {
typ *rt.GoType
}


func (d *funcDecoder) FromDom(vp unsafe.Pointer, node Node, ctx *context) error {
if node.IsNull() {
*(*unsafe.Pointer)(vp) = nil
return nil
}
return error_unsuppoted(d.typ)
}

18 changes: 18 additions & 0 deletions internal/encoder/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,8 @@ func (self *Compiler) compileOps(p *ir.Program, sp int, vt reflect.Type) {
self.compileSlice(p, sp, vt.Elem())
case reflect.Struct:
self.compileStruct(p, sp, vt)
case reflect.Func:
self.compileUnsupportedType(p, vt)
default:
panic(vars.Error_type(vt))
}
Expand Down Expand Up @@ -644,6 +646,22 @@ func (self *Compiler) compileInterface(p *ir.Program, vt reflect.Type) {
p.Pin(e)
}

func (self *Compiler) compileUnsupportedType(p *ir.Program, vt reflect.Type) {
x := p.PC()
p.Add(ir.OP_is_nil_p1)

/* if not nil, return unsupported */
p.Rtt(ir.OP_unsupported, vt)

/* the "null" value */
e := p.PC()
p.Add(ir.OP_goto)
p.Pin(x)
p.Add(ir.OP_null)
p.Pin(e)
}


func (self *Compiler) compileMarshaler(p *ir.Program, op ir.Op, vt reflect.Type, mt reflect.Type) {
pc := p.PC()
vk := vt.Kind()
Expand Down
6 changes: 6 additions & 0 deletions internal/encoder/ir/op.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ const (
OP_marshal_text_p
OP_cond_set
OP_cond_testc
OP_unsupported
)

const (
Expand Down Expand Up @@ -141,6 +142,7 @@ var OpNames = [256]string{
OP_marshal_text_p: "marshal_text_p",
OP_cond_set: "cond_set",
OP_cond_testc: "cond_testc",
OP_unsupported: "unsupported",
}

func (self Op) String() string {
Expand Down Expand Up @@ -273,6 +275,10 @@ func (self Instr) Vk() reflect.Kind {
return (*rt.GoType)(self.p).Kind()
}

func (self Instr) GoType() *rt.GoType {
return (*rt.GoType)(self.p)
}

func (self Instr) Vt() reflect.Type {
return (*rt.GoType)(self.p).Pack()
}
Expand Down
21 changes: 1 addition & 20 deletions internal/encoder/pools_amd64.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package encoder

import (
"errors"
"reflect"
"unsafe"

Expand Down Expand Up @@ -52,29 +51,11 @@ var _KeepAlive struct {
frame [x86.FP_offs]byte
}

var errCallShadow = errors.New("DON'T CALL THIS!")

// Faker func of _Encoder, used to export its stackmap as _Encoder's
func _Encoder_Shadow(rb *[]byte, vp unsafe.Pointer, sb *vars.Stack, fv uint64) (err error) {
// align to assembler_amd64.go: x86.FP_offs
var frame [x86.FP_offs]byte

// must keep all args and frames noticeable to GC
_KeepAlive.rb = rb
_KeepAlive.vp = vp
_KeepAlive.sb = sb
_KeepAlive.fv = fv
_KeepAlive.err = err
_KeepAlive.frame = frame

return errCallShadow
}

func makeEncoderX86(vt *rt.GoType, ex ...interface{}) (interface{}, error) {
pp, err := NewCompiler().Compile(vt.Pack(), ex[0].(bool))
if err != nil {
return nil, err
}
}
as := x86.NewAssembler(pp)
as.Name = vt.String()
return as.Load(), nil
Expand Down
4 changes: 4 additions & 0 deletions internal/encoder/vars/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ func Error_number(number json.Number) error {
}
}

func Error_unsuppoted(typ *rt.GoType) error {
return &json.UnsupportedTypeError{Type: typ.Pack() }
}

func Error_marshaler(ret []byte, pos int) error {
return fmt.Errorf("invalid Marshaler output json syntax at %d: %q", pos, ret)
}
Expand Down
9 changes: 2 additions & 7 deletions internal/encoder/vm/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,8 @@ func Execute(b *[]byte, p unsafe.Pointer, s *vars.Stack, flags uint64, prog *ir.
if err := alg.EncodeJsonMarshaler(&buf, *(*json.Marshaler)(unsafe.Pointer(&it)), (flags)); err != nil {
return err
}
case ir.OP_unsupported:
return vars.Error_unsuppoted(ins.GoType())
default:
panic(fmt.Sprintf("not implement %s at %d", ins.Op().String(), pc))
}
Expand All @@ -347,13 +349,6 @@ func Execute(b *[]byte, p unsafe.Pointer, s *vars.Stack, flags uint64, prog *ir.
return nil
}

// func to_buf(w unsafe.Pointer, l int, c int) []byte {
// return rt.BytesFrom(unsafe.Pointer(uintptr(w)-uintptr(l)), l, c)
// }

// func from_buf(buf []byte) (unsafe.Pointer, int, int) {
// return rt.IndexByte(buf, len(buf)), len(buf), cap(buf)
// }

func has_opts(opts uint64, bit int) bool {
return opts & (1<<bit) != 0
Expand Down
10 changes: 10 additions & 0 deletions internal/encoder/x86/assembler_regabi_amd64.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ var _OpFuncTab = [256]func(*Assembler, *ir.Instr){
ir.OP_marshal_text_p: (*Assembler)._asm_OP_marshal_text_p,
ir.OP_cond_set: (*Assembler)._asm_OP_cond_set,
ir.OP_cond_testc: (*Assembler)._asm_OP_cond_testc,
ir.OP_unsupported: (*Assembler)._asm_OP_unsupported,
}

func (self *Assembler) instr(v *ir.Instr) {
Expand Down Expand Up @@ -1187,6 +1188,15 @@ func (self *Assembler) _asm_OP_cond_testc(p *ir.Instr) {
self.Xjmp("JC", p.Vi())
}

var _F_error_unsupported = jit.Func(vars.Error_unsuppoted)

func (self *Assembler) _asm_OP_unsupported(i *ir.Instr) {
typ := int64(uintptr(unsafe.Pointer(i.GoType())))
self.Emit("MOVQ", jit.Imm(typ), _AX)
self.call_go(_F_error_unsupported)
self.Sjmp("JMP", _LB_error)
}

func (self *Assembler) print_gc(i int, p1 *ir.Instr, p2 *ir.Instr) {
self.Emit("MOVQ", jit.Imm(int64(p2.Op())), _CX) // MOVQ $(p2.Op()), AX
self.Emit("MOVQ", jit.Imm(int64(p1.Op())), _BX) // MOVQ $(p1.Op()), BX
Expand Down
77 changes: 77 additions & 0 deletions issue_test/issue491_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package issue_test

import (
"encoding/json"
"testing"

"github.com/bytedance/sonic"
"github.com/davecgh/go-spew/spew"
"github.com/stretchr/testify/require"
)

type Function = func()

func MockFunc() {

}

type Unable struct {
Functions []Function
}

type StructWithUnable struct {
Foo *Unable `json:"foo"`
Bar *Unable `json:"bar,omitempty"`
}

func TestIssue491_MarshalNilFunction(t *testing.T) {
// Wrapper a unbale serde type
tests := []interface{} {
map[string]*Function{},
map[*Function]*Function{},
[]Function{},
StructWithUnable{},
struct {
Foo *int
}{},
struct {
Foo Function
}{},
chan int(nil),
}
for _, v := range(tests) {
assertMarshal(t, sonic.ConfigDefault, v)
}
}

func TestIssue491_UnmarshalUnsupportedNil(t *testing.T) {
type Test struct {
data string
value interface{}
}

tests := []Test{
{
data: "null",
value: new([]Function),
},
{
data: "[null, null]",
value: new([]Function),
},
{
data: "{\"foo\": null}",
value: new(struct {
Foo *Function
}),
},
}
for _, v := range(tests) {
spew.Dump(v)
jerr := json.Unmarshal([]byte(v.data), &v.value)
require.NoError(t, jerr)
serr := sonic.Unmarshal([]byte(v.data), &v.value)
require.NoError(t, serr)
}
}

0 comments on commit 09960d0

Please sign in to comment.