Skip to content

Commit 08f10f9

Browse files
authored
feat: support transaction (#102)
* feat: support transaction * add begin, commit ,rollback * fix serverinfo * fix unit tests * fix * f * settest suit * fix tests * add test for commit, rollback * fix * add tests * fix * fix * fix commit
1 parent 0f677c5 commit 08f10f9

File tree

8 files changed

+153
-55
lines changed

8 files changed

+153
-55
lines changed

connection.go

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,15 @@ const (
2525
)
2626

2727
type DatabendConn struct {
28-
ctx context.Context
29-
cfg *Config
30-
cancel context.CancelFunc
31-
closed int32
32-
stmts []*databendStmt
33-
logger *log.Logger
34-
rest *APIClient
35-
commit func() error
28+
ctx context.Context
29+
cfg *Config
30+
cancel context.CancelFunc
31+
closed int32
32+
stmts []*databendStmt
33+
logger *log.Logger
34+
rest *APIClient
35+
batchMode bool
36+
batchInsert func() error
3637
}
3738

3839
func (dc *DatabendConn) exec(ctx context.Context, query string, args ...driver.Value) (driver.Result, error) {
@@ -95,11 +96,23 @@ func (dc *DatabendConn) query(ctx context.Context, query string, args ...driver.
9596
return newNextRows(ctx, dc, r0)
9697
}
9798

98-
//func (dc *DatabendConn) Begin() (driver.Tx, error) {
99-
// return dc.BeginTx(dc.ctx, driver.TxOptions{})
100-
//}
99+
func (dc *DatabendConn) Begin() (driver.Tx, error) {
101100

102-
func (dc *DatabendConn) Begin() (driver.Tx, error) { return dc, nil }
101+
return dc.BeginTx(dc.ctx, driver.TxOptions{})
102+
}
103+
104+
func (dc *DatabendConn) BeginTx(
105+
ctx context.Context,
106+
opts driver.TxOptions) (
107+
driver.Tx, error) {
108+
if dc.rest == nil {
109+
return nil, driver.ErrBadConn
110+
}
111+
if _, err := dc.exec(ctx, "BEGIN"); err != nil {
112+
return nil, err
113+
}
114+
return &databendTx{dc}, nil
115+
}
103116

104117
func (dc *DatabendConn) cleanup() {
105118
// must flush log buffer while the process is running.
@@ -128,7 +141,8 @@ func (dc *DatabendConn) prepare(ctx context.Context, query string) (*databendStm
128141
if err != nil {
129142
return nil, err
130143
}
131-
dc.commit = batch.BatchInsert
144+
dc.batchInsert = batch.BatchInsert
145+
dc.batchMode = true
132146
stmt := &databendStmt{
133147
dc: dc,
134148
query: query,
@@ -193,22 +207,15 @@ func (dc *DatabendConn) QueryContext(ctx context.Context, query string, args []d
193207
return dc.query(ctx, query, values...)
194208
}
195209

196-
// Commit applies prepared statement if it exists
197-
func (dc *DatabendConn) Commit() (err error) {
198-
if dc.commit == nil {
210+
// ExecuteBatch applies batch prepared statement if it exists
211+
func (dc *DatabendConn) ExecuteBatch() (err error) {
212+
if dc.batchInsert == nil {
199213
return nil
200214
}
201215
defer func() {
202-
dc.commit = nil
216+
dc.batchInsert = nil
203217
}()
204-
return dc.commit()
205-
}
206-
207-
// Rollback cleans prepared statement
208-
func (dc *DatabendConn) Rollback() error {
209-
dc.commit = nil
210-
dc.Close()
211-
return nil
218+
return dc.batchInsert()
212219
}
213220

214221
// checkQueryID checks if query_id exists in context, if not, generate a new one

query.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,20 @@ type SessionState struct {
8989
// KeepServerSessionSecs uint64 `json:"keep_server_session_secs,omitempty"`
9090

9191
Settings map[string]string `json:"settings,omitempty"`
92+
93+
// txn
94+
TxnState string `json:"txn_state,omitempty"`
95+
LastServerInfo ServerInfo `json:"last_server_info,omitempty"`
96+
LastQueryIds []string `json:"last_query_ids,omitempty"`
9297
}
9398

9499
type StageAttachmentConfig struct {
95100
Location string `json:"location"`
96101
FileFormatOptions map[string]string `json:"file_format_options,omitempty"`
97102
CopyOptions map[string]string `json:"copy_options,omitempty"`
98103
}
104+
105+
type ServerInfo struct {
106+
Id string `json:"id"`
107+
StartTime string `json:"start_time"`
108+
}

query_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ func Test_SessionState(t *testing.T) {
1717
}
1818
buf, err := json.Marshal(ss)
1919
require.NoError(t, err)
20-
assert.Equal(t, `{"database":"db1"}`, string(buf))
20+
assert.Equal(t, "{\"database\":\"db1\",\"last_server_info\":{\"id\":\"\",\"start_time\":\"\"}}", string(buf))
2121

2222
buf = []byte(`{"database":"db1", "secondary_roles": []}`)
2323
err = json.Unmarshal(buf, ss)

stmt.go

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,25 @@
11
package godatabend
22

33
import (
4-
"context"
54
"database/sql/driver"
6-
"regexp"
75

86
"github.com/pkg/errors"
97
)
108

11-
var (
12-
splitInsertRe = regexp.MustCompile(`(?si)(.+\s*VALUES)\s*(\(.+\))`)
13-
)
14-
159
type databendStmt struct {
16-
dc *DatabendConn
17-
closed int32
18-
prefix string
19-
pattern string
20-
index []int
21-
batchMode bool
22-
args [][]driver.Value
23-
query string
24-
batch Batch
10+
dc *DatabendConn
11+
closed int32
12+
prefix string
13+
pattern string
14+
index []int
15+
args [][]driver.Value
16+
query string
17+
batch Batch
2518
}
2619

2720
func (stmt *databendStmt) Close() error {
2821
logger.WithContext(stmt.dc.ctx).Infoln("Stmt.Close")
29-
return nil
22+
return stmt.dc.Close()
3023
}
3124

3225
func (stmt *databendStmt) NumInput() int {
@@ -41,12 +34,6 @@ func (stmt *databendStmt) Exec(args []driver.Value) (driver.Result, error) {
4134
return nil, err
4235
}
4336

44-
//2. /v1/upload_to_stage csv file
45-
46-
// 3. copy into db.table from @~/csv
47-
48-
// 4. delete the file ?
49-
5037
return driver.RowsAffected(0), nil
5138
}
5239

@@ -62,8 +49,3 @@ func (stmt *databendStmt) Query(args []driver.Value) (driver.Rows, error) {
6249
logger.WithContext(stmt.dc.ctx).Infoln("Stmt.Query")
6350
return nil, errors.New("only Exec method supported in batch mode")
6451
}
65-
66-
func (stmt *databendStmt) commit(ctx context.Context) error {
67-
logger.WithContext(stmt.dc.ctx).Infoln("Stmt Commit")
68-
return stmt.batch.BatchInsert()
69-
}

tests/docker-compose.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ services:
77
volumes:
88
- ./data:/data
99
databend:
10-
image: docker.io/datafuselabs/databend
10+
image: datafuselabs/databend:v1.2.360-nightly
1111
environment:
1212
- QUERY_DEFAULT_USER=databend
1313
- QUERY_DEFAULT_PASSWORD=databend

tests/main_test.go

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,15 @@ const (
2929
createTable2 = `create table %s (a string);`
3030
)
3131

32+
var (
33+
dsn = "http://databend:databend@localhost:8000?presigned_url_disabled=true"
34+
)
35+
36+
func init() {
37+
dsn = os.Getenv("TEST_DATABEND_DSN")
38+
//dsn = "http://databend:databend@localhost:8000?presigned_url_disabled=true"
39+
}
40+
3241
func TestDatabendSuite(t *testing.T) {
3342
suite.Run(t, new(DatabendTestSuite))
3443
}
@@ -44,7 +53,6 @@ type DatabendTestSuite struct {
4453
func (s *DatabendTestSuite) SetupSuite() {
4554
var err error
4655

47-
dsn := os.Getenv("TEST_DATABEND_DSN")
4856
s.NotEmpty(dsn)
4957
s.db, err = sql.Open("databend", dsn)
5058
s.Nil(err)
@@ -80,6 +88,7 @@ func (s *DatabendTestSuite) SetupTest() {
8088

8189
func (s *DatabendTestSuite) TearDownTest() {
8290
// t := s.T()
91+
s.SetupSuite()
8392

8493
// t.Logf("teardown test with table %s", s.table)
8594
_, err := s.db.Exec(fmt.Sprintf("DROP TABLE %s", s.table))
@@ -251,6 +260,52 @@ func (s *DatabendTestSuite) TestQueryNull() {
251260
s.r.NoError(rows.Close())
252261
}
253262

263+
func (s *DatabendTestSuite) TestTransactionCommit() {
264+
tx, err := s.db.Begin()
265+
s.r.Nil(err)
266+
267+
_, err = tx.Exec(fmt.Sprintf("INSERT INTO %s (i64) VALUES (?)", s.table), int64(1))
268+
s.r.Nil(err)
269+
270+
err = tx.Commit()
271+
s.r.Nil(err)
272+
273+
rows, err := s.db.Query(fmt.Sprintf("SELECT * FROM %s", s.table))
274+
s.r.Nil(err)
275+
276+
result, err := scanValues(rows)
277+
s.r.Nil(err)
278+
s.r.Equal([][]interface{}{[]interface{}{"1", "NULL", "NULL", "NULL", "NULL", "NULL", "NULL", "NULL", "NULL"}}, result)
279+
280+
s.r.NoError(rows.Close())
281+
}
282+
283+
func (s *DatabendTestSuite) TestTransactionRollback() {
284+
tx, err := s.db.Begin()
285+
s.r.Nil(err)
286+
287+
_, err = tx.Exec(fmt.Sprintf("INSERT INTO %s (i64) VALUES (?)", s.table), int64(1))
288+
s.r.Nil(err)
289+
rows, err := s.db.Query(fmt.Sprintf("SELECT * FROM %s", s.table))
290+
s.r.Nil(err)
291+
292+
result, err := scanValues(rows)
293+
s.r.Nil(err)
294+
s.r.Equal([][]interface{}(nil), result)
295+
296+
err = tx.Rollback()
297+
s.r.Nil(err)
298+
299+
rows, err = s.db.Query(fmt.Sprintf("SELECT * FROM %s", s.table))
300+
s.r.Nil(err)
301+
302+
result, err = scanValues(rows)
303+
s.r.Nil(err)
304+
s.r.Empty(result)
305+
306+
s.r.NoError(rows.Close())
307+
}
308+
254309
func scanValues(rows *sql.Rows) (interface{}, error) {
255310
var err error
256311
var result [][]interface{}

tests/session_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
)
1010

1111
func (s *DatabendTestSuite) TestChangeDatabase() {
12+
s.SetupSuite()
1213
r := require.New(s.T())
1314
var result string
1415

transaction.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package godatabend
2+
3+
import (
4+
"database/sql/driver"
5+
)
6+
7+
type databendTx struct {
8+
dc *DatabendConn
9+
}
10+
11+
func (tx *databendTx) Commit() (err error) {
12+
if tx.dc == nil || tx.dc.rest == nil {
13+
return driver.ErrBadConn
14+
}
15+
defer func() {
16+
tx.dc.batchInsert = nil
17+
}()
18+
if tx.dc.batchMode && tx.dc.batchInsert != nil {
19+
err = tx.dc.batchInsert()
20+
if err != nil {
21+
return
22+
}
23+
}
24+
// complicity with old server version
25+
if tx.dc.rest.sessionState.TxnState != "" {
26+
_, err = tx.dc.exec(tx.dc.ctx, "COMMIT")
27+
if err != nil {
28+
return
29+
}
30+
}
31+
return
32+
}
33+
34+
func (tx *databendTx) Rollback() (err error) {
35+
if tx.dc == nil || tx.dc.rest == nil {
36+
return driver.ErrBadConn
37+
}
38+
_, err = tx.dc.exec(tx.dc.ctx, "ROLLBACK")
39+
if err != nil {
40+
return
41+
}
42+
return
43+
}

0 commit comments

Comments
 (0)