Skip to content

Commit 7663e7e

Browse files
authored
Fixes in \copy implementation for Clickhouse (#457)
* Fixes in \copy implementation for Clickhouse * fixup: remove TODOs (pr feedback)
1 parent d140a28 commit 7663e7e

File tree

3 files changed

+162
-3
lines changed

3 files changed

+162
-3
lines changed

drivers/clickhouse/clickhouse.go

+76-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
package clickhouse
66

77
import (
8+
"context"
89
"database/sql"
10+
"fmt"
11+
"reflect"
912
"strconv"
1013
"strings"
1114

@@ -35,7 +38,79 @@ func init() {
3538
}
3639
return false
3740
},
38-
Copy: drivers.CopyWithInsert(func(int) string { return "?" }),
41+
Copy: CopyWithInsert,
3942
NewMetadataReader: NewMetadataReader,
4043
})
4144
}
45+
46+
// CopyWithInsert builds a copy handler based on insert.
47+
func CopyWithInsert(ctx context.Context, db *sql.DB, rows *sql.Rows, table string) (int64, error) {
48+
columns, err := rows.Columns()
49+
if err != nil {
50+
return 0, fmt.Errorf("failed to fetch source rows columns: %w", err)
51+
}
52+
clen := len(columns)
53+
query := table
54+
if !strings.HasPrefix(strings.ToLower(query), "insert into") {
55+
leftParen := strings.IndexRune(table, '(')
56+
if leftParen == -1 {
57+
colRows, err := db.QueryContext(ctx, "SELECT * FROM "+table+" WHERE 1=0")
58+
if err != nil {
59+
return 0, fmt.Errorf("failed to execute query to determine target table columns: %w", err)
60+
}
61+
columns, err := colRows.Columns()
62+
_ = colRows.Close()
63+
if err != nil {
64+
return 0, fmt.Errorf("failed to fetch target table columns: %w", err)
65+
}
66+
table += "(" + strings.Join(columns, ", ") + ")"
67+
}
68+
query = "INSERT INTO " + table + " VALUES (" + strings.Repeat("?, ", clen-1) + "?)"
69+
}
70+
tx, err := db.BeginTx(ctx, nil)
71+
if err != nil {
72+
return 0, fmt.Errorf("failed to begin transaction: %w", err)
73+
}
74+
stmt, err := tx.PrepareContext(ctx, query)
75+
if err != nil {
76+
return 0, fmt.Errorf("failed to prepare insert query: %w", err)
77+
}
78+
defer stmt.Close()
79+
columnTypes, err := rows.ColumnTypes()
80+
if err != nil {
81+
return 0, fmt.Errorf("failed to fetch source column types: %w", err)
82+
}
83+
values := make([]interface{}, clen)
84+
valueRefs := make([]reflect.Value, clen)
85+
actuals := make([]interface{}, clen)
86+
for i := 0; i < len(columnTypes); i++ {
87+
valueRefs[i] = reflect.New(columnTypes[i].ScanType())
88+
values[i] = valueRefs[i].Interface()
89+
}
90+
var n int64
91+
for rows.Next() {
92+
err = rows.Scan(values...)
93+
if err != nil {
94+
return n, fmt.Errorf("failed to scan row: %w", err)
95+
}
96+
//We can't use values... in Exec() below, because, in some cases, clickhouse
97+
//driver doesn't accept pointer to an argument instead of the arg itself.
98+
for i := range values {
99+
actuals[i] = valueRefs[i].Elem().Interface()
100+
}
101+
res, err := stmt.ExecContext(ctx, actuals...)
102+
if err != nil {
103+
return n, fmt.Errorf("failed to exec insert: %w", err)
104+
}
105+
rn, err := res.RowsAffected()
106+
if err != nil {
107+
return n, fmt.Errorf("failed to check rows affected: %w", err)
108+
}
109+
n += rn
110+
}
111+
err = tx.Commit()
112+
if err != nil {
113+
return n, fmt.Errorf("failed to commit transaction: %w", err)
114+
}
115+
return n, rows.Err()
116+
}

drivers/clickhouse/clickhouse_test.go

+78-2
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,25 @@
11
package clickhouse_test
22

33
import (
4+
"context"
45
"database/sql"
56
"flag"
67
"fmt"
8+
"github.com/xo/dburl"
9+
"github.com/xo/usql/drivers"
710
"log"
811
"os"
912
"path/filepath"
1013
"testing"
14+
"time"
1115

1216
dt "github.com/ory/dockertest/v3"
1317
"github.com/xo/usql/drivers/clickhouse"
1418
"github.com/xo/usql/drivers/metadata"
1519
"github.com/yookoala/realpath"
20+
21+
_ "github.com/xo/usql/drivers/csvq"
22+
_ "github.com/xo/usql/drivers/moderncsqlite"
1623
)
1724

1825
// db is the database connection.
@@ -59,7 +66,7 @@ func doMain(m *testing.M, cleanup bool) (int, error) {
5966
if cleanup {
6067
defer func() {
6168
if err := pool.Purge(db.res); err != nil {
62-
fmt.Fprintf(os.Stderr, "error: could not purge resoure: %v\n", err)
69+
fmt.Fprintf(os.Stderr, "error: could not purge resource: %v\n", err)
6370
}
6471
}()
6572
}
@@ -85,7 +92,7 @@ func TestSchemas(t *testing.T) {
8592
if err != nil {
8693
t.Fatalf("could not read schemas: %v", err)
8794
}
88-
checkNames(t, "schema", res, "default", "system", "tutorial", "tutorial_unexpected", "INFORMATION_SCHEMA", "information_schema")
95+
checkNames(t, "schema", res, "default", "system", "tutorial", "tutorial_unexpected", "INFORMATION_SCHEMA", "information_schema", "copy_test")
8996
}
9097

9198
func TestTables(t *testing.T) {
@@ -119,6 +126,75 @@ func TestColumns(t *testing.T) {
119126
checkNames(t, "column", res, colNames()...)
120127
}
121128

129+
func TestCopy(t *testing.T) {
130+
// Tests with csvq source DB. That driver doesn't support ScanType()
131+
for _, destTableSpec := range []string{
132+
"copy_test.dest",
133+
"copy_test.dest(StringCol, NumCol)",
134+
"insert into copy_test.dest values(?, ?)",
135+
} {
136+
t.Run("csvq_"+destTableSpec, func(t *testing.T) {
137+
testCopy(t, destTableSpec, "csvq:.")
138+
})
139+
}
140+
// Test with a driver that supports ScanType()
141+
t.Run("sqlite", func(t *testing.T) {
142+
testCopy(t, "copy_test.dest", "moderncsqlite://:memory:")
143+
})
144+
}
145+
146+
func testCopy(t *testing.T, destTableSpec string, sourceDbUrlStr string) {
147+
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
148+
defer cancel()
149+
_, err := db.db.ExecContext(ctx, "truncate table copy_test.dest")
150+
if err != nil {
151+
t.Fatalf("could not truncate copy_test table: %v", err)
152+
}
153+
// Prepare copy destination URL
154+
port := db.res.GetPort("9000/tcp")
155+
dbUrlStr := fmt.Sprintf("clickhouse://127.0.0.1:%s", port)
156+
dbUrl, err := dburl.Parse(dbUrlStr)
157+
if err != nil {
158+
t.Fatalf("could not parse clickhouse url %s: %v", dbUrlStr, err)
159+
}
160+
// Prepare source data
161+
sourceDbUrl, err := dburl.Parse(sourceDbUrlStr)
162+
if err != nil {
163+
t.Fatalf("could not parse source DB url %s: %v", sourceDbUrlStr, err)
164+
}
165+
sourceDb, err := drivers.Open(ctx, sourceDbUrl, nil, nil)
166+
if err != nil {
167+
t.Fatalf("could not open sourceDb: %v", err)
168+
}
169+
defer sourceDb.Close()
170+
rows, err := sourceDb.QueryContext(ctx, "select 'string', 1")
171+
if err != nil {
172+
t.Fatalf("could not retrieve source rows: %v", err)
173+
}
174+
// Do Copy, ignoring copied rows count because clickhouse driver doesn't report RowsAffected
175+
_, err = drivers.Copy(ctx, dbUrl, nil, nil, rows, destTableSpec)
176+
if err != nil {
177+
t.Fatalf("copy failed: %v", err)
178+
}
179+
rows, err = db.db.QueryContext(ctx, "select StringCol, NumCol from copy_test.dest")
180+
if err != nil {
181+
t.Fatalf("failed to query: %v", err)
182+
}
183+
defer rows.Close()
184+
var copiedString string
185+
var copiedNum int
186+
if !rows.Next() {
187+
t.Fatalf("nothing copied")
188+
}
189+
err = rows.Scan(&copiedString, &copiedNum)
190+
if err != nil {
191+
t.Fatalf("could not read copied data: %v", err)
192+
}
193+
if copiedString != "string" || copiedNum != 1 {
194+
t.Fatalf("copied data differs: %s != string, %d != 1", copiedString, copiedNum)
195+
}
196+
}
197+
122198
func checkNames(t *testing.T, typ string, res interface{ Next() bool }, exp ...string) {
123199
n := make(map[string]bool)
124200
for _, s := range exp {

drivers/clickhouse/testdata/clickhouse.sql

+8
Original file line numberDiff line numberDiff line change
@@ -340,3 +340,11 @@ CREATE TABLE tutorial_unexpected.hits_v1 (
340340
)
341341
ENGINE = MergeTree()
342342
ORDER BY (Unexpected);
343+
344+
CREATE DATABASE copy_test;
345+
CREATE TABLE copy_test.dest (
346+
StringCol String,
347+
NumCol UInt32
348+
)
349+
ENGINE = MergeTree()
350+
ORDER BY (StringCol);

0 commit comments

Comments
 (0)