Skip to content

Commit cc876a6

Browse files
murfffikenshaw
authored andcommitted
Dereference fix in common copy implementation
1 parent 17f9e30 commit cc876a6

File tree

5 files changed

+93
-124
lines changed

5 files changed

+93
-124
lines changed

drivers/clickhouse/clickhouse.go

+1-76
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,7 @@
55
package clickhouse
66

77
import (
8-
"context"
98
"database/sql"
10-
"fmt"
11-
"reflect"
129
"strconv"
1310
"strings"
1411

@@ -38,79 +35,7 @@ func init() {
3835
}
3936
return false
4037
},
41-
Copy: CopyWithInsert,
38+
Copy: drivers.CopyWithInsert(func(int) string { return "?" }),
4239
NewMetadataReader: NewMetadataReader,
4340
})
4441
}
45-
46-
// CopyWithInsert builds a copy handler based on insert.
47-
func CopyWithInsert(ctx context.Context, db *sql.DB, rows *sql.Rows, table string) (int64, error) {
48-
columns, err := rows.Columns()
49-
if err != nil {
50-
return 0, fmt.Errorf("failed to fetch source rows columns: %w", err)
51-
}
52-
clen := len(columns)
53-
query := table
54-
if !strings.HasPrefix(strings.ToLower(query), "insert into") {
55-
leftParen := strings.IndexRune(table, '(')
56-
if leftParen == -1 {
57-
colRows, err := db.QueryContext(ctx, "SELECT * FROM "+table+" WHERE 1=0")
58-
if err != nil {
59-
return 0, fmt.Errorf("failed to execute query to determine target table columns: %w", err)
60-
}
61-
columns, err := colRows.Columns()
62-
_ = colRows.Close()
63-
if err != nil {
64-
return 0, fmt.Errorf("failed to fetch target table columns: %w", err)
65-
}
66-
table += "(" + strings.Join(columns, ", ") + ")"
67-
}
68-
query = "INSERT INTO " + table + " VALUES (" + strings.Repeat("?, ", clen-1) + "?)"
69-
}
70-
tx, err := db.BeginTx(ctx, nil)
71-
if err != nil {
72-
return 0, fmt.Errorf("failed to begin transaction: %w", err)
73-
}
74-
stmt, err := tx.PrepareContext(ctx, query)
75-
if err != nil {
76-
return 0, fmt.Errorf("failed to prepare insert query: %w", err)
77-
}
78-
defer stmt.Close()
79-
columnTypes, err := rows.ColumnTypes()
80-
if err != nil {
81-
return 0, fmt.Errorf("failed to fetch source column types: %w", err)
82-
}
83-
values := make([]interface{}, clen)
84-
valueRefs := make([]reflect.Value, clen)
85-
actuals := make([]interface{}, clen)
86-
for i := 0; i < len(columnTypes); i++ {
87-
valueRefs[i] = reflect.New(columnTypes[i].ScanType())
88-
values[i] = valueRefs[i].Interface()
89-
}
90-
var n int64
91-
for rows.Next() {
92-
err = rows.Scan(values...)
93-
if err != nil {
94-
return n, fmt.Errorf("failed to scan row: %w", err)
95-
}
96-
//We can't use values... in Exec() below, because, in some cases, clickhouse
97-
//driver doesn't accept pointer to an argument instead of the arg itself.
98-
for i := range values {
99-
actuals[i] = valueRefs[i].Elem().Interface()
100-
}
101-
res, err := stmt.ExecContext(ctx, actuals...)
102-
if err != nil {
103-
return n, fmt.Errorf("failed to exec insert: %w", err)
104-
}
105-
rn, err := res.RowsAffected()
106-
if err != nil {
107-
return n, fmt.Errorf("failed to check rows affected: %w", err)
108-
}
109-
n += rn
110-
}
111-
err = tx.Commit()
112-
if err != nil {
113-
return n, fmt.Errorf("failed to commit transaction: %w", err)
114-
}
115-
return n, rows.Err()
116-
}

drivers/drivers.go

