From 04912c07c29d66794851df52b7d8e7de13a11935 Mon Sep 17 00:00:00 2001 From: Joe Tsai Date: Tue, 28 Jan 2025 16:30:50 -0800 Subject: [PATCH] Preserve semantic options in Encoder and Decoder A jsontext.Encoder and jsontext.Decoder could be constructed with semantic options (e.g., json.Deterministic) where such options are ignored for all encode/decode specific operations. However, allow semantic options to take effect when such an Encoder or Decoder is passed to MarshalEncode or UnmarshalDecode. The semantic option can still be overridden. One reason for this behavior is for easier migration. For example, this v1 code: dec := jsonv1.NewDecoder(in) dec.DisallowUnknownFields() for { ... := dec.Decode(...) } can be migrated as: dec := jsontext.NewDecoder(in, json.RejectUnknownMembers(true)) for { ... := json.UnmarshalDecode(dec, ...) } Notice that RejectUnknownMembers does not need to be repeatedly passed to every UnmarshalDecode call because it is implicitly stored on the Decoder. The alternative behavior is to have the construction of an Encoder or Decoder explicitly drop any semantic options. However, this seems like extra work for no benefit. --- arshal.go | 17 ++-- arshal_test.go | 124 ++++++++++++++++++++++++++++++ internal/jsonopts/options.go | 54 ++++++++----- internal/jsonopts/options_test.go | 107 +++++++++++++++++++------- 4 files changed, 250 insertions(+), 52 deletions(-) diff --git a/arshal.go b/arshal.go index 8a7891a..e4603de 100644 --- a/arshal.go +++ b/arshal.go @@ -45,9 +45,8 @@ func mayReuseOpt(coderOpts *jsonopts.Struct, opts []Options) *jsonopts.Struct { return coderOpts } // If the caller provides no options, then just reuse the coder's options, - // which should only contain encoding/decoding related flags. + // which may contain both marshaling/unmarshaling and encoding/decoding flags. case 0: - // TODO: This is buggy if coderOpts ever contains non-coder options. return coderOpts } return nil @@ -224,8 +223,11 @@ func MarshalWrite(out io.Writer, in any, opts ...Options) (err error) { // MarshalEncode serializes a Go value into an [jsontext.Encoder] according to // the provided marshal options (while ignoring unmarshal, encode, or decode options). +// Any marshal-relevant options already specified on the [jsontext.Encoder] +// take lower precedence than the set of options provided by the caller. // Unlike [Marshal] and [MarshalWrite], encode options are ignored because // they must have already been specified on the provided [jsontext.Encoder]. +// // See [Marshal] for details about the conversion of a Go value into JSON. func MarshalEncode(out *jsontext.Encoder, in any, opts ...Options) (err error) { xe := export.Encoder(out) @@ -233,8 +235,8 @@ func MarshalEncode(out *jsontext.Encoder, in any, opts ...Options) (err error) { if mo == nil { mo = getStructOptions() defer putStructOptions(mo) - mo.Join(opts...) - mo.CopyCoderOptions(&xe.Struct) + *mo = xe.Struct // initialize with encoder options before joining + mo.JoinWithoutCoderOptions(opts...) } err = marshalEncode(out, in, mo) if err != nil && mo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) { @@ -467,8 +469,11 @@ func unmarshalFull(in *jsontext.Decoder, out any, uo *jsonopts.Struct) error { // UnmarshalDecode deserializes a Go value from a [jsontext.Decoder] according to // the provided unmarshal options (while ignoring marshal, encode, or decode options). +// Any unmarshal options already specified on the [jsontext.Decoder] +// take lower precedence than the set of options provided by the caller. // Unlike [Unmarshal] and [UnmarshalRead], decode options are ignored because // they must have already been specified on the provided [jsontext.Decoder]. +// // The input may be a stream of one or more JSON values, // where this only unmarshals the next JSON value in the stream. // The output must be a non-nil pointer. @@ -479,8 +484,8 @@ func UnmarshalDecode(in *jsontext.Decoder, out any, opts ...Options) (err error) if uo == nil { uo = getStructOptions() defer putStructOptions(uo) - uo.Join(opts...) - uo.CopyCoderOptions(&xd.Struct) + *uo = xd.Struct // initialize with decoder options before joining + uo.JoinWithoutCoderOptions(opts...) } err = unmarshalDecode(in, out, uo) if err != nil && uo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) { diff --git a/arshal_test.go b/arshal_test.go index 94769e2..0ef80da 100644 --- a/arshal_test.go +++ b/arshal_test.go @@ -9300,6 +9300,66 @@ func TestUintSet(t *testing.T) { } } +func TestUnmarshalDecodeOptions(t *testing.T) { + var calledFuncs int + var calledOptions Options + in := strings.NewReader(strings.Repeat("\"\xde\xad\xbe\xef\"\n", 5)) + dec := jsontext.NewDecoder(in, + jsontext.AllowInvalidUTF8(true), // decoder-specific option + WithUnmarshalers(UnmarshalFromFunc(func(_ *jsontext.Decoder, _ any, opts Options) error { + if v, _ := GetOption(opts, jsontext.AllowInvalidUTF8); !v { + t.Errorf("nested Options.AllowInvalidUTF8 = false, want true") + } + calledFuncs++ + calledOptions = opts + return SkipFunc + })), // unmarshal-specific option; only relevant for UnmarshalDecode + ) + + if err := UnmarshalDecode(dec, new(string)); err != nil { + t.Fatalf("UnmarshalDecode: %v", err) + } + if calledFuncs != 1 { + t.Fatalf("calledFuncs = %d, want 1", calledFuncs) + } + if err := UnmarshalDecode(dec, new(string), calledOptions); err != nil { + t.Fatalf("UnmarshalDecode: %v", err) + } + if calledFuncs != 2 { + t.Fatalf("calledFuncs = %d, want 2", calledFuncs) + } + if err := UnmarshalDecode(dec, new(string), + jsontext.AllowInvalidUTF8(false), // should be ignored + WithUnmarshalers(nil), // should override + ); err != nil { + t.Fatalf("UnmarshalDecode: %v", err) + } + if calledFuncs != 2 { + t.Fatalf("calledFuncs = %d, want 2", calledFuncs) + } + if err := UnmarshalDecode(dec, new(string)); err != nil { + t.Fatalf("UnmarshalDecode: %v", err) + } + if calledFuncs != 3 { + t.Fatalf("calledFuncs = %d, want 3", calledFuncs) + } + if err := UnmarshalDecode(dec, new(string), JoinOptions( + jsontext.AllowInvalidUTF8(false), // should be ignored + WithUnmarshalers(UnmarshalFromFunc(func(_ *jsontext.Decoder, _ any, opts Options) error { + if v, _ := GetOption(opts, jsontext.AllowInvalidUTF8); !v { + t.Errorf("nested Options.AllowInvalidUTF8 = false, want true") + } + calledFuncs = math.MaxInt + return SkipFunc + })), // should override + )); err != nil { + t.Fatalf("UnmarshalDecode: %v", err) + } + if calledFuncs != math.MaxInt { + t.Fatalf("calledFuncs = %d, want %d", calledFuncs, math.MaxInt) + } +} + // BenchmarkUnmarshalDecodeOptions is a minimal decode operation to measure // the overhead options setup before the unmarshal operation. func BenchmarkUnmarshalDecodeOptions(b *testing.B) { @@ -9323,6 +9383,70 @@ func BenchmarkUnmarshalDecodeOptions(b *testing.B) { b.Run("New", makeBench(DefaultOptionsV2())) } +func TestMarshalEncodeOptions(t *testing.T) { + var calledFuncs int + var calledOptions Options + out := new(bytes.Buffer) + enc := jsontext.NewEncoder( + out, + jsontext.AllowInvalidUTF8(true), // encoder-specific option + WithMarshalers(MarshalToFunc(func(_ *jsontext.Encoder, _ any, opts Options) error { + if v, _ := GetOption(opts, jsontext.AllowInvalidUTF8); !v { + t.Errorf("nested Options.AllowInvalidUTF8 = false, want true") + } + calledFuncs++ + calledOptions = opts + return SkipFunc + })), // marshal-specific option; only relevant for MarshalEncode + ) + + if err := MarshalEncode(enc, "\xde\xad\xbe\xef"); err != nil { + t.Fatalf("MarshalEncode: %v", err) + } + if calledFuncs != 1 { + t.Fatalf("calledFuncs = %d, want 1", calledFuncs) + } + if err := MarshalEncode(enc, "\xde\xad\xbe\xef", calledOptions); err != nil { + t.Fatalf("MarshalEncode: %v", err) + } + if calledFuncs != 2 { + t.Fatalf("calledFuncs = %d, want 2", calledFuncs) + } + if err := MarshalEncode(enc, "\xde\xad\xbe\xef", + jsontext.AllowInvalidUTF8(false), // should be ignored + WithMarshalers(nil), // should override + ); err != nil { + t.Fatalf("MarshalEncode: %v", err) + } + if calledFuncs != 2 { + t.Fatalf("calledFuncs = %d, want 2", calledFuncs) + } + if err := MarshalEncode(enc, "\xde\xad\xbe\xef"); err != nil { + t.Fatalf("MarshalEncode: %v", err) + } + if calledFuncs != 3 { + t.Fatalf("calledFuncs = %d, want 3", calledFuncs) + } + if err := MarshalEncode(enc, "\xde\xad\xbe\xef", JoinOptions( + jsontext.AllowInvalidUTF8(false), // should be ignored + WithMarshalers(MarshalToFunc(func(_ *jsontext.Encoder, _ any, opts Options) error { + if v, _ := GetOption(opts, jsontext.AllowInvalidUTF8); !v { + t.Errorf("nested Options.AllowInvalidUTF8 = false, want true") + } + calledFuncs = math.MaxInt + return SkipFunc + })), // should override + )); err != nil { + t.Fatalf("MarshalEncode: %v", err) + } + if calledFuncs != math.MaxInt { + t.Fatalf("calledFuncs = %d, want %d", calledFuncs, math.MaxInt) + } + if out.String() != strings.Repeat("\"\xde\xad\ufffd\ufffd\"\n", 5) { + t.Fatalf("output mismatch:\n\tgot: %s\n\twant: %s", out.String(), strings.Repeat("\"\xde\xad\xbe\xef\"\n", 5)) + } +} + // BenchmarkMarshalEncodeOptions is a minimal encode operation to measure // the overhead of options setup before the marshal operation. func BenchmarkMarshalEncodeOptions(b *testing.B) { diff --git a/internal/jsonopts/options.go b/internal/jsonopts/options.go index c23c280..e689c69 100644 --- a/internal/jsonopts/options.go +++ b/internal/jsonopts/options.go @@ -59,17 +59,6 @@ var DefaultOptionsV1 = Struct{ }, } -// CopyCoderOptions copies coder-specific options from src to dst. -// This is used by json.MarshalEncode and json.UnmarshalDecode since those -// functions ignore any coder-specific options and uses the options from the -// Encoder or Decoder that is passed in. -func (dst *Struct) CopyCoderOptions(src *Struct) { - srcFlags := src.Flags - srcFlags.Clear(^jsonflags.AllCoderFlags) - dst.Flags.Join(srcFlags) - dst.CoderValues = src.CoderValues -} - func (*Struct) JSONOptions(internal.NotForPublicUse) {} // GetUnknownOption is injected by the "json" package to handle Options @@ -123,43 +112,70 @@ func GetOption[T any](opts Options, setter func(T) Options) (T, bool) { var JoinUnknownOption = func(*Struct, Options) { panic("unknown option") } func (dst *Struct) Join(srcs ...Options) { + dst.join(false, srcs...) +} + +func (dst *Struct) JoinWithoutCoderOptions(srcs ...Options) { + dst.join(true, srcs...) +} + +func (dst *Struct) join(excludeCoderOptions bool, srcs ...Options) { for _, src := range srcs { switch src := src.(type) { case nil: continue case jsonflags.Bools: + if excludeCoderOptions { + src &= ^jsonflags.AllCoderFlags + } dst.Flags.Set(src) case Indent: + if excludeCoderOptions { + continue + } dst.Flags.Set(jsonflags.Multiline | jsonflags.Indent | 1) dst.Indent = string(src) case IndentPrefix: + if excludeCoderOptions { + continue + } dst.Flags.Set(jsonflags.Multiline | jsonflags.IndentPrefix | 1) dst.IndentPrefix = string(src) case ByteLimit: + if excludeCoderOptions { + continue + } dst.Flags.Set(jsonflags.ByteLimit | 1) dst.ByteLimit = int64(src) case DepthLimit: + if excludeCoderOptions { + continue + } dst.Flags.Set(jsonflags.DepthLimit | 1) dst.DepthLimit = int(src) case *Struct: - dst.Flags.Join(src.Flags) - if src.Flags.Has(jsonflags.NonBooleanFlags) { - if src.Flags.Has(jsonflags.Indent) { + srcFlags := src.Flags // shallow copy the flags + if excludeCoderOptions { + srcFlags.Clear(jsonflags.AllCoderFlags) + } + dst.Flags.Join(srcFlags) + if srcFlags.Has(jsonflags.NonBooleanFlags) { + if srcFlags.Has(jsonflags.Indent) { dst.Indent = src.Indent } - if src.Flags.Has(jsonflags.IndentPrefix) { + if srcFlags.Has(jsonflags.IndentPrefix) { dst.IndentPrefix = src.IndentPrefix } - if src.Flags.Has(jsonflags.ByteLimit) { + if srcFlags.Has(jsonflags.ByteLimit) { dst.ByteLimit = src.ByteLimit } - if src.Flags.Has(jsonflags.DepthLimit) { + if srcFlags.Has(jsonflags.DepthLimit) { dst.DepthLimit = src.DepthLimit } - if src.Flags.Has(jsonflags.Marshalers) { + if srcFlags.Has(jsonflags.Marshalers) { dst.Marshalers = src.Marshalers } - if src.Flags.Has(jsonflags.Unmarshalers) { + if srcFlags.Has(jsonflags.Unmarshalers) { dst.Unmarshalers = src.Unmarshalers } } diff --git a/internal/jsonopts/options_test.go b/internal/jsonopts/options_test.go index 613654c..ab87827 100644 --- a/internal/jsonopts/options_test.go +++ b/internal/jsonopts/options_test.go @@ -21,32 +21,11 @@ func makeFlags(f ...jsonflags.Bools) (fs jsonflags.Flags) { return fs } -func TestCopyCoderOptions(t *testing.T) { - got := &Struct{ - Flags: makeFlags(jsonflags.Indent|jsonflags.AllowInvalidUTF8|0, jsonflags.Multiline|jsonflags.AllowDuplicateNames|jsonflags.Unmarshalers|1), - CoderValues: CoderValues{Indent: " "}, - ArshalValues: ArshalValues{Unmarshalers: "something"}, - } - src := &Struct{ - Flags: makeFlags(jsonflags.Indent|jsonflags.Deterministic|jsonflags.Marshalers|1, jsonflags.Multiline|0), - CoderValues: CoderValues{Indent: "\t"}, - ArshalValues: ArshalValues{Marshalers: "something"}, - } - want := &Struct{ - Flags: makeFlags(jsonflags.AllowInvalidUTF8|jsonflags.Multiline|0, jsonflags.Indent|jsonflags.AllowDuplicateNames|jsonflags.Unmarshalers|1), - CoderValues: CoderValues{Indent: "\t"}, - ArshalValues: ArshalValues{Unmarshalers: "something"}, - } - got.CopyCoderOptions(src) - if !reflect.DeepEqual(got, want) { - t.Errorf("CopyCoderOptions:\n\tgot: %+v\n\twant: %+v", got, want) - } -} - func TestJoin(t *testing.T) { tests := []struct { - in Options - want *Struct + in Options + excludeCoders bool + want *Struct }{{ in: jsonflags.AllowInvalidUTF8 | 1, want: &Struct{Flags: makeFlags(jsonflags.AllowInvalidUTF8 | 1)}, @@ -69,7 +48,8 @@ func TestJoin(t *testing.T) { CoderValues: CoderValues{Indent: "\t"}, }, }, { - in: &DefaultOptionsV1, want: func() *Struct { + in: &DefaultOptionsV1, + want: func() *Struct { v1 := DefaultOptionsV1 v1.Flags.Set(jsonflags.Indent | 1) v1.Flags.Set(jsonflags.Multiline | 0) @@ -77,17 +57,90 @@ func TestJoin(t *testing.T) { return &v1 }(), // v1 fully replaces before (except for whitespace related flags) }, { - in: &DefaultOptionsV2, want: func() *Struct { + in: &DefaultOptionsV2, + want: func() *Struct { v2 := DefaultOptionsV2 v2.Flags.Set(jsonflags.Indent | 1) v2.Flags.Set(jsonflags.Multiline | 0) v2.Indent = "\t" return &v2 }(), // v2 fully replaces before (except for whitespace related flags) + }, { + in: jsonflags.Deterministic | jsonflags.AllowInvalidUTF8 | 1, excludeCoders: true, + want: func() *Struct { + v2 := DefaultOptionsV2 + v2.Flags.Set(jsonflags.Deterministic | 1) + v2.Flags.Set(jsonflags.Indent | 1) + v2.Flags.Set(jsonflags.Multiline | 0) + v2.Indent = "\t" + return &v2 + }(), + }, { + in: jsontext.WithIndentPrefix(" "), excludeCoders: true, + want: func() *Struct { + v2 := DefaultOptionsV2 + v2.Flags.Set(jsonflags.Deterministic | 1) + v2.Flags.Set(jsonflags.Indent | 1) + v2.Flags.Set(jsonflags.Multiline | 0) + v2.Indent = "\t" + return &v2 + }(), + }, { + in: jsontext.WithIndentPrefix(" "), excludeCoders: false, + want: func() *Struct { + v2 := DefaultOptionsV2 + v2.Flags.Set(jsonflags.Deterministic | 1) + v2.Flags.Set(jsonflags.Indent | 1) + v2.Flags.Set(jsonflags.IndentPrefix | 1) + v2.Flags.Set(jsonflags.Multiline | 1) + v2.Indent = "\t" + v2.IndentPrefix = " " + return &v2 + }(), + }, { + in: &Struct{ + Flags: jsonflags.Flags{ + Presence: uint64(jsonflags.Deterministic | jsonflags.Indent | jsonflags.IndentPrefix), + Values: uint64(jsonflags.Indent | jsonflags.IndentPrefix), + }, + CoderValues: CoderValues{Indent: " ", IndentPrefix: " "}, + }, + excludeCoders: true, + want: func() *Struct { + v2 := DefaultOptionsV2 + v2.Flags.Set(jsonflags.Indent | 1) + v2.Flags.Set(jsonflags.IndentPrefix | 1) + v2.Flags.Set(jsonflags.Multiline | 1) + v2.Indent = "\t" + v2.IndentPrefix = " " + return &v2 + }(), + }, { + in: &Struct{ + Flags: jsonflags.Flags{ + Presence: uint64(jsonflags.Deterministic | jsonflags.Indent | jsonflags.IndentPrefix), + Values: uint64(jsonflags.Indent | jsonflags.IndentPrefix), + }, + CoderValues: CoderValues{Indent: " ", IndentPrefix: " "}, + }, + excludeCoders: false, + want: func() *Struct { + v2 := DefaultOptionsV2 + v2.Flags.Set(jsonflags.Indent | 1) + v2.Flags.Set(jsonflags.IndentPrefix | 1) + v2.Flags.Set(jsonflags.Multiline | 1) + v2.Indent = " " + v2.IndentPrefix = " " + return &v2 + }(), }} got := new(Struct) for i, tt := range tests { - got.Join(tt.in) + if tt.excludeCoders { + got.JoinWithoutCoderOptions(tt.in) + } else { + got.Join(tt.in) + } if !reflect.DeepEqual(got, tt.want) { t.Fatalf("%d: Join:\n\tgot: %+v\n\twant: %+v", i, got, tt.want) }