Skip to content

Commit f57b285

Browse files
authored
Merge pull request #2151 from ludusrusso/fix-2146
handling double pointer on sql.Scanner interface when scanning rows
2 parents 2ec9004 + 5c9b565 commit f57b285

File tree

4 files changed

+73
-2
lines changed

4 files changed

+73
-2
lines changed

pgtype/json.go

+17-1
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,12 @@ func (c *JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanP
143143
case BytesScanner:
144144
return scanPlanBinaryBytesToBytesScanner{}
145145

146+
}
147+
146148
// Cannot rely on sql.Scanner being handled later because scanPlanJSONToJSONUnmarshal will take precedence.
147149
//
148150
// https://github.com/jackc/pgx/issues/1418
149-
case sql.Scanner:
151+
if isSQLScanner(target) {
150152
return &scanPlanSQLScanner{formatCode: format}
151153
}
152154

@@ -155,6 +157,20 @@ func (c *JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanP
155157
}
156158
}
157159

160+
// we need to check if the target is a pointer to a sql.Scanner (or any of the pointer ref tree implements a sql.Scanner).
161+
//
162+
// https://github.com/jackc/pgx/issues/2146
163+
func isSQLScanner(v any) bool {
164+
val := reflect.ValueOf(v)
165+
for val.Kind() == reflect.Ptr {
166+
if _, ok := val.Interface().(sql.Scanner); ok {
167+
return true
168+
}
169+
val = val.Elem()
170+
}
171+
return false
172+
}
173+
158174
type scanPlanAnyToString struct{}
159175

160176
func (scanPlanAnyToString) Scan(src []byte, dst any) error {

pgtype/json_test.go

+27
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ func TestJSONCodec(t *testing.T) {
6363

6464
// Test driver.Valuer is used before json.Marshaler (https://github.com/jackc/pgx/issues/1805)
6565
{Issue1805(7), new(Issue1805), isExpectedEq(Issue1805(7))},
66+
// Test driver.Scanner is used before json.Unmarshaler (https://github.com/jackc/pgx/issues/2146)
67+
{Issue2146(7), new(*Issue2146), isPtrExpectedEq(Issue2146(7))},
6668
})
6769

6870
pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{
@@ -109,6 +111,31 @@ func (i Issue1805) MarshalJSON() ([]byte, error) {
109111
return nil, errors.New("MarshalJSON called")
110112
}
111113

114+
type Issue2146 int
115+
116+
func (i *Issue2146) Scan(src any) error {
117+
var source []byte
118+
switch src.(type) {
119+
case string:
120+
source = []byte(src.(string))
121+
case []byte:
122+
source = src.([]byte)
123+
default:
124+
return errors.New("unknown source type")
125+
}
126+
var newI int
127+
if err := json.Unmarshal(source, &newI); err != nil {
128+
return err
129+
}
130+
*i = Issue2146(newI + 1)
131+
return nil
132+
}
133+
134+
func (i Issue2146) Value() (driver.Value, error) {
135+
b, err := json.Marshal(int(i - 1))
136+
return string(b), err
137+
}
138+
112139
// https://github.com/jackc/pgx/issues/1273#issuecomment-1221414648
113140
func TestJSONCodecUnmarshalSQLNull(t *testing.T) {
114141
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {

pgtype/pgtype.go

+21-1
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,12 @@ type scanPlanSQLScanner struct {
396396
}
397397

398398
func (plan *scanPlanSQLScanner) Scan(src []byte, dst any) error {
399-
scanner := dst.(sql.Scanner)
399+
scanner := getSQLScanner(dst)
400+
401+
if scanner == nil {
402+
return fmt.Errorf("cannot scan into %T", dst)
403+
}
404+
400405
if src == nil {
401406
// This is necessary because interface value []byte:nil does not equal nil:nil for the binary format path and the
402407
// text format path would be converted to empty string.
@@ -408,6 +413,21 @@ func (plan *scanPlanSQLScanner) Scan(src []byte, dst any) error {
408413
}
409414
}
410415

416+
// we don't know if the target is a sql.Scanner or a pointer on a sql.Scanner, so we need to check recursively
417+
func getSQLScanner(target any) sql.Scanner {
418+
val := reflect.ValueOf(target)
419+
for val.Kind() == reflect.Ptr {
420+
if _, ok := val.Interface().(sql.Scanner); ok {
421+
if val.IsNil() {
422+
val.Set(reflect.New(val.Type().Elem()))
423+
}
424+
return val.Interface().(sql.Scanner)
425+
}
426+
val = val.Elem()
427+
}
428+
return nil
429+
}
430+
411431
type scanPlanString struct{}
412432

413433
func (scanPlanString) Scan(src []byte, dst any) error {

pgtype/pgtype_test.go

+8
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"fmt"
1010
"net"
1111
"os"
12+
"reflect"
1213
"regexp"
1314
"strconv"
1415
"testing"
@@ -631,3 +632,10 @@ func isExpectedEq(a any) func(any) bool {
631632
return a == v
632633
}
633634
}
635+
636+
func isPtrExpectedEq(a any) func(any) bool {
637+
return func(v any) bool {
638+
val := reflect.ValueOf(v)
639+
return a == val.Elem().Interface()
640+
}
641+
}

0 commit comments

Comments
 (0)