From 4cdcd8a360e6f62ba68914effb0fc4e2fcc92382 Mon Sep 17 00:00:00 2001 From: Mateusz Poliwczak Date: Thu, 11 May 2023 16:00:14 +0200 Subject: [PATCH] initial --- binary.go | 26 +++ builder.go | 192 +++++++++++++++++++ builder_test.go | 145 +++++++++++++++ go.mod | 3 + parser.go | 483 ++++++++++++++++++++++++++++++++++++++++++++++++ parser_test.go | 370 +++++++++++++++++++++++++++++++++++++ types.go | 177 ++++++++++++++++++ 7 files changed, 1396 insertions(+) create mode 100644 binary.go create mode 100644 builder.go create mode 100644 builder_test.go create mode 100644 go.mod create mode 100644 parser.go create mode 100644 parser_test.go create mode 100644 types.go diff --git a/binary.go b/binary.go new file mode 100644 index 0000000..c081a07 --- /dev/null +++ b/binary.go @@ -0,0 +1,26 @@ +package dnsmsg + +//TODO: use binary.BigEndian in entire package (because of lower cost (golang/go#42958)) (but probably after golang/go#54097 gets fixed) + +func unpackUint16(b []byte) uint16 { + _ = b[1] + return uint16(b[0])<<8 | uint16(b[1]) +} + +func unpackUint32(b []byte) uint32 { + _ = b[3] + return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3]) +} + +func appendUint16(b []byte, v uint16) []byte { + return append(b, byte(v>>8), byte(v)) +} + +func appendUint32(b []byte, v uint32) []byte { + return append(b, + byte(v>>24), + byte(v>>16), + byte(v>>8), + byte(v), + ) +} diff --git a/builder.go b/builder.go new file mode 100644 index 0000000..c0f4f5c --- /dev/null +++ b/builder.go @@ -0,0 +1,192 @@ +package dnsmsg + +import ( + "errors" +) + +func MakeQuery[T name](msg []byte, id uint16, flags Flags, q Question[T]) []byte { + // Header + msg = appendUint16(msg, id) + msg = appendUint16(msg, uint16(flags)) + msg = appendUint16(msg, 1) + msg = appendUint16(msg, 0) + msg = appendUint16(msg, 0) + msg = appendUint16(msg, 0) + + // Question + appendName(msg, q.Name) + + msg = appendUint16(msg, uint16(q.Type)) + msg = appendUint16(msg, uint16(q.Class)) + return msg +} + +func MakeQueryWithEDNS0[T name](msg []byte, id uint16, flags Flags, q Question[T], ends0 EDNS0) []byte { + // Header + msg = appendUint16(msg, id) + msg = appendUint16(msg, uint16(flags)) + msg = appendUint16(msg, 1) + msg = appendUint16(msg, 0) + msg = appendUint16(msg, 0) + msg = appendUint16(msg, 1) + + // Question + appendName(msg, q.Name) + msg = appendUint16(msg, uint16(q.Type)) + msg = appendUint16(msg, uint16(q.Class)) + + // EDNS0 + msg = append(msg, 0) // root name + msg = appendUint16(msg, uint16(TypeOPT)) + msg = appendUint16(msg, ends0.Payload) + + // TODO: support rest of EDNS0 stuff. + msg = appendUint32(msg, 0) + msg = appendUint16(msg, 0) + return msg +} + +func appendName[T name](buf []byte, n T) []byte { + switch n := any(n).(type) { + case Name: + return appendEscapedName(buf, n.name) + case ParserName: + return n.appendRawName(buf) + default: + panic("appendName: unsupported name type") + } +} + +var errInvalidName = errors.New("invalid name") + +type Name struct { + name string +} + +func NewName(name string) (Name, error) { + if !isValidEscapedName(name) { + return Name{}, errInvalidName + } + return Name{name: name}, nil +} + +func isValidEscapedName(m string) bool { + if m == "" { + return false + } + + if m == "." { + return true + } + + labelLength := 0 + nameLength := 0 + inEscape := false + rooted := false + + for i := 0; i < len(m); i++ { + char := m[i] + rooted = false + + switch char { + case '.': + if inEscape { + labelLength++ + inEscape = false + continue + } + if labelLength == 0 || labelLength > maxLabelLength { + return false + } + rooted = true + nameLength += labelLength + 1 + labelLength = 0 + case '\\': + inEscape = !inEscape + if !inEscape { + labelLength++ + } + default: + if inEscape && isDigit(char) { + if len(m[i:]) < 3 || !isDigit(m[i+1]) || !isDigit(m[i+2]) { + return false + } + if _, ok := decodeDDD([3]byte{char, m[i+1], m[i+2]}); !ok { + return false + } + i += 2 + } + inEscape = false + labelLength++ + } + } + + nameLength += labelLength + + if inEscape { + return false + } + + if nameLength > 254 || nameLength == 254 && !rooted { + return false + } + + return true +} + +func appendEscapedName(buf []byte, m string) []byte { + labelLength := byte(0) + + labelIndex := len(buf) + buf = append(buf, 0) + + for i := 0; i < len(m); i++ { + char := m[i] + switch char { + case '.': + buf[labelIndex] = labelLength + labelLength = 0 + labelIndex = len(buf) + buf = append(buf, 0) + case '\\': + if isDigit(m[i+1]) { + labelLength++ + ddd, _ := decodeDDD([3]byte{m[i+1], m[i+2], m[i+3]}) + buf = append(buf, ddd) + i += 3 + continue + } + buf = append(buf, m[i+1]) + i += 1 + labelLength++ + default: + labelLength++ + buf = append(buf, char) + } + } + + if labelLength != 0 { + buf[labelIndex] = labelLength + } + + if buf[len(buf)-1] != 0 { + buf = append(buf, 0) + } + + return buf +} + +func isDigit(char byte) bool { + return char >= '0' && char <= '9' +} + +func decodeDDD(ddd [3]byte) (uint8, bool) { + ddd[0] -= '0' + ddd[1] -= '0' + ddd[2] -= '0' + num := uint16(ddd[0])*100 + uint16(ddd[1])*10 + uint16(ddd[2]) + if num > 255 { + return 0, false + } + return uint8(num), true +} diff --git a/builder_test.go b/builder_test.go new file mode 100644 index 0000000..0768111 --- /dev/null +++ b/builder_test.go @@ -0,0 +1,145 @@ +package dnsmsg + +import ( + "fmt" + "strings" + "testing" +) + +var ( + escapes = "\\.\\223\\.\\\\" + escapesCharCount = 4 + label54 = escapes + strings.Repeat("a", 54-2*escapesCharCount) + escapes + label63 = escapes + strings.Repeat("a", 63-2*escapesCharCount) + escapes + label64 = escapes + strings.Repeat("a", 64-2*escapesCharCount) + escapes +) + +var newNameTests = []struct { + name string + ok bool + diferentAsString bool +}{ + {name: "", ok: false}, + {name: ".", ok: true}, + {name: "com.", ok: true}, + {name: "com", ok: true}, + + {name: "go.dev", ok: true}, + {name: "go.dev.", ok: true}, + {name: "www.go.dev", ok: true}, + {name: "www.go.dev.", ok: true}, + + {name: "www..go.dev", ok: false}, + {name: ".www.go.dev", ok: false}, + {name: "..www.go.dev", ok: false}, + {name: "www.go.dev..", ok: false}, + + {name: "www.go.dev\\.", ok: true}, + {name: "www.go.dev\\..", ok: true}, + {name: "www.go.dev\\...", ok: false}, + {name: "www\\..go.dev", ok: true}, + {name: "www\\...go.dev", ok: false}, + + {name: "\\\\www.go.dev.", ok: true}, + {name: "\\\\www.go.dev.", ok: true}, + {name: "www.go.dev\\\\\\.", ok: true}, + {name: "www.go.dev\\\\\\.", ok: true}, + {name: "\\ww\\ w.go.dev", ok: true, diferentAsString: true}, + {name: "ww\\w.go.dev", ok: true, diferentAsString: true}, + {name: "www.go.dev\\\\", ok: true}, + + {name: "\\223www.go.dev", ok: true}, + {name: "\\000www.go.dev", ok: true}, + {name: "\\255www.go.dev", ok: true}, + + {name: "\\256www.go.dev", ok: false}, + {name: "\\999www.go.dev", ok: false}, + {name: "\\12www.go.dev", ok: false}, + {name: "\\1www.go.dev", ok: false}, + {name: "www.go.dev\\223", ok: true}, + {name: "www.go.dev\\12", ok: false}, + {name: "www.go.dev\\1", ok: false}, + {name: "www.go.dev\\", ok: false}, + + {name: label63 + ".go.dev", ok: true}, + {name: label64 + ".go.dev", ok: false}, + + // 253B non-rooted name. + { + name: fmt.Sprintf("%[1]v.%[1]v.%[1]v.%v.go.dev", label63, label54), + ok: true, + }, + + // 254B rooted name. + { + name: fmt.Sprintf("%[1]v.%[1]v.%[1]v.%v.go.dev.", label63, label54), + ok: true, + }, + + // 254B non-rooted name. + { + name: fmt.Sprintf("%[1]v.%[1]v.%[1]v.%va.go.dev", label63, label54), + ok: false, + }, + + // 255B rooted name. + { + name: fmt.Sprintf("%[1]v.%[1]v.%[1]v.%va.go.dev.", label63, label54), + ok: false, + }, +} + +func TestNewName(t *testing.T) { + for _, v := range newNameTests { + _, err := NewName(v.name) + expectErr := errInvalidName + if v.ok { + expectErr = nil + } + if expectErr != err { + t.Errorf("'%v' got error: %v, expected: %v", v.name, err, expectErr) + } + } +} + +func TestAppendEscapedName(t *testing.T) { + for _, v := range newNameTests { + n, err := NewName(v.name) + if err != nil { + continue + } + + packedName := appendEscapedName(nil, v.name) + + p, err := NewParser(packedName) + if err != nil { + continue + } + + name := ParserName{m: &p, nameStart: 0} + _, err = name.unpack() + if err != nil { + t.Errorf("'%v' failed while unpacking packed name: %v\n\traw: %v", v.name, err, packedName) + continue + } + + if !name.EqualName(n) { + t.Errorf("'%v' ParserName is not equal to name\n\traw: %v", v.name, packedName) + continue + } + + if v.diferentAsString { + continue + } + + expectName := v.name + dotAtEnd := expectName[len(expectName)-1] == '.' + if !dotAtEnd || (len(expectName) > 2 && dotAtEnd && expectName[len(expectName)-2] == '\\') { + expectName += "." + } + + if name := name.String(); name != expectName { + t.Errorf("'%v' got name: %v, expected: %v\n\traw: %v", v.name, name, expectName, packedName) + } + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..c5d0530 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/mateusz834/dnsmsg + +go 1.20 diff --git a/parser.go b/parser.go new file mode 100644 index 0000000..865fed9 --- /dev/null +++ b/parser.go @@ -0,0 +1,483 @@ +package dnsmsg + +import ( + "errors" + "strings" +) + +const ( + maxDNSMessageLength = 1<<16 - 1 + maxEncodedNameLen = 255 + maxLabelLength = 63 +) + +var ( + errInvalidDNSMessage = errors.New("invalid dns message") + errInvalidDNSName = errors.New("invalid dns name encoding") + errPtrLoop = errors.New("dns compression pointer loop") + errDNSMsgTooLong = errors.New("too long dns message") +) + +func NewParser(msg []byte) (Parser, error) { + if len(msg) > maxDNSMessageLength { + return Parser{}, errDNSMsgTooLong + } + + return Parser{ + msg: msg, + }, nil +} + +type Parser struct { + msg []byte + curOffset uint16 +} + +func (m *Parser) availMsgData() uint16 { + return uint16(len(m.msg)) - uint16(m.curOffset) +} + +func (m *Parser) Header() (Header, error) { + var hdr Header + if m.availMsgData() < headerLen { + return hdr, errInvalidDNSMessage + } + + hdr.unpack([headerLen]byte(m.msg[m.curOffset:])) + m.curOffset += headerLen + return hdr, nil +} + +func (m *Parser) Question() (Question[ParserName], error) { + q := Question[ParserName]{ + Name: ParserName{ + m: m, + nameStart: m.curOffset, + }, + } + + offset, err := q.Name.unpack() + if err != nil { + return Question[ParserName]{}, err + } + + tmpOffset := m.curOffset + offset + + if len(m.msg[tmpOffset:]) < 4 { + return Question[ParserName]{}, errInvalidDNSMessage + } + + q.Type = Type(unpackUint16(m.msg[tmpOffset : tmpOffset+2])) + q.Class = Class(unpackUint16(m.msg[tmpOffset+2 : tmpOffset+4])) + m.curOffset = tmpOffset + 4 + + return q, nil +} + +func (m *Parser) ResourceHeader() (ResourceHeader[ParserName], error) { + q := ResourceHeader[ParserName]{ + Name: ParserName{ + m: m, + nameStart: m.curOffset, + }, + } + + offset, err := q.Name.unpack() + if err != nil { + return ResourceHeader[ParserName]{}, err + } + + tmpOffset := m.curOffset + offset + + if m.availMsgData() < 10 { + return ResourceHeader[ParserName]{}, errInvalidDNSMessage + } + + q.Type = Type(unpackUint16(m.msg[tmpOffset : tmpOffset+2])) + q.Class = Class(unpackUint16(m.msg[tmpOffset+2 : tmpOffset+4])) + q.TTL = unpackUint32(m.msg[tmpOffset+4 : tmpOffset+8]) + q.Length = unpackUint16(m.msg[tmpOffset+8 : tmpOffset+10]) + m.curOffset = tmpOffset + 10 + + return q, nil +} + +func (m *Parser) Skip(length int) error { + if int(m.availMsgData()) < length { + return errInvalidDNSMessage + } + m.curOffset += uint16(length) + return nil +} + +func (m *Parser) Name() (ParserName, error) { + name := ParserName{ + m: m, + nameStart: m.curOffset, + } + + offset, err := name.unpack() + if err != nil { + return ParserName{}, err + } + + m.curOffset += offset + return name, nil +} + +func (m *Parser) RawResource(length uint16) ([]byte, error) { + if len(m.msg[m.curOffset:]) < int(length) { + return nil, errInvalidDNSMessage + } + + msg := m.msg[m.curOffset : m.curOffset+length] + m.curOffset += length + + return msg, nil +} + +func (m *Parser) ResourceA(length uint16) (ResourceA, error) { + if length != 4 || m.availMsgData() < 4 { + return ResourceA{}, errInvalidDNSMessage + } + + m.curOffset += 4 + return ResourceA{ + A: *(*[4]byte)(m.msg[m.curOffset-4 : m.curOffset]), + }, nil +} + +func (m *Parser) ResourceAAAA(length uint16) (ResourceAAAA, error) { + if length != 16 || m.availMsgData() < 16 { + return ResourceAAAA{}, errInvalidDNSMessage + } + + m.curOffset += 16 + return ResourceAAAA{ + AAAA: *(*[16]byte)(m.msg[m.curOffset-16 : m.curOffset]), + }, nil +} + +func (m *Parser) ResourceCNAME(RDLength uint16) (ResourceCNAME[ParserName], error) { + r := ResourceCNAME[ParserName]{ + CNAME: ParserName{ + m: m, + nameStart: m.curOffset, + }, + } + + offset, err := r.CNAME.unpack() + if err != nil { + return ResourceCNAME[ParserName]{}, err + } + + if offset != RDLength { + return ResourceCNAME[ParserName]{}, errInvalidDNSMessage + } + + m.curOffset += offset + return r, nil +} + +func (m *Parser) ResourceMX(RDLength uint16) (ResourceMX[ParserName], error) { + r := ResourceMX[ParserName]{ + MX: ParserName{ + m: m, + nameStart: m.curOffset + 2, + }, + } + + if m.availMsgData() < 2 { + return ResourceMX[ParserName]{}, errInvalidDNSMessage + } + + r.Pref = unpackUint16(m.msg[m.curOffset : m.curOffset+2]) + + offset, err := r.MX.unpack() + if err != nil { + return ResourceMX[ParserName]{}, err + } + + if offset != RDLength-2 { + return ResourceMX[ParserName]{}, errInvalidDNSMessage + } + + m.curOffset += offset + 2 + return r, nil +} + +func (m *Parser) ResourceTXT(RDLength uint16) (ResourceTXT, error) { + if len(m.msg[m.curOffset:]) < int(RDLength) { + return ResourceTXT{}, errInvalidDNSMessage + } + + r := ResourceTXT{ + TXT: m.msg[m.curOffset : m.curOffset+RDLength], + } + + for i := 0; i < len(r.TXT); { + i += int(r.TXT[i]) + 1 + if i == len(r.TXT) { + m.curOffset += RDLength + return r, nil + } + } + + return ResourceTXT{}, errInvalidDNSMessage +} + +const ptrLoopCount = 16 + +type ParserName struct { + m *Parser + nameStart uint16 + rawLen uint8 +} + +func (ParserName) name() {} + +func (m *ParserName) RawLen() uint8 { + return m.rawLen +} + +// unpack parses the name, m.m and m.nameStart must be set accordingly +// before calling this method. +func (m *ParserName) unpack() (uint16, error) { + var ( + // length of the raw name, without compression pointers. + rawNameLen = uint16(0) + + // message offset, length up to the first compression pointer (if any, including it). + offset = uint16(0) + + ptrCount = uint8(0) + ) + + for i := int(m.nameStart); i < len(m.m.msg); { + // Compression pointer + if m.m.msg[i]&0xC0 == 0xC0 { + if ptrCount++; ptrCount > ptrLoopCount { + return 0, errPtrLoop + } + + if offset == 0 { + offset = rawNameLen + 2 + } + + // Compression pointer is 2 bytes long. + if len(m.m.msg) == int(i)+1 { + return 0, errInvalidDNSName + } + + i = int(uint16(m.m.msg[i]^0xC0)<<8 | uint16(m.m.msg[i+1])) + continue + } + + // Two leading bits are reserved, except for compression pointer (above). + if m.m.msg[i]&0xC0 != 0 { + return 0, errInvalidDNSName + } + + if rawNameLen++; rawNameLen > maxEncodedNameLen { + return 0, errInvalidDNSName + } + + if m.m.msg[i] == 0 { + if offset == 0 { + offset = rawNameLen + } + m.rawLen = uint8(rawNameLen) + return offset, nil + } + + rawNameLen += uint16(m.m.msg[i]) + i += int(m.m.msg[i]) + 1 + } + + return 0, errInvalidDNSName +} + +// Equal reports whether m and m2 represents the same name. +// It does not require identical internal representation of the name. +// Letters are compared in a case insensitive manner. +// m an m2 might be created using two different parsers. +func (m *ParserName) Equal(m2 *ParserName) bool { + im1 := m.nameStart + im2 := m2.nameStart + + for { + // Resolve all compression pointers of m + for m.m.msg[im1]&0xC0 == 0xC0 { + im1 = uint16(m.m.msg[im1]^0xC0)<<8 | uint16(m.m.msg[im1+1]) + } + + // Resolve all compression pointers of m2 + for m2.m.msg[im2]&0xC0 == 0xC0 { + im2 = uint16(m2.m.msg[im2]^0xC0)<<8 | uint16(m2.m.msg[im2+1]) + } + + // if we point to the same location in the same parser, then it is equal. + if m.m == m2.m && im1 == im2 { + return true + } + + // different label lengths + if m.m.msg[im1] != m2.m.msg[im2] { + return false + } + + if m.m.msg[im1] == 0 { + return true + } + + if !caseInsensitiveEqual(m.m.msg[im1+1:im1+1+uint16(m.m.msg[im1])], m2.m.msg[im2+1:im2+1+uint16(m2.m.msg[im2])]) { + return false + } + + im1 += uint16(m.m.msg[im1]) + 1 + im2 += uint16(m2.m.msg[im2]) + 1 + } +} + +// Equal reports whether m and m2 represents the same name. +func (m *ParserName) EqualName(m2 Name) bool { + im1 := m.nameStart + nameOffset := 0 + + for { + // Resolve all compression pointers of m + for m.m.msg[im1]&0xC0 == 0xC0 { + im1 = uint16(m.m.msg[im1]^0xC0)<<8 | uint16(m.m.msg[im1+1]) + } + + labelLength := m.m.msg[im1] + + if labelLength == 0 { + return len(m2.name) == nameOffset || ((len(m2.name)-nameOffset) == 1 && m2.name[nameOffset] == '.') + } + + im1++ + for _, v := range m.m.msg[im1 : im1+uint16(labelLength)] { + if len(m2.name)-nameOffset == 0 { + return false + } + + char := m2.name[nameOffset] + nameOffset++ + if char == '\\' { + char = m2.name[nameOffset] + nameOffset++ + if isDigit(char) { + char, _ = decodeDDD([3]byte{char, m2.name[nameOffset], m2.name[nameOffset+1]}) + nameOffset += 2 + } + } + + if !equalASCIICaseInsensitive(char, v) { + return false + } + } + + if len(m2.name)-nameOffset != 0 { + if m2.name[nameOffset] != '.' { + return false + } + nameOffset++ + } + + im1 += uint16(labelLength) + } +} + +// len(a) must be caseInsensitiveEqual to len(b) +func caseInsensitiveEqual(a []byte, b []byte) bool { + for i := 0; i < len(a); i++ { + if !equalASCIICaseInsensitive(a[i], b[i]) { + return false + } + } + return true +} + +func equalASCIICaseInsensitive(a, b byte) bool { + const caseDiff = 'a' - 'A' + + if a >= 'a' && a <= 'z' { + a -= caseDiff + } + + if b >= 'a' && b <= 'z' { + b -= caseDiff + } + + return a == b +} + +// String returns the human name encoding of m. Dots inside the label +// (not separating labels) are escaped as '\.', slashes are encoded as '\\', +// other octets not in range (including) 0x21 through 0xFE are encoded using the \DDD syntax. +func (m *ParserName) String() string { + builder := strings.Builder{} + builder.Grow(int(m.RawLen() - 1)) + + i := m.nameStart + for { + if m.m.msg[i]&0xC0 == 0xC0 { + i = uint16(m.m.msg[i]^0xC0)<<8 | uint16(m.m.msg[i+1]) + continue + } + + if m.m.msg[i] == 0 { + if builder.Len() == 0 { + builder.WriteByte('.') + } + return builder.String() + } + + for _, v := range m.m.msg[i+1 : i+uint16(m.m.msg[i])+1] { + switch { + case v == '.': + builder.WriteString("\\.") + case v == '\\': + builder.WriteString("\\\\") + case v < '!' || v > '~': + builder.WriteByte('\\') + builder.Write(toASCIIDecimal(v)) + default: + builder.WriteByte(v) + } + } + + builder.WriteByte('.') + i += uint16(m.m.msg[i]) + 1 + } +} + +func toASCIIDecimal(v byte) []byte { + var d [3]byte + tmp := v / 100 + v -= tmp * 100 + d[0] = tmp + '0' + tmp = v / 10 + v -= tmp * 10 + d[1] = tmp + '0' + d[2] = v + '0' + return d[:] +} + +func (m *ParserName) appendRawName(raw []byte) []byte { + i := m.nameStart + for { + if m.m.msg[i]&0xC0 == 0xC0 { + i = uint16(m.m.msg[i]^0xC0)<<8 | uint16(m.m.msg[i+1]) + continue + } + + if m.m.msg[i] == 0 { + return append(raw, 0) + } + + raw = append(raw, m.m.msg[i:i+uint16(m.m.msg[i])+1]...) + i += uint16(m.m.msg[i]) + 1 + } +} diff --git a/parser_test.go b/parser_test.go new file mode 100644 index 0000000..fb2e500 --- /dev/null +++ b/parser_test.go @@ -0,0 +1,370 @@ +package dnsmsg + +import ( + "bytes" + "fmt" + "testing" +) + +var nameUnpackTests = []struct { + name string + + msg []byte + nameStart uint16 + + err error + offset uint8 + rawLen uint8 +}{ + {name: "valid go.dev", msg: []byte{2, 'g', 'o', 3, 'd', 'e', 'v', 0}, offset: 8, rawLen: 8}, + {name: "nameStart 2 valid go.dev", nameStart: 2, msg: []byte{32, 3, 2, 'g', 'o', 3, 'd', 'e', 'v', 0}, offset: 8, rawLen: 8}, + {name: "nameStart 2 junk after name valid go.dev", nameStart: 2, msg: []byte{32, 3, 2, 'g', 'o', 3, 'd', 'e', 'v', 0, 2, 66, 66, 0}, offset: 8, rawLen: 8}, + {name: "www.go.dev", msg: []byte{3, 'w', 'w', 'w', 2, 'g', 'o', 3, 'd', 'e', 'v', 0}, offset: 12, rawLen: 12}, + {name: "www.go.dev ptr forward", msg: []byte{3, 'w', 'w', 'w', 0xC0, 10, 2, 2, 1, 1, 2, 'g', 'o', 3, 'd', 'e', 'v', 0}, offset: 6, rawLen: 12}, + {name: "www.go.dev ptr forward with junk", nameStart: 3, msg: []byte{2, 1, 1, 3, 'w', 'w', 'w', 0xC0, 13, 2, 2, 1, 1, 2, 'g', 'o', 3, 'd', 'e', 'v', 0, 2, 22, 33}, offset: 6, rawLen: 12}, + {name: "www.go.dev ptr backwards", nameStart: 11, msg: []byte{2, 'g', 'o', 3, 'd', 'e', 'v', 0, 2, 1, 1, 3, 'w', 'w', 'w', 0xC0, 0}, offset: 6, rawLen: 12}, + {name: "www.go.dev ptr backwards with junk", nameStart: 14, msg: []byte{2, 1, 1, 2, 'g', 'o', 3, 'd', 'e', 'v', 0, 2, 1, 1, 3, 'w', 'w', 'w', 0xC0, 3, 2, 22, 22}, offset: 6, rawLen: 12}, + { + name: "255B", + msg: func() []byte { + var buf []byte + a63 := bytes.Repeat([]byte{'a'}, 63) + a61 := bytes.Repeat([]byte{'a'}, 61) + + for i := 0; i < 3; i++ { + buf = append(buf, byte(len(a63))) + buf = append(buf, a63...) + } + + buf = append(buf, byte(len(a61))) + buf = append(buf, a61...) + buf = append(buf, 0) + + if len(buf) != 255 { + panic("invalid name") + } + + return buf + }(), + offset: 255, + rawLen: 255, + }, + { + name: "255B with one compression pointer", + msg: func() []byte { + var buf []byte + a63 := bytes.Repeat([]byte{'a'}, 63) + z61 := bytes.Repeat([]byte{'z'}, 61) + + buf = append(buf, byte(len(z61))) + buf = append(buf, z61...) + buf = append(buf, 0xC0, byte(len(buf))+4) + + buf = append(buf, 32, 32) // random data + + for i := 0; i < 3; i++ { + buf = append(buf, byte(len(a63))) + buf = append(buf, a63...) + } + buf = append(buf, 0) + + // +4 (pointer and random data in between") + if len(buf) != 255+4 { + panic("invalid name") + } + + return buf + }(), + offset: 64, + rawLen: 255, + }, + { + name: "256B", + msg: func() []byte { + var buf []byte + a63 := bytes.Repeat([]byte{'a'}, 63) + a62 := bytes.Repeat([]byte{'a'}, 62) + + for i := 0; i < 3; i++ { + buf = append(buf, byte(len(a63))) + buf = append(buf, a63...) + } + + buf = append(buf, byte(len(a62))) + buf = append(buf, a62...) + buf = append(buf, 0) + + if len(buf) != 256 { + panic("invalid name") + } + + return buf + }(), + err: errInvalidDNSName, + }, + { + name: "256B with one compression pointer", + msg: func() []byte { + var buf []byte + a63 := bytes.Repeat([]byte{'a'}, 63) + z62 := bytes.Repeat([]byte{'z'}, 62) + + buf = append(buf, byte(len(z62))) + buf = append(buf, z62...) + buf = append(buf, 0xC0, byte(len(buf))+4) + + buf = append(buf, 32, 32) // random data + + for i := 0; i < 3; i++ { + buf = append(buf, byte(len(a63))) + buf = append(buf, a63...) + } + + buf = append(buf, 0) + + // +4 (pointer and random data in between") + if len(buf) != 256+4 { + panic("invalid name") + } + + return buf + }(), + err: errInvalidDNSName, + }, + + {name: "smaller name than label length", msg: []byte{3, 'w', 'w', 'w', 2, 'g', 'o', 5, 'd', 'e', 'v', 0}, err: errInvalidDNSName}, + {name: "missing root label", msg: []byte{3, 'w', 'w', 'w', 2, 'g', 'o', 3, 'd', 'e', 'v'}, err: errInvalidDNSName}, + {name: "pointer loop 1", msg: []byte{3, 'w', 'w', 'w', 0xC0, 0}, err: errPtrLoop}, + {name: "pointer loop 2", nameStart: 2, msg: []byte{32, 32, 0xC0, 2, 32}, err: errPtrLoop}, + {name: "reserved label bit 2", msg: []byte{0b10000000}, err: errInvalidDNSName}, + {name: "reserved label bit 1", msg: []byte{0b01000000}, err: errInvalidDNSName}, +} + +func TestParserNameUnpack(t *testing.T) { + for _, v := range nameUnpackTests { + t.Run(v.name, func(t *testing.T) { + msg, err := NewParser(v.msg) + if err != nil { + t.Fatalf("unexpected NewParser() error: %v", err) + } + + m := ParserName{m: &msg, nameStart: v.nameStart} + + offset, err := m.unpack() + if err != v.err { + t.Fatalf("got err: %v, expected: %v", err, v.err) + } + + if offset != uint16(v.offset) { + t.Fatalf("got offset: %v, expected: %v", offset, v.offset) + } + + if rawLen := m.RawLen(); rawLen != v.rawLen { + t.Fatalf("got RawLen: %v, expected: %v", v.rawLen, rawLen) + } + }) + } +} + +func FuzzParserNameUnpack(f *testing.F) { + for _, v := range nameUnpackTests { + f.Add(v.nameStart, v.msg) + } + f.Fuzz(func(_ *testing.T, nameStart uint16, buf []byte) { + msg, err := NewParser(buf) + if err != nil { + return + } + m := ParserName{m: &msg, nameStart: nameStart} + m.unpack() + }) +} + +func prepNameSameMsg(buf []byte, n1Start, n2Start uint16) [2]ParserName { + msg, err := NewParser(buf) + if err != nil { + panic(err) + } + + m1 := ParserName{m: &msg, nameStart: n1Start} + _, err = m1.unpack() + if err != nil { + panic(err) + } + + m2 := ParserName{m: &msg, nameStart: n2Start} + _, err = m2.unpack() + if err != nil { + panic(err) + } + + var n [2]ParserName + n[0] = m1 + n[1] = m2 + return n +} + +func prepNameDifferentMsg(buf1, buf2 []byte, n1Start, n2Start uint16) [2]ParserName { + msg1, err := NewParser(buf1) + if err != nil { + panic(err) + } + + msg2, err := NewParser(buf2) + if err != nil { + panic(err) + } + + m1 := ParserName{m: &msg1, nameStart: n1Start} + _, err = m1.unpack() + if err != nil { + panic(err) + } + + m2 := ParserName{m: &msg2, nameStart: n2Start} + _, err = m2.unpack() + if err != nil { + panic(err) + } + + var n [2]ParserName + n[0] = m1 + n[1] = m2 + return n +} + +var nameEqualTests = []struct { + name string + + names [2]ParserName + equal bool +}{ + { + name: "(same msg) same nameStart", + names: prepNameSameMsg([]byte{ + 2, 'g', 'o', 3, 'd', 'e', 'v', 0, + }, 0, 0), + equal: true, + }, + + { + name: "(same msg) second name directly points to first name", + names: prepNameSameMsg([]byte{ + 2, 'g', 'o', 3, 'd', 'e', 'v', 0, + 0xC0, 0, + }, 0, 8), + equal: true, + }, + + { + name: "(same msg) two separate names, without compression pointers", + names: prepNameSameMsg([]byte{ + 2, 'g', 'o', 3, 'd', 'e', 'v', 0, + 2, 'g', 'o', 3, 'd', 'e', 'v', 0, + }, 0, 8), + equal: true, + }, + + { + name: "(same msg) two separate names without compression pointers with different letter case", + names: prepNameSameMsg([]byte{ + 2, 'G', 'o', 3, 'd', 'E', 'V', 0, + 2, 'g', 'O', 3, 'D', 'e', 'V', 0, + }, 0, 8), + equal: true, + }, + + { + name: "(same msg) two different names go.dev www.go.dev, no pointers", + names: prepNameSameMsg([]byte{ + 2, 'g', 'o', 3, 'd', 'e', 'v', 0, + 3, 'w', 'w', 'w', 2, 'g', 'o', 3, 'd', 'e', 'v', 0, + }, 0, 8), + equal: false, + }, + + { + name: "(same msg) two different names go.dev go.go.dev, no pointers", + names: prepNameSameMsg([]byte{ + 2, 'g', 'o', 3, 'd', 'e', 'b', 0, + 2, 'g', 'o', 2, 'g', 'o', 3, 'd', 'e', 'v', 0, + }, 0, 8), + equal: false, + }, + + { + name: "(same msg) two different names go.dev www.go.dev with pointers", + names: prepNameSameMsg([]byte{ + 2, 'G', 'o', 3, 'd', 'R', 'V', 0, + 3, 'w', 'w', 'w', 0xC0, 0, + }, 0, 8), + equal: false, + }, + + { + name: "(different msgs) same name no pointers", + names: prepNameDifferentMsg([]byte{ + 2, 'g', 'o', 3, 'd', 'e', 'v', 0, + }, []byte{ + 2, 'g', 'o', 3, 'd', 'e', 'v', 0, + }, 0, 0), + equal: true, + }, + + { + name: "(different msgs) same names, different letter case, no pointers", + names: prepNameDifferentMsg([]byte{ + 2, 'G', 'o', 3, 'd', 'E', 'V', 0, + }, []byte{ + 2, 'G', 'O', 3, 'D', 'e', 'v', 0, + }, 0, 0), + equal: true, + }, + + { + name: "(different msgs) different names, no pointers", + names: prepNameDifferentMsg([]byte{ + 2, 'g', 'o', 3, 'd', 'e', 'v', 0, + }, []byte{ + 2, 'g', 'o', 2, 'g', 'o', 3, 'd', 'e', 'v', 0, + }, 0, 0), + equal: false, + }, + + { + name: "(different msgs) same name with pointers", + names: prepNameDifferentMsg([]byte{ + 2, 'g', 'o', 3, 'd', 'e', 'v', 0, + }, []byte{ + 3, 'd', 'e', 'v', 0, 2, 'g', 'o', 0xC0, 0, + }, 0, 5), + equal: true, + }, + + { + name: "(different msgs) different names with pointers", + names: prepNameDifferentMsg([]byte{ + 2, 'g', 'o', 3, 'd', 'e', 'v', 0, + }, []byte{ + 3, 'd', 'e', 'v', 0, 2, 'g', 'o', 2, 'g', 'o', 0xC0, 0, + }, 0, 5), + equal: false, + }, +} + +func TestNameEqual(t *testing.T) { + for i, v := range nameEqualTests { + for ti, tv := range []string{"n[0].Equal(n[1])", "n[1].Equal(n[0])"} { + prefix := fmt.Sprintf("%v: %v: %v:", i, v.name, tv) + + names := v.names + if ti == 1 { + names[0], names[1] = v.names[1], v.names[0] + } + + if eq := names[0].Equal(&names[1]); eq != v.equal { + t.Errorf("%v expected: %v, but: %v", prefix, v.equal, eq) + } + } + } +} diff --git a/types.go b/types.go new file mode 100644 index 0000000..8249da9 --- /dev/null +++ b/types.go @@ -0,0 +1,177 @@ +package dnsmsg + +type EDNS0 struct { + Payload uint16 +} + +type name interface { + name() +} + +type Type uint16 + +const ( + TypeA Type = 1 + TypeNS Type = 2 + TypeCNAME Type = 5 + TypeSOA Type = 6 + TypePTR Type = 12 + TypeMX Type = 15 + TypeTXT Type = 16 + TypeAAAA Type = 28 + TypeOPT Type = 41 +) + +type Class uint16 + +const ( + ClassIN Class = 1 +) + +type Bit uint8 + +const ( + BitAA Bit = 10 + BitTC Bit = 9 + BitRD Bit = 8 + BitRA Bit = 7 + BitAD Bit = 5 + BitCD Bit = 4 +) + +type OpCode uint8 + +const ( + OpCodeQuery OpCode = 0 +) + +type RCode uint8 + +const ( + RCodeSuccess RCode = 0 +) + +type Flags uint16 + +const bitQR = 1 << 15 + +func (f Flags) Query() bool { + return f&bitQR == 0 +} + +func (f Flags) Response() bool { + return f&bitQR != 0 +} + +func (f Flags) Bit(bit Bit) bool { + return f&(1<> 11) & 0b1111) +} + +func (f Flags) RCode() RCode { + return RCode(f & 0b1111) +} + +func (f *Flags) SetQuery() { + *f &= ^Flags(bitQR) // zero the QR bit +} + +func (f *Flags) SetResponse() { + *f |= bitQR +} + +func (f *Flags) SetBit(bit Bit, val bool) { + *f &= ^Flags(1 << bit) // zero bit + if !val { + return + } + *f |= (1 << bit) +} + +func (f *Flags) SetOpCode(o OpCode) { + *f &= ^Flags(0b1111 << 11) // zero the opcode bits + *f |= Flags(o) << 11 +} + +func (f *Flags) SetRCode(r RCode) { + *f &= ^Flags(0b1111) // zero the rcode bits + *f |= Flags(r) +} + +const headerLen = 12 + +type Header struct { + ID uint16 + Flags Flags + QDCount uint16 + ANCount uint16 + NSCount uint16 + ARCount uint16 +} + +func (h *Header) unpack(msg [headerLen]byte) { + h.ID = unpackUint16(msg[:2]) + h.Flags = Flags(unpackUint16(msg[2:4])) + h.QDCount = unpackUint16(msg[4:6]) + h.ANCount = unpackUint16(msg[6:8]) + h.NSCount = unpackUint16(msg[8:10]) + h.ARCount = unpackUint16(msg[10:12]) +} + +type Question[T name] struct { + Name T + Type Type + Class Class +} + +type ResourceHeader[T name] struct { + Name T + Type Type + Class Class + TTL uint32 + Length uint16 +} + +type ResourceA struct { + A [4]byte +} + +type ResourceNS[T name] struct { + NS T +} + +type ResourceCNAME[T name] struct { + CNAME T +} + +type ResourceSOA[T name] struct { + NS T + Mbox T + Serial uint32 + Refresh uint32 + Retry uint32 + Expire uint32 + Minimum uint32 +} + +type ResourcePTR[T name] struct { + PTR T +} + +type ResourceMX[T name] struct { + MX T + Pref uint16 +} + +type ResourceTXT struct { + // TXT is as defined by RFC 1035 a "One or more s" + // so it is a one or more byte-length prefixed data + TXT []byte +} + +type ResourceAAAA struct { + AAAA [16]byte +}