Skip to content

Commit 8ddef3a

Browse files
committed
Add ErrorIs and ErrorAs assertions
1 parent 87709f9 commit 8ddef3a

File tree

2 files changed

+289
-0
lines changed

2 files changed

+289
-0
lines changed

be/errors.go

+103
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
package be
22

33
import (
4+
"errors"
45
"fmt"
6+
"reflect"
57
"strings"
68

79
"github.com/rliebz/ghost"
@@ -121,3 +123,104 @@ want: %v`,
121123
),
122124
}
123125
}
126+
127+
// ErrorIs asserts that an error matches another using [errors.Is].
128+
func ErrorIs(err error, target error) ghost.Result {
129+
args := ghostlib.ArgsFromAST(err, target)
130+
argErr, argTarget := args[0], args[1]
131+
132+
if errors.Is(err, target) {
133+
return ghost.Result{
134+
Ok: true,
135+
Message: fmt.Sprintf(`error %v is target %v
136+
error: %v
137+
target: %v`,
138+
argErr,
139+
argTarget,
140+
err,
141+
target,
142+
),
143+
}
144+
}
145+
146+
return ghost.Result{
147+
Ok: false,
148+
Message: fmt.Sprintf(`error %v is not target %v
149+
error: %v
150+
target: %v`,
151+
argErr,
152+
argTarget,
153+
err,
154+
target,
155+
),
156+
}
157+
}
158+
159+
var errorType = reflect.TypeOf((*error)(nil)).Elem()
160+
161+
// ErrorAs asserts that an error matches another using [errors.As].
162+
func ErrorAs(err error, target any) ghost.Result {
163+
args := ghostlib.ArgsFromAST(err, target)
164+
argErr, argTarget := args[0], args[1]
165+
166+
if err == nil {
167+
return ghost.Result{
168+
Ok: false,
169+
Message: fmt.Sprintf("error %v was nil", argErr),
170+
}
171+
}
172+
173+
// These next few checks are for invalid usage, where errors.As will panic if
174+
// a caller hits any of them. As an assertion library, it's probably more
175+
// polite to never panic.
176+
177+
if target == nil {
178+
return ghost.Result{
179+
Ok: false,
180+
Message: fmt.Sprintf("target %v cannot be nil", argTarget),
181+
}
182+
}
183+
184+
val := reflect.ValueOf(target)
185+
typ := val.Type()
186+
if typ.Kind() != reflect.Ptr || val.IsNil() {
187+
return ghost.Result{
188+
Ok: false,
189+
Message: fmt.Sprintf("target %v must be a non-nil pointer", argTarget),
190+
}
191+
}
192+
targetType := typ.Elem()
193+
194+
if targetType.Kind() != reflect.Interface && !targetType.Implements(errorType) {
195+
return ghost.Result{
196+
Ok: false,
197+
Message: fmt.Sprintf("*target %v must be interface or implement error", argTarget),
198+
}
199+
}
200+
201+
if errors.As(err, target) {
202+
return ghost.Result{
203+
Ok: true,
204+
Message: fmt.Sprintf(`error %v set as target %v
205+
error: %v
206+
target: %v`,
207+
argErr,
208+
argTarget,
209+
err,
210+
targetType,
211+
),
212+
}
213+
}
214+
215+
return ghost.Result{
216+
Ok: false,
217+
Message: fmt.Sprintf(`error %v cannot be set as target %v
218+
error: %v
219+
target: %v`,
220+
argErr,
221+
argTarget,
222+
err,
223+
targetType,
224+
),
225+
}
226+
}

be/errors_test.go

+186
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ package be_test
22

33
import (
44
"errors"
5+
"fmt"
6+
"io/fs"
7+
"os"
58
"testing"
69

710
"github.com/rliebz/ghost"
@@ -177,3 +180,186 @@ got: <nil>
177180
want: boo`))
178181
})
179182
}
183+
184+
func TestErrorIs(t *testing.T) {
185+
t.Run("match", func(t *testing.T) {
186+
g := ghost.New(t)
187+
188+
target := errors.New("foobar")
189+
err := fmt.Errorf("wrapping: %w", target)
190+
191+
result := be.ErrorIs(err, target)
192+
g.Should(be.True(result.Ok))
193+
g.Should(be.Equal(
194+
result.Message,
195+
`error err is target target
196+
error: wrapping: foobar
197+
target: foobar`,
198+
))
199+
200+
result = be.ErrorIs(fmt.Errorf("wrapping: %w", target), target)
201+
g.Should(be.True(result.Ok))
202+
g.Should(be.Equal(
203+
result.Message,
204+
`error fmt.Errorf("wrapping: %w", target) is target target
205+
error: wrapping: foobar
206+
target: foobar`,
207+
))
208+
})
209+
210+
t.Run("no match", func(t *testing.T) {
211+
g := ghost.New(t)
212+
213+
target := errors.New("foobar")
214+
err := fmt.Errorf("wrapping: %v", target) //nolint:errorlint // test case
215+
216+
result := be.ErrorIs(err, target)
217+
g.Should(be.False(result.Ok))
218+
g.Should(be.Equal(
219+
result.Message,
220+
`error err is not target target
221+
error: wrapping: foobar
222+
target: foobar`,
223+
))
224+
225+
result = be.ErrorIs(fmt.Errorf("wrapping: %v", target), target) //nolint:errorlint // test case
226+
g.Should(be.False(result.Ok))
227+
g.Should(be.Equal(
228+
result.Message,
229+
`error fmt.Errorf("wrapping: %v", target) is not target target
230+
error: wrapping: foobar
231+
target: foobar`,
232+
))
233+
})
234+
235+
t.Run("nil", func(t *testing.T) {
236+
g := ghost.New(t)
237+
238+
var target error
239+
var err error
240+
241+
result := be.ErrorIs(err, target)
242+
g.Should(be.True(result.Ok))
243+
g.Should(be.Equal(result.Message, `error err is target target
244+
error: <nil>
245+
target: <nil>`))
246+
247+
result = be.ErrorIs(nil, nil)
248+
g.Should(be.True(result.Ok))
249+
g.Should(be.Equal(result.Message, `error nil is target nil
250+
error: <nil>
251+
target: <nil>`))
252+
})
253+
}
254+
255+
func TestErrorAs(t *testing.T) {
256+
t.Run("match", func(t *testing.T) {
257+
g := ghost.New(t)
258+
259+
var target *fs.PathError
260+
_, err := os.Open("some-non-existing-file")
261+
262+
result := be.ErrorAs(err, &target)
263+
g.Should(be.True(result.Ok))
264+
g.Should(be.Equal(
265+
result.Message,
266+
`error err set as target &target
267+
error: open some-non-existing-file: no such file or directory
268+
target: *fs.PathError`,
269+
))
270+
271+
result = be.ErrorAs(fmt.Errorf("wrapping: %w", err), &target)
272+
g.Should(be.True(result.Ok))
273+
g.Should(be.Equal(
274+
result.Message,
275+
`error fmt.Errorf("wrapping: %w", err) set as target &target
276+
error: wrapping: open some-non-existing-file: no such file or directory
277+
target: *fs.PathError`,
278+
))
279+
})
280+
281+
t.Run("no match", func(t *testing.T) {
282+
g := ghost.New(t)
283+
284+
var target *fs.PathError
285+
err := errors.New("oh no")
286+
287+
result := be.ErrorAs(err, &target)
288+
g.Should(be.False(result.Ok))
289+
g.Should(be.Equal(
290+
result.Message,
291+
`error err cannot be set as target &target
292+
error: oh no
293+
target: *fs.PathError`,
294+
))
295+
296+
result = be.ErrorAs(errors.New("oh no"), &target)
297+
g.Should(be.False(result.Ok))
298+
g.Should(be.Equal(
299+
result.Message,
300+
`error errors.New("oh no") cannot be set as target &target
301+
error: oh no
302+
target: *fs.PathError`,
303+
))
304+
})
305+
306+
t.Run("nil error", func(t *testing.T) {
307+
g := ghost.New(t)
308+
309+
var target error
310+
var err error
311+
312+
result := be.ErrorAs(err, target)
313+
g.Should(be.False(result.Ok))
314+
g.Should(be.Equal(result.Message, `error err was nil`))
315+
316+
result = be.ErrorAs(nil, nil)
317+
g.Should(be.False(result.Ok))
318+
g.Should(be.Equal(result.Message, `error nil was nil`))
319+
})
320+
321+
t.Run("nil target", func(t *testing.T) {
322+
g := ghost.New(t)
323+
324+
var target error
325+
err := errors.New("oh no")
326+
327+
result := be.ErrorAs(err, target)
328+
g.Should(be.False(result.Ok))
329+
g.Should(be.Equal(result.Message, `target target cannot be nil`))
330+
331+
result = be.ErrorAs(errors.New("oh no"), nil)
332+
g.Should(be.False(result.Ok))
333+
g.Should(be.Equal(result.Message, `target nil cannot be nil`))
334+
})
335+
336+
t.Run("non-pointer target", func(t *testing.T) {
337+
g := ghost.New(t)
338+
339+
target := "Hello"
340+
err := errors.New("oh no")
341+
342+
result := be.ErrorAs(err, target)
343+
g.Should(be.False(result.Ok))
344+
g.Should(be.Equal(result.Message, `target target must be a non-nil pointer`))
345+
346+
result = be.ErrorAs(errors.New("oh no"), "Hello")
347+
g.Should(be.False(result.Ok))
348+
g.Should(be.Equal(result.Message, `target "Hello" must be a non-nil pointer`))
349+
})
350+
351+
t.Run("non-error target element", func(t *testing.T) {
352+
g := ghost.New(t)
353+
354+
target := "Hello"
355+
err := errors.New("oh no")
356+
357+
result := be.ErrorAs(err, &target)
358+
g.Should(be.False(result.Ok))
359+
g.Should(be.Equal(result.Message, `*target &target must be interface or implement error`))
360+
361+
result = be.ErrorAs(errors.New("oh no"), new(string))
362+
g.Should(be.False(result.Ok))
363+
g.Should(be.Equal(result.Message, `*target new(string) must be interface or implement error`))
364+
})
365+
}

0 commit comments

Comments
 (0)