From 7f632f76db65f28038b8797fbfe5e15a9f63e387 Mon Sep 17 00:00:00 2001 From: Axel Wagner Date: Thu, 18 Aug 2022 16:18:34 +0200 Subject: [PATCH] encoding/xml: add (*Encoder).Close Flush can not check for unclosed elements, as more data might be encoded after Flush is called. Close implicitly calls Flush and also checks that all opened elements are closed as well. Fixes #53346 Change-Id: I889b9f5ae54e5dfabb9e6948d96c5ed7bc1110f9 Reviewed-on: https://go-review.googlesource.com/c/go/+/424777 TryBot-Result: Gopher Robot Reviewed-by: Ian Lance Taylor Auto-Submit: Ian Lance Taylor Run-TryBot: Ian Lance Taylor Reviewed-by: David Chase --- api/next/53346.txt | 1 + src/encoding/xml/marshal.go | 81 +++++++++++++++++++++++++++++--- src/encoding/xml/marshal_test.go | 58 +++++++++++++++++++++++ 3 files changed, 133 insertions(+), 7 deletions(-) create mode 100644 api/next/53346.txt diff --git a/api/next/53346.txt b/api/next/53346.txt new file mode 100644 index 00000000000000..dd39f231d5dc26 --- /dev/null +++ b/api/next/53346.txt @@ -0,0 +1 @@ +pkg encoding/xml, method (*Encoder) Close() error #53346 diff --git a/src/encoding/xml/marshal.go b/src/encoding/xml/marshal.go index 01f673a85165a0..07b6042da87bf9 100644 --- a/src/encoding/xml/marshal.go +++ b/src/encoding/xml/marshal.go @@ -8,6 +8,7 @@ import ( "bufio" "bytes" "encoding" + "errors" "fmt" "io" "reflect" @@ -78,7 +79,11 @@ const ( // Marshal will return an error if asked to marshal a channel, function, or map. func Marshal(v any) ([]byte, error) { var b bytes.Buffer - if err := NewEncoder(&b).Encode(v); err != nil { + enc := NewEncoder(&b) + if err := enc.Encode(v); err != nil { + return nil, err + } + if err := enc.Close(); err != nil { return nil, err } return b.Bytes(), nil @@ -129,6 +134,9 @@ func MarshalIndent(v any, prefix, indent string) ([]byte, error) { if err := enc.Encode(v); err != nil { return nil, err } + if err := enc.Close(); err != nil { + return nil, err + } return b.Bytes(), nil } @@ -139,7 +147,7 @@ type Encoder struct { // NewEncoder returns a new encoder that writes to w. func NewEncoder(w io.Writer) *Encoder { - e := &Encoder{printer{Writer: bufio.NewWriter(w)}} + e := &Encoder{printer{w: bufio.NewWriter(w)}} e.p.encoder = e return e } @@ -163,7 +171,7 @@ func (enc *Encoder) Encode(v any) error { if err != nil { return err } - return enc.p.Flush() + return enc.p.w.Flush() } // EncodeElement writes the XML encoding of v to the stream, @@ -178,7 +186,7 @@ func (enc *Encoder) EncodeElement(v any, start StartElement) error { if err != nil { return err } - return enc.p.Flush() + return enc.p.w.Flush() } var ( @@ -224,7 +232,7 @@ func (enc *Encoder) EncodeToken(t Token) error { case ProcInst: // First token to be encoded which is also a ProcInst with target of xml // is the xml declaration. The only ProcInst where target of xml is allowed. - if t.Target == "xml" && p.Buffered() != 0 { + if t.Target == "xml" && p.w.Buffered() != 0 { return fmt.Errorf("xml: EncodeToken of ProcInst xml target only valid for xml declaration, first token encoded") } if !isNameString(t.Target) { @@ -297,11 +305,18 @@ func isValidDirective(dir Directive) bool { // Flush flushes any buffered XML to the underlying writer. // See the EncodeToken documentation for details about when it is necessary. func (enc *Encoder) Flush() error { - return enc.p.Flush() + return enc.p.w.Flush() +} + +// Close the Encoder, indicating that no more data will be written. It flushes +// any buffered XML to the underlying writer and returns an error if the +// written XML is invalid (e.g. by containing unclosed elements). +func (enc *Encoder) Close() error { + return enc.p.Close() } type printer struct { - *bufio.Writer + w *bufio.Writer encoder *Encoder seq int indent string @@ -313,6 +328,8 @@ type printer struct { attrPrefix map[string]string // map name space -> prefix prefixes []string tags []Name + closed bool + err error } // createAttrPrefix finds the name space prefix attribute to use for the given name space, @@ -961,6 +978,56 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error { return p.cachedWriteError() } +// Write implements io.Writer +func (p *printer) Write(b []byte) (n int, err error) { + if p.closed && p.err == nil { + p.err = errors.New("use of closed Encoder") + } + if p.err == nil { + n, p.err = p.w.Write(b) + } + return n, p.err +} + +// WriteString implements io.StringWriter +func (p *printer) WriteString(s string) (n int, err error) { + if p.closed && p.err == nil { + p.err = errors.New("use of closed Encoder") + } + if p.err == nil { + n, p.err = p.w.WriteString(s) + } + return n, p.err +} + +// WriteByte implements io.ByteWriter +func (p *printer) WriteByte(c byte) error { + if p.closed && p.err == nil { + p.err = errors.New("use of closed Encoder") + } + if p.err == nil { + p.err = p.w.WriteByte(c) + } + return p.err +} + +// Close the Encoder, indicating that no more data will be written. It flushes +// any buffered XML to the underlying writer and returns an error if the +// written XML is invalid (e.g. by containing unclosed elements). +func (p *printer) Close() error { + if p.closed { + return nil + } + p.closed = true + if err := p.w.Flush(); err != nil { + return err + } + if len(p.tags) > 0 { + return fmt.Errorf("unclosed tag <%s>", p.tags[len(p.tags)-1].Local) + } + return nil +} + // return the bufio Writer's cached write error func (p *printer) cachedWriteError() error { _, err := p.Write(nil) diff --git a/src/encoding/xml/marshal_test.go b/src/encoding/xml/marshal_test.go index 3fe7e2dc001c4d..774793a6c54ff5 100644 --- a/src/encoding/xml/marshal_test.go +++ b/src/encoding/xml/marshal_test.go @@ -2531,3 +2531,61 @@ func TestMarshalZeroValue(t *testing.T) { t.Fatalf("unexpected unmarshal result, want %q but got %q", proofXml, anotherXML) } } + +var closeTests = []struct { + desc string + toks []Token + want string + err string +}{{ + desc: "unclosed start element", + toks: []Token{ + StartElement{Name{"", "foo"}, nil}, + }, + want: ``, + err: "unclosed tag ", +}, { + desc: "closed element", + toks: []Token{ + StartElement{Name{"", "foo"}, nil}, + EndElement{Name{"", "foo"}}, + }, + want: ``, +}, { + desc: "directive", + toks: []Token{ + Directive("foo"), + }, + want: ``, +}} + +func TestClose(t *testing.T) { + for _, tt := range closeTests { + tt := tt + t.Run(tt.desc, func(t *testing.T) { + var out strings.Builder + enc := NewEncoder(&out) + for j, tok := range tt.toks { + if err := enc.EncodeToken(tok); err != nil { + t.Fatalf("token #%d: %v", j, err) + } + } + err := enc.Close() + switch { + case tt.err != "" && err == nil: + t.Error(" expected error; got none") + case tt.err == "" && err != nil: + t.Errorf(" got error: %v", err) + case tt.err != "" && err != nil && tt.err != err.Error(): + t.Errorf(" error mismatch; got %v, want %v", err, tt.err) + } + if got := out.String(); got != tt.want { + t.Errorf("\ngot %v\nwant %v", got, tt.want) + } + t.Log(enc.p.closed) + if err := enc.EncodeToken(Directive("foo")); err == nil { + t.Errorf("unexpected success when encoding after Close") + } + }) + } +}