Skip to content

Commit a367ec1

Browse files
murfffikenshaw
authored andcommitted
fix: strip trailing semicolon in cosmosdb and unify implementation
cosmosdb, like several other databases, needs trailing semicolons to be removed from the statement before it is passed to the driver. This fix applies that to cosmosdb and makes the other databases with identical processing use the same implementation - presto, trino, athena, SAP ASE. Drivers that strip trailing semicolon as part of other processing are not included in the refactoring to keep scope small. Testing: will run usql with trino, and cosmosdb
1 parent 3719f36 commit a367ec1

File tree

6 files changed

+16
-33
lines changed

6 files changed

+16
-33
lines changed

drivers/athena/athena.go

+1-8
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,15 @@ package athena
55

66
import (
77
"context"
8-
"regexp"
98

109
_ "github.com/uber/athenadriver/go" // DRIVER: awsathena
11-
"github.com/xo/dburl"
1210
"github.com/xo/usql/drivers"
1311
)
1412

1513
func init() {
16-
endRE := regexp.MustCompile(`;?\s*$`)
1714
drivers.Register("awsathena", drivers.Driver{
1815
AllowMultilineComments: true,
19-
Process: func(_ *dburl.URL, prefix string, sqlstr string) (string, string, bool, error) {
20-
sqlstr = endRE.ReplaceAllString(sqlstr, "")
21-
typ, q := drivers.QueryExecType(prefix, sqlstr)
22-
return typ, sqlstr, q, nil
23-
},
16+
Process: drivers.StripTrailingSemicolon,
2417
Version: func(ctx context.Context, db drivers.DB) (string, error) {
2518
var ver string
2619
err := db.QueryRowContext(

drivers/cosmos/cosmos.go

+3-1
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,7 @@ import (
99
)
1010

1111
func init() {
12-
drivers.Register("cosmos", drivers.Driver{})
12+
drivers.Register("cosmos", drivers.Driver{
13+
Process: drivers.StripTrailingSemicolon,
14+
}, "gocosmos")
1315
}

drivers/drivers.go

+9
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"fmt"
1010
"io"
1111
"reflect"
12+
"regexp"
1213
"strings"
1314
"time"
1415
"unicode"
@@ -612,3 +613,11 @@ func CopyWithInsert(placeholder func(int) string) func(ctx context.Context, db *
612613
func init() {
613614
dburl.OdbcIgnoreQueryPrefixes = []string{"usql_"}
614615
}
616+
617+
var endRE = regexp.MustCompile(`;?\s*$`)
618+
619+
func StripTrailingSemicolon(_ *dburl.URL, prefix string, sqlstr string) (string, string, bool, error) {
620+
sqlstr = endRE.ReplaceAllString(sqlstr, "")
621+
typ, q := QueryExecType(prefix, sqlstr)
622+
return typ, sqlstr, q, nil
623+
}

drivers/presto/presto.go

+1-8
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,15 @@ package presto
55

66
import (
77
"context"
8-
"regexp"
98

109
_ "github.com/prestodb/presto-go-client/presto" // DRIVER
11-
"github.com/xo/dburl"
1210
"github.com/xo/usql/drivers"
1311
)
1412

1513
func init() {
16-
endRE := regexp.MustCompile(`;?\s*$`)
1714
drivers.Register("presto", drivers.Driver{
1815
AllowMultilineComments: true,
19-
Process: func(_ *dburl.URL, prefix string, sqlstr string) (string, string, bool, error) {
20-
sqlstr = endRE.ReplaceAllString(sqlstr, "")
21-
typ, q := drivers.QueryExecType(prefix, sqlstr)
22-
return typ, sqlstr, q, nil
23-
},
16+
Process: drivers.StripTrailingSemicolon,
2417
Version: func(ctx context.Context, db drivers.DB) (string, error) {
2518
var ver string
2619
err := db.QueryRowContext(

drivers/sapase/sapase.go

+1-8
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,14 @@ package sapase
66
import (
77
"context"
88
"errors"
9-
"regexp"
109
"strconv"
1110
"strings"
1211

1312
"github.com/thda/tds" // DRIVER: tds
14-
"github.com/xo/dburl"
1513
"github.com/xo/usql/drivers"
1614
)
1715

1816
func init() {
19-
endRE := regexp.MustCompile(`;?\s*$`)
2017
drivers.Register("tds", drivers.Driver{
2118
AllowMultilineComments: true,
2219
RequirePreviousPassword: true,
@@ -49,10 +46,6 @@ func init() {
4946
IsPasswordErr: func(err error) bool {
5047
return strings.Contains(err.Error(), "Login failed")
5148
},
52-
Process: func(_ *dburl.URL, prefix string, sqlstr string) (string, string, bool, error) {
53-
sqlstr = endRE.ReplaceAllString(sqlstr, "")
54-
typ, q := drivers.QueryExecType(prefix, sqlstr)
55-
return typ, sqlstr, q, nil
56-
},
49+
Process: drivers.StripTrailingSemicolon,
5750
})
5851
}

drivers/trino/trino.go

+1-8
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,14 @@ package trino
66
import (
77
"context"
88
"io"
9-
"regexp"
109

1110
_ "github.com/trinodb/trino-go-client/trino" // DRIVER
12-
"github.com/xo/dburl"
1311
"github.com/xo/usql/drivers"
1412
"github.com/xo/usql/drivers/metadata"
1513
infos "github.com/xo/usql/drivers/metadata/informationschema"
1614
)
1715

1816
func init() {
19-
endRE := regexp.MustCompile(`;?\s*$`)
2017
newReader := func(db drivers.DB, opts ...metadata.ReaderOption) metadata.Reader {
2118
ir := infos.New(
2219
infos.WithPlaceholder(func(int) string { return "?" }),
@@ -40,11 +37,7 @@ func init() {
4037
}
4138
drivers.Register("trino", drivers.Driver{
4239
AllowMultilineComments: true,
43-
Process: func(_ *dburl.URL, prefix string, sqlstr string) (string, string, bool, error) {
44-
sqlstr = endRE.ReplaceAllString(sqlstr, "")
45-
typ, q := drivers.QueryExecType(prefix, sqlstr)
46-
return typ, sqlstr, q, nil
47-
},
40+
Process: drivers.StripTrailingSemicolon,
4841
Version: func(ctx context.Context, db drivers.DB) (string, error) {
4942
var ver string
5043
err := db.QueryRowContext(

0 commit comments

Comments
 (0)