Skip to content

Commit d50547d

Browse files
committed
Make sure that it is equivalent to v.Interface().(T)
1 parent cb41764 commit d50547d

File tree

2 files changed

+39
-5
lines changed

2 files changed

+39
-5
lines changed

src/reflect/all_test.go

+30-2
Original file line numberDiff line numberDiff line change
@@ -8701,7 +8701,7 @@ func newPtr[T any](t T) *T {
87018701
return &t
87028702
}
87038703

8704-
func TestTypeAssert(t *testing.T) {
8704+
func TestTypeAssertConcreteTypes(t *testing.T) {
87058705
testTypeAssert(t, int(1111))
87068706
testTypeAssert(t, int(111111111))
87078707
testTypeAssert(t, int(-111111111))
@@ -8714,13 +8714,41 @@ func TestTypeAssert(t *testing.T) {
87148714
testTypeAssert(t, newPtr(111111111))
87158715
testTypeAssert(t, newPtr(-111111111))
87168716
testTypeAssert(t, newPtr([2]int{-111111111, -22222222}))
8717-
8717+
testTypeAssert(t, [2]*int{newPtr(-111111111), newPtr(-22222222)})
87188718
testTypeAssert(t, newPtr(time.Now()))
87198719

87208720
testTypeAssertDifferentType[uint](t, int(111111111))
87218721
testTypeAssertDifferentType[uint](t, int(-111111111))
87228722
}
87238723

8724+
func TestTypeAssertInterfaceTypes(t *testing.T) {
8725+
v, ok := TypeAssert[any](ValueOf(1))
8726+
if v != any(1) || !ok {
8727+
t.Errorf("TypeAssert[any](1) = (%v, %v); want = (1, true)", v, ok)
8728+
}
8729+
8730+
v, ok = TypeAssert[fmt.Stringer](ValueOf(1))
8731+
if v != nil || ok {
8732+
t.Errorf("TypeAssert[fmt.Stringer](1) = (%v, %v); want = (1, false)", v, ok)
8733+
}
8734+
8735+
v, ok = TypeAssert[any](ValueOf(testTypeWithMethod{"test"}))
8736+
if v != any(testTypeWithMethod{"test"}) || !ok {
8737+
t.Errorf(`TypeAssert[any](testTypeWithMethod{"test"}) = (%v, %v); want = (testTypeWithMethod{"test"}, true)`, v, ok)
8738+
}
8739+
8740+
v, ok = TypeAssert[fmt.Stringer](ValueOf(testTypeWithMethod{"test"}))
8741+
if v != fmt.Stringer(testTypeWithMethod{"test"}) || !ok {
8742+
t.Errorf(`TypeAssert[fmt.Stringer](testTypeWithMethod{"test"}) = (%v, %v); want = (testTypeWithMethod{"test"}, true)`, v, ok)
8743+
}
8744+
8745+
val := &testTypeWithMethod{"test"}
8746+
v, ok = TypeAssert[fmt.Stringer](ValueOf(val))
8747+
if v != fmt.Stringer(val) || !ok {
8748+
t.Errorf(`TypeAssert[fmt.Stringer](&testTypeWithMethod{"test"}) = (%v, %v); want = (&testTypeWithMethod{"test"}, true)`, v, ok)
8749+
}
8750+
}
8751+
87248752
type testTypeWithMethod struct {
87258753
val string
87268754
}

src/reflect/value.go

+9-3
Original file line numberDiff line numberDiff line change
@@ -1520,7 +1520,6 @@ func TypeAssert[T any](v Value) (T, bool) {
15201520
if v.flag == 0 {
15211521
panic(&ValueError{"reflect.TypeAssert", Invalid})
15221522
}
1523-
15241523
if v.flag&flagRO != 0 {
15251524
// Do not allow access to unexported values via Interface,
15261525
// because they might be pointers that should not be
@@ -1532,12 +1531,19 @@ func TypeAssert[T any](v Value) (T, bool) {
15321531
v = makeMethodValue("TypeAssert", v)
15331532
}
15341533

1535-
if abi.TypeFor[T]() != v.typ_ {
1534+
if abi.TypeFor[T]() != v.typ() {
1535+
// TypeAssert[T] should work the same way as v.Interface().(T), thus we need
1536+
// to handle following case properly: TypeAssert[any](ValueOf(1)).
1537+
// Note that we will not hit here is such case: TypeAssert[any](ValueOf(any(1))).
1538+
if abi.TypeFor[T]().Kind() == abi.Interface {
1539+
v, ok := packEface(v).(T)
1540+
return v, ok
1541+
}
15361542
var zero T
15371543
return zero, false
15381544
}
15391545

1540-
if v.typ_.IsDirectIface() {
1546+
if v.typ().IsDirectIface() {
15411547
return *(*T)(unsafe.Pointer(&v.ptr)), true
15421548
}
15431549

0 commit comments

Comments
 (0)