Skip to content

Commit 0adc681

Browse files
committed
Make be.ErrorAs type safe
This should be exactly as expressive in correct scenarios, but also prevent a number of invalid inputs. It could theoretically be a breaking change, but in practice I suspect it isn't.
1 parent fe621ff commit 0adc681

File tree

2 files changed

+10
-64
lines changed

2 files changed

+10
-64
lines changed

be/errors.go

+5-29
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package be
33
import (
44
"errors"
55
"fmt"
6-
"reflect"
76
"strings"
87

98
"github.com/rliebz/ghost"
@@ -156,10 +155,8 @@ target: %v`,
156155
}
157156
}
158157

159-
var errorType = reflect.TypeOf((*error)(nil)).Elem()
160-
161158
// ErrorAs asserts that an error matches another using [errors.As].
162-
func ErrorAs(err error, target any) ghost.Result {
159+
func ErrorAs[T error](err error, target *T) ghost.Result {
163160
args := ghostlib.ArgsFromAST(err, target)
164161
argErr, argTarget := args[0], args[1]
165162

@@ -170,44 +167,23 @@ func ErrorAs(err error, target any) ghost.Result {
170167
}
171168
}
172169

173-
// These next few checks are for invalid usage, where errors.As will panic if
174-
// a caller hits any of them. As an assertion library, it's probably more
175-
// polite to never panic.
176-
177170
if target == nil {
178171
return ghost.Result{
179172
Ok: false,
180173
Message: fmt.Sprintf("target %v cannot be nil", argTarget),
181174
}
182175
}
183176

184-
val := reflect.ValueOf(target)
185-
typ := val.Type()
186-
if typ.Kind() != reflect.Ptr || val.IsNil() {
187-
return ghost.Result{
188-
Ok: false,
189-
Message: fmt.Sprintf("target %v must be a non-nil pointer", argTarget),
190-
}
191-
}
192-
targetType := typ.Elem()
193-
194-
if targetType.Kind() != reflect.Interface && !targetType.Implements(errorType) {
195-
return ghost.Result{
196-
Ok: false,
197-
Message: fmt.Sprintf("*target %v must be interface or implement error", argTarget),
198-
}
199-
}
200-
201177
if errors.As(err, target) {
202178
return ghost.Result{
203179
Ok: true,
204180
Message: fmt.Sprintf(`error %v set as target %v
205181
error: %v
206-
target: %v`,
182+
target: %T`,
207183
argErr,
208184
argTarget,
209185
err,
210-
targetType,
186+
*target,
211187
),
212188
}
213189
}
@@ -216,11 +192,11 @@ target: %v`,
216192
Ok: false,
217193
Message: fmt.Sprintf(`error %v cannot be set as target %v
218194
error: %v
219-
target: %v`,
195+
target: %T`,
220196
argErr,
221197
argTarget,
222198
err,
223-
targetType,
199+
*target,
224200
),
225201
}
226202
}

be/errors_test.go

+5-35
Original file line numberDiff line numberDiff line change
@@ -308,57 +308,27 @@ target: *strconv.NumError`,
308308
var target error
309309
var err error
310310

311-
result := be.ErrorAs(err, target)
311+
result := be.ErrorAs(err, &target)
312312
g.Should(be.False(result.Ok))
313313
g.Should(be.Equal(result.Message, `error err was nil`))
314314

315-
result = be.ErrorAs(nil, nil)
315+
result = be.ErrorAs(nil, new(error))
316316
g.Should(be.False(result.Ok))
317317
g.Should(be.Equal(result.Message, `error nil was nil`))
318318
})
319319

320320
t.Run("nil target", func(t *testing.T) {
321321
g := ghost.New(t)
322322

323-
var target error
323+
var target *error
324324
err := errors.New("oh no")
325325

326326
result := be.ErrorAs(err, target)
327327
g.Should(be.False(result.Ok))
328328
g.Should(be.Equal(result.Message, `target target cannot be nil`))
329329

330-
result = be.ErrorAs(errors.New("oh no"), nil)
331-
g.Should(be.False(result.Ok))
332-
g.Should(be.Equal(result.Message, `target nil cannot be nil`))
333-
})
334-
335-
t.Run("non-pointer target", func(t *testing.T) {
336-
g := ghost.New(t)
337-
338-
target := "Hello"
339-
err := errors.New("oh no")
340-
341-
result := be.ErrorAs(err, target)
342-
g.Should(be.False(result.Ok))
343-
g.Should(be.Equal(result.Message, `target target must be a non-nil pointer`))
344-
345-
result = be.ErrorAs(errors.New("oh no"), "Hello")
346-
g.Should(be.False(result.Ok))
347-
g.Should(be.Equal(result.Message, `target "Hello" must be a non-nil pointer`))
348-
})
349-
350-
t.Run("non-error target element", func(t *testing.T) {
351-
g := ghost.New(t)
352-
353-
target := "Hello"
354-
err := errors.New("oh no")
355-
356-
result := be.ErrorAs(err, &target)
357-
g.Should(be.False(result.Ok))
358-
g.Should(be.Equal(result.Message, `*target &target must be interface or implement error`))
359-
360-
result = be.ErrorAs(errors.New("oh no"), new(string))
330+
result = be.ErrorAs[error](errors.New("oh no"), nil)
361331
g.Should(be.False(result.Ok))
362-
g.Should(be.Equal(result.Message, `*target new(string) must be interface or implement error`))
332+
g.Should(be.Equal(result.Message, `target <nil> cannot be nil`))
363333
})
364334
}

0 commit comments

Comments
 (0)