Skip to content

Commit

Permalink
use a generic wrapper interface Unwrapper
Browse files Browse the repository at this point in the history
This should allow for a proper recovery of the original type

Upgrade the errors package for consistent behavior of Wraps
  • Loading branch information
Greg Weber committed Jun 19, 2024
1 parent 3b9a882 commit 2177897
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 51 deletions.
58 changes: 43 additions & 15 deletions error_code.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,47 +163,75 @@ type unwrapper interface {
Unwrap() error
}

type Unwrapper[T any] interface {
Unwrapped() T
}

type ErrorCodeWrap[Wrapped ErrorCode] interface {
ErrorCode
Unwrapper[Wrapped]
}

// wrappedErrorCode is a convenience to maintain the ErrorCode type when wrapping errors
type wrappedErrorCode struct {
type WrappedErrorCode[Wrapped ErrorCode] struct {
Err error
ErrorCode ErrorCode
ErrorCode Wrapped
}

// Code fulfills the ErrorCode interface
func (wrapped wrappedErrorCode) Code() Code {
func (wrapped WrappedErrorCode[Wrapped]) Code() Code {
return wrapped.ErrorCode.Code()
}

// Error fulfills the ErrorCode interface
func (wrapped wrappedErrorCode) Error() string {
func (wrapped WrappedErrorCode[Wrapped]) Error() string {
return wrapped.Err.Error()
}

// Allow unwrapping
func (wrapped wrappedErrorCode) Unwrap() error {
return wrapped.Err
func (wrapped WrappedErrorCode[Wrapped]) Unwrap() error {
return wrapped.ErrorCode
}

func (wrapped WrappedErrorCode[Wrapped]) Unwrapped() Wrapped {
return wrapped.ErrorCode
}

// Wrap is a convenience that calls errors.Wrap but still returns the ErrorCode interface
func Wrap(errCode ErrorCode, msg string) ErrorCode {
return wrappedErrorCode{
Err: errors.Wrap(errCode, msg),
// If a nil ErrorCode is given it will be returned as nil
func Wrap[EC ErrorCode](errCode EC, msg string) ErrorCodeWrap[EC] {
err := errors.Wrap(errCode, msg)
if err == nil {
return nil
}
return WrappedErrorCode[EC]{
Err: err,
ErrorCode: errCode,
}
}

// Wrapf is a convenience that calls errors.Wrapf but still returns the ErrorCode interface
func Wrapf(errCode ErrorCode, msg string, args ...interface{}) ErrorCode {
return wrappedErrorCode{
Err: errors.Wrapf(errCode, msg, args...),
// If a nil ErrorCode is given it will be returned as nil
func Wrapf[EC ErrorCode](errCode EC, msg string, args ...interface{}) ErrorCodeWrap[EC] {
err := errors.Wrapf(errCode, msg, args...)
if err == nil {
return nil
}
return WrappedErrorCode[EC]{
Err: err,
ErrorCode: errCode,
}
}

// Wraps is a convenience that calls errors.Wraps but still returns the ErrorCode interface
func Wraps(errCode ErrorCode, msg string, args ...interface{}) ErrorCode {
return wrappedErrorCode{
Err: errors.Wraps(errCode, msg, args...),
// If a nil ErrorCode is given it will be returned as nil
func Wraps[EC ErrorCode](errCode EC, msg string, args ...interface{}) ErrorCodeWrap[EC] {
err := errors.Wraps(errCode, msg, args...)
if err == nil {
return nil
}
return WrappedErrorCode[EC]{
Err: err,
ErrorCode: errCode,
}
}
Expand Down
74 changes: 47 additions & 27 deletions error_code_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func TestMinimalErrorCode(t *testing.T) {
minimal := MinimalError{}
AssertCodes(t, minimal)
ErrorEquals(t, minimal, "error")
ClientDataEquals(t, minimal, minimal)
ClientDataEqualsDef(t, minimal, minimal)
OpEquals(t, minimal, "")
UserMsgEquals(t, minimal, "")
}
Expand All @@ -80,7 +80,7 @@ func TestChildOnlyErrorCode(t *testing.T) {
coe := ChildOnlyError{}
AssertCodes(t, coe)
ErrorEquals(t, coe, "error")
ClientDataEquals(t, coe, coe)
ClientDataEqualsDef(t, coe, coe)
}

// Test a top-level error
Expand Down Expand Up @@ -168,51 +168,65 @@ func TestErrorWrapperCode(t *testing.T) {
wrapped := ErrorWrapper{Err: errors.New("error")}
AssertCodes(t, wrapped)
ErrorEquals(t, wrapped, "error")
ClientDataEquals(t, wrapped, errors.New("error"))
ClientDataEqualsDef(t, wrapped, errors.New("error"))
s2 := Struct2{A: "A", B: "B"}
wrappedS2 := ErrorWrapper{Err: s2}
AssertCodes(t, wrappedS2)
ErrorEquals(t, wrappedS2, "error A & B A & B")
ClientDataEquals(t, wrappedS2, s2)
ClientDataEqualsDef(t, wrappedS2, s2)
s1 := Struct1{A: "A"}
ClientDataEquals(t, ErrorWrapper{Err: s1}, s1)
ClientDataEqualsDef(t, ErrorWrapper{Err: s1}, s1)
sconst := StructConstError1{A: "A"}
ClientDataEquals(t, ErrorWrapper{Err: sconst}, sconst)
ClientDataEqualsDef(t, ErrorWrapper{Err: sconst}, sconst)
}

func TestErrorWrapperNil(t *testing.T) {
if errcode.Wrap[errcode.ErrorCode](nil, "wrapped") != nil {
t.Errorf("not nil")
}
if errcode.Wrapf[errcode.ErrorCode](nil, "wrapped") != nil {
t.Errorf("not nil")
}
if errcode.Wraps[errcode.ErrorCode](nil, "wrapped") != nil {
t.Errorf("not nil")
}
}

func TestErrorWrapperFunctions(t *testing.T) {
coded := errcode.NewBadRequestErr(errors.New("underlying"))
underlying := errors.New("underlying")
coded := errcode.NewBadRequestErr(underlying)
AssertCode(t, coded, errcode.InvalidInputCode.CodeStr())

{
wrap := errcode.Wrap(coded, "wrapped")
AssertCode(t, coded, wrap.Code().CodeStr())
AssertCode(t, wrap, errcode.InvalidInputCode.CodeStr())
if errMsg := wrap.Error(); errMsg != "wrapped: underlying" {
t.Errorf("Wrap unexpected: %s", errMsg)
}
if errors.Unwrap(wrap) == coded {
t.Error("bad unwrap")
if errors.Unwrap(wrap).Error() != underlying.Error() {
t.Errorf("bad unwrap %+v", errors.Unwrap(wrap))
}
}

{
wrapf := errcode.Wrapf(coded, "wrapped %s", "arg")
AssertCode(t, coded, wrapf.Code().CodeStr())
AssertCode(t, wrapf, errcode.InvalidInputCode.CodeStr())
if errMsg := wrapf.Error(); errMsg != "wrapped arg: underlying" {
t.Errorf("Wrap unexpected: %s", errMsg)
}
if errors.Unwrap(wrapf) == coded {
t.Error("bad unwrap")
if errors.Unwrap(wrapf).Error() != underlying.Error() {
t.Errorf("bad unwrap %+v", errors.Unwrap(wrapf))
}
}

{
wraps := errcode.Wraps(coded, "wrapped", "arg", 1)
AssertCode(t, coded, wraps.Code().CodeStr())
AssertCode(t, wraps, errcode.InvalidInputCode.CodeStr())
if errMsg := wraps.Error(); errMsg != "wrapped arg=1: underlying" {
t.Errorf("Wrap unexpected: %s", errMsg)
}
if errors.Unwrap(wraps) == coded {
t.Error("bad unwrap")
if errors.Unwrap(wraps).Error() != underlying.Error() {
t.Errorf("bad unwrap %+v", errors.Unwrap(wraps))
}
}
}
Expand Down Expand Up @@ -242,13 +256,13 @@ func TestNewInvalidInputErr(t *testing.T) {
AssertCode(t, err, internalCodeStr)
AssertHTTPCode(t, err, 500)
ErrorEquals(t, err, "error")
ClientDataEquals(t, err, MinimalError{}, internalCodeStr)
ClientDataEquals(t, err, MinimalError{}, internalCodeStr, MinimalError{})

wrappedInternalErr := errcode.NewInternalErr(internalErr)
AssertCode(t, err, internalCodeStr)
AssertHTTPCode(t, err, 500)
ErrorEquals(t, err, "error")
ClientDataEquals(t, wrappedInternalErr, MinimalError{}, internalCodeStr)
ClientDataEquals(t, wrappedInternalErr, MinimalError{}, internalCodeStr, MinimalError{})
// It should use the original stack trace, not the wrapped
AssertStackEquals(t, wrappedInternalErr, errcode.StackTrace(internalErr))

Expand Down Expand Up @@ -283,14 +297,14 @@ func TestNewInternalErr(t *testing.T) {
AssertCode(t, err, internalCodeStr)
AssertHTTPCode(t, err, 500)
ErrorEquals(t, err, "error")
ClientDataEquals(t, err, MinimalError{}, internalCodeStr)
ClientDataEquals(t, err, MinimalError{}, internalCodeStr, MinimalError{})

invalidErr := errcode.NewInvalidInputErr(MinimalError{})
err = errcode.NewInternalErr(invalidErr)
AssertCode(t, err, internalCodeStr)
AssertHTTPCode(t, err, 500)
ErrorEquals(t, err, "error")
ClientDataEquals(t, err, MinimalError{}, internalCodeStr)
ClientDataEquals(t, err, MinimalError{}, internalCodeStr, MinimalError{})
}

// Test Operation
Expand Down Expand Up @@ -328,7 +342,7 @@ func TestOpErrorCode(t *testing.T) {
AssertOperation(t, has, "has")
AssertCodes(t, has)
ErrorEquals(t, has, "error")
ClientDataEquals(t, has, has)
ClientDataEqualsDef(t, has, has)
OpEquals(t, has, "has")

OpEquals(t, OpErrorEmbed{}, "")
Expand Down Expand Up @@ -370,7 +384,7 @@ func TestUserMsg(t *testing.T) {
AssertUserMsg(t, ue, "user")
AssertCodes(t, ue)
ErrorEquals(t, ue, "error")
ClientDataEquals(t, ue, ue)
ClientDataEqualsDef(t, ue, ue)
UserMsgEquals(t, ue, "user")

UserMsgEquals(t, UserMsgErrorEmbed{}, "")
Expand Down Expand Up @@ -428,11 +442,12 @@ func ErrorEquals(t *testing.T, err error, msg string) {
}
}

func ClientDataEquals(t *testing.T, code errcode.ErrorCode, data interface{}, codeStrs ...errcode.CodeStr) {
codeStr := codeString
if len(codeStrs) > 0 {
codeStr = codeStrs[0]
}
func ClientDataEqualsDef(t *testing.T, code errcode.ErrorCode, data interface{}) {
t.Helper()
ClientDataEquals(t, code, data, codeString)
}

func ClientDataEquals(t *testing.T, code errcode.ErrorCode, data interface{}, codeStr errcode.CodeStr, otherCodes ...errcode.ErrorCode) {
t.Helper()

jsonEquals(t, "ClientData", data, errcode.ClientData(code))
Expand All @@ -441,11 +456,16 @@ func ClientDataEquals(t *testing.T, code errcode.ErrorCode, data interface{}, co
msg = code.Error()
}

others := make([]errcode.JSONFormat, len(otherCodes))
for i, err := range otherCodes {
others[i] = errcode.NewJSONFormat(err)
}
jsonExpected := errcode.JSONFormat{
Data: data,
Msg: msg,
Code: codeStr,
Operation: errcode.Operation(data),
Others: others,
}
newJSON := errcode.NewJSONFormat(code)
jsonEquals(t, "JSONFormat", jsonExpected, newJSON)
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ module github.com/gregwebs/errcode

go 1.21.9

require github.com/gregwebs/errors v1.2.0
require github.com/gregwebs/errors v1.5.0
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
github.com/gregwebs/errors v1.2.0 h1:9QmMmbIPtgVNKEyinWD08z2brRKEf3CbWj+tRraNas0=
github.com/gregwebs/errors v1.2.0/go.mod h1:1NkCObP7+scylHlC69lwHl2ACOHwktWYrZV4EJDEl6g=
github.com/gregwebs/errors v1.5.0 h1:+vMiQwtPnVVr2RuVebjVQMnMZwUPIpeTU/iXgCOFBfE=
github.com/gregwebs/errors v1.5.0/go.mod h1:1NkCObP7+scylHlC69lwHl2ACOHwktWYrZV4EJDEl6g=
15 changes: 9 additions & 6 deletions group.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,16 @@ import (
// ErrorCodes return all errors (from an ErrorGroup) that are of interface ErrorCode.
// It first calls the Errors function.
func ErrorCodes(err error) []ErrorCode {
errors := errors.Errors(err)
errorCodes := make([]ErrorCode, 0, len(errors))
for _, errItem := range errors {
if errcode, ok := errItem.(ErrorCode); ok {
errorCodes = append(errorCodes, errcode)
errorCodes := make([]ErrorCode, 0)
errors.WalkDeep(err, func(err error) bool {
if errcode, ok := err.(ErrorCode); ok {
// avoid duplicating codes
if len(errorCodes) == 0 || errorCodes[len(errorCodes)-1].Code().codeStr != errcode.Code().codeStr {
errorCodes = append(errorCodes, errcode)
}
}
}
return false
})
return errorCodes
}

Expand Down
1 change: 1 addition & 0 deletions group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ func AssertLength[Any any](t *testing.T, slice []Any, expected int) {
}

}

func TestErrorCodes(t *testing.T) {
codes := errcode.ErrorCodes(nil)
AssertLength(t, codes, 0)
Expand Down

0 comments on commit 2177897

Please sign in to comment.