Skip to content

Commit dcb7193

Browse files
authored
Merge pull request #2236 from felix-roehrich/fr/fix-plan-scan
Alternative implementation for JSONCodec.PlanScan
2 parents 1abf7d9 + a5353af commit dcb7193

File tree

3 files changed

+77
-74
lines changed

3 files changed

+77
-74
lines changed

pgtype/json.go

+45-50
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,27 @@ func (c *JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Enco
7171
}
7272
}
7373

74+
// JSON needs its on scan plan for pointers to handle 'null'::json(b).
75+
// Consider making pointerPointerScanPlan more flexible in the future.
76+
type jsonPointerScanPlan struct {
77+
next ScanPlan
78+
}
79+
80+
func (p jsonPointerScanPlan) Scan(src []byte, dst any) error {
81+
el := reflect.ValueOf(dst).Elem()
82+
if src == nil || string(src) == "null" {
83+
el.SetZero()
84+
return nil
85+
}
86+
87+
el.Set(reflect.New(el.Type().Elem()))
88+
if p.next != nil {
89+
return p.next.Scan(src, el.Interface())
90+
}
91+
92+
return nil
93+
}
94+
7495
type encodePlanJSONCodecEitherFormatString struct{}
7596

7697
func (encodePlanJSONCodecEitherFormatString) Encode(value any, buf []byte) (newBuf []byte, err error) {
@@ -117,64 +138,38 @@ func (e *encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (
117138
return buf, nil
118139
}
119140

120-
func (c *JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
141+
func (c *JSONCodec) PlanScan(m *Map, oid uint32, formatCode int16, target any) ScanPlan {
142+
return c.planScan(m, oid, formatCode, target, 0)
143+
}
144+
145+
// JSON cannot fallback to pointerPointerScanPlan because of 'null'::json(b),
146+
// so we need to duplicate the logic here.
147+
func (c *JSONCodec) planScan(m *Map, oid uint32, formatCode int16, target any, depth int) ScanPlan {
148+
if depth > 8 {
149+
return &scanPlanFail{m: m, oid: oid, formatCode: formatCode}
150+
}
151+
121152
switch target.(type) {
122153
case *string:
123-
return scanPlanAnyToString{}
124-
125-
case **string:
126-
// This is to fix **string scanning. It seems wrong to special case **string, but it's not clear what a better
127-
// solution would be.
128-
//
129-
// https://github.com/jackc/pgx/issues/1470 -- **string
130-
// https://github.com/jackc/pgx/issues/1691 -- ** anything else
131-
132-
if wrapperPlan, nextDst, ok := TryPointerPointerScanPlan(target); ok {
133-
if nextPlan := m.planScan(oid, format, nextDst, 0); nextPlan != nil {
134-
if _, failed := nextPlan.(*scanPlanFail); !failed {
135-
wrapperPlan.SetNext(nextPlan)
136-
return wrapperPlan
137-
}
138-
}
139-
}
140-
154+
return &scanPlanAnyToString{}
141155
case *[]byte:
142-
return scanPlanJSONToByteSlice{}
156+
return &scanPlanJSONToByteSlice{}
143157
case BytesScanner:
144-
return scanPlanBinaryBytesToBytesScanner{}
145-
146-
}
147-
148-
// Cannot rely on sql.Scanner being handled later because scanPlanJSONToJSONUnmarshal will take precedence.
149-
//
150-
// https://github.com/jackc/pgx/issues/1418
151-
if isSQLScanner(target) {
152-
return &scanPlanSQLScanner{formatCode: format}
158+
return &scanPlanBinaryBytesToBytesScanner{}
159+
case sql.Scanner:
160+
return &scanPlanSQLScanner{formatCode: formatCode}
153161
}
154162

155-
return &scanPlanJSONToJSONUnmarshal{
156-
unmarshal: c.Unmarshal,
163+
rv := reflect.ValueOf(target)
164+
if rv.Kind() == reflect.Pointer && rv.Elem().Kind() == reflect.Pointer {
165+
var plan jsonPointerScanPlan
166+
plan.next = c.planScan(m, oid, formatCode, rv.Elem().Interface(), depth+1)
167+
return plan
168+
} else {
169+
return &scanPlanJSONToJSONUnmarshal{unmarshal: c.Unmarshal}
157170
}
158171
}
159172

160-
// we need to check if the target is a pointer to a sql.Scanner (or any of the pointer ref tree implements a sql.Scanner).
161-
//
162-
// https://github.com/jackc/pgx/issues/2146
163-
func isSQLScanner(v any) bool {
164-
if _, is := v.(sql.Scanner); is {
165-
return true
166-
}
167-
168-
val := reflect.ValueOf(v)
169-
for val.Kind() == reflect.Ptr {
170-
if _, ok := val.Interface().(sql.Scanner); ok {
171-
return true
172-
}
173-
val = val.Elem()
174-
}
175-
return false
176-
}
177-
178173
type scanPlanAnyToString struct{}
179174

180175
func (scanPlanAnyToString) Scan(src []byte, dst any) error {
@@ -202,7 +197,7 @@ type scanPlanJSONToJSONUnmarshal struct {
202197
}
203198

204199
func (s *scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error {
205-
if src == nil {
200+
if src == nil || string(src) == "null" {
206201
dstValue := reflect.ValueOf(dst)
207202
if dstValue.Kind() == reflect.Ptr {
208203
el := dstValue.Elem()

pgtype/json_test.go

+31
Original file line numberDiff line numberDiff line change
@@ -326,3 +326,34 @@ func TestJSONCodecScanToNonPointerValues(t *testing.T) {
326326
require.Equal(t, 42, m)
327327
})
328328
}
329+
330+
func TestJSONCodecScanNull(t *testing.T) {
331+
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
332+
var dest struct{}
333+
err := conn.QueryRow(ctx, "select null::jsonb").Scan(&dest)
334+
require.Error(t, err)
335+
require.Contains(t, err.Error(), "cannot scan NULL into *struct {}")
336+
337+
err = conn.QueryRow(ctx, "select 'null'::jsonb").Scan(&dest)
338+
require.Error(t, err)
339+
require.Contains(t, err.Error(), "cannot scan NULL into *struct {}")
340+
341+
var destPointer *struct{}
342+
err = conn.QueryRow(ctx, "select null::jsonb").Scan(&destPointer)
343+
require.NoError(t, err)
344+
require.Nil(t, destPointer)
345+
346+
err = conn.QueryRow(ctx, "select 'null'::jsonb").Scan(&destPointer)
347+
require.NoError(t, err)
348+
require.Nil(t, destPointer)
349+
})
350+
}
351+
352+
func TestJSONCodecScanNullToPointerToSQLScanner(t *testing.T) {
353+
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
354+
var dest *Issue2146
355+
err := conn.QueryRow(ctx, "select null::jsonb").Scan(&dest)
356+
require.NoError(t, err)
357+
require.Nil(t, dest)
358+
})
359+
}

pgtype/pgtype.go

+1-24
Original file line numberDiff line numberDiff line change
@@ -396,11 +396,7 @@ type scanPlanSQLScanner struct {
396396
}
397397

398398
func (plan *scanPlanSQLScanner) Scan(src []byte, dst any) error {
399-
scanner := getSQLScanner(dst)
400-
401-
if scanner == nil {
402-
return fmt.Errorf("cannot scan into %T", dst)
403-
}
399+
scanner := dst.(sql.Scanner)
404400

405401
if src == nil {
406402
// This is necessary because interface value []byte:nil does not equal nil:nil for the binary format path and the
@@ -413,25 +409,6 @@ func (plan *scanPlanSQLScanner) Scan(src []byte, dst any) error {
413409
}
414410
}
415411

416-
// we don't know if the target is a sql.Scanner or a pointer on a sql.Scanner, so we need to check recursively
417-
func getSQLScanner(target any) sql.Scanner {
418-
if sc, is := target.(sql.Scanner); is {
419-
return sc
420-
}
421-
422-
val := reflect.ValueOf(target)
423-
for val.Kind() == reflect.Ptr {
424-
if _, ok := val.Interface().(sql.Scanner); ok {
425-
if val.IsNil() {
426-
val.Set(reflect.New(val.Type().Elem()))
427-
}
428-
return val.Interface().(sql.Scanner)
429-
}
430-
val = val.Elem()
431-
}
432-
return nil
433-
}
434-
435412
type scanPlanString struct{}
436413

437414
func (scanPlanString) Scan(src []byte, dst any) error {

0 commit comments

Comments
 (0)