diff --git a/decoder.go b/decoder.go index a0f3ad6..934b621 100644 --- a/decoder.go +++ b/decoder.go @@ -35,6 +35,7 @@ type decoder struct { read byte write byte top byte + lastStartElement bool } // NewDecoder creates a new Decoder. @@ -130,15 +131,24 @@ func (thiz *decoder) NextToken(t *Token) error { // That's fine. We just did not consume the end token // because there could have been an implicit // "/>" close at the end of the start element. + thiz.lastStartElement = false case '/': - // Immediately closing last openend StartElement. - // This will generate an EndElement with the same - // name that we used in the previous StartElement. - _, err = thiz.discard(1) - if err != nil { + if thiz.lastStartElement { + // Immediately closing last openend StartElement. + // This will generate an EndElement with the same + // name that we used in the previous StartElement. + _, err = thiz.discard(1) + if err != nil { + return err + } + thiz.lastStartElement = false + return thiz.decodeEndElement(t, thiz.lastOpen) + } + thiz.unreadByte() + cntn, err := thiz.decodeText(t) + if err != nil || !cntn { return err } - return thiz.decodeEndElement(t, thiz.lastOpen) case '<': b, err = thiz.readByte() if err != nil { @@ -146,6 +156,7 @@ func (thiz *decoder) NextToken(t *Token) error { } switch b { case '?': + thiz.lastStartElement = false err = thiz.decodeProcInst(t) thiz.unreadByte() return err @@ -162,6 +173,7 @@ func (thiz *decoder) NextToken(t *Token) error { return err } case '[': + thiz.lastStartElement = false return thiz.readCDATA() default: return errors.New("invalid XML: comment or CDATA expected") @@ -172,11 +184,14 @@ func (thiz *decoder) NextToken(t *Token) error { if err != nil { return err } + thiz.lastStartElement = false return thiz.decodeEndElement(t, name) default: + thiz.lastStartElement = true return thiz.decodeStartElement(t) } default: + thiz.lastStartElement = false thiz.unreadByte() cntn, err := thiz.decodeText(t) if err != nil || !cntn { diff --git a/fuzzing_test.go b/fuzzing_test.go new file mode 100644 index 0000000..f39d6f5 --- /dev/null +++ b/fuzzing_test.go @@ -0,0 +1,119 @@ +package gosaxml_test + +import ( + "bytes" + "github.com/HBTGmbH/gosaxml" + "github.com/stretchr/testify/assert" + "io" + "math/rand" + "testing" +) + +var startNameRunes = []rune(":-_abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") +var restNameRunes = []rune("0123456789-_abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") +var stringRunes = []rune("/:+*#.!§$%&/[]=?`´'0123456789-_abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") +var textRunes = []rune("\"/:+*#'.!§$%&[]=?`´'0123456789-_abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + +func randName(r *rand.Rand) string { + c := 1 + r.Intn(10) + b := make([]rune, c) + b[0] = startNameRunes[r.Intn(len(startNameRunes))] + for i := 1; i < c; i++ { + b[i] = restNameRunes[r.Intn(len(restNameRunes))] + } + return string(b) +} + +func randText(r *rand.Rand) string { + c := 1 + r.Intn(255) + b := make([]rune, c) + for i := 0; i < c; i++ { + b[i] = textRunes[r.Intn(len(textRunes))] + } + return string(b) +} + +func randString(r *rand.Rand) string { + c := r.Intn(30) + b := make([]rune, c) + for i := 0; i < c; i++ { + b[i] = stringRunes[r.Intn(len(stringRunes))] + } + return string(b) +} + +func buildElement(i int, b *bytes.Buffer, r *rand.Rand, lastOpen bool) bool { + switch i { + case 0: + if lastOpen { + b.WriteString(">") + } + name := randName(r) + b.WriteString("<") + b.WriteString(name) + numAttrs := r.Intn(10) + for j := 0; j < numAttrs; j++ { + b.WriteString(" ") + buildAttribute(b, r) + } + ended := buildElement(r.Intn(2), b, r, true) + if !ended { + b.WriteString("") + } + return false + case 1: + if lastOpen { + b.WriteString(">") + } + b.WriteString(randText(r)) + return false + default: + b.WriteString("/>") + return true + } +} + +func buildAttribute(b *bytes.Buffer, r *rand.Rand) { + name := randName(r) + value := randString(r) + randName(r) + b.WriteString(name) + b.WriteString("=\"") + b.WriteString(value) + b.WriteString("\"") +} + +func TestFuzz(t *testing.T) { + // given + s1 := rand.NewSource(123456789) + r := rand.New(s1) + n := 100000 + + for i := 0; i < n; i++ { + b := &bytes.Buffer{} + buildElement(0, b, r, false) + xml := b.String() + reader := bytes.NewReader(b.Bytes()) + dec := gosaxml.NewDecoder(reader) + w := &bytes.Buffer{} + enc := gosaxml.NewEncoder(w, gosaxml.NewNamespaceModifier()) + var tk gosaxml.Token + + // when + for { + err := dec.NextToken(&tk) + if err == io.EOF { + break + } + assert.Nil(t, err) + err = enc.EncodeToken(&tk) + assert.Nil(t, err) + } + assert.Nil(t, enc.Flush()) + + // then + assert.Equal(t, xml, w.String()) + } +}