Skip to content

Commit

Permalink
ErrorWrap interface and re-organize internals
Browse files Browse the repository at this point in the history
The ErrorWrap interface is useful for formatting.
Use this to reuse formatting code.
And to forward formatting for ErrorWrap
  • Loading branch information
Greg Weber authored and gregwebs committed Oct 31, 2024
1 parent d2ee623 commit 96e5f58
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 99 deletions.
201 changes: 110 additions & 91 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,28 +87,48 @@ import (
// New returns an error with the supplied message.
// New also records the stack trace at the point it was called.
func New(message string) error {
return &fundamental{withStack{
stderrors.New(message),
callers(),
}}
return &fundamental{stderrors.New(message), callers()}
}

// Errorf formats according to a format specifier and returns the string
// as a value that satisfies error.
// Errorf also records the stack trace at the point it was called.
func Errorf(format string, args ...interface{}) error {
err := fmt.Errorf(format, args...)
stacked := withStack{
error: err,
stack: callers(),
}
// if %w was successfully used then this is not a fundamental error
if _, ok := err.(unwrapper); ok {
return &addStack{stacked}
return &addStack{withStack{err, callers()}}
} else if _, ok := err.(unwraps); ok {
return &addStack{stacked}
return &addStack{withStack{err, callers()}}
}
return &fundamental{err, callers()}
}

// fundamental is a base error that doesn't wrap other errors
// It stores an error rather than just a string. This allows for:
// * reuse of existing patterns
// * usage of Errorf to support any formatting
// The latter is done in part to support %w, but note that if %w is used we don't use fundamental
type fundamental struct {
error
*stack
}

func (f *fundamental) StackTrace() StackTrace { return f.stack.StackTrace() }
func (f *fundamental) HasStack() bool { return true }
func (f *fundamental) Format(s fmt.State, verb rune) {
switch verb {
case 'v':
if s.Flag('+') {
writeString(s, f.Error())
f.StackTrace().Format(s, verb)
return
}
fallthrough
case 's':
writeString(s, f.Error())
case 'q':
fmt.Fprintf(s, "%q", f.Error())
}
return &fundamental{stacked}
}

// StackTraceAware is an optimization to avoid repetitive traversals of an error chain.
Expand All @@ -126,17 +146,6 @@ func HasStack(err error) bool {
return GetStackTracer(err) != nil
}

// fundamental is a base error that doesn't wrap other errors
// originally it stored just a string, but switching to storing an error allows for
// * simple re-use of withStack
// * usage of Errorf to support any formatting
// The latter is done to support %w, but if %w is used we don't use fundamental
type fundamental struct {
withStack
}

func (f *fundamental) ErrorNoUnwrap() string { return f.Error() }

// AddStack annotates err with a stack trace at the point WithStack was called.
// It will first check with HasStack to see if a stack trace already exists before creating another one.
func AddStack(err error) error {
Expand All @@ -146,7 +155,6 @@ func AddStack(err error) error {
if HasStack(err) {
return err
}

return &addStack{withStack{err, callers()}}
}

Expand All @@ -161,42 +169,37 @@ func AddStackSkip(err error, skip int) error {
return &addStack{withStack{err, callersSkip(skip + 3)}}
}

// GetStackTracer will return the first StackTracer in the causer chain.
// This function is used by AddStack to avoid creating redundant stack traces.
//
// You can also use the StackTracer interface on the returned error to get the stack trace.
func GetStackTracer(origErr error) StackTracer {
var stacked StackTracer
WalkDeep(origErr, func(err error) bool {
if stackTracer, ok := err.(StackTracer); ok {
stacked = stackTracer
return true
}
return false
})
return stacked
}

type withStack struct {
error
*stack
}

func (w *withStack) Unwrap() error { return Unwrap(w.error) }

func (w *withStack) StackTrace() StackTrace { return w.stack.StackTrace() }
func (w *withStack) Unwrap() error { return w.error }
func (w *withStack) ErrorNoUnwrap() string { return "" }
func (w *withStack) HasStack() bool { return true }
func (w *withStack) Format(s fmt.State, verb rune) {
formatError(w, s, verb)
}

func formatError(err ErrorUnwrap, s fmt.State, verb rune) {
switch verb {
case 'v':
if s.Flag('+') {
formatterPlusV(s, verb, w.error)
w.stack.Format(s, verb)
formatterPlusV(s, verb, err.Unwrap())
if msg := err.ErrorNoUnwrap(); msg != "" {
writeString(s, "\n"+msg)
}
if stackTracer, ok := err.(StackTracer); ok {
stackTracer.StackTrace().Format(s, verb)
}
return
}
fallthrough
case 's':
writeString(s, w.Error())
writeString(s, err.Error())
case 'q':
fmt.Fprintf(s, "%q", w.Error())
fmt.Fprintf(s, "%q", err.Error())
}
}

Expand All @@ -206,7 +209,10 @@ type addStack struct {
withStack
}

func (w *addStack) Unwrap() error { return w.error }
func (a *addStack) Unwrap() error { return a.error }
func (a *addStack) Format(s fmt.State, verb rune) {
formatError(a, s, verb)
}

// Wrap returns an error annotating err with a stack trace
// at the point Wrap is called, and the supplied message.
Expand All @@ -215,13 +221,9 @@ func Wrap(err error, message string) error {
if err == nil {
return nil
}
return &withStack{
&withMessage{
cause: err,
msg: message,
causeHasStack: HasStack(err),
},
callers(),
return &withMessage{
msg: message,
withStack: withStack{err, callers()},
}
}

Expand All @@ -232,13 +234,9 @@ func Wrapf(err error, format string, args ...interface{}) error {
if err == nil {
return nil
}
return &withStack{
&withMessage{
cause: err,
msg: fmt.Sprintf(format, args...),
causeHasStack: HasStack(err),
},
callers(),
return &withMessage{
msg: fmt.Sprintf(format, args...),
withStack: withStack{err, callers()},
}
}

Expand All @@ -249,23 +247,34 @@ func WithMessage(err error, message string) error {
if err == nil {
return nil
}
return &withMessage{
cause: err,
msg: message,
causeHasStack: HasStack(err),
return &withMessageNoStack{
msg: message,
error: err,
}
}

type withMessage struct {
cause error
msg string
causeHasStack bool
msg string
withStack
}

func (w *withMessage) Error() string { return w.msg + ": " + w.cause.Error() }
func (w *withMessage) Unwrap() error { return w.cause }
func (w *withMessage) Error() string { return w.msg + ": " + w.error.Error() }
func (w *withMessage) ErrorNoUnwrap() string { return w.msg }
func (w *withMessage) HasStack() bool { return w.causeHasStack }
func (w *withMessage) Format(s fmt.State, verb rune) {
formatError(w, s, verb)
}

type withMessageNoStack struct {
msg string
error
}

func (w *withMessageNoStack) Error() string { return w.msg + ": " + w.error.Error() }
func (w *withMessageNoStack) Unwrap() error { return w.error }
func (w *withMessageNoStack) ErrorNoUnwrap() string { return w.msg }
func (w *withMessageNoStack) Format(s fmt.State, verb rune) {
formatError(w, s, verb)
}

func formatterPlusV(s fmt.State, verb rune, err error) {
if f, ok := err.(fmt.Formatter); ok {
Expand All @@ -275,20 +284,6 @@ func formatterPlusV(s fmt.State, verb rune, err error) {
}
}

func (w *withMessage) Format(s fmt.State, verb rune) {
switch verb {
case 'v':
if s.Flag('+') {
formatterPlusV(s, verb, w.Unwrap())
writeString(s, "\n"+w.msg)
return
}
fallthrough
case 's', 'q':
writeString(s, w.Error())
}
}

// Cause returns the underlying cause of the error, if possible.
// Unwrap goes just one level deep, but Cause goes all the way to the bottom
// If nil is given, it will return nil
Expand Down Expand Up @@ -358,10 +353,10 @@ func writeString(w io.Writer, s string) {
}
}

// ErrorNoUnwrap is designed to give just the message of the individual error without any unwrapping.
// ErrorUnwrap allows wrapped errors to give just the message of the individual error without any unwrapping.
//
// The existing Error() string interface loses all structure of error data.
// This extends to all errors that it is wrapping, which will get included in the output of Error()
// The existing Error() convention extends that output to all errors that are wrapped.
// ErrorNoUnwrap() has just the wrapping message without additional unwrapped messages.
//
// Existing Error() definitions look like this:
//
Expand All @@ -370,9 +365,12 @@ func writeString(w io.Writer, s string) {
// An ErrorNoUnwrap() definitions look like this:
//
// func (hasWrapped) ErrorNoUnwrap() string { return hasWrapped.message }
//
// This only needs to be defined if an error has an Unwrap method
type ErrorNotUnwrapped interface {
type ErrorUnwrap interface {
error
Unwrap() error
// ErrorNoUnwrap is the error message component of the wrapping
// It will be a prefix of Error()
// If there is no message in the wrapping then this can return an empty string
ErrorNoUnwrap() string
}

Expand Down Expand Up @@ -418,6 +416,27 @@ func (ew *ErrorWrap) WrapError(wrap func(error) error) {
ew.error = wrap(ew.error)
}

func (ew *ErrorWrap) HasStack() bool {
return HasStack(ew.error)
}

func (ew *ErrorWrap) Format(s fmt.State, verb rune) {
forwardFormatting(ew.error, s, verb)
}

// Forward to a Formatter if it exists
func forwardFormatting(err error, s fmt.State, verb rune) {
if formatter, ok := err.(fmt.Formatter); ok {
formatter.Format(s, verb)
} else if errUnwrap, ok := err.(ErrorUnwrap); ok {
formatError(errUnwrap, s, verb)
} else {
fmtString := fmt.FormatString(s, verb)
// unwrap before calling forwamrdFormatting to avoid infinite recursion
fmt.Fprintf(s, fmtString, err)
}
}

var _ ErrorWrapper = (*ErrorWrap)(nil) // assert implements interface

// WrapFn returns a wrapping function that calls Wrap
Expand Down
7 changes: 4 additions & 3 deletions errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ func TestWrapf(t *testing.T) {
}

for _, tt := range tests {
got := Wrapf(tt.err, tt.message).Error()
got := Wrap(tt.err, tt.message).Error()
if got != tt.want {
t.Errorf("Wrapf(%v, %q): got: %v, want %v", tt.err, tt.message, got, tt.want)
}
Expand Down Expand Up @@ -411,8 +411,9 @@ func TestFormatWrapped(t *testing.T) {
t.Errorf("Unexpected wrapping format: %+v", wrapped)
}
unwrapped := Unwrap(wrapped)
if fmt.Sprintf("%v", unwrapped) != "underlying" {
t.Errorf("Unexpected unwrapping format: %v", wrapped)
got := fmt.Sprintf("%v", unwrapped)
if got != "underlying" {
t.Errorf("Unexpected unwrapping format, got: %s, wrapped: %v", got, wrapped)
}
if !strings.HasPrefix(fmt.Sprintf("%+v", unwrapped), "underlying") {
t.Errorf("Unexpected unwrapping format: %+v", wrapped)
Expand Down
16 changes: 16 additions & 0 deletions stack.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,22 @@ type StackTracer interface {
StackTrace() StackTrace
}

// GetStackTracer will return the first StackTracer in the causer chain.
// This function is used by AddStack to avoid creating redundant stack traces.
//
// You can also use the StackTracer interface on the returned error to get the stack trace.
func GetStackTracer(origErr error) StackTracer {
var stacked StackTracer
WalkDeep(origErr, func(err error) bool {
if stackTracer, ok := err.(StackTracer); ok {
stacked = stackTracer
return true
}
return false
})
return stacked
}

// Frame represents a program counter inside a stack frame.
type Frame uintptr

Expand Down
11 changes: 7 additions & 4 deletions stack_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,9 @@ func TestStackTrace(t *testing.T) {
}}
for i, tt := range tests {
ste := GetStackTracer(tt.err)
if ste == nil {
t.Fatalf("expected a stack trace from test %d error: %v", i+1, tt.err)
}
st := ste.StackTrace()
for j, want := range tt.want {
testFormatRegexp(t, i, st[j], "%+v", want)
Expand Down Expand Up @@ -247,19 +250,19 @@ func TestStackTraceFormat(t *testing.T) {
}, {
stackTrace()[:2],
"%v",
`[stack_test.go:201 stack_test.go:248]`,
`[stack_test.go:204 stack_test.go:251]`,
}, {
stackTrace()[:2],
"%+v",
"\n" +
"github.com/gregwebs/errors.stackTrace\n" +
"\tgithub.com/gregwebs/errors/stack_test.go:201\n" +
"\tgithub.com/gregwebs/errors/stack_test.go:204\n" +
"github.com/gregwebs/errors.TestStackTraceFormat\n" +
"\tgithub.com/gregwebs/errors/stack_test.go:252",
"\tgithub.com/gregwebs/errors/stack_test.go:255",
}, {
stackTrace()[:2],
"%#v",
`[]errors.Frame{stack_test.go:201, stack_test.go:260}`,
`[]errors.Frame{stack_test.go:204, stack_test.go:263}`,
}}

for i, tt := range tests {
Expand Down
Loading

0 comments on commit 96e5f58

Please sign in to comment.