From 10304c2ec28b71ffe49b0fd72a17ae557d62f726 Mon Sep 17 00:00:00 2001 From: diamondburned Date: Tue, 28 Feb 2023 03:33:55 -0800 Subject: [PATCH] Initial commit --- LICENSE | 13 +++++ decoder.go | 136 +++++++++++++++++++++++++++++++++++++++++++++++++++++ encoder.go | 98 ++++++++++++++++++++++++++++++++++++++ error.go | 70 +++++++++++++++++++++++++++ go.mod | 8 ++++ go.sum | 4 ++ hrt.go | 105 +++++++++++++++++++++++++++++++++++++++++ 7 files changed, 434 insertions(+) create mode 100644 LICENSE create mode 100644 decoder.go create mode 100644 encoder.go create mode 100644 error.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 hrt.go diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..8a557a2 --- /dev/null +++ b/LICENSE @@ -0,0 +1,13 @@ +Copyright 2023 diamondburned + +Permission to use, copy, modify, and/or distribute this software for any purpose +with or without fee is hereby granted, provided that the above copyright notice +and this permission notice appear in all copies. + +THE SOFTWARE IS PROVIDED “AS IS” AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH +REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND +FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, +INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS +OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER +TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF +THIS SOFTWARE. diff --git a/decoder.go b/decoder.go new file mode 100644 index 0000000..804fca1 --- /dev/null +++ b/decoder.go @@ -0,0 +1,136 @@ +package hrt + +import ( + "net/http" + "reflect" + "strconv" + + "github.com/go-chi/chi/v5" + "github.com/pkg/errors" +) + +// Decoder describes a decoder that decodes the request type. +type Decoder interface { + // Decode decodes the given value from the given reader. + Decode(*http.Request, any) error +} + +// MethodDecoder is an encoder that only encodes or decodes if the request +// method matches the methods in it. +type MethodDecoder map[string]Decoder + +// Decode implements the Decoder interface. +func (e MethodDecoder) Decode(r *http.Request, v any) error { + dec, ok := e[r.Method] + if !ok { + dec, ok = e["*"] + } + if !ok { + return WrapHTTPError(http.StatusMethodNotAllowed, errors.New("method not allowed")) + } + return dec.Decode(r, v) +} + +// URLDecoder decodes chi.URLParams and url.Values into a struct. It only does +// Decoding; the Encode method is a no-op. The decoder makes no effort to +// traverse the struct and decode nested structs. If neither a chi.URLParam nor +// a url.Value is found for a field, the field is left untouched. +// +// For the sake of supporting code generators, the decoder also reads the `json` +// tag if the `url` tag is not present. +// +// # Example +// +// The following Go type would be decoded to have 2 URL parameters: +// +// type Data struct { +// ID string +// Num int `url:"num"` +// Nested struct { +// ID string +// } +// } +// +var URLDecoder Decoder = urlDecoder{} + +type urlDecoder struct{} + +func (d urlDecoder) Decode(r *http.Request, v any) error { + rv := reflect.Indirect(reflect.ValueOf(v)) + if !rv.IsValid() { + return errors.New("invalid value") + } + + if rv.Kind() != reflect.Struct { + return errors.New("value is not a struct") + } + + rt := rv.Type() + nfields := rv.NumField() + + for i := 0; i < nfields; i++ { + rfv := rv.Field(i) + rft := rt.Field(i) + if !rft.IsExported() { + continue + } + + var name string + if tag := rft.Tag.Get("json"); tag != "" { + name = tag + } else if tag := rft.Tag.Get("url"); tag != "" { + name = tag + } else { + name = rft.Name + } + + value := chi.URLParam(r, name) + if value == "" { + value = r.FormValue(name) + } + if value == "" { + continue + } + + setPrimitiveFromString(rfv.Type(), rfv, value) + } + + return nil +} + +func setPrimitiveFromString(rf reflect.Type, rv reflect.Value, s string) error { + switch rf.Kind() { + case reflect.String: + rv.SetString(s) + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + i, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return errors.Wrap(err, "invalid int") + } + rv.SetInt(i) + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + i, err := strconv.ParseUint(s, 10, 64) + if err != nil { + return errors.Wrap(err, "invalid uint") + } + rv.SetUint(i) + + case reflect.Float32, reflect.Float64: + f, err := strconv.ParseFloat(s, 64) + if err != nil { + return errors.Wrap(err, "invalid float") + } + rv.SetFloat(f) + + case reflect.Bool: + b, err := strconv.ParseBool(s) + if err != nil { + return errors.Wrap(err, "invalid bool") + } + rv.SetBool(b) + } + + return nil +} diff --git a/encoder.go b/encoder.go new file mode 100644 index 0000000..1aac4d8 --- /dev/null +++ b/encoder.go @@ -0,0 +1,98 @@ +package hrt + +import ( + "encoding/json" + "net/http" +) + +// DefaultEncoder is the default encoder used by the router. It decodes GET +// requests using the query string and URL parameter; everything else uses JSON. +var DefaultEncoder = CombinedEncoder{ + Encoder: JSONEncoder, + Decoder: MethodDecoder{ + // For the sake of being RESTful, we use a URLDecoder for GET requests. + "GET": URLDecoder, + // Everything else will be decoded as JSON. + "*": JSONEncoder, + }, +} + +// Encoder describes an encoder that encodes or decodes the request and response +// types. +type Encoder interface { + // Encode encodes the given value into the given writer. + Encode(http.ResponseWriter, any) error + // An encoder must be able to decode the same type it encodes. + Decoder +} + +// CombinedEncoder combines an encoder and decoder pair into one. +type CombinedEncoder struct { + Encoder Encoder + Decoder Decoder +} + +// Encode implements the Encoder interface. +func (e CombinedEncoder) Encode(w http.ResponseWriter, v any) error { + return e.Encoder.Encode(w, v) +} + +// Decode implements the Decoder interface. +func (e CombinedEncoder) Decode(r *http.Request, v any) error { + return e.Decoder.Decode(r, v) +} + +// JSONEncoder is an encoder that encodes and decodes JSON. +var JSONEncoder Encoder = jsonEncoder{} + +type jsonEncoder struct{} + +func (e jsonEncoder) Encode(w http.ResponseWriter, v any) error { + w.Header().Set("Content-Type", "application/json") + return json.NewEncoder(w).Encode(v) +} + +func (e jsonEncoder) Decode(r *http.Request, v any) error { + return json.NewDecoder(r.Body).Decode(v) +} + +// Validator describes a type that can validate itself. +type Validator interface { + Validate() error +} + +// EncoderWithValidator wraps an encoder with one that calls Validate() on the +// value after decoding and before encoding if the value implements Validator. +func EncoderWithValidator(enc Encoder) Encoder { + return validatorEncoder{enc} +} + +type validatorEncoder struct{ enc Encoder } + +func (e validatorEncoder) Encode(w http.ResponseWriter, v any) error { + if validator, ok := v.(Validator); ok { + if err := validator.Validate(); err != nil { + return err + } + } + + if err := e.enc.Encode(w, v); err != nil { + return err + } + + return nil +} + +func (e validatorEncoder) Decode(r *http.Request, v any) error { + if err := e.enc.Decode(r, v); err != nil { + return err + } + + if validator, ok := v.(Validator); ok { + if err := validator.Validate(); err != nil { + return err + } + } + + return nil +} diff --git a/error.go b/error.go new file mode 100644 index 0000000..be52137 --- /dev/null +++ b/error.go @@ -0,0 +1,70 @@ +package hrt + +import ( + "errors" + "fmt" + "net/http" +) + +// HTTPError extends the error interface with an HTTP status code. +type HTTPError interface { + error + HTTPStatus() int +} + +// ErrorHTTPStatus returns the HTTP status code for the given error. If the +// error is not an HTTPError, it returns defaultCode. +func ErrorHTTPStatus(err error, defaultCode int) int { + var httpErr HTTPError + if errors.As(err, &httpErr) { + return httpErr.HTTPStatus() + } + return defaultCode +} + +type wrappedHTTPError struct { + code int + err error +} + +// WrapHTTPError wraps an error with an HTTP status code. +func WrapHTTPError(code int, err error) HTTPError { + return wrappedHTTPError{code, err} +} + +func (e wrappedHTTPError) HTTPStatus() int { + return e.code +} + +func (e wrappedHTTPError) Error() string { + return fmt.Sprintf("error status %d: %s", e.code, e.err) +} + +func (e wrappedHTTPError) Unwrap() error { + return e.err +} + +// ErrorWriter is a writer that writes an error to the response. +type ErrorWriter interface { + WriteError(w http.ResponseWriter, err error) +} + +// WriteErrorFunc is a function that implements the ErrorWriter interface. +type WriteErrorFunc func(w http.ResponseWriter, err error) + +// WriteError implements the ErrorWriter interface. +func (f WriteErrorFunc) WriteError(w http.ResponseWriter, err error) { + f(w, err) +} + +// TextErrorWriter writes the error into the response in plain text. 500 +// status code is used by default. +var TextErrorWriter ErrorWriter = textErrorWriter{} + +type textErrorWriter struct{} + +func (textErrorWriter) WriteError(w http.ResponseWriter, err error) { + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(ErrorHTTPStatus(err, http.StatusInternalServerError)) + fmt.Fprintln(w, err) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..853e469 --- /dev/null +++ b/go.mod @@ -0,0 +1,8 @@ +module github.com/diamondburned/hrt + +go 1.18 + +require ( + github.com/go-chi/chi/v5 v5.0.8 + github.com/pkg/errors v0.9.1 +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..0e66dc9 --- /dev/null +++ b/go.sum @@ -0,0 +1,4 @@ +github.com/go-chi/chi/v5 v5.0.8 h1:lD+NLqFcAi1ovnVZpsnObHGW4xb4J8lNmoYVfECH1Y0= +github.com/go-chi/chi/v5 v5.0.8/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= diff --git a/hrt.go b/hrt.go new file mode 100644 index 0000000..b06f53c --- /dev/null +++ b/hrt.go @@ -0,0 +1,105 @@ +// Package hrt implements a type-safe HTTP router. It aids in creating a uniform +// API interface while making it easier to create API handlers. +// +// HRT stands for (H)TTP (r)outer with (t)ypes. +package hrt + +import ( + "context" + "net/http" +) + +type ctxKey uint8 + +const ( + routerOptsCtxKey ctxKey = iota + requestCtxKey +) + +// RequestFromContext returns the request from the Handler's context. +func RequestFromContext(ctx context.Context) *http.Request { + return ctx.Value(requestCtxKey).(*http.Request) +} + +// Opts contains options for the router. +type Opts struct { + Encoder Encoder + ErrorWriter ErrorWriter +} + +// DefaultOpts is the default options for the router. +var DefaultOpts = Opts{ + Encoder: DefaultEncoder, + ErrorWriter: TextErrorWriter, +} + +// OptsFromContext returns the options from the Handler's context. DefaultOpts +// is returned if no options are found. +func OptsFromContext(ctx context.Context) Opts { + opts, ok := ctx.Value(routerOptsCtxKey).(Opts) + if ok { + return opts + } + return DefaultOpts +} + +// WithOpts returns a new context with the given options. +func WithOpts(ctx context.Context, opts Opts) context.Context { + return context.WithValue(ctx, routerOptsCtxKey, opts) +} + +// Use creates a middleware that injects itself into each request's context. +func Use(opts Opts) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := WithOpts(r.Context(), opts) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +// None indicates that the request has no body or the request does not return +// anything. +type None struct{} + +// Empty is a value of None. +var Empty = None{} + +// Handler describes a generic handler that takes in a type and returns a +// response. +type Handler[RequestT, ResponseT any] func(ctx context.Context, req RequestT) (ResponseT, error) + +// Wrap wraps a handler into a http.Handler. It exists because Go's type +// inference doesn't work well with the Handler type. +func Wrap[RequestT, ResponseT any](f func(ctx context.Context, req RequestT) (ResponseT, error)) http.HandlerFunc { + return Handler[RequestT, ResponseT](f).ServeHTTP +} + +// ServeHTTP implements the http.Handler interface. +func (h Handler[RequestT, ResponseT]) ServeHTTP(w http.ResponseWriter, r *http.Request) { + var req RequestT + + // Context cycle! Let's go!! + ctx := context.WithValue(r.Context(), requestCtxKey, r) + + opts := OptsFromContext(ctx) + if _, ok := any(req).(None); !ok { + if err := opts.Encoder.Decode(r, &req); err != nil { + opts.ErrorWriter.WriteError(w, WrapHTTPError(http.StatusBadRequest, err)) + return + } + } + + resp, err := h(ctx, req) + if err != nil { + opts.ErrorWriter.WriteError(w, err) + return + } + + if _, ok := any(resp).(None); !ok { + if err := opts.Encoder.Encode(w, resp); err != nil { + opts.ErrorWriter.WriteError(w, WrapHTTPError(http.StatusInternalServerError, err)) + return + } + } +}