diff --git a/api/next/53346.txt b/api/next/53346.txt new file mode 100644 index 00000000000000..dd39f231d5dc26 --- /dev/null +++ b/api/next/53346.txt @@ -0,0 +1 @@ +pkg encoding/xml, method (*Encoder) Close() error #53346 diff --git a/src/encoding/xml/marshal.go b/src/encoding/xml/marshal.go index 01f673a85165a0..07b6042da87bf9 100644 --- a/src/encoding/xml/marshal.go +++ b/src/encoding/xml/marshal.go @@ -8,6 +8,7 @@ import ( "bufio" "bytes" "encoding" + "errors" "fmt" "io" "reflect" @@ -78,7 +79,11 @@ const ( // Marshal will return an error if asked to marshal a channel, function, or map. func Marshal(v any) ([]byte, error) { var b bytes.Buffer - if err := NewEncoder(&b).Encode(v); err != nil { + enc := NewEncoder(&b) + if err := enc.Encode(v); err != nil { + return nil, err + } + if err := enc.Close(); err != nil { return nil, err } return b.Bytes(), nil @@ -129,6 +134,9 @@ func MarshalIndent(v any, prefix, indent string) ([]byte, error) { if err := enc.Encode(v); err != nil { return nil, err } + if err := enc.Close(); err != nil { + return nil, err + } return b.Bytes(), nil } @@ -139,7 +147,7 @@ type Encoder struct { // NewEncoder returns a new encoder that writes to w. func NewEncoder(w io.Writer) *Encoder { - e := &Encoder{printer{Writer: bufio.NewWriter(w)}} + e := &Encoder{printer{w: bufio.NewWriter(w)}} e.p.encoder = e return e } @@ -163,7 +171,7 @@ func (enc *Encoder) Encode(v any) error { if err != nil { return err } - return enc.p.Flush() + return enc.p.w.Flush() } // EncodeElement writes the XML encoding of v to the stream, @@ -178,7 +186,7 @@ func (enc *Encoder) EncodeElement(v any, start StartElement) error { if err != nil { return err } - return enc.p.Flush() + return enc.p.w.Flush() } var ( @@ -224,7 +232,7 @@ func (enc *Encoder) EncodeToken(t Token) error { case ProcInst: // First token to be encoded which is also a ProcInst with target of xml // is the xml declaration. The only ProcInst where target of xml is allowed. - if t.Target == "xml" && p.Buffered() != 0 { + if t.Target == "xml" && p.w.Buffered() != 0 { return fmt.Errorf("xml: EncodeToken of ProcInst xml target only valid for xml declaration, first token encoded") } if !isNameString(t.Target) { @@ -297,11 +305,18 @@ func isValidDirective(dir Directive) bool { // Flush flushes any buffered XML to the underlying writer. // See the EncodeToken documentation for details about when it is necessary. func (enc *Encoder) Flush() error { - return enc.p.Flush() + return enc.p.w.Flush() +} + +// Close the Encoder, indicating that no more data will be written. It flushes +// any buffered XML to the underlying writer and returns an error if the +// written XML is invalid (e.g. by containing unclosed elements). +func (enc *Encoder) Close() error { + return enc.p.Close() } type printer struct { - *bufio.Writer + w *bufio.Writer encoder *Encoder seq int indent string @@ -313,6 +328,8 @@ type printer struct { attrPrefix map[string]string // map name space -> prefix prefixes []string tags []Name + closed bool + err error } // createAttrPrefix finds the name space prefix attribute to use for the given name space, @@ -961,6 +978,56 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error { return p.cachedWriteError() } +// Write implements io.Writer +func (p *printer) Write(b []byte) (n int, err error) { + if p.closed && p.err == nil { + p.err = errors.New("use of closed Encoder") + } + if p.err == nil { + n, p.err = p.w.Write(b) + } + return n, p.err +} + +// WriteString implements io.StringWriter +func (p *printer) WriteString(s string) (n int, err error) { + if p.closed && p.err == nil { + p.err = errors.New("use of closed Encoder") + } + if p.err == nil { + n, p.err = p.w.WriteString(s) + } + return n, p.err +} + +// WriteByte implements io.ByteWriter +func (p *printer) WriteByte(c byte) error { + if p.closed && p.err == nil { + p.err = errors.New("use of closed Encoder") + } + if p.err == nil { + p.err = p.w.WriteByte(c) + } + return p.err +} + +// Close the Encoder, indicating that no more data will be written. It flushes +// any buffered XML to the underlying writer and returns an error if the +// written XML is invalid (e.g. by containing unclosed elements). +func (p *printer) Close() error { + if p.closed { + return nil + } + p.closed = true + if err := p.w.Flush(); err != nil { + return err + } + if len(p.tags) > 0 { + return fmt.Errorf("unclosed tag <%s>", p.tags[len(p.tags)-1].Local) + } + return nil +} + // return the bufio Writer's cached write error func (p *printer) cachedWriteError() error { _, err := p.Write(nil) diff --git a/src/encoding/xml/marshal_test.go b/src/encoding/xml/marshal_test.go index 3fe7e2dc001c4d..774793a6c54ff5 100644 --- a/src/encoding/xml/marshal_test.go +++ b/src/encoding/xml/marshal_test.go @@ -2531,3 +2531,61 @@ func TestMarshalZeroValue(t *testing.T) { t.Fatalf("unexpected unmarshal result, want %q but got %q", proofXml, anotherXML) } } + +var closeTests = []struct { + desc string + toks []Token + want string + err string +}{{ + desc: "unclosed start element", + toks: []Token{ + StartElement{Name{"", "foo"}, nil}, + }, + want: ``, + err: "unclosed tag ", +}, { + desc: "closed element", + toks: []Token{ + StartElement{Name{"", "foo"}, nil}, + EndElement{Name{"", "foo"}}, + }, + want: ``, +}, { + desc: "directive", + toks: []Token{ + Directive("foo"), + }, + want: ``, +}} + +func TestClose(t *testing.T) { + for _, tt := range closeTests { + tt := tt + t.Run(tt.desc, func(t *testing.T) { + var out strings.Builder + enc := NewEncoder(&out) + for j, tok := range tt.toks { + if err := enc.EncodeToken(tok); err != nil { + t.Fatalf("token #%d: %v", j, err) + } + } + err := enc.Close() + switch { + case tt.err != "" && err == nil: + t.Error(" expected error; got none") + case tt.err == "" && err != nil: + t.Errorf(" got error: %v", err) + case tt.err != "" && err != nil && tt.err != err.Error(): + t.Errorf(" error mismatch; got %v, want %v", err, tt.err) + } + if got := out.String(); got != tt.want { + t.Errorf("\ngot %v\nwant %v", got, tt.want) + } + t.Log(enc.p.closed) + if err := enc.EncodeToken(Directive("foo")); err == nil { + t.Errorf("unexpected success when encoding after Close") + } + }) + } +}