Skip to content

Commit 33f7b64

Browse files
committed
fix: better paths for resolvers, prevent subtle mistakes
1 parent ba0e5c6 commit 33f7b64

File tree

3 files changed

+136
-15
lines changed

3 files changed

+136
-15
lines changed

chain.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package huma
22

3+
// Middlewares is a list of middleware functions that can be attached to an
4+
// API and will be called for all incoming requests.
35
type Middlewares []func(ctx Context, next func(Context))
46

57
// Handler builds and returns a handler func from the chain of middlewares,

huma.go

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,15 @@ func findParams(registry Registry, op *Operation, t reflect.Type) *findResult[*p
8989
return nil
9090
}
9191

92+
if f.Type.Kind() == reflect.Pointer {
93+
// TODO: support pointers? The problem is that when we dynamically
94+
// create an instance of the input struct the `params.Every(...)`
95+
// call cannot set them as the value is `reflect.Invalid` unless
96+
// dynamically allocated, but we don't know when to allocate until
97+
// after the `Every` callback has run. Doable, but a bigger change.
98+
panic("pointers are not supported for path/query/header parameters")
99+
}
100+
92101
pfi := &paramFieldInfo{
93102
Type: f.Type,
94103
Schema: SchemaFromField(registry, f, ""),
@@ -171,6 +180,9 @@ func findResolvers(resolverType, t reflect.Type) *findResult[bool] {
171180
func findDefaults(t reflect.Type) *findResult[any] {
172181
return findInType(t, nil, func(sf reflect.StructField, i []int) any {
173182
if d := sf.Tag.Get("default"); d != "" {
183+
if sf.Type.Kind() == reflect.Pointer {
184+
panic("pointers cannot have default values")
185+
}
174186
return jsonTagValue(sf, sf.Type, d)
175187
}
176188
return nil
@@ -210,6 +222,12 @@ type findResult[T comparable] struct {
210222
}
211223

212224
func (r *findResult[T]) every(current reflect.Value, path []int, v T, f func(reflect.Value, T)) {
225+
if current.Kind() == reflect.Invalid {
226+
// Indirect from below may have resulted in no value, for example
227+
// an optional field may have been omitted; just ignore it.
228+
return
229+
}
230+
213231
if len(path) == 0 {
214232
f(current, v)
215233
return
@@ -246,19 +264,45 @@ func jsonName(field reflect.StructField) string {
246264
}
247265

248266
func (r *findResult[T]) everyPB(current reflect.Value, path []int, pb *PathBuffer, v T, f func(reflect.Value, T)) {
267+
if current.Kind() == reflect.Invalid {
268+
// Indirect from below may have resulted in no value, for example
269+
// an optional field may have been omitted; just ignore it.
270+
return
271+
}
249272
switch current.Kind() {
250273
case reflect.Struct:
251274
if len(path) == 0 {
252275
f(current, v)
253276
return
254277
}
255278
field := current.Type().Field(path[0])
279+
pops := 0
256280
if !field.Anonymous {
281+
// The path name can come from one of four places: path parameter,
282+
// query parameter, header parameter, or body field.
257283
// TODO: pre-compute type/field names? Could save a few allocations.
258-
pb.Push(jsonName(field))
284+
pops++
285+
if path := field.Tag.Get("path"); path != "" && pb.Len() == 0 {
286+
pb.Push("path")
287+
pb.Push(path)
288+
pops++
289+
} else if query := field.Tag.Get("query"); query != "" && pb.Len() == 0 {
290+
pb.Push("query")
291+
pb.Push(query)
292+
pops++
293+
} else if header := field.Tag.Get("header"); header != "" && pb.Len() == 0 {
294+
pb.Push("header")
295+
pb.Push(header)
296+
pops++
297+
} else {
298+
// The body is _always_ in a field called "Body", which turns into
299+
// `body` in the path buffer, so we don't need to push it separately
300+
// like the the params fields above.
301+
pb.Push(jsonName(field))
302+
}
259303
}
260304
r.everyPB(reflect.Indirect(current.Field(path[0])), path[1:], pb, v, f)
261-
if !field.Anonymous {
305+
for i := 0; i < pops; i++ {
262306
pb.Pop()
263307
}
264308
case reflect.Slice:

huma_test.go

Lines changed: 88 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -612,18 +612,39 @@ func TestOpenAPI(t *testing.T) {
612612
}
613613
}
614614

615+
type IntNot3 int
616+
617+
func (i IntNot3) Resolve(ctx huma.Context, prefix *huma.PathBuffer) []error {
618+
if i != 0 && i%3 == 0 {
619+
return []error{&huma.ErrorDetail{
620+
Location: prefix.String(),
621+
Message: "Value cannot be a multiple of three",
622+
Value: i,
623+
}}
624+
}
625+
return nil
626+
}
627+
628+
var _ huma.ResolverWithPath = (*IntNot3)(nil)
629+
615630
type ExhaustiveErrorsInputBody struct {
616-
Name string `json:"name" maxLength:"10"`
617-
Count int `json:"count" minimum:"1"`
631+
Name string `json:"name" maxLength:"10"`
632+
Count IntNot3 `json:"count" minimum:"1"`
633+
634+
// Having a pointer which is never loaded should not cause
635+
// the tests to fail when running resolvers.
636+
Ptr *IntNot3 `json:"ptr,omitempty" minimum:"1"`
618637
}
619638

620639
func (b *ExhaustiveErrorsInputBody) Resolve(ctx huma.Context) []error {
621640
return []error{fmt.Errorf("body resolver error")}
622641
}
623642

624643
type ExhaustiveErrorsInput struct {
625-
ID string `path:"id" maxLength:"5"`
626-
Body ExhaustiveErrorsInputBody `json:"body"`
644+
ID IntNot3 `path:"id" maximum:"10"`
645+
Query IntNot3 `query:"query"`
646+
Header IntNot3 `header:"header"`
647+
Body ExhaustiveErrorsInputBody `json:"body"`
627648
}
628649

629650
func (i *ExhaustiveErrorsInput) Resolve(ctx huma.Context) []error {
@@ -634,21 +655,21 @@ func (i *ExhaustiveErrorsInput) Resolve(ctx huma.Context) []error {
634655
}}
635656
}
636657

637-
type ExhaustiveErrorsOutput struct {
638-
}
658+
var _ huma.Resolver = (*ExhaustiveErrorsInput)(nil)
639659

640660
func TestExhaustiveErrors(t *testing.T) {
641661
r, app := humatest.New(t, huma.DefaultConfig("Test API", "1.0.0"))
642662
huma.Register(app, huma.Operation{
643663
OperationID: "test",
644664
Method: http.MethodPut,
645665
Path: "/errors/{id}",
646-
}, func(ctx context.Context, input *ExhaustiveErrorsInput) (*ExhaustiveErrorsOutput, error) {
647-
return &ExhaustiveErrorsOutput{}, nil
666+
}, func(ctx context.Context, input *ExhaustiveErrorsInput) (*struct{}, error) {
667+
return nil, nil
648668
})
649669

650-
req, _ := http.NewRequest(http.MethodPut, "/errors/123456", strings.NewReader(`{"name": "12345678901", "count": 0}`))
670+
req, _ := http.NewRequest(http.MethodPut, "/errors/15?query=3", strings.NewReader(`{"name": "12345678901", "count": -6}`))
651671
req.Header.Set("Content-Type", "application/json")
672+
req.Header.Set("Header", "3")
652673
w := httptest.NewRecorder()
653674
r.ServeHTTP(w, req)
654675
assert.Equal(t, http.StatusUnprocessableEntity, w.Code)
@@ -659,23 +680,39 @@ func TestExhaustiveErrors(t *testing.T) {
659680
"detail": "validation failed",
660681
"errors": [
661682
{
662-
"message": "expected length <= 5",
683+
"message": "expected number <= 10",
663684
"location": "path.id",
664-
"value": "123456"
685+
"value": 15
665686
}, {
666687
"message": "expected length <= 10",
667688
"location": "body.name",
668689
"value": "12345678901"
669690
}, {
670691
"message": "expected number >= 1",
671692
"location": "body.count",
672-
"value": 0
693+
"value": -6
673694
}, {
674695
"message": "input resolver error",
675696
"location": "path.id",
676-
"value": "123456"
697+
"value": 15
698+
}, {
699+
"message": "Value cannot be a multiple of three",
700+
"location": "path.id",
701+
"value": 15
702+
}, {
703+
"message": "Value cannot be a multiple of three",
704+
"location": "query.query",
705+
"value": 3
706+
}, {
707+
"message": "Value cannot be a multiple of three",
708+
"location": "header.header",
709+
"value": 3
677710
}, {
678711
"message": "body resolver error"
712+
}, {
713+
"message": "Value cannot be a multiple of three",
714+
"location": "body.count",
715+
"value": -6
679716
}
680717
]
681718
}`, w.Body.String())
@@ -745,6 +782,44 @@ func TestResolverCustomStatus(t *testing.T) {
745782
assert.Contains(t, w.Body.String(), "nope")
746783
}
747784

785+
func TestParamPointerPanics(t *testing.T) {
786+
// For now we don't support these, so we panic rather than have subtle
787+
// bugs that are hard to track down.
788+
_, app := humatest.New(t, huma.DefaultConfig("Test API", "1.0.0"))
789+
790+
assert.Panics(t, func() {
791+
huma.Register(app, huma.Operation{
792+
OperationID: "bug",
793+
Method: http.MethodGet,
794+
Path: "/bug",
795+
}, func(ctx context.Context, input *struct {
796+
Param *string `query:"param"`
797+
}) (*struct{}, error) {
798+
return nil, nil
799+
})
800+
})
801+
}
802+
803+
func TestPointerDefaultPanics(t *testing.T) {
804+
// For now we don't support these, so we panic rather than have subtle
805+
// bugs that are hard to track down.
806+
_, app := humatest.New(t, huma.DefaultConfig("Test API", "1.0.0"))
807+
808+
assert.Panics(t, func() {
809+
huma.Register(app, huma.Operation{
810+
OperationID: "bug",
811+
Method: http.MethodGet,
812+
Path: "/bug",
813+
}, func(ctx context.Context, input *struct {
814+
Body struct {
815+
Value *string `json:"value,omitempty" default:"foo"`
816+
}
817+
}) (*struct{}, error) {
818+
return nil, nil
819+
})
820+
})
821+
}
822+
748823
func BenchmarkSecondDecode(b *testing.B) {
749824
type MediumSized struct {
750825
ID int `json:"id"`

0 commit comments

Comments
 (0)