Skip to content

Commit fcc3fd6

Browse files
committed
Fixing issue with vertica driver
1 parent 9494649 commit fcc3fd6

File tree

1 file changed

+32
-74
lines changed

1 file changed

+32
-74
lines changed

drivers/vertica/vertica.go

+32-74
Original file line numberDiff line numberDiff line change
@@ -8,66 +8,25 @@ import (
88
"crypto/tls"
99
"crypto/x509"
1010
"database/sql"
11-
"fmt"
11+
"errors"
1212
"io"
1313
"net/url"
1414
"os"
1515
"regexp"
1616
"strings"
1717

18-
vertigo "github.com/vertica/vertica-sql-go" // DRIVER
18+
vertica "github.com/vertica/vertica-sql-go" // DRIVER
1919
"github.com/vertica/vertica-sql-go/logger"
2020
"github.com/xo/dburl"
2121
"github.com/xo/usql/drivers"
2222
)
2323

2424
func init() {
25-
// List of custom TLS configurations that may be applied via query in connection string.
26-
customTlsConfig := map[string]func(string, *tls.Config) error{
27-
"ca_path": func(queryValue string, c *tls.Config) error {
28-
rootCertPool := x509.NewCertPool()
29-
30-
pem, err := os.ReadFile(queryValue)
31-
if err != nil {
32-
return err
33-
}
34-
35-
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
36-
return fmt.Errorf("error: failed to append pem to cert pool")
37-
}
38-
39-
c.RootCAs = rootCertPool
40-
41-
return nil
42-
},
43-
}
44-
45-
hasCustomTlsConfig := func(queries url.Values) bool {
46-
for key := range customTlsConfig {
47-
if queries.Has(key) {
48-
return true
49-
}
50-
}
51-
52-
return false
53-
}
54-
55-
applyCustomTlsConfig := func(queries url.Values, tlsConfig *tls.Config) error {
56-
for key, configFunction := range customTlsConfig {
57-
if queries.Has(key) {
58-
if err := configFunction(queries.Get(key), tlsConfig); err != nil {
59-
return err
60-
}
61-
}
62-
}
63-
64-
return nil
65-
}
66-
6725
// turn off logging
6826
if os.Getenv("VERTICA_SQL_GO_LOG_LEVEL") == "" {
6927
logger.SetLogLevel(logger.NONE)
7028
}
29+
7130
errCodeRE := regexp.MustCompile(`(?i)^\[([0-9a-z]+)\]\s+(.+)`)
7231
drivers.Register("vertica", drivers.Driver{
7332
AllowDollar: true,
@@ -80,42 +39,28 @@ func init() {
8039
return ver, nil
8140
},
8241
Open: func(_ context.Context, u *dburl.URL, stdout, stderr func() io.Writer) (func(string, string) (*sql.DB, error), error) {
83-
return func(_, _ string) (*sql.DB, error) {
84-
queries := u.Query()
85-
86-
if hasCustomTlsConfig(queries) {
87-
if queries.Get("tlsmode") != "server-strict" {
88-
configNames := []string{}
89-
90-
for key := range customTlsConfig {
91-
configNames = append(configNames, key)
92-
}
93-
94-
return nil, fmt.Errorf(fmt.Sprintf("error: when custom tls configurations are set (%s), tlsmode must be set to server-strict", strings.Join(configNames, ",")))
42+
return func(driver, dsn string) (*sql.DB, error) {
43+
u, err := url.Parse(dsn)
44+
if err != nil {
45+
return nil, err
46+
}
47+
q := u.Query()
48+
if name := q.Get("ca_path"); name != "" {
49+
if q.Get("tlsmode") != "server-strict" {
50+
return nil, errors.New("tlsmode must be set to server-strict: ca_path is set")
9551
}
96-
97-
tlsConfig := &tls.Config{ServerName: u.Hostname()}
98-
99-
if err := applyCustomTlsConfig(queries, tlsConfig); err != nil {
52+
cfg := &tls.Config{
53+
ServerName: u.Hostname(),
54+
}
55+
if err := addCA(name, cfg); err != nil {
10056
return nil, err
10157
}
102-
103-
if err := vertigo.RegisterTLSConfig("custom_tls_config", tlsConfig); err != nil {
58+
if err := vertica.RegisterTLSConfig("custom_tls_config", cfg); err != nil {
10459
return nil, err
10560
}
106-
107-
queries.Set("tlsmode", "custom_tls_config")
108-
}
109-
110-
dsn := url.URL{
111-
Scheme: u.Driver,
112-
User: u.User,
113-
Host: u.Host,
114-
Path: u.Path,
115-
RawQuery: queries.Encode(),
61+
q.Set("tlsmode", "custom_tls_config")
11662
}
117-
118-
return sql.Open(u.Driver, dsn.String())
63+
return sql.Open(driver, u.String())
11964
}, nil
12065
},
12166
ChangePassword: func(db drivers.DB, user, newpw, _ string) error {
@@ -134,3 +79,16 @@ func init() {
13479
},
13580
})
13681
}
82+
83+
// addCA adds the specified file name as a ca to the tls config.
84+
func addCA(name string, cfg *tls.Config) error {
85+
pool := x509.NewCertPool()
86+
switch pem, err := os.ReadFile(name); {
87+
case err != nil:
88+
return err
89+
case !pool.AppendCertsFromPEM(pem):
90+
return errors.New("failed to append pem to cert pool")
91+
}
92+
cfg.RootCAs = pool
93+
return nil
94+
}

0 commit comments

Comments
 (0)