Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ Other supported formats are listed below.
* `multisubnetfailover`
* `true` (Default) Client attempt to connect to all IPs simultaneously.
* `false` Client attempts to connect to IPs in serial.
* `sendStringParametersAsUnicode`
* `true` (Default) Go default string types sent as `nvarchar`.
* `false` Go default string types sent as `varchar`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ing types sent as varchar.

without any way of indicating the code page, how does this work? Go strings are UTF8

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shueybubbles it's upto the clients to decide how they want to use Go strings. In the same way other drives like postgres for Go also works.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not convinced it's a good idea for the driver to have such a broad toggle without a more comprehensive fix for mixing string encodings AND having a way for clients to specify the code page of the input.

Note the driver already has a way to designate input parameter strings as varchar to avoid the conversion - you need to use the driver-specific VarChar
Using that type has the benefit of apps being explicit about the desired conversion behavior and being able to mix unicode strings with non-unicode strings in the same connection.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shueybubbles the fact that this flag is made available in the official support for sqlserver is because it makes the applications which work by default with non-unicode strings have lesser chances of failure. With this once and for all we can add this in the connection string, and nowhere explicit type casts are required on top of a string to varchar, without which any database columns having varchar as a type, will work the way as expected(for example index queries, etc). We are just trying to adhere to the official documentation of sqlserver here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shueybubbles can you help close this? This has been pending since long. We can have a quick connect on this if needed.


### Connection parameters for namedpipe package
* `pipe` - If set, no Browser query is made and named pipe used will be `\\<host>\pipe\<pipe>`
Expand Down Expand Up @@ -371,7 +374,7 @@ To pass specific types to the query parameters, say `varchar` or `date` types,
you must convert the types to the type before passing in. The following types
are supported:

* string -> nvarchar
* string -> nvarchar(by default, will be varchar if `sendStringParametersAsUnicode` is set to true)
* mssql.VarChar -> varchar
* time.Time -> datetimeoffset or datetime (TDS version dependent)
* mssql.DateTime1 -> datetime
Expand Down
69 changes: 43 additions & 26 deletions msdsn/conn_str.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,32 +53,33 @@ const (
)

const (
Database = "database"
Encrypt = "encrypt"
Password = "password"
ChangePassword = "change password"
UserID = "user id"
Port = "port"
TrustServerCertificate = "trustservercertificate"
Certificate = "certificate"
TLSMin = "tlsmin"
PacketSize = "packet size"
LogParam = "log"
ConnectionTimeout = "connection timeout"
HostNameInCertificate = "hostnameincertificate"
KeepAlive = "keepalive"
ServerSpn = "serverspn"
WorkstationID = "workstation id"
AppName = "app name"
ApplicationIntent = "applicationintent"
FailoverPartner = "failoverpartner"
FailOverPort = "failoverport"
DisableRetry = "disableretry"
Server = "server"
Protocol = "protocol"
DialTimeout = "dial timeout"
Pipe = "pipe"
MultiSubnetFailover = "multisubnetfailover"
Database = "database"
Encrypt = "encrypt"
Password = "password"
ChangePassword = "change password"
UserID = "user id"
Port = "port"
TrustServerCertificate = "trustservercertificate"
Certificate = "certificate"
TLSMin = "tlsmin"
PacketSize = "packet size"
LogParam = "log"
ConnectionTimeout = "connection timeout"
HostNameInCertificate = "hostnameincertificate"
KeepAlive = "keepalive"
ServerSpn = "serverspn"
WorkstationID = "workstation id"
AppName = "app name"
ApplicationIntent = "applicationintent"
FailoverPartner = "failoverpartner"
FailOverPort = "failoverport"
DisableRetry = "disableretry"
Server = "server"
Protocol = "protocol"
DialTimeout = "dial timeout"
Pipe = "pipe"
MultiSubnetFailover = "multisubnetfailover"
SendStringParametersAsUnicode = "sendstringparametersasunicode"
)

type Config struct {
Expand Down Expand Up @@ -131,6 +132,9 @@ type Config struct {
ColumnEncryption bool
// Attempt to connect to all IPs in parallel when MultiSubnetFailover is true
MultiSubnetFailover bool

// Sets a boolean value that indicates if sending string parameters to the server in UNICODE format is enabled.
SendStringParametersAsUnicode bool
}

func readDERFile(filename string) ([]byte, error) {
Expand Down Expand Up @@ -504,6 +508,19 @@ func Parse(dsn string) (Config, error) {
// Defaulting to true to prevent breaking change although other client libraries default to false
p.MultiSubnetFailover = true
}

sendStringParametersAsUnicode, ok := params[SendStringParametersAsUnicode]
if ok {
p.SendStringParametersAsUnicode, err = strconv.ParseBool(sendStringParametersAsUnicode)
if err != nil {
return p, fmt.Errorf("invalid %s '%s': %s", SendStringParametersAsUnicode,
sendStringParametersAsUnicode, err.Error())
}
} else {
// defaulting to true for backward compatibility
p.SendStringParametersAsUnicode = true
}

return p, nil
}

Expand Down
9 changes: 9 additions & 0 deletions msdsn/conn_str_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,15 @@ func TestValidConnectionString(t *testing.T) {
{"sqlserver://somehost?encrypt=true&tlsmin=1.1&columnencryption=1", func(p Config) bool {
return p.Host == "somehost" && p.Encryption == EncryptionRequired && p.TLSConfig.MinVersion == tls.VersionTLS11 && p.ColumnEncryption
}},
{"sqlserver://somehost", func(p Config) bool {
return p.Host == "somehost" && p.SendStringParametersAsUnicode
}},
{"sqlserver://somehost?sendStringParametersAsUnicode=true", func(p Config) bool {
return p.Host == "somehost" && p.SendStringParametersAsUnicode
}},
{"sqlserver://somehost?sendStringParametersAsUnicode=false", func(p Config) bool {
return p.Host == "somehost" && !p.SendStringParametersAsUnicode
}},
}
for _, ts := range connStrings {
p, err := Parse(ts.connStr)
Expand Down
22 changes: 16 additions & 6 deletions mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -563,8 +563,8 @@ func (s *Stmt) sendQuery(ctx context.Context, args []namedValue) (err error) {
if err != nil {
return
}
params[0] = makeStrParam(s.query)
params[1] = makeStrParam(strings.Join(decls, ","))
params[0] = makeStrParam(s.query, true)
params[1] = makeStrParam(strings.Join(decls, ","), true)
}
if err = sendRpc(conn.sess.buf, headers, proc, 0, params, reset); err != nil {
if conn.sess.logFlags&logErrors != 0 {
Expand Down Expand Up @@ -968,9 +968,19 @@ func (r *Rows) ColumnTypeNullable(index int) (nullable, ok bool) {
return
}

func makeStrParam(val string) (res param) {
res.ti.TypeId = typeNVarChar
res.buffer = str2ucs2(val)
func getSendStringParametersAsUnicode(s *Stmt) bool {
return s == nil || s.c == nil || s.c.connector == nil || s.c.connector.params.SendStringParametersAsUnicode
}

func makeStrParam(val string, sendStringParametersAsUnicode bool) (res param) {
if sendStringParametersAsUnicode {
res.ti.TypeId = typeNVarChar
res.buffer = str2ucs2(val)
res.ti.Size = len(res.buffer)
return
}
res.ti.TypeId = typeBigVarChar
res.buffer = []byte(val)
res.ti.Size = len(res.buffer)
return
}
Expand Down Expand Up @@ -1046,7 +1056,7 @@ func (s *Stmt) makeParam(val driver.Value) (res param, err error) {
res.ti.Size = len(val)
res.buffer = val
case string:
res = makeStrParam(val)
res = makeStrParam(val, getSendStringParametersAsUnicode(s))
case sql.NullString:
// only null values should be getting here
res.ti.TypeId = typeNVarChar
Expand Down
55 changes: 55 additions & 0 deletions queries_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,61 @@ func TestSelectNewTypes(t *testing.T) {
}
}

func TestSelectWithVarchar(t *testing.T) {
conn, logger := openWithVarcharDSN(t)
defer conn.Close()
defer logger.StopLogging()

t.Run("scan into string", func(t *testing.T) {
type testStruct struct {
sql string
args []interface{}
val string
}

longstr := strings.Repeat("x", 10000)

values := []testStruct{
{"'abc'", []interface{}{}, "abc"},
{"N'abc'", []interface{}{}, "abc"},
{"cast(N'abc' as nvarchar(max))", []interface{}{}, "abc"},
{"cast('abc' as text)", []interface{}{}, "abc"},
{"cast(N'abc' as ntext)", []interface{}{}, "abc"},
{"cast('abc' as char(3))", []interface{}{}, "abc"},
{"cast('abc' as varchar(3))", []interface{}{}, "abc"},
{fmt.Sprintf("cast(N'%s' as nvarchar(max))", longstr), []interface{}{}, longstr},
{"cast(cast('abc' as varchar(3)) as sql_variant)", []interface{}{}, "abc"},
{"cast(cast('abc' as char(3)) as sql_variant)", []interface{}{}, "abc"},
{"cast(N'abc' as sql_variant)", []interface{}{}, "abc"},
{"@p1", []interface{}{"abc"}, "abc"},
{"@p1", []interface{}{longstr}, longstr},
}

for _, test := range values {
t.Run(test.sql, func(t *testing.T) {
stmt, err := conn.Prepare("select " + test.sql)
if err != nil {
t.Error("Prepare failed:", test.sql, err.Error())
return
}
defer stmt.Close()

row := stmt.QueryRow(test.args...)
var retval string
err = row.Scan(&retval)
if err != nil {
t.Error("Scan failed:", test.sql, err.Error())
return
}
if retval != test.val {
t.Errorf("Values don't match '%s' '%s' for test: %s", retval, test.val, test.sql)
return
}
})
}
})
}

func TestTrans(t *testing.T) {
conn, logger := open(t)
defer conn.Close()
Expand Down
19 changes: 19 additions & 0 deletions tds_go110_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,22 @@ func getTestConnector(t testing.TB) (*Connector, *testLogger) {
}
return connector, &tl
}

func openWithVarcharDSN(t testing.TB) (*sql.DB, *testLogger) {
connector, logger := getTestConnectorWithVarcharDSN(t)
conn := sql.OpenDB(connector)
return conn, logger
}

func getTestConnectorWithVarcharDSN(t testing.TB) (*Connector, *testLogger) {
tl := testLogger{t: t}
SetLogger(&tl)
s := testConnParams(t)
s.SendStringParametersAsUnicode = true
connector, err := NewConnector(s.URL().String())
if err != nil {
t.Error("Open connection failed:", err.Error())
return nil, &tl
}
return connector, &tl
}