Skip to content

Commit 3f3d26d

Browse files
committed
Better detect nil value under interface
1 parent 253f801 commit 3f3d26d

File tree

3 files changed

+127
-1
lines changed

3 files changed

+127
-1
lines changed

internal/query/client.go

+3-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"github.com/ydb-platform/ydb-go-sdk/v3/internal/types"
2222
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext"
2323
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
24+
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xreflect"
2425
"github.com/ydb-platform/ydb-go-sdk/v3/query"
2526
"github.com/ydb-platform/ydb-go-sdk/v3/retry"
2627
"github.com/ydb-platform/ydb-go-sdk/v3/trace"
@@ -641,9 +642,10 @@ func New(ctx context.Context, cc grpc.ClientConnInterface, cfg *config.Config) *
641642

642643
// checkTxControlWithCommit checks that if WithTxControl is used, it must be with WithCommit
643644
func checkTxControlWithCommit(txControl options.TxControl) error {
644-
if txControl != nil && !txControl.Commit() {
645+
if !xreflect.IsPointToNil(txControl) && !txControl.Commit() {
645646
return xerrors.WithStackTrace(errNoCommit)
646647
}
648+
647649
return nil
648650
}
649651

internal/xreflect/is_nil.go

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package xreflect
2+
3+
import "reflect"
4+
5+
func IsPointToNil(v any) bool {
6+
if v == nil {
7+
return true
8+
}
9+
10+
rVal := reflect.ValueOf(v)
11+
12+
return isValPointToNil(rVal)
13+
}
14+
15+
func isValPointToNil(v reflect.Value) bool {
16+
kind := v.Kind()
17+
var res bool
18+
switch kind {
19+
case reflect.Slice:
20+
return false
21+
case reflect.Chan, reflect.Func, reflect.Map, reflect.UnsafePointer:
22+
res = v.IsNil()
23+
case reflect.Pointer, reflect.Interface:
24+
elem := v.Elem()
25+
if v.IsNil() {
26+
return true
27+
}
28+
res = isValPointToNil(elem)
29+
}
30+
return res
31+
}

internal/xreflect/is_nil_test.go

+93
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
package xreflect
2+
3+
import (
4+
"testing"
5+
)
6+
7+
func TestPointToNilValue(t *testing.T) {
8+
var nilIntPointer *int
9+
var vInterface any
10+
vInterface = nilIntPointer
11+
12+
// Test cases for different nil and non-nil scenarios
13+
tests := []struct {
14+
name string
15+
input any
16+
expected bool
17+
}{
18+
{
19+
name: "nil interface",
20+
input: nil,
21+
expected: true,
22+
},
23+
{
24+
name: "nil pointer to int",
25+
input: (*int)(nil),
26+
expected: true,
27+
},
28+
{
29+
name: "non-nil pointer to int",
30+
input: new(int),
31+
expected: false,
32+
},
33+
{
34+
name: "nil slice",
35+
input: []int(nil),
36+
expected: false,
37+
},
38+
{
39+
name: "empty slice",
40+
input: []int{},
41+
expected: false,
42+
},
43+
{
44+
name: "nil map",
45+
input: map[string]int(nil),
46+
expected: true,
47+
},
48+
{
49+
name: "empty map",
50+
input: map[string]int{},
51+
expected: false,
52+
},
53+
{
54+
name: "nil channel",
55+
input: (chan int)(nil),
56+
expected: true,
57+
},
58+
{
59+
name: "non-nil channel",
60+
input: make(chan int),
61+
expected: false,
62+
},
63+
{
64+
name: "nil function",
65+
input: (func())(nil),
66+
expected: true,
67+
},
68+
{
69+
name: "nested nil pointer",
70+
input: &nilIntPointer,
71+
expected: true,
72+
},
73+
{
74+
name: "interface with stored nil pointer",
75+
input: vInterface,
76+
expected: true,
77+
},
78+
{
79+
name: "non-nil interface value",
80+
input: interface{}("test"),
81+
expected: false,
82+
},
83+
}
84+
85+
// Execute all test cases
86+
for _, tt := range tests {
87+
t.Run(tt.name, func(t *testing.T) {
88+
if got := IsPointToNil(tt.input); got != tt.expected {
89+
t.Errorf("IsPointToNil() = %v, want %v", got, tt.expected)
90+
}
91+
})
92+
}
93+
}

0 commit comments

Comments
 (0)