Skip to content

Commit

Permalink
Support encoding.TextUnmarshaler for URLDecoder
Browse files Browse the repository at this point in the history
  • Loading branch information
diamondburned committed Jun 19, 2024
1 parent 86ff8f6 commit a938bc3
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 0 deletions.
96 changes: 96 additions & 0 deletions decoder_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package hrt

import (
"net/http"
"net/url"
"reflect"
"testing"
"time"
)

func TestURLDecoder(t *testing.T) {
type Mega struct {
String string `form:"string"`
Number float64 `form:"number"`
Integer int `form:"integer"`
Time time.Time `form:"time"`
OptString *string `form:"optstring"`
OptNumber *float64 `form:"optnumber"`
OptInteger *int `form:"optinteger"`
OptTime *time.Time `form:"opttime"`
}

tests := []struct {
name string
input url.Values
expect result[Mega]
}{
{
name: "only required fields",
input: url.Values{
"string": {"hello"},
"number": {"3.14"},
"integer": {"42"},
"time": {"2021-01-01T00:00:00Z"},
},
expect: okResult(Mega{
String: "hello",
Number: 3.14,
Integer: 42,
Time: time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC),
}),
},
{
name: "only optional fields",
input: url.Values{
"optstring": {"world"},
"optnumber": {"2.71"},
"optinteger": {"24"},
"opttime": {"2020-01-01T00:00:00Z"},
},
expect: okResult(Mega{
OptString: ptrTo("world"),
OptNumber: ptrTo(2.71),
OptInteger: ptrTo(24),
OptTime: ptrTo(time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)),
}),
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
req := &http.Request{
Form: test.input,
}

var got Mega
err := URLDecoder.Decode(req, &got)
res := combineResult(got, err)

if !reflect.DeepEqual(test.expect, res) {
t.Errorf("unexpected test result:\n"+
"expected: %v\n"+
"got: %v\n", test.expect, res)
}
})
}
}

type result[T any] struct {
value T
error string
}

func okResult[T any](value T) result[T] {
return result[T]{value: value}
}

func combineResult[T any](value T, err error) result[T] {
res := result[T]{value: value}
if err != nil {
res.error = err.Error()
}
return res
}

func ptrTo[T any](v T) *T { return &v }
23 changes: 23 additions & 0 deletions internal/rfutil/rfutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
package rfutil

import (
"encoding"
"reflect"
"strconv"

"github.com/pkg/errors"
)

var textUnmarshalerType = reflect.TypeFor[encoding.TextUnmarshaler]()

// SetPrimitiveFromString sets the value of a primitive type from a string. It
// supports strings, ints, uints, floats and bools. If s is empty, the value is
// left untouched.
Expand All @@ -16,34 +19,54 @@ func SetPrimitiveFromString(rf reflect.Type, rv reflect.Value, s string) error {
return nil
}

if rf.Kind() == reflect.Ptr {
rf = rf.Elem()

newValue := reflect.New(rf)
rv.Set(newValue)
rv = newValue.Elem()
}

switch rf.Kind() {
case reflect.String:
rv.SetString(s)
return nil

case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
i, err := strconv.ParseInt(s, 10, rf.Bits())
if err != nil {
return errors.Wrap(err, "invalid int")
}
rv.SetInt(i)
return nil

case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
i, err := strconv.ParseUint(s, 10, rf.Bits())
if err != nil {
return errors.Wrap(err, "invalid uint")
}
rv.SetUint(i)
return nil

case reflect.Float32, reflect.Float64:
f, err := strconv.ParseFloat(s, rf.Bits())
if err != nil {
return errors.Wrap(err, "invalid float")
}
rv.SetFloat(f)
return nil

case reflect.Bool:
// False means omitted according to MDN.
rv.SetBool(s != "")
return nil
}

if reflect.PointerTo(rf).Implements(textUnmarshalerType) {
unmarshaler := rv.Addr().Interface().(encoding.TextUnmarshaler)
if err := unmarshaler.UnmarshalText([]byte(s)); err != nil {
return errors.Wrap(err, "text unmarshaling")
}
}

return nil
Expand Down

0 comments on commit a938bc3

Please sign in to comment.