Skip to content

Commit 531311d

Browse files
authored
Improve data loading performance (#134)
1 parent 3ad5c8f commit 531311d

7 files changed

+145
-57
lines changed

dbr.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ func exec(ctx context.Context, runner runner, log EventReceiver, builder Builder
116116
Dialect: d,
117117
IgnoreBinary: true,
118118
}
119-
err := i.interpolate(placeholder, []interface{}{builder}, true)
119+
err := i.encodePlaceholder(builder, true)
120120
query, value := i.String(), i.Value()
121121
if err != nil {
122122
return nil, log.EventErrKv("dbr.exec.interpolate", err, kvs{
@@ -154,7 +154,7 @@ func query(ctx context.Context, runner runner, log EventReceiver, builder Builde
154154
Dialect: d,
155155
IgnoreBinary: true,
156156
}
157-
err := i.interpolate(placeholder, []interface{}{builder}, true)
157+
err := i.encodePlaceholder(builder, true)
158158
query, value := i.String(), i.Value()
159159
if err != nil {
160160
return 0, log.EventErrKv("dbr.select.interpolate", err, kvs{

insert.go

+14-11
Original file line numberDiff line numberDiff line change
@@ -150,20 +150,23 @@ func (b *InsertStmt) Record(structValue interface{}) *InsertStmt {
150150
v := reflect.Indirect(reflect.ValueOf(structValue))
151151

152152
if v.Kind() == reflect.Struct {
153-
m := structMap(v)
154-
if v.CanSet() {
155-
// ID is recommended by golint here
156-
if field, ok := m["id"]; ok && field.Kind() == reflect.Int64 {
157-
b.RecordID = field.Addr().Interface().(*int64)
153+
found := make([]interface{}, len(b.Column)+1)
154+
// ID is recommended by golint here
155+
findValueByName(v, append(b.Column, "id"), found, false)
156+
157+
value := found[:len(found)-1]
158+
for i, v := range value {
159+
if v != nil {
160+
value[i] = v.(reflect.Value).Interface()
158161
}
159162
}
160163

161-
var value []interface{}
162-
for _, key := range b.Column {
163-
if val, ok := m[key]; ok {
164-
value = append(value, val.Interface())
165-
} else {
166-
value = append(value, nil)
164+
if v.CanSet() {
165+
switch idField := found[len(found)-1].(type) {
166+
case reflect.Value:
167+
if idField.Kind() == reflect.Int64 {
168+
b.RecordID = idField.Addr().Interface().(*int64)
169+
}
167170
}
168171
}
169172
b.Values(value...)

interpolate.go

+2-4
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ func (i *interpolator) interpolate(query string, value []interface{}, topLevel b
6262
}
6363

6464
if valueIndex >= len(value) {
65-
break
65+
return ErrPlaceholderCount
6666
}
6767

6868
i.WriteString(query[:index])
@@ -103,10 +103,8 @@ func (i *interpolator) encodePlaceholder(value interface{}, topLevel bool) error
103103
}
104104
paren := false
105105
switch value.(type) {
106-
case *SelectStmt:
106+
case *SelectStmt, *union:
107107
paren = !topLevel
108-
case *union:
109-
paren = true
110108
}
111109
if paren {
112110
i.WriteString("(")

load.go

+16-11
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ func Load(rows *sql.Rows, value interface{}) (int, error) {
4646
var err error
4747

4848
if isMapOfSlices {
49-
elem = reflect.New(v.Type().Elem().Elem()).Elem()
49+
elem = reflectAlloc(v.Type().Elem().Elem())
5050
} else if isSlice || isMap {
51-
elem = reflect.New(v.Type().Elem()).Elem()
51+
elem = reflectAlloc(v.Type().Elem())
5252
} else {
5353
elem = v
5454
}
@@ -58,7 +58,7 @@ func Load(rows *sql.Rows, value interface{}) (int, error) {
5858
if err != nil {
5959
return 0, err
6060
}
61-
keyElem = reflect.New(v.Type().Key()).Elem()
61+
keyElem = reflectAlloc(v.Type().Key())
6262
keyPtr, err := findPtr(column[0:1], keyElem)
6363
if err != nil {
6464
return 0, err
@@ -95,6 +95,13 @@ func Load(rows *sql.Rows, value interface{}) (int, error) {
9595
return count, nil
9696
}
9797

98+
func reflectAlloc(typ reflect.Type) reflect.Value {
99+
if typ.Kind() == reflect.Ptr {
100+
return reflect.New(typ.Elem())
101+
}
102+
return reflect.New(typ).Elem()
103+
}
104+
98105
type dummyScanner struct{}
99106

100107
func (dummyScanner) Scan(interface{}) error {
@@ -107,18 +114,16 @@ var (
107114
)
108115

109116
func findPtr(column []string, value reflect.Value) ([]interface{}, error) {
110-
if value.Addr().Type().Implements(typeScanner) {
117+
if value.CanAddr() && value.Addr().Type().Implements(typeScanner) {
111118
return []interface{}{value.Addr().Interface()}, nil
112119
}
113120
switch value.Kind() {
114121
case reflect.Struct:
115-
var ptr []interface{}
116-
m := structMap(value)
117-
for _, key := range column {
118-
if val, ok := m[key]; ok {
119-
ptr = append(ptr, val.Addr().Interface())
120-
} else {
121-
ptr = append(ptr, dummyDest)
122+
ptr := make([]interface{}, len(column))
123+
findValueByName(value, column, ptr, true)
124+
for i := range ptr {
125+
if ptr[i] == nil {
126+
ptr[i] = dummyDest
122127
}
123128
}
124129
return ptr, nil

load_benchmark_test.go

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
package dbr
2+
3+
import (
4+
"fmt"
5+
"testing"
6+
7+
"github.com/jmoiron/sqlx"
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
func BenchmarkLoadValues(b *testing.B) {
12+
sess := mysqlSession
13+
for _, v := range []string{
14+
`DROP TABLE IF EXISTS suggestions`,
15+
`CREATE TABLE suggestions (
16+
id serial PRIMARY KEY,
17+
title varchar(255),
18+
body text
19+
)`,
20+
} {
21+
_, err := sess.Exec(v)
22+
require.NoError(b, err)
23+
}
24+
tx, err := sess.Begin()
25+
require.NoError(b, err)
26+
27+
const maxRows = 100000
28+
29+
for i := 0; i < maxRows; i++ {
30+
_, err := tx.InsertInto("suggestions").
31+
Columns("title", "body").
32+
Values("title", "body").
33+
Exec()
34+
require.NoError(b, err)
35+
}
36+
err = tx.Commit()
37+
require.NoError(b, err)
38+
39+
type Suggestion struct {
40+
Title *string
41+
Body *string
42+
}
43+
for n := 10; n <= maxRows; n *= 10 {
44+
query := fmt.Sprintf("SELECT * FROM suggestions ORDER BY id ASC LIMIT %d", n)
45+
46+
b.Run(fmt.Sprintf("sqlx_%d", n), func(b *testing.B) {
47+
b.StopTimer()
48+
db, err := sqlx.Connect("mysql", mysqlDSN)
49+
require.NoError(b, err)
50+
db = db.Unsafe()
51+
defer db.Close()
52+
53+
for i := 0; i < b.N; i++ {
54+
var suggs []*Suggestion
55+
b.StartTimer()
56+
err := db.Select(&suggs, query)
57+
b.StopTimer()
58+
require.NoError(b, err)
59+
require.Len(b, suggs, n)
60+
}
61+
})
62+
b.Run(fmt.Sprintf("dbr_%d", n), func(b *testing.B) {
63+
b.StopTimer()
64+
65+
for i := 0; i < b.N; i++ {
66+
var suggs []*Suggestion
67+
b.StartTimer()
68+
_, err := sess.SelectBySql(query).Load(&suggs)
69+
b.StopTimer()
70+
require.NoError(b, err)
71+
require.Len(b, suggs, n)
72+
}
73+
})
74+
}
75+
}

util.go

+14-11
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,11 @@ func camelCaseToSnakeCase(name string) string {
4343
return buf.String()
4444
}
4545

46-
func structMap(value reflect.Value) map[string]reflect.Value {
47-
m := make(map[string]reflect.Value)
48-
structValue(m, value)
49-
return m
50-
}
51-
5246
var (
5347
typeValuer = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
5448
)
5549

56-
func structValue(m map[string]reflect.Value, value reflect.Value) {
50+
func findValueByName(value reflect.Value, name []string, found []interface{}, retPtr bool) {
5751
if value.Type().Implements(typeValuer) {
5852
return
5953
}
@@ -62,7 +56,7 @@ func structValue(m map[string]reflect.Value, value reflect.Value) {
6256
if value.IsNil() {
6357
return
6458
}
65-
structValue(m, value.Elem())
59+
findValueByName(value.Elem(), name, found, retPtr)
6660
case reflect.Struct:
6761
t := value.Type()
6862
for i := 0; i < t.NumField(); i++ {
@@ -81,10 +75,19 @@ func structValue(m map[string]reflect.Value, value reflect.Value) {
8175
tag = NameMapping(field.Name)
8276
}
8377
fieldValue := value.Field(i)
84-
if _, ok := m[tag]; !ok {
85-
m[tag] = fieldValue
78+
for i, want := range name {
79+
if want != tag {
80+
continue
81+
}
82+
if found[i] == nil {
83+
if retPtr {
84+
found[i] = fieldValue.Addr().Interface()
85+
} else {
86+
found[i] = fieldValue
87+
}
88+
}
8689
}
87-
structValue(m, fieldValue)
90+
findValueByName(fieldValue, name, found, retPtr)
8891
}
8992
}
9093
}

util_test.go

+22-18
Original file line numberDiff line numberDiff line change
@@ -56,54 +56,58 @@ func BenchmarkCamelCaseToSnakeCase(b *testing.B) {
5656
}
5757
}
5858

59-
func TestStructMap(t *testing.T) {
59+
func TestFindValueByName(t *testing.T) {
6060
for _, test := range []struct {
61-
in interface{}
62-
ok []string
63-
bad []string
61+
in interface{}
62+
name []string
63+
want []string
6464
}{
6565
{
6666
in: struct {
6767
CreatedAt time.Time
6868
}{},
69-
ok: []string{"created_at"},
69+
name: []string{"created_at"},
70+
want: []string{"created_at"},
7071
},
7172
{
7273
in: struct {
7374
intVal int
7475
}{},
75-
bad: []string{"int_val"},
76+
name: []string{"int_val"},
7677
},
7778
{
7879
in: struct {
7980
IntVal int `db:"test"`
8081
}{},
81-
ok: []string{"test"},
82-
bad: []string{"int_val"},
82+
name: []string{"test"},
83+
want: []string{"test"},
8384
},
8485
{
8586
in: struct {
8687
IntVal int `db:"-"`
8788
}{},
88-
bad: []string{"int_val"},
89+
name: []string{"int_val"},
8990
},
9091
{
9192
in: struct {
9293
Test1 struct {
9394
Test2 int
9495
}
9596
}{},
96-
ok: []string{"test2"},
97+
name: []string{"test2"},
98+
want: []string{"test2"},
9799
},
98100
} {
99-
m := structMap(reflect.ValueOf(test.in))
100-
for _, c := range test.ok {
101-
_, ok := m[c]
102-
require.True(t, ok)
103-
}
104-
for _, c := range test.bad {
105-
_, ok := m[c]
106-
require.False(t, ok)
101+
found := make([]interface{}, len(test.name))
102+
findValueByName(reflect.ValueOf(test.in), test.name, found, false)
103+
104+
var got []string
105+
for i, v := range found {
106+
if v != nil {
107+
got = append(got, test.name[i])
108+
}
107109
}
110+
111+
require.Equal(t, test.want, got)
108112
}
109113
}

0 commit comments

Comments
 (0)