Skip to content
Draft
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,9 @@ are supported:
* "github.com/golang-sql/civil".Date -> date
* "github.com/golang-sql/civil".DateTime -> datetime2
* "github.com/golang-sql/civil".Time -> time
* mssql.NullDate -> date (nullable)
* mssql.NullDateTime -> datetime2 (nullable)
* mssql.NullTime -> time (nullable)
* mssql.TVP -> Table Value Parameter (TDS version dependent)

Using an `int` parameter will send a 4 byte value (int) from a 32bit app and an 8 byte value (bigint) from a 64bit app.
Expand Down
16 changes: 16 additions & 0 deletions bulkcopy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"testing"
"time"

"github.com/golang-sql/civil"
"github.com/stretchr/testify/assert"
)

Expand All @@ -29,6 +30,9 @@ func TestBulkcopyWithInvalidNullableType(t *testing.T) {
"test_nullint16",
"test_nulltime",
"test_nulluniqueidentifier",
"test_nulldate",
"test_nulldatetime",
"test_nullciviltime",
}
values := []interface{}{
sql.NullFloat64{Valid: false},
Expand All @@ -40,6 +44,9 @@ func TestBulkcopyWithInvalidNullableType(t *testing.T) {
sql.NullInt16{Valid: false},
sql.NullTime{Valid: false},
NullUniqueIdentifier{Valid: false},
NullDate{Valid: false},
NullDateTime{Valid: false},
NullTime{Valid: false},
}

pool, logger := open(t)
Expand Down Expand Up @@ -176,6 +183,9 @@ func testBulkcopy(t *testing.T, guidConversion bool) {
{"test_nullint32", sql.NullInt32{2147483647, true}, 2147483647},
{"test_nullint16", sql.NullInt16{32767, true}, 32767},
{"test_nulltime", sql.NullTime{time.Date(2010, 11, 12, 13, 14, 15, 120000000, time.UTC), true}, time.Date(2010, 11, 12, 13, 14, 15, 120000000, time.UTC)},
{"test_nulldate", NullDate{civil.Date{Year: 2010, Month: 11, Day: 12}, true}, time.Date(2010, 11, 12, 0, 0, 0, 0, time.UTC)},
{"test_nulldatetime", NullDateTime{civil.DateTime{Date: civil.Date{Year: 2010, Month: 11, Day: 12}, Time: civil.Time{Hour: 13, Minute: 14, Second: 15, Nanosecond: 120000000}}, true}, time.Date(2010, 11, 12, 13, 14, 15, 120000000, time.UTC)},
{"test_nullciviltime", NullTime{civil.Time{Hour: 13, Minute: 14, Second: 15, Nanosecond: 123000000}, true}, time.Date(1, 1, 1, 13, 14, 15, 123000000, time.UTC)},
{"test_datetimen_midnight", time.Date(2025, 1, 1, 23, 59, 59, 998_350_000, time.UTC), time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC)},
// {"test_smallmoney", 1234.56, nil},
// {"test_money", 1234.56, nil},
Expand Down Expand Up @@ -351,6 +361,9 @@ func setupNullableTypeTable(ctx context.Context, t *testing.T, conn *sql.Conn, t
[test_nullint16] [smallint] NULL,
[test_nulltime] [datetime] NULL,
[test_nulluniqueidentifier] [uniqueidentifier] NULL,
[test_nulldate] [date] NULL,
[test_nulldatetime] [datetime2] NULL,
[test_nullciviltime] [time] NULL,
CONSTRAINT [PK_` + tableName + `_id] PRIMARY KEY CLUSTERED
(
[id] ASC
Expand Down Expand Up @@ -438,6 +451,9 @@ func setupTable(ctx context.Context, t *testing.T, conn *sql.Conn, tableName str
[test_nullint32] [int] NULL,
[test_nullint16] [smallint] NULL,
[test_nulltime] [datetime] NULL,
[test_nulldate] [date] NULL,
[test_nulldatetime] [datetime2] NULL,
[test_nullciviltime] [time] NULL,
[test_datetimen_midnight] [datetime] NULL,
CONSTRAINT [PK_` + tableName + `_id] PRIMARY KEY CLUSTERED
(
Expand Down
214 changes: 214 additions & 0 deletions civil_null.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
package mssql

import (
"database/sql/driver"
"encoding/json"
"fmt"
"time"

"github.com/golang-sql/civil"
)

// NullDate represents a civil.Date that may be null.
// NullDate implements the Scanner interface so it can be used as a scan destination,
// similar to sql.NullString.
type NullDate struct {
Date civil.Date
Valid bool // Valid is true if Date is not NULL
}

// Scan implements the Scanner interface.
func (n *NullDate) Scan(value interface{}) error {
if value == nil {
n.Date, n.Valid = civil.Date{}, false
return nil
}
n.Valid = true
switch v := value.(type) {
case time.Time:
n.Date = civil.DateOf(v)
return nil
default:
n.Valid = false
return fmt.Errorf("cannot scan %T into NullDate", value)
}
}

// Value implements the driver Valuer interface.
func (n NullDate) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.Date.In(time.UTC), nil
}

// String returns the string representation of the date or "NULL".
func (n NullDate) String() string {
if !n.Valid {
return "NULL"
}
return n.Date.String()
}

// MarshalText implements the encoding.TextMarshaler interface.
func (n NullDate) MarshalText() ([]byte, error) {
if !n.Valid {
return []byte("null"), nil
}
return n.Date.MarshalText()
}

// UnmarshalJSON implements the json.Unmarshaler interface.
func (n *NullDate) UnmarshalJSON(b []byte) error {
if string(b) == "null" {
n.Date, n.Valid = civil.Date{}, false
return nil
}
err := json.Unmarshal(b, &n.Date)
n.Valid = err == nil
return err
}

// MarshalJSON implements the json.Marshaler interface.
func (n NullDate) MarshalJSON() ([]byte, error) {
if !n.Valid {
return []byte("null"), nil
}
return json.Marshal(n.Date)
}

// NullDateTime represents a civil.DateTime that may be null.
// NullDateTime implements the Scanner interface so it can be used as a scan destination,
// similar to sql.NullString.
type NullDateTime struct {
DateTime civil.DateTime
Valid bool // Valid is true if DateTime is not NULL
}

// Scan implements the Scanner interface.
func (n *NullDateTime) Scan(value interface{}) error {
if value == nil {
n.DateTime, n.Valid = civil.DateTime{}, false
return nil
}
n.Valid = true
switch v := value.(type) {
case time.Time:
n.DateTime = civil.DateTimeOf(v)
return nil
default:
n.Valid = false
return fmt.Errorf("cannot scan %T into NullDateTime", value)
}
}

// Value implements the driver Valuer interface.
func (n NullDateTime) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.DateTime.In(time.UTC), nil
}

// String returns the string representation of the datetime or "NULL".
func (n NullDateTime) String() string {
if !n.Valid {
return "NULL"
}
return n.DateTime.String()
}

// MarshalText implements the encoding.TextMarshaler interface.
func (n NullDateTime) MarshalText() ([]byte, error) {
if !n.Valid {
return []byte("null"), nil
}
return n.DateTime.MarshalText()
}

// UnmarshalJSON implements the json.Unmarshaler interface.
func (n *NullDateTime) UnmarshalJSON(b []byte) error {
if string(b) == "null" {
n.DateTime, n.Valid = civil.DateTime{}, false
return nil
}
err := json.Unmarshal(b, &n.DateTime)
n.Valid = err == nil
return err
}

// MarshalJSON implements the json.Marshaler interface.
func (n NullDateTime) MarshalJSON() ([]byte, error) {
if !n.Valid {
return []byte("null"), nil
}
return json.Marshal(n.DateTime)
}

// NullTime represents a civil.Time that may be null.
// NullTime implements the Scanner interface so it can be used as a scan destination,
// similar to sql.NullString.
type NullTime struct {
Time civil.Time
Valid bool // Valid is true if Time is not NULL
}

// Scan implements the Scanner interface.
func (n *NullTime) Scan(value interface{}) error {
if value == nil {
n.Time, n.Valid = civil.Time{}, false
return nil
}
n.Valid = true
switch v := value.(type) {
case time.Time:
n.Time = civil.TimeOf(v)
return nil
default:
n.Valid = false
return fmt.Errorf("cannot scan %T into NullTime", value)
}
}

// Value implements the driver Valuer interface.
func (n NullTime) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return time.Date(1, 1, 1, n.Time.Hour, n.Time.Minute, n.Time.Second, n.Time.Nanosecond, time.UTC), nil
}

// String returns the string representation of the time or "NULL".
func (n NullTime) String() string {
if !n.Valid {
return "NULL"
}
return n.Time.String()
}

// MarshalText implements the encoding.TextMarshaler interface.
func (n NullTime) MarshalText() ([]byte, error) {
if !n.Valid {
return []byte("null"), nil
}
return n.Time.MarshalText()
}

// UnmarshalJSON implements the json.Unmarshaler interface.
func (n *NullTime) UnmarshalJSON(b []byte) error {
if string(b) == "null" {
n.Time, n.Valid = civil.Time{}, false
return nil
}
err := json.Unmarshal(b, &n.Time)
n.Valid = err == nil
return err
}

// MarshalJSON implements the json.Marshaler interface.
func (n NullTime) MarshalJSON() ([]byte, error) {
if !n.Valid {
return []byte("null"), nil
}
return json.Marshal(n.Time)
}
Loading
Loading