diff --git a/decoder_test.go b/decoder_test.go new file mode 100644 index 0000000..8b50933 --- /dev/null +++ b/decoder_test.go @@ -0,0 +1,96 @@ +package hrt + +import ( + "net/http" + "net/url" + "reflect" + "testing" + "time" +) + +func TestURLDecoder(t *testing.T) { + type Mega struct { + String string `form:"string"` + Number float64 `form:"number"` + Integer int `form:"integer"` + Time time.Time `form:"time"` + OptString *string `form:"optstring"` + OptNumber *float64 `form:"optnumber"` + OptInteger *int `form:"optinteger"` + OptTime *time.Time `form:"opttime"` + } + + tests := []struct { + name string + input url.Values + expect result[Mega] + }{ + { + name: "only required fields", + input: url.Values{ + "string": {"hello"}, + "number": {"3.14"}, + "integer": {"42"}, + "time": {"2021-01-01T00:00:00Z"}, + }, + expect: okResult(Mega{ + String: "hello", + Number: 3.14, + Integer: 42, + Time: time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC), + }), + }, + { + name: "only optional fields", + input: url.Values{ + "optstring": {"world"}, + "optnumber": {"2.71"}, + "optinteger": {"24"}, + "opttime": {"2020-01-01T00:00:00Z"}, + }, + expect: okResult(Mega{ + OptString: ptrTo("world"), + OptNumber: ptrTo(2.71), + OptInteger: ptrTo(24), + OptTime: ptrTo(time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)), + }), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + req := &http.Request{ + Form: test.input, + } + + var got Mega + err := URLDecoder.Decode(req, &got) + res := combineResult(got, err) + + if !reflect.DeepEqual(test.expect, res) { + t.Errorf("unexpected test result:\n"+ + "expected: %v\n"+ + "got: %v\n", test.expect, res) + } + }) + } +} + +type result[T any] struct { + value T + error string +} + +func okResult[T any](value T) result[T] { + return result[T]{value: value} +} + +func combineResult[T any](value T, err error) result[T] { + res := result[T]{value: value} + if err != nil { + res.error = err.Error() + } + return res +} + +func ptrTo[T any](v T) *T { return &v } diff --git a/internal/rfutil/rfutil.go b/internal/rfutil/rfutil.go index 20b3261..444aa3c 100644 --- a/internal/rfutil/rfutil.go +++ b/internal/rfutil/rfutil.go @@ -2,12 +2,15 @@ package rfutil import ( + "encoding" "reflect" "strconv" "github.com/pkg/errors" ) +var textUnmarshalerType = reflect.TypeFor[encoding.TextUnmarshaler]() + // SetPrimitiveFromString sets the value of a primitive type from a string. It // supports strings, ints, uints, floats and bools. If s is empty, the value is // left untouched. @@ -16,9 +19,18 @@ func SetPrimitiveFromString(rf reflect.Type, rv reflect.Value, s string) error { return nil } + if rf.Kind() == reflect.Ptr { + rf = rf.Elem() + + newValue := reflect.New(rf) + rv.Set(newValue) + rv = newValue.Elem() + } + switch rf.Kind() { case reflect.String: rv.SetString(s) + return nil case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: i, err := strconv.ParseInt(s, 10, rf.Bits()) @@ -26,6 +38,7 @@ func SetPrimitiveFromString(rf reflect.Type, rv reflect.Value, s string) error { return errors.Wrap(err, "invalid int") } rv.SetInt(i) + return nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: i, err := strconv.ParseUint(s, 10, rf.Bits()) @@ -33,6 +46,7 @@ func SetPrimitiveFromString(rf reflect.Type, rv reflect.Value, s string) error { return errors.Wrap(err, "invalid uint") } rv.SetUint(i) + return nil case reflect.Float32, reflect.Float64: f, err := strconv.ParseFloat(s, rf.Bits()) @@ -40,10 +54,19 @@ func SetPrimitiveFromString(rf reflect.Type, rv reflect.Value, s string) error { return errors.Wrap(err, "invalid float") } rv.SetFloat(f) + return nil case reflect.Bool: // False means omitted according to MDN. rv.SetBool(s != "") + return nil + } + + if reflect.PointerTo(rf).Implements(textUnmarshalerType) { + unmarshaler := rv.Addr().Interface().(encoding.TextUnmarshaler) + if err := unmarshaler.UnmarshalText([]byte(s)); err != nil { + return errors.Wrap(err, "text unmarshaling") + } } return nil