Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions errors/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ func New(args ...interface{}) *Error {
case error:
e.Err = arg
e.propagateContexts()
case context.Context:
_ = e.WithContext(arg.(context.Context))
case string:
e.msg = arg
}
Expand Down
347 changes: 347 additions & 0 deletions errors/error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,48 @@ const (

type testContextKey string

// Custom context type that implements context.Context interface
type customContext struct {
context.Context
customField string
}

// requestContext is a more realistic custom context that might be used in a web application
type requestContext struct {
context.Context
requestID string
userID string
traceID string
}

func newRequestContext(parent context.Context, requestID, userID, traceID string) *requestContext {
// Set the basics and extras in the underlying context
ctx := common.SetBasics(parent, common.Basics{
"request_id": requestID,
"user_id": userID,
})
ctx = common.SetExtras(ctx, common.Extras{
"trace_id": traceID,
"service": "error-service",
})

return &requestContext{
Context: ctx,
requestID: requestID,
userID: userID,
traceID: traceID,
}
}

// Additional methods for the custom context
func (r *requestContext) GetRequestID() string {
return r.requestID
}

func (r *requestContext) GetUserID() string {
return r.userID
}

func TestNew(t *testing.T) {
childErr := New(Op(childOp), BadRequest, childErrMsg)

Expand Down Expand Up @@ -215,3 +257,308 @@ func Test_extras(t *testing.T) {
}
assert.Equal(wantExtras, err.extras())
}

// TestNew_WithContextParameter demonstrates how the New function handles context.Context parameters.
// This test shows that any type implementing context.Context interface will be properly handled,
// including custom context types that embed additional fields or methods.
func TestNew_WithContextParameter(t *testing.T) {
tests := []struct {
name string
setupFunc func() *Error
expectedBasics common.Basics
expectedExtras common.Extras
}{
{
name: "context passed directly to New function",
setupFunc: func() *Error {
ctx := context.Background()
ctx = common.SetBasics(ctx, common.Basics{
"request_id": "req-123",
"user_id": "user-456",
})
ctx = common.SetExtras(ctx, common.Extras{
"trace_id": "trace-789",
})

return New(Op("test.operation"), BadRequest, "test error", ctx)
},
expectedBasics: common.Basics{
"request_id": "req-123",
"user_id": "user-456",
},
expectedExtras: common.Extras{
"trace_id": "trace-789",
},
},
{
name: "context with child error - both via New parameter",
setupFunc: func() *Error {
childCtx := context.Background()
childCtx = common.SetBasics(childCtx, common.Basics{
"session_id": "session-def",
"user_id": "user-child",
})
childCtx = common.SetExtras(childCtx, common.Extras{
"operation": "child-op",
})

parentCtx := context.Background()
parentCtx = common.SetBasics(parentCtx, common.Basics{
"request_id": "req-123",
"user_id": "user-parent",
})
parentCtx = common.SetExtras(parentCtx, common.Extras{
"trace_id": "trace-789",
})

childErr := New(Op("child.operation"), BadRequest, "child error", childCtx)
return New(Op("parent.operation"), "parent error", childErr, parentCtx)
},
expectedBasics: common.Basics{
"request_id": "req-123",
"session_id": "session-def",
"user_id": "user-child", // Child context overrides parent (propagation happens first)
},
expectedExtras: common.Extras{
"trace_id": "trace-789",
"operation": "child-op",
},
},
{
name: "empty context passed to New",
setupFunc: func() *Error {
ctx := context.Background()
return New(Op("test.operation"), BadRequest, "test error", ctx)
},
expectedBasics: nil,
expectedExtras: nil,
},
{
name: "mixed context methods - New parameter and WithContext",
setupFunc: func() *Error {
// Context via New parameter
newCtx := context.Background()
newCtx = common.SetBasics(newCtx, common.Basics{
"from_new": "value1",
"override": "from_new",
})

// Context via WithContext (should not override existing)
withCtx := context.Background()
withCtx = common.SetBasics(withCtx, common.Basics{
"from_with": "value2",
"override": "from_with",
})
withCtx = common.SetExtras(withCtx, common.Extras{
"extra_key": "extra_value",
})

return New(Op("test.operation"), BadRequest, "test error", newCtx).
WithContext(withCtx)
},
expectedBasics: common.Basics{
"from_new": "value1",
"override": "from_new", // New parameter should take precedence
"from_with": "value2", // Additional keys from WithContext should be added
},
expectedExtras: common.Extras{
"extra_key": "extra_value",
},
},
{
name: "custom context type implementing context.Context interface",
setupFunc: func() *Error {
baseCtx := context.Background()
baseCtx = common.SetBasics(baseCtx, common.Basics{
"request_id": "req-123",
})
baseCtx = common.SetExtras(baseCtx, common.Extras{
"trace_id": "trace-789",
})

customCtx := customContext{
Context: baseCtx,
customField: "custom_value",
}

return New(Op("test.operation"), BadRequest, "test error", customCtx)
},
expectedBasics: common.Basics{
"request_id": "req-123", // Custom context types work with type switch
},
expectedExtras: common.Extras{
"trace_id": "trace-789", // Custom context types work with type switch
},
},
{
name: "realistic request context with nested errors",
setupFunc: func() *Error {
// Simulate a realistic web application scenario
reqCtx := newRequestContext(
context.Background(),
"req-web-456",
"user-789",
"trace-abc-def",
)

// Simulate a database error with its own context
dbCtx := common.SetBasics(context.Background(), common.Basics{
"database": "postgres",
"table": "users",
})
dbCtx = common.SetExtras(dbCtx, common.Extras{
"query_duration": "150ms",
})

dbErr := New(Op("database.query"), NotFound, "user not found", dbCtx)

// Service layer error that wraps the database error
serviceErr := New(Op("user.service.get"), "failed to fetch user", dbErr, reqCtx)

// Controller layer error - demonstrates how context propagates up
return New(Op("user.controller.get"), BadRequest, "invalid user request", serviceErr)
},
expectedBasics: common.Basics{
"request_id": "req-web-456", // From request context
"user_id": "user-789", // From request context
"database": "postgres", // From database context
"table": "users", // From database context
},
expectedExtras: common.Extras{
"trace_id": "trace-abc-def", // From request context
"service": "error-service", // From request context
"query_duration": "150ms", // From database context
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert := assert.New(t)
err := tt.setupFunc()

assert.Equal(tt.expectedBasics, err.basics())
assert.Equal(tt.expectedExtras, err.extras())

if tt.expectedBasics != nil || tt.expectedExtras != nil {
assert.NotNil(err.ctx)
}
})
}
}

func TestError_ContextHandling(t *testing.T) {
tests := []struct {
name string
setupFunc func() *Error
expectedBasics common.Basics
expectedExtras common.Extras
expectContextNil bool
expectChildCtxNil bool
}{
{
name: "context with multiple layers and overrides",
setupFunc: func() *Error {
ctx := context.Background()
ctx = common.SetBasics(ctx, common.Basics{
"request_id": "req-123",
"user_id": "user-456",
})
ctx = common.SetExtras(ctx, common.Extras{
"trace_id": "trace-789",
"span_id": "span-abc",
})

childCtx := common.SetBasics(ctx, common.Basics{
"user_id": "user-override", // This should override parent
"session_id": "session-def", // This should be added
})
childCtx = common.SetExtras(childCtx, common.Extras{
"span_id": "span-override", // This should override parent
"operation": "child-op", // This should be added
})

childErr := New(Op("child.operation"), BadRequest, "child error").
WithContext(childCtx)

return New(Op("parent.operation"), "parent error", childErr).
WithContext(ctx)
},
expectedBasics: common.Basics{
"request_id": "req-123", // From parent, unchanged
"user_id": "user-override", // From child, overridden
"session_id": "session-def", // From child, new
},
expectedExtras: common.Extras{
"trace_id": "trace-789", // From parent, unchanged
"span_id": "span-override", // From child, overridden
"operation": "child-op", // From child, new
},
expectContextNil: false,
expectChildCtxNil: true,
},
{
name: "context without child error",
setupFunc: func() *Error {
ctx := context.Background()
ctx = common.SetBasics(ctx, common.Basics{"key": "value"})
ctx = common.SetExtras(ctx, common.Extras{"extra": "data"})

return New(Op("test.operation"), BadRequest, "test error").
WithContext(ctx)
},
expectedBasics: common.Basics{"key": "value"},
expectedExtras: common.Extras{"extra": "data"},
expectContextNil: false,
},
{
name: "empty context",
setupFunc: func() *Error {
ctx := context.Background()
return New(Op("test.operation"), BadRequest, "test error").
WithContext(ctx)
},
expectedBasics: nil,
expectedExtras: nil,
expectContextNil: false, // WithContext always creates ctx map
},
{
name: "no context set",
setupFunc: func() *Error {
return New(Op("test.operation"), BadRequest, "test error")
},
expectedBasics: nil,
expectedExtras: nil,
expectContextNil: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert := assert.New(t)
err := tt.setupFunc()

assert.Equal(tt.expectedBasics, err.basics())
assert.Equal(tt.expectedExtras, err.extras())

if tt.expectContextNil {
assert.Nil(err.ctx)
} else {
assert.NotNil(err.ctx)
if tt.expectedBasics != nil {
assert.Contains(err.ctx, "basics")
}
if tt.expectedExtras != nil {
assert.Contains(err.ctx, "extras")
}
}

if tt.expectChildCtxNil {
childErrTyped, ok := err.Err.(*Error)
if ok {
assert.Nil(childErrTyped.ctx, "child error context should be cleared after propagation")
}
}
})
}
}
Loading