Skip to content

Commit

Permalink
encoding/xml: add (*Encoder).Close
Browse files Browse the repository at this point in the history
Flush can not check for unclosed elements, as more data might be encoded
after Flush is called. Close implicitly calls Flush and also checks that
all opened elements are closed as well.

Fixes golang#53346

Change-Id: I889b9f5ae54e5dfabb9e6948d96c5ed7bc1110f9
Reviewed-on: https://go-review.googlesource.com/c/go/+/424777
TryBot-Result: Gopher Robot <[email protected]>
Reviewed-by: Ian Lance Taylor <[email protected]>
Auto-Submit: Ian Lance Taylor <[email protected]>
Run-TryBot: Ian Lance Taylor <[email protected]>
Reviewed-by: David Chase <[email protected]>
  • Loading branch information
Merovius authored and gopherbot committed Aug 23, 2022
1 parent be9e244 commit 7f632f7
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 7 deletions.
1 change: 1 addition & 0 deletions api/next/53346.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pkg encoding/xml, method (*Encoder) Close() error #53346
81 changes: 74 additions & 7 deletions src/encoding/xml/marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"bufio"
"bytes"
"encoding"
"errors"
"fmt"
"io"
"reflect"
Expand Down Expand Up @@ -78,7 +79,11 @@ const (
// Marshal will return an error if asked to marshal a channel, function, or map.
func Marshal(v any) ([]byte, error) {
var b bytes.Buffer
if err := NewEncoder(&b).Encode(v); err != nil {
enc := NewEncoder(&b)
if err := enc.Encode(v); err != nil {
return nil, err
}
if err := enc.Close(); err != nil {
return nil, err
}
return b.Bytes(), nil
Expand Down Expand Up @@ -129,6 +134,9 @@ func MarshalIndent(v any, prefix, indent string) ([]byte, error) {
if err := enc.Encode(v); err != nil {
return nil, err
}
if err := enc.Close(); err != nil {
return nil, err
}
return b.Bytes(), nil
}

Expand All @@ -139,7 +147,7 @@ type Encoder struct {

// NewEncoder returns a new encoder that writes to w.
func NewEncoder(w io.Writer) *Encoder {
e := &Encoder{printer{Writer: bufio.NewWriter(w)}}
e := &Encoder{printer{w: bufio.NewWriter(w)}}
e.p.encoder = e
return e
}
Expand All @@ -163,7 +171,7 @@ func (enc *Encoder) Encode(v any) error {
if err != nil {
return err
}
return enc.p.Flush()
return enc.p.w.Flush()
}

// EncodeElement writes the XML encoding of v to the stream,
Expand All @@ -178,7 +186,7 @@ func (enc *Encoder) EncodeElement(v any, start StartElement) error {
if err != nil {
return err
}
return enc.p.Flush()
return enc.p.w.Flush()
}

var (
Expand Down Expand Up @@ -224,7 +232,7 @@ func (enc *Encoder) EncodeToken(t Token) error {
case ProcInst:
// First token to be encoded which is also a ProcInst with target of xml
// is the xml declaration. The only ProcInst where target of xml is allowed.
if t.Target == "xml" && p.Buffered() != 0 {
if t.Target == "xml" && p.w.Buffered() != 0 {
return fmt.Errorf("xml: EncodeToken of ProcInst xml target only valid for xml declaration, first token encoded")
}
if !isNameString(t.Target) {
Expand Down Expand Up @@ -297,11 +305,18 @@ func isValidDirective(dir Directive) bool {
// Flush flushes any buffered XML to the underlying writer.
// See the EncodeToken documentation for details about when it is necessary.
func (enc *Encoder) Flush() error {
return enc.p.Flush()
return enc.p.w.Flush()
}

// Close the Encoder, indicating that no more data will be written. It flushes
// any buffered XML to the underlying writer and returns an error if the
// written XML is invalid (e.g. by containing unclosed elements).
func (enc *Encoder) Close() error {
return enc.p.Close()
}

type printer struct {
*bufio.Writer
w *bufio.Writer
encoder *Encoder
seq int
indent string
Expand All @@ -313,6 +328,8 @@ type printer struct {
attrPrefix map[string]string // map name space -> prefix
prefixes []string
tags []Name
closed bool
err error
}

// createAttrPrefix finds the name space prefix attribute to use for the given name space,
Expand Down Expand Up @@ -961,6 +978,56 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error {
return p.cachedWriteError()
}

// Write implements io.Writer
func (p *printer) Write(b []byte) (n int, err error) {
if p.closed && p.err == nil {
p.err = errors.New("use of closed Encoder")
}
if p.err == nil {
n, p.err = p.w.Write(b)
}
return n, p.err
}

// WriteString implements io.StringWriter
func (p *printer) WriteString(s string) (n int, err error) {
if p.closed && p.err == nil {
p.err = errors.New("use of closed Encoder")
}
if p.err == nil {
n, p.err = p.w.WriteString(s)
}
return n, p.err
}

// WriteByte implements io.ByteWriter
func (p *printer) WriteByte(c byte) error {
if p.closed && p.err == nil {
p.err = errors.New("use of closed Encoder")
}
if p.err == nil {
p.err = p.w.WriteByte(c)
}
return p.err
}

// Close the Encoder, indicating that no more data will be written. It flushes
// any buffered XML to the underlying writer and returns an error if the
// written XML is invalid (e.g. by containing unclosed elements).
func (p *printer) Close() error {
if p.closed {
return nil
}
p.closed = true
if err := p.w.Flush(); err != nil {
return err
}
if len(p.tags) > 0 {
return fmt.Errorf("unclosed tag <%s>", p.tags[len(p.tags)-1].Local)
}
return nil
}

// return the bufio Writer's cached write error
func (p *printer) cachedWriteError() error {
_, err := p.Write(nil)
Expand Down
58 changes: 58 additions & 0 deletions src/encoding/xml/marshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2531,3 +2531,61 @@ func TestMarshalZeroValue(t *testing.T) {
t.Fatalf("unexpected unmarshal result, want %q but got %q", proofXml, anotherXML)
}
}

var closeTests = []struct {
desc string
toks []Token
want string
err string
}{{
desc: "unclosed start element",
toks: []Token{
StartElement{Name{"", "foo"}, nil},
},
want: `<foo>`,
err: "unclosed tag <foo>",
}, {
desc: "closed element",
toks: []Token{
StartElement{Name{"", "foo"}, nil},
EndElement{Name{"", "foo"}},
},
want: `<foo></foo>`,
}, {
desc: "directive",
toks: []Token{
Directive("foo"),
},
want: `<!foo>`,
}}

func TestClose(t *testing.T) {
for _, tt := range closeTests {
tt := tt
t.Run(tt.desc, func(t *testing.T) {
var out strings.Builder
enc := NewEncoder(&out)
for j, tok := range tt.toks {
if err := enc.EncodeToken(tok); err != nil {
t.Fatalf("token #%d: %v", j, err)
}
}
err := enc.Close()
switch {
case tt.err != "" && err == nil:
t.Error(" expected error; got none")
case tt.err == "" && err != nil:
t.Errorf(" got error: %v", err)
case tt.err != "" && err != nil && tt.err != err.Error():
t.Errorf(" error mismatch; got %v, want %v", err, tt.err)
}
if got := out.String(); got != tt.want {
t.Errorf("\ngot %v\nwant %v", got, tt.want)
}
t.Log(enc.p.closed)
if err := enc.EncodeToken(Directive("foo")); err == nil {
t.Errorf("unexpected success when encoding after Close")
}
})
}
}

0 comments on commit 7f632f7

Please sign in to comment.