Skip to content

Commit

Permalink
assert: improve unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
cornelk committed Dec 19, 2024
1 parent 894f1c8 commit f75e8c7
Show file tree
Hide file tree
Showing 2 changed files with 186 additions and 38 deletions.
54 changes: 30 additions & 24 deletions assert/assert.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,43 @@ import (
"errors"
"fmt"
"reflect"
"testing"
)

// Testing is an interface that includes the methods used from *testing.T.
type Testing interface {
Helper()
Error(args ...any)
FailNow()
}

// Equal asserts that two objects are equal.
func Equal(t *testing.T, expected, actual any, errorMessage ...string) {
func Equal(t Testing, expected, actual any, msgAndArgs ...any) {
t.Helper()
if equal(expected, actual) {
return
}

msg := fmt.Sprintf("Not equal: \nexpected: %v\nactual : %v", expected, actual)
fail(t, msg, errorMessage...)
fail(t, msg, msgAndArgs...)
}

// NoError asserts that a function returned no error.
func NoError(t *testing.T, err error, errorMessage ...string) {
func NoError(t Testing, err error, msgAndArgs ...any) {
t.Helper()
if err == nil {
return
}

msg := fmt.Sprintf("Unexpected error:\n%+v", err)
fail(t, msg, errorMessage...)
fail(t, msg, msgAndArgs...)
}

// Error asserts that a function returned an error.
func Error(t *testing.T, err error, expectedError string, errorMessage ...string) {
func Error(t Testing, err error, expectedError string, msgAndArgs ...any) {
t.Helper()
if err == nil {
msg := fmt.Sprintf("Error message not equal: \nexpected: %v\nactual : nil", expectedError)
fail(t, msg, errorMessage...)
fail(t, msg, msgAndArgs...)
return
}

Expand All @@ -45,15 +51,15 @@ func Error(t *testing.T, err error, expectedError string, errorMessage ...string
}

msg := fmt.Sprintf("Error message not equal: \nexpected: %v\nactual : %v", expectedError, actual)
fail(t, msg, errorMessage...)
fail(t, msg, msgAndArgs...)
}

// ErrorIs asserts that a function returned an error that matches the specified error.
func ErrorIs(t *testing.T, err, expectedError error, errorMessage ...string) {
func ErrorIs(t Testing, err, expectedError error, msgAndArgs ...any) {
t.Helper()
if err == nil {
msg := fmt.Sprintf("Error not returned: \nexpected: %v\nactual : nil", expectedError)
fail(t, msg, errorMessage...)
fail(t, msg, msgAndArgs...)
return
}

Expand All @@ -62,59 +68,59 @@ func ErrorIs(t *testing.T, err, expectedError error, errorMessage ...string) {
}

msg := fmt.Sprintf("Error not equal: \nexpected: %v\nactual : %v", expectedError, err)
fail(t, msg, errorMessage...)
fail(t, msg, msgAndArgs...)
}

// True asserts that the specified value is true.
func True(t *testing.T, value bool, errorMessage ...string) {
func True(t Testing, value bool, msgAndArgs ...any) {
t.Helper()
if value {
return
}
fail(t, "Unexpected false", errorMessage...)
fail(t, "Unexpected false", msgAndArgs...)
}

// False asserts that the specified value is false.
func False(t *testing.T, value bool, errorMessage ...string) {
func False(t Testing, value bool, msgAndArgs ...any) {
t.Helper()
if !value {
return
}
fail(t, "Unexpected true", errorMessage...)
fail(t, "Unexpected true", msgAndArgs...)
}

// Len asserts that the specified object has the expected length.
func Len(t *testing.T, object any, expectedLen int, errorMessage ...string) {
func Len(t Testing, object any, expectedLen int, msgAndArgs ...any) {
t.Helper()
actualLen := reflect.ValueOf(object).Len()
if actualLen == expectedLen {
return
}

msg := fmt.Sprintf("Length not equal: \nexpected: %d\nactual : %d", expectedLen, actualLen)
fail(t, msg, errorMessage...)
fail(t, msg, msgAndArgs...)
}

// NotNil asserts that the specified object is not nil.
func NotNil(t *testing.T, object any, errorMessage ...string) {
func NotNil(t Testing, object any, msgAndArgs ...any) {
t.Helper()
if !isNil(object) {
return
}

msg := "Expected value to be not nil"
fail(t, msg, errorMessage...)
fail(t, msg, msgAndArgs...)
}

// Nil asserts that the specified object is nil.
func Nil(t *testing.T, object any, errorMessage ...string) {
func Nil(t Testing, object any, msgAndArgs ...any) {
t.Helper()
if isNil(object) {
return
}

msg := "Expected value to be nil"
fail(t, msg, errorMessage...)
fail(t, msg, msgAndArgs...)
}

func equal(expected, actual any) bool {
Expand Down Expand Up @@ -151,10 +157,10 @@ func isNil(value any) bool {
}
}

func fail(t *testing.T, message string, errorMessage ...string) {
func fail(t Testing, message string, msgAndArgs ...any) {
t.Helper()
if len(errorMessage) != 0 {
message = fmt.Sprintf("%s\n%s", message, errorMessage)
if len(msgAndArgs) > 0 {
message += "\n" + fmt.Sprintf(msgAndArgs[0].(string), msgAndArgs[1:]...)
}
t.Error(message)
t.FailNow()
Expand Down
170 changes: 156 additions & 14 deletions assert/assert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,45 +7,187 @@ import (
)

func TestEqual(t *testing.T) {
Equal(t, 1, 1)
tst := &errorCapture{}
Equal(tst, 1, 1)
if tst.failed {
t.Error("Equal failed")
}

tst = &errorCapture{}
Equal(tst, 1, 2)
if !tst.failed {
t.Error("Equal failed")
}
}

func TestNoError(t *testing.T) {
NoError(t, nil)
tst := &errorCapture{}
NoError(tst, nil)
if tst.failed {
t.Error("NoError failed")
}

tst = &errorCapture{}
NoError(tst, errors.New("error"))
if !tst.failed {
t.Error("NoError failed")
}
}

func TestError(t *testing.T) {
err := errors.New("error text")
Error(t, err, err.Error())
tst := &errorCapture{}
Error(tst, errors.New("error"), "error")
if tst.failed {
t.Error("Error failed")
}

tst = &errorCapture{}
Error(tst, nil, "error")
if !tst.failed {
t.Error("Error failed")
}

tst = &errorCapture{}
Error(tst, errors.New("error"), "other")
if !tst.failed {
t.Error("Error failed")
}
}

func TestErrorIs(t *testing.T) {
errTest := errors.New("error")
err := fmt.Errorf("error: %w", errTest)
ErrorIs(t, err, errTest)
tst := &errorCapture{}
ErrorIs(tst, errors.New("error"), errors.New("error"))
if !tst.failed {
t.Error("ErrorIs failed")
}

tst = &errorCapture{}
ErrorIs(tst, errors.New("error"), errors.New("other"))
if !tst.failed {
t.Error("ErrorIs failed")
}

tst = &errorCapture{}
ErrorIs(tst, nil, errors.New("error"))
if !tst.failed {
t.Error("ErrorIs failed")
}

tst = &errorCapture{}
err := errors.New("error")
ErrorIs(tst, fmt.Errorf("wrapped: %w", err), err)
if tst.failed {
t.Error("ErrorIs failed")
}
}

func TestTrue(t *testing.T) {
True(t, true)
tst := &errorCapture{}
True(tst, true)
if tst.failed {
t.Error("True failed")
}

tst = &errorCapture{}
True(tst, false)
if !tst.failed {
t.Error("True failed")
}
}

func TestFalse(t *testing.T) {
False(t, false)
tst := &errorCapture{}
False(tst, false)
if tst.failed {
t.Error("False failed")
}

tst = &errorCapture{}
False(tst, true)
if !tst.failed {
t.Error("False failed")
}
}

func TestInterfaceNilEqual(t *testing.T) {
var values []int
Equal(t, nil, values)
tst := &errorCapture{}
Equal(tst, nil, nil)
if tst.failed {
t.Error("InterfaceNilEqual failed")
}

tst = &errorCapture{}
Equal(tst, nil, 1)
if !tst.failed {
t.Error("InterfaceNilEqual failed")
}
}

func TestLen(t *testing.T) {
Len(t, []int{1, 2, 3}, 3)
tst := &errorCapture{}
Len(tst, []int{1, 2}, 2)
if tst.failed {
t.Error("Len failed")
}

tst = &errorCapture{}
Len(tst, []int{}, 2)
if !tst.failed {
t.Error("Len failed")
}
}

func TestNotNil(t *testing.T) {
NotNil(t, "not nil")
tst := &errorCapture{}
NotNil(tst, 1)
if tst.failed {
t.Error("NotNil failed")
}

tst = &errorCapture{}
NotNil(tst, nil)
if !tst.failed {
t.Error("NotNil failed")
}
}

func TestNil(t *testing.T) {
Nil(t, nil)
tst := &errorCapture{}
Nil(tst, nil)
if tst.failed {
t.Error("Nil failed")
}

tst = &errorCapture{}
Nil(tst, 1)
if !tst.failed {
t.Error("Nil failed")
}
}

func TestFail(t *testing.T) {
tst := &errorCapture{}
fail(tst, "error", "msg %d", 1)
if !tst.failed {
t.Error("Fail failed")
}
if tst.errs[0].(string) != "error\nmsg 1" {
t.Error("Fail failed")
}
}

type errorCapture struct {
errs []any
failed bool
}

func (e *errorCapture) Helper() {
}

func (e *errorCapture) Error(args ...any) {
e.errs = append([]any{}, args...)
}

func (e *errorCapture) FailNow() {
e.failed = true
}

0 comments on commit f75e8c7

Please sign in to comment.