Skip to content

Commit f1ed1e4

Browse files
authored
feat: Shopspring decimal support (#287)
* Queries and parameters tests. * OUT parameter tests. * Bug fix. * TVP test. * Bug fix. * Another TVP test. * Another TVP test. * Another TVP test. * Bug fix. * Bug fix. * Another TVP test. * Bulk copy test. * Bug fix. * Bug fix. * Another bulk copy test. * Bug fix. * Some MONEY type tests. * Support MONEY type. * Support Decimal type. * Support NULL for MONEY type. * Bug fix and more tests. * MONEY encoding test. * Bug fix. * Query MONEY encoding test. * Bug fix. * Query MONEY encoding tests. * DECIMAL encoding tests. * Bug fix. * Bulk copy money tests. * Bug fix. * Money bulkcopy support. * Bug fix. * SMALLMONEY bulkcopy support. * Bug fix. * Bug fix. * Refactoring - generalize Money type wrapper. * Bug fix. * TVP MONEY test. * TVP MONEY test. * TVP MONEY test. * TVP MONEY test. * Remove redundant file. * Simplify money encoding. * Money test. * More money tests. * Remove redundant file. * Update README. * Money / bulk tests - error cases.
1 parent 0997918 commit f1ed1e4

File tree

13 files changed

+1736
-129
lines changed

13 files changed

+1736
-129
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,8 @@ are supported:
438438
* "github.com/golang-sql/civil".DateTime -> datetime2
439439
* "github.com/golang-sql/civil".Time -> time
440440
* mssql.TVP -> Table Value Parameter (TDS version dependent)
441+
* "github.com/shopspring/decimal".Decimal -> decimal
442+
* mssql.Money -> money
441443
442444
Using an `int` parameter will send a 4 byte value (int) from a 32bit app and an 8 byte value (bigint) from a 64bit app.
443445
To make sure your integer parameter matches the size of the SQL parameter, use the appropriate sized type like `int32` or `int8`.

bulkcopy.go

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414

1515
"github.com/microsoft/go-mssqldb/internal/decimal"
1616
"github.com/microsoft/go-mssqldb/msdsn"
17+
shopspring "github.com/shopspring/decimal"
1718
)
1819

1920
type Bulk struct {
@@ -346,6 +347,10 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error)
346347
loc := getTimezone(b.cn)
347348

348349
switch valuer := val.(type) {
350+
case Money[shopspring.Decimal]:
351+
return b.makeParam(valuer.Decimal, col)
352+
case Money[shopspring.NullDecimal]:
353+
return b.makeParam(valuer.Decimal, col)
349354
case driver.Valuer:
350355
var e error
351356
val, e = driver.DefaultParameterConverter.ConvertValue(valuer)
@@ -561,7 +566,37 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error)
561566
err = fmt.Errorf("mssql: invalid type for time column: %T %s", val, val)
562567
return
563568
}
564-
// case typeMoney, typeMoney4, typeMoneyN:
569+
case typeMoney, typeMoney4, typeMoneyN:
570+
switch v := val.(type) {
571+
case string:
572+
money, err := decimal.StringToDecimalScale(v, 4)
573+
if err != nil {
574+
return res, err
575+
}
576+
577+
buf := make([]byte, col.ti.Size)
578+
579+
integer0 := money.GetInteger(0)
580+
if col.ti.Size == 4 {
581+
if money.IsPositive() {
582+
binary.LittleEndian.PutUint32(buf, integer0)
583+
} else {
584+
binary.LittleEndian.PutUint32(buf, ^integer0+1)
585+
}
586+
} else {
587+
integer := (uint64(money.GetInteger(1)) << 32) | uint64(integer0)
588+
if !money.IsPositive() {
589+
integer = ^integer + 1
590+
}
591+
592+
binary.LittleEndian.PutUint32(buf, uint32(integer>>32))
593+
binary.LittleEndian.PutUint32(buf[4:], uint32(integer))
594+
}
595+
596+
res.buffer = buf
597+
default:
598+
return res, fmt.Errorf("unknown value for money: %T %#v", v, v)
599+
}
565600
case typeDecimal, typeDecimalN, typeNumeric, typeNumericN:
566601
prec := col.ti.Prec
567602
scale := col.ti.Scale

bulkcopy_test.go

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"testing"
1414
"time"
1515

16+
"github.com/shopspring/decimal"
1617
"github.com/stretchr/testify/assert"
1718
)
1819