+12-8
Original file line numberDiff line numberDiff line change
@@ -540,16 +540,12 @@ func CopyWithInsert(placeholder func(int) string) func(ctx context.Context, db *
540540
if !strings.HasPrefix(strings.ToLower(query), "insert into") {
541541
leftParen := strings.IndexRune(table, '(')
542542
if leftParen == -1 {
543-
colStmt, err := db.PrepareContext(ctx, "SELECT * FROM "+table+" WHERE 1=0")
544-
if err != nil {
545-
return 0, fmt.Errorf("failed to prepare query to determine target table columns: %w", err)
546-
}
547-
defer colStmt.Close()
548-
colRows, err := colStmt.QueryContext(ctx)
543+
colRows, err := db.QueryContext(ctx, "SELECT * FROM "+table+" WHERE 1=0")
549544
if err != nil {
550545
return 0, fmt.Errorf("failed to execute query to determine target table columns: %w", err)
551546
}
552547
columns, err := colRows.Columns()
548+
_ = colRows.Close()
553549
if err != nil {
554550
return 0, fmt.Errorf("failed to fetch target table columns: %w", err)
555551
}
@@ -576,16 +572,24 @@ func CopyWithInsert(placeholder func(int) string) func(ctx context.Context, db *
576572
return 0, fmt.Errorf("failed to fetch source column types: %w", err)
577573
}
578574
values := make([]interface{}, clen)
575+
valueRefs := make([]reflect.Value, clen)
576+
actuals := make([]interface{}, clen)
579577
for i := 0; i < len(columnTypes); i++ {
580-
values[i] = reflect.New(columnTypes[i].ScanType()).Interface()
578+
valueRefs[i] = reflect.New(columnTypes[i].ScanType())
579+
values[i] = valueRefs[i].Interface()
581580
}
582581
var n int64
583582
for rows.Next() {
584583
err = rows.Scan(values...)
585584
if err != nil {
586585
return n, fmt.Errorf("failed to scan row: %w", err)
587586
}
588-
res, err := stmt.ExecContext(ctx, values...)
587+
//We can't use values... in Exec() below, because some drivers
588+
//don't accept pointer to an argument instead of the arg itself.
589+
for i := range values {
590+
actuals[i] = valueRefs[i].Elem().Interface()
591+
}
592+
res, err := stmt.ExecContext(ctx, actuals...)
589593
if err != nil {
590594
return n, fmt.Errorf("failed to exec insert: %w", err)
591595
}

drivers/drivers_test.go

+77-40
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,10 @@ var (
115115
DSN: "trino://test@localhost:%s/tpch/sf1",
116116
DockerPort: "8080/tcp",
117117
},
118+
"csvq": {
119+
// go test sets working directory to current package regardless of initial working directory
120+
DSN: "csvq://./testdata/csvq",
121+
},
118122
}
119123
cleanup bool
120124
)
@@ -144,30 +148,21 @@ func TestMain(m *testing.M) {
144148
}
145149

146150
for dbName, db := range dbs {
147-
var ok bool
148-
db.Resource, ok = pool.ContainerByName(db.RunOptions.Name)
149-
if !ok {
150-
buildOpts := &dt.BuildOptions{
151-
ContextDir: "./testdata/docker",
152-
BuildArgs: db.BuildArgs,
153-
}
154-
db.Resource, err = pool.BuildAndRunWithBuildOptions(buildOpts, db.RunOptions)
155-
if err != nil {
156-
log.Fatalf("Could not start %s: %s", dbName, err)
157-
}
158-
}
159-
160-
hostPort := db.Resource.GetPort(db.DockerPort)
161-
db.URL, err = dburl.Parse(fmt.Sprintf(db.DSN, hostPort))
151+
dsn, hostPort := getConnInfo(dbName, db, pool)
152+
db.URL, err = dburl.Parse(dsn)
162153
if err != nil {
163154
log.Fatalf("Failed to parse %s URL %s: %v", dbName, db.DSN, err)
164155
}
165156

166157
if len(db.Exec) != 0 {
158+
readyDSN := db.ReadyDSN
167159
if db.ReadyDSN == "" {
168-
db.ReadyDSN = db.DSN
160+
readyDSN = db.DSN
161+
}
162+
if hostPort != "" {
163+
readyDSN = fmt.Sprintf(db.ReadyDSN, hostPort)
169164
}
170-
readyURL, err := dburl.Parse(fmt.Sprintf(db.ReadyDSN, hostPort))
165+
readyURL, err := dburl.Parse(readyDSN)
171166
if err != nil {
172167
log.Fatalf("Failed to parse %s ready URL %s: %v", dbName, db.ReadyDSN, err)
173168
}
@@ -205,15 +200,46 @@ func TestMain(m *testing.M) {
205200
// You can't defer this because os.Exit doesn't care for defer
206201
if cleanup {
207202
for _, db := range dbs {
208-
if err := pool.Purge(db.Resource); err != nil {
209-
log.Fatal("Could not purge resource: ", err)
203+
if db.Resource != nil {
204+
if err := pool.Purge(db.Resource); err != nil {
205+
log.Fatal("Could not purge resource: ", err)
206+
}
210207
}
211208
}
212209
}
213210

214211
os.Exit(code)
215212
}
216213

214+
func getConnInfo(dbName string, db *Database, pool *dt.Pool) (string, string) {
215+
if db.RunOptions == nil {
216+
return db.DSN, ""
217+
}
218+
219+
var ok bool
220+
db.Resource, ok = pool.ContainerByName(db.RunOptions.Name)
221+
if ok && !db.Resource.Container.State.Running {
222+
err := db.Resource.Close()
223+
if err != nil {
224+
log.Fatalf("Failed to clean up stale container %s: %s", dbName, err)
225+
}
226+
ok = false
227+
}
228+
if !ok {
229+
buildOpts := &dt.BuildOptions{
230+
ContextDir: "./testdata/docker",
231+
BuildArgs: db.BuildArgs,
232+
}
233+
var err error
234+
db.Resource, err = pool.BuildAndRunWithBuildOptions(buildOpts, db.RunOptions)
235+
if err != nil {
236+
log.Fatalf("Failed to start %s: %s", dbName, err)
237+
}
238+
}
239+
hostPort := db.Resource.GetPort(db.DockerPort)
240+
return fmt.Sprintf(db.DSN, hostPort), hostPort
241+
}
242+
217243
func TestWriter(t *testing.T) {
218244
type testFunc struct {
219245
label string
@@ -467,37 +493,48 @@ func TestCopy(t *testing.T) {
467493
src: "select first_name, last_name, address_id, picture, email, store_id, active, username, password, last_update from staff",
468494
dest: "staff_copy(first_name, last_name, address_id, picture, email, store_id, active, username, password, last_update)",
469495
},
496+
{
497+
dbName: "csvq",
498+
setupQueries: []setupQuery{
499+
{query: "CREATE TABLE IF NOT EXISTS staff_copy AS SELECT * FROM `staff.csv` WHERE 0=1", check: true},
500+
},
501+
src: "select first_name, last_name, address_id, email, store_id, active, username, password, last_update from staff",
502+
dest: "staff_copy",
503+
},
470504
}
471505
for _, test := range testCases {
472506
db, ok := dbs[test.dbName]
473507
if !ok {
474508
continue
475509
}
476510

477-
// TODO test copy from a different DB, maybe csvq?
478-
// TODO test copy from same DB
511+
t.Run(test.dbName, func(t *testing.T) {
512+
513+
// TODO test copy from a different DB, maybe csvq?
514+
// TODO test copy from same DB
479515

480-
for _, q := range test.setupQueries {
481-
_, err := db.DB.Exec(q.query)
482-
if q.check && err != nil {
483-
log.Fatalf("Failed to run setup query `%s`: %v", q.query, err)
516+
for _, q := range test.setupQueries {
517+
_, err := db.DB.Exec(q.query)
518+
if q.check && err != nil {
519+
t.Fatalf("Failed to run setup query `%s`: %v", q.query, err)
520+
}
521+
}
522+
rows, err := pg.DB.Query(test.src)
523+
if err != nil {
524+
t.Fatalf("Could not get rows to copy: %v", err)
484525
}
485-
}
486-
rows, err := pg.DB.Query(test.src)
487-
if err != nil {
488-
log.Fatalf("Could not get rows to copy: %v", err)
489-
}
490526

491-
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
492-
defer cancel()
493-
var rlen int64 = 1
494-
n, err := drivers.Copy(ctx, db.URL, nil, nil, rows, test.dest)
495-
if err != nil {
496-
log.Fatalf("Could not copy: %v", err)
497-
}
498-
if n != rlen {
499-
log.Fatalf("Expected to copy %d rows but got %d", rlen, n)
500-
}
527+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
528+
defer cancel()
529+
var rlen int64 = 1
530+
n, err := drivers.Copy(ctx, db.URL, nil, nil, rows, test.dest)
531+
if err != nil {
532+
t.Fatalf("Could not copy: %v", err)
533+
}
534+
if n != rlen {
535+
t.Fatalf("Expected to copy %d rows but got %d", rlen, n)
536+
}
537+
})
501538
}
502539
}
503540

drivers/testdata/csvq/.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
*_copy

drivers/testdata/csvq/staff.csv

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
first_name,last_name,address_id,email,store_id,active,username,password,last_update
2+
John,Doe,1,[email protected],1,true,jdoe,abc,2024-05-10T08:12:05.46875Z

0 commit comments

Comments
 (0)