Skip to content

Commit

Permalink
GetClientData: don't default to self
Browse files Browse the repository at this point in the history
opt-in will help prevent leaking data
  • Loading branch information
Greg Weber committed Jun 25, 2024
1 parent 104ea06 commit 4c60bb3
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 24 deletions.
5 changes: 2 additions & 3 deletions error_code.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,6 @@ type HasClientData interface {

// ClientData retrieves data from a structure that implements HasClientData
// It will unwrap errors to look for HasClientData
// If HasClientData is not defined it uses the given ErrorCode object.
// Normally this function is used rather than GetClientData.
func ClientData(errCode ErrorCode) interface{} {
if hasData, ok := errCode.(HasClientData); ok {
Expand All @@ -263,7 +262,7 @@ func ClientData(errCode ErrorCode) interface{} {
break
}
}
return errCode
return nil
}

// JSONFormat serializes an ErrorCode to a particular JSON format.
Expand Down Expand Up @@ -293,7 +292,7 @@ type JSONFormat struct {
func OperationClientData(errCode ErrorCode) (string, interface{}) {
op := Operation(errCode)
data := ClientData(errCode)
if op == "" {
if op == "" && data != nil {
op = Operation(data)
}
return op, data
Expand Down
67 changes: 46 additions & 21 deletions error_code_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func TestHttpErrorCode(t *testing.T) {
http := HTTPError{}
AssertHTTPCode(t, http, 900)
ErrorEquals(t, http, "error")
ClientDataEquals(t, http, http, httpCodeStr)
ClientDataEquals(t, http, nil, httpCodeStr)
}

// Test a very simple error
Expand All @@ -60,7 +60,7 @@ func TestMinimalErrorCode(t *testing.T) {
minimal := MinimalError{}
AssertCodes(t, minimal)
ErrorEquals(t, minimal, "error")
ClientDataEqualsDef(t, minimal, minimal)
ClientDataEqualsDef(t, minimal, nil)
OpEquals(t, minimal, "")
UserMsgEquals(t, minimal, "")
}
Expand All @@ -80,7 +80,7 @@ func TestChildOnlyErrorCode(t *testing.T) {
coe := ChildOnlyError{}
AssertCodes(t, coe)
ErrorEquals(t, coe, "error")
ClientDataEqualsDef(t, coe, coe)
ClientDataEqualsDef(t, coe, nil)
}

// Test a top-level error
Expand All @@ -100,7 +100,7 @@ func TestTopErrorCode(t *testing.T) {
top := TopError{}
AssertCodes(t, top, topCodeStr)
ErrorEquals(t, top, "error")
ClientDataEquals(t, top, top, topCodeStr)
ClientDataEquals(t, top, nil, topCodeStr)
}

// Test a deep hierarchy
Expand All @@ -122,7 +122,7 @@ func TestDeepErrorCode(t *testing.T) {
AssertHTTPCode(t, deep, 800)
AssertCode(t, deep, deepCodeStr)
ErrorEquals(t, deep, "error")
ClientDataEquals(t, deep, deep, deepCodeStr)
ClientDataEquals(t, deep, nil, deepCodeStr)
}

// Test an ErrorWrapper that has different error types placed into it
Expand Down Expand Up @@ -249,37 +249,38 @@ func (ic InternalChild) Error() string { return "internal child error" }
func (ic InternalChild) Code() errcode.Code { return internalChild }

func TestNewInvalidInputErr(t *testing.T) {
err := errcode.NewInvalidInputErr(errors.New("new error"))
var err errcode.ErrorCode
err = errcode.NewInvalidInputErr(errors.New("new error"))
AssertCodes(t, err, "input")
ErrorEquals(t, err, "new error")
ClientDataEquals(t, err, errors.New("new error"), "input")

err = errcode.NewInvalidInputErr(MinimalError{})
AssertCodes(t, err, "input.testcode")
ErrorEquals(t, err, "error")
ClientDataEquals(t, err, MinimalError{}, errcode.CodeStr("input.testcode"))
ClientDataEquals(t, err, nil, errcode.CodeStr("input.testcode"))

internalErr := errcode.NewInternalErr(MinimalError{})
err = errcode.NewInvalidInputErr(internalErr)
internalCodeStr := errcode.CodeStr("internal")
AssertCode(t, err, internalCodeStr)
AssertHTTPCode(t, err, 500)
ErrorEquals(t, err, "error")
ClientDataEquals(t, err, MinimalError{}, internalCodeStr, MinimalError{})
ClientDataEquals(t, err, nil, internalCodeStr, MinimalError{})

wrappedInternalErr := errcode.NewInternalErr(internalErr)
AssertCode(t, err, internalCodeStr)
AssertHTTPCode(t, err, 500)
ErrorEquals(t, err, "error")
ClientDataEquals(t, wrappedInternalErr, MinimalError{}, internalCodeStr, MinimalError{})
ClientDataEquals(t, wrappedInternalErr, nil, internalCodeStr, MinimalError{})
// It should use the original stack trace, not the wrapped
AssertStackEquals(t, wrappedInternalErr, errcode.StackTrace(internalErr))

err = errcode.NewInvalidInputErr(InternalChild{})
AssertCode(t, err, internalChildCodeStr)
AssertHTTPCode(t, err, 503)
ErrorEquals(t, err, "internal child error")
ClientDataEquals(t, err, InternalChild{}, internalChildCodeStr)
ClientDataEquals(t, err, nil, internalChildCodeStr)
}

func TestStackTrace(t *testing.T) {
Expand All @@ -306,14 +307,14 @@ func TestNewInternalErr(t *testing.T) {
AssertCode(t, err, internalCodeStr)
AssertHTTPCode(t, err, 500)
ErrorEquals(t, err, "error")
ClientDataEquals(t, err, MinimalError{}, internalCodeStr, MinimalError{})
ClientDataEquals(t, err, nil, internalCodeStr, MinimalError{})

invalidErr := errcode.NewInvalidInputErr(MinimalError{})
err = errcode.NewInternalErr(invalidErr)
AssertCode(t, err, internalCodeStr)
AssertHTTPCode(t, err, 500)
ErrorEquals(t, err, "error")
ClientDataEquals(t, err, MinimalError{}, internalCodeStr, MinimalError{})
ClientDataEquals(t, err, nil, internalCodeStr, MinimalError{})
}

// Test Operation
Expand Down Expand Up @@ -351,7 +352,11 @@ func TestOpErrorCode(t *testing.T) {
AssertOperation(t, has, "has")
AssertCodes(t, has)
ErrorEquals(t, has, "error")
ClientDataEqualsDef(t, has, has)
ClientDataResult(t, has, clientDataResult{
data: nil,
operation: has.GetOperation(),
codeStr: has.Code().CodeStr(),
})
OpEquals(t, has, "has")

OpEquals(t, OpErrorEmbed{}, "")
Expand Down Expand Up @@ -397,7 +402,7 @@ func TestUserMsg(t *testing.T) {
AssertUserMsg(t, ue, "user")
AssertCodes(t, ue)
ErrorEquals(t, ue, "error")
ClientDataEqualsDef(t, ue, ue)
ClientDataEqualsDef(t, ue, nil)
UserMsgEquals(t, ue, "user")

UserMsgEquals(t, UserMsgErrorEmbed{}, "")
Expand Down Expand Up @@ -464,30 +469,50 @@ func ClientDataEqualsDef(t *testing.T, code errcode.ErrorCode, data interface{})
ClientDataEquals(t, code, data, codeString)
}

func ClientDataEquals(t *testing.T, code errcode.ErrorCode, data interface{}, codeStr errcode.CodeStr, otherCodes ...errcode.ErrorCode) {
type clientDataResult struct {
data interface{}
operation string
codeStr errcode.CodeStr
otherCodes []errcode.ErrorCode
}

func ClientDataResult(t *testing.T, code errcode.ErrorCode, result clientDataResult) {
t.Helper()

jsonEquals(t, "ClientData", data, errcode.ClientData(code))
jsonEquals(t, "ClientData", result.data, errcode.ClientData(code))
msg := errcode.GetUserMsg(code)
if msg == "" {
msg = code.Error()
}

others := make([]errcode.JSONFormat, len(otherCodes))
for i, err := range otherCodes {
others := make([]errcode.JSONFormat, len(result.otherCodes))
for i, err := range result.otherCodes {
others[i] = errcode.NewJSONFormat(err)
}
op := result.operation
if op == "" {
op = errcode.Operation(result.data)
}
jsonExpected := errcode.JSONFormat{
Data: data,
Data: result.data,
Msg: msg,
Code: codeStr,
Operation: errcode.Operation(data),
Code: result.codeStr,
Operation: op,
Others: others,
}
newJSON := errcode.NewJSONFormat(code)
jsonEquals(t, "JSONFormat", jsonExpected, newJSON)
}

func ClientDataEquals(t *testing.T, code errcode.ErrorCode, data interface{}, codeStr errcode.CodeStr, otherCodes ...errcode.ErrorCode) {
t.Helper()
ClientDataResult(t, code, clientDataResult{
data: data,
codeStr: codeStr,
otherCodes: otherCodes,
})
}

func jsonEquals(t *testing.T, errPrefix string, expectedIn interface{}, gotIn interface{}) {
t.Helper()
got, err1 := json.Marshal(gotIn)
Expand Down

0 comments on commit 4c60bb3

Please sign in to comment.