Skip to content

Commit

Permalink
Merge pull request #377 from upper/issue-370
Browse files Browse the repository at this point in the history
Fixes #370
  • Loading branch information
José Nieto authored Jun 11, 2017
2 parents 4d7953a + 8ae1524 commit 0b0d44b
Show file tree
Hide file tree
Showing 9 changed files with 510 additions and 38 deletions.
7 changes: 5 additions & 2 deletions internal/sqladapter/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@ func (c *collection) InsertReturning(item interface{}) error {
}

// Fetch the row that was just interted into newItem
if err = col.Find(id).One(newItem); err != nil {
err = col.Find(id).One(newItem)
if err != nil {
goto cancel
}

Expand All @@ -184,14 +185,16 @@ func (c *collection) InsertReturning(item interface{}) error {
itemV.SetMapIndex(keyV, newItemV.MapIndex(keyV))
}
default:
panic("default")
err = fmt.Errorf("InsertReturning: expecting a pointer to map or struct, got %T", newItem)
goto cancel
}

if !inTx {
// This is only executed if t.Database() was **not** a transaction and if
// sess was created with sess.NewTransaction().
return tx.Commit()
}

return err

cancel:
Expand Down
6 changes: 5 additions & 1 deletion internal/sqladapter/sqladapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ func IsKeyValue(v interface{}) bool {
return true
}
switch v.(type) {
case int64, int, uint, uint64, driver.Valuer:
case int64, int, uint, uint64,
[]int64, []int, []uint, []uint64,
[]byte, []string,
[]interface{},
driver.Valuer:
return true
}
return false
Expand Down
45 changes: 30 additions & 15 deletions internal/sqladapter/testing/adapter.go.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -1053,26 +1053,41 @@ func TestCompositeKeys(t *testing.T) {
compositeKeys := sess.Collection("composite_keys")
n := rand.Intn(100000)
{
n := rand.Intn(100000)
item := itemWithCompoundKey{
"ABCDEF",
strconv.Itoa(n),
"Some value",
}
item := itemWithCompoundKey{
"ABCDEF",
strconv.Itoa(n),
"Some value",
}
id, err := compositeKeys.Insert(&item)
assert.NoError(t, err)
assert.NotZero(t, id)
id, err := compositeKeys.Insert(&item)
assert.NoError(t, err)
assert.NotZero(t, id)
var item2 itemWithCompoundKey
assert.NotEqual(t, item2.SomeVal, item.SomeVal)
var item2 itemWithCompoundKey
assert.NotEqual(t, item2.SomeVal, item.SomeVal)
// Finding by ID
err = compositeKeys.Find(id).One(&item2)
assert.NoError(t, err)
// Finding by ID
err = compositeKeys.Find(id).One(&item2)
assert.NoError(t, err)
assert.Equal(t, item2.SomeVal, item.SomeVal)
assert.Equal(t, item2.SomeVal, item.SomeVal)
}
{
n := rand.Intn(100000)
item := itemWithCompoundKey{
"ABCDEF",
strconv.Itoa(n),
"Some value",
}
err := compositeKeys.InsertReturning(&item)
assert.NoError(t, err)
}
assert.NoError(t, cleanUpCheck(sess))
assert.NoError(t, sess.Close())
Expand Down
8 changes: 5 additions & 3 deletions internal/sqladapter/tx.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,12 @@ func (b *baseTx) Committed() bool {
}

func (b *baseTx) Commit() (err error) {
if err = b.Tx.Commit(); err == nil {
b.committed.Store(struct{}{})
err = b.Tx.Commit()
if err != nil {
return err
}
return err
b.committed.Store(struct{}{})
return nil
}

func (w *databaseTx) Commit() error {
Expand Down
6 changes: 6 additions & 0 deletions lib/sqlbuilder/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -352,9 +352,15 @@ func (tu *templateWithUtils) toColumnValues(term interface{}) (cv exql.ColumnVal
q, a := Preprocess(value.Raw(), value.Arguments())
columnValue.Value = exql.RawValue(q)
args = append(args, a...)
case driver.Valuer:
columnValue.Value = exql.RawValue("?")
args = append(args, value)
default:
v, isSlice := toInterfaceArguments(value)

//valuer, ok := value.(driver.Valuer)
//log.Printf("valuer: %v, ok: %v, (%v) %T", valuer, ok, value, value)

if isSlice {
if columnValue.Operator == "" {
columnValue.Operator = sqlInOperator
Expand Down
38 changes: 29 additions & 9 deletions mssql/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ package mssql

import (
"database/sql"
"log"
"strings"

"upper.io/db.v3"
"upper.io/db.v3/internal/sqladapter"
Expand All @@ -37,6 +35,8 @@ type table struct {

d *database
name string

hasIdentityColumn *bool
}

var (
Expand Down Expand Up @@ -76,7 +76,6 @@ func (t *table) Insert(item interface{}) (interface{}, error) {
for j := 0; j < len(pKey); j++ {
if pKey[j] == columnNames[i] {
if columnValues[i] != nil {
log.Printf("%v -- %v", pKey[j], columnValues[i])
hasKeys = true
break
}
Expand All @@ -85,13 +84,34 @@ func (t *table) Insert(item interface{}) (interface{}, error) {
}

if hasKeys {
_, err = t.d.Exec("SET IDENTITY_INSERT " + t.Name() + " ON")
// TODO: Find a way to check if the table has composite keys without an
// identity property.
if err != nil && !strings.Contains(err.Error(), "does not have the identity property") {
return nil, err
if t.hasIdentityColumn == nil {
var hasIdentityColumn bool
var identityColumns int

row, err := t.d.QueryRow("SELECT COUNT(1) FROM sys.identity_columns WHERE OBJECT_NAME(object_id) = ?", t.Name())
if err != nil {
return nil, err
}

err = row.Scan(&identityColumns)
if err != nil {
return nil, err
}

if identityColumns > 0 {
hasIdentityColumn = true
}

t.hasIdentityColumn = &hasIdentityColumn
}

if *t.hasIdentityColumn {
_, err = t.d.Exec("SET IDENTITY_INSERT " + t.Name() + " ON")
if err != nil {
return nil, err
}
defer t.d.Exec("SET IDENTITY_INSERT " + t.Name() + " OFF")
}
defer t.d.Exec("SET IDENTITY_INSERT " + t.Name() + " OFF")
}

q := t.d.InsertInto(t.Name()).
Expand Down
6 changes: 3 additions & 3 deletions mssql/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,10 @@ func (d *database) NewDatabaseTx(ctx context.Context) (sqladapter.DatabaseTx, er

connFn := func() error {
sqlTx, err := compat.BeginTx(clone.BaseDatabase.Session(), ctx, nil)
if err == nil {
return clone.BindTx(ctx, sqlTx)
if err != nil {
return err
}
return err
return clone.BindTx(ctx, sqlTx)
}

if err := d.BaseDatabase.WaitForConnection(connFn); err != nil {
Expand Down
Loading

0 comments on commit 0b0d44b

Please sign in to comment.