@@ -29,6 +30,8 @@ func TestBulkcopyWithInvalidNullableType(t *testing.T) {
2930
"test_nullint16",
3031
"test_nulltime",
3132
"test_nulluniqueidentifier",
33+
"test_nulldecimal",
34+
"test_nullmoney",
3235
}
3336
values := []interface{}{
3437
sql.NullFloat64{Valid: false},
@@ -40,6 +43,8 @@ func TestBulkcopyWithInvalidNullableType(t *testing.T) {
4043
sql.NullInt16{Valid: false},
4144
sql.NullTime{Valid: false},
4245
NullUniqueIdentifier{Valid: false},
46+
decimal.NullDecimal{Valid: false},
47+
Money[decimal.NullDecimal]{decimal.NullDecimal{Valid: false}},
4348
}
4449

4550
pool, logger := open(t)
@@ -176,9 +181,11 @@ func testBulkcopy(t *testing.T, guidConversion bool) {
176181
{"test_nullint32", sql.NullInt32{2147483647, true}, 2147483647},
177182
{"test_nullint16", sql.NullInt16{32767, true}, 32767},
178183
{"test_nulltime", sql.NullTime{time.Date(2010, 11, 12, 13, 14, 15, 120000000, time.UTC), true}, time.Date(2010, 11, 12, 13, 14, 15, 120000000, time.UTC)},
184+
{"test_nulldecimal", decimal.NewNullDecimal(decimal.New(1232355, -4)), decimal.New(1232355, -4)},
185+
{"test_nullmoney", Money[decimal.NullDecimal]{decimal.NewNullDecimal(decimal.New(-21232311232355, -4))}, decimal.New(-21232311232355, -4)},
179186
{"test_datetimen_midnight", time.Date(2025, 1, 1, 23, 59, 59, 998_350_000, time.UTC), time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC)},
180-
// {"test_smallmoney", 1234.56, nil},
181-
// {"test_money", 1234.56, nil},
187+
{"test_smallmoney", Money[decimal.Decimal]{decimal.New(-32856, -4)}, decimal.New(-32856, -4)},
188+
{"test_money", Money[decimal.Decimal]{decimal.New(-21232311232355, -4)}, decimal.New(-21232311232355, -4)},
182189
{"test_decimal_18_0", 1234.0001, "1234"},
183190
{"test_decimal_9_2", -1234.560001, "-1234.56"},
184191
{"test_decimal_20_0", 1234, "1234"},
@@ -334,6 +341,20 @@ func compareValue(a interface{}, expected interface{}) bool {
334341
return expected.Equal(got) && ez == az
335342
}
336343
return false
344+
case decimal.Decimal:
345+
actual, err := decimal.NewFromString(a.(string))
346+
if err != nil {
347+
return false
348+
}
349+
350+
return expected.Equal(actual)
351+
case Money[decimal.Decimal]:
352+
actual, err := decimal.NewFromString(a.(string))
353+
if err != nil {
354+
return false
355+
}
356+
357+
return expected.Decimal.Equal(actual)
337358
default:
338359
return reflect.DeepEqual(expected, a)
339360
}
@@ -351,6 +372,8 @@ func setupNullableTypeTable(ctx context.Context, t *testing.T, conn *sql.Conn, t
351372
[test_nullint16] [smallint] NULL,
352373
[test_nulltime] [datetime] NULL,
353374
[test_nulluniqueidentifier] [uniqueidentifier] NULL,
375+
[test_nulldecimal] [decimal](18, 4) NULL,
376+
[test_nullmoney] [money] NULL,
354377
CONSTRAINT [PK_` + tableName + `_id] PRIMARY KEY CLUSTERED
355378
(
356379
[id] ASC
@@ -438,6 +461,8 @@ func setupTable(ctx context.Context, t *testing.T, conn *sql.Conn, tableName str
438461
[test_nullint32] [int] NULL,
439462
[test_nullint16] [smallint] NULL,
440463
[test_nulltime] [datetime] NULL,
464+
[test_nulldecimal] [decimal](18, 4) NULL,
465+
[test_nullmoney] [money] NULL,
441466
[test_datetimen_midnight] [datetime] NULL,
442467
CONSTRAINT [PK_` + tableName + `_id] PRIMARY KEY CLUSTERED
443468
(

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ require (
3333
github.com/kylelemons/godebug v1.1.0 // indirect
3434
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect
3535
github.com/pmezard/go-difflib v1.0.0 // indirect
36+
github.com/shopspring/decimal v1.4.0 // indirect
3637
golang.org/x/net v0.40.0 // indirect
3738
gopkg.in/yaml.v3 v3.0.1 // indirect
3839
)

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ github.com/redis/go-redis/v9 v9.8.0 h1:q3nRvjrlge/6UD7eTu/DSg2uYiU2mCL0G/uzBWqhi
6464
github.com/redis/go-redis/v9 v9.8.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
6565
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
6666
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
67+
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
68+
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
6769
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
6870
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
6971
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=

internal/decimal/decimal.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ func init() {
3636

3737
const autoScale = 100
3838

39+
// GetInteger gets the ind'th element of the integer array
40+
func (d *Decimal) GetInteger(ind uint8) uint32 {
41+
return d.integer[ind]
42+
}
43+
3944
// SetInteger sets the ind'th element in the integer array
4045
func (d *Decimal) SetInteger(integer uint32, ind uint8) {
4146
d.integer[ind] = integer

money.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
package mssql
2+
3+
import (
4+
"database/sql"
5+
"database/sql/driver"
6+
7+
"github.com/shopspring/decimal"
8+
)
9+
10+
type Money[D decimal.Decimal|decimal.NullDecimal] struct {
11+
Decimal D
12+
}
13+
14+
func (m Money[D]) Value() (driver.Value, error) {
15+
valuer, _ := any(m.Decimal).(driver.Valuer)
16+
17+
return valuer.Value()
18+
}
19+
20+
func (m *Money[D]) Scan(v any) error {
21+
scanner, _ := any(&m.Decimal).(sql.Scanner)
22+
23+
return scanner.Scan(v);
24+
}

0 commit comments

Comments
 (0)