Skip to content

Commit c9077e5

Browse files
committed
Allow recursion using ,recurse
1 parent 5f3e10d commit c9077e5

File tree

2 files changed

+86
-58
lines changed

2 files changed

+86
-58
lines changed

sqlstruct.go

Lines changed: 71 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -9,75 +9,87 @@ the Go standard library's database/sql package.
99
The package matches struct field names to SQL query column names. A field can
1010
also specify a matching column with "sql" tag, if it's different from field
1111
name. Unexported fields or fields marked with `sql:"-"` are ignored, just like
12-
with "encoding/json" package.
12+
with "encoding/json" package. Fields marked with `sql:",recurse"` are treated as
13+
embedded structs and are recursively scanned.
1314
1415
For example:
1516
16-
type T struct {
17-
F1 string
18-
F2 string `sql:"field2"`
19-
F3 string `sql:"-"`
20-
}
17+
type T1 struct {
18+
F4 string `sql:"field4"`
19+
}
2120
22-
rows, err := db.Query(fmt.Sprintf("SELECT %s FROM tablename", sqlstruct.Columns(T{})))
23-
...
21+
type T2 struct {
22+
F5 string `sql:"field5"`
23+
}
24+
25+
type T struct {
26+
F1 string
27+
F2 string `sql:"field2"`
28+
F3 string `sql:"-"`
29+
fieldT1 T1 `sql:",recurse"`
30+
T2
31+
}
2432
25-
for rows.Next() {
26-
var t T
27-
err = sqlstruct.Scan(&t, rows)
28-
...
29-
}
33+
rows, err := db.Query(fmt.Sprintf("SELECT %s FROM tablename", sqlstruct.Columns(T{})))
34+
...
3035
31-
err = rows.Err() // get any errors encountered during iteration
36+
for rows.Next() {
37+
var t T
38+
err = sqlstruct.Scan(&t, rows)
39+
...
40+
}
41+
42+
err = rows.Err() // get any errors encountered during iteration
3243
3344
Aliased tables in a SQL statement may be scanned into a specific structure identified
3445
by the same alias, using the ColumnsAliased and ScanAliased functions:
3546
36-
type User struct {
37-
Id int `sql:"id"`
38-
Username string `sql:"username"`
39-
Email string `sql:"address"`
40-
Name string `sql:"name"`
41-
HomeAddress *Address `sql:"-"`
42-
}
43-
44-
type Address struct {
45-
Id int `sql:"id"`
46-
City string `sql:"city"`
47-
Street string `sql:"address"`
48-
}
49-
50-
...
51-
52-
var user User
53-
var address Address
54-
sql := `
47+
type User struct {
48+
Id int `sql:"id"`
49+
Username string `sql:"username"`
50+
Email string `sql:"address"`
51+
Name string `sql:"name"`
52+
HomeAddress *Address `sql:"-"`
53+
}
54+
55+
type Address struct {
56+
Id int `sql:"id"`
57+
City string `sql:"city"`
58+
Street string `sql:"address"`
59+
}
60+
61+
...
62+
63+
var user User
64+
var address Address
65+
sql := `
66+
5567
SELECT %s, %s FROM users AS u
5668
INNER JOIN address AS a ON a.id = u.address_id
5769
WHERE u.username = ?
5870
`
59-
sql = fmt.Sprintf(sql, sqlstruct.ColumnsAliased(*user, "u"), sqlstruct.ColumnsAliased(*address, "a"))
60-
rows, err := db.Query(sql, "gedi")
61-
if err != nil {
62-
log.Fatal(err)
63-
}
64-
defer rows.Close()
65-
if rows.Next() {
66-
err = sqlstruct.ScanAliased(&user, rows, "u")
67-
if err != nil {
68-
log.Fatal(err)
69-
}
70-
err = sqlstruct.ScanAliased(&address, rows, "a")
71-
if err != nil {
72-
log.Fatal(err)
73-
}
74-
user.HomeAddress = address
75-
}
76-
fmt.Printf("%+v", *user)
77-
// output: "{Id:1 Username:gedi Email:[email protected] Name:Gedas HomeAddress:0xc21001f570}"
78-
fmt.Printf("%+v", *user.HomeAddress)
79-
// output: "{Id:2 City:Vilnius Street:Plento 34}"
8071
72+
sql = fmt.Sprintf(sql, sqlstruct.ColumnsAliased(*user, "u"), sqlstruct.ColumnsAliased(*address, "a"))
73+
rows, err := db.Query(sql, "gedi")
74+
if err != nil {
75+
log.Fatal(err)
76+
}
77+
defer rows.Close()
78+
if rows.Next() {
79+
err = sqlstruct.ScanAliased(&user, rows, "u")
80+
if err != nil {
81+
log.Fatal(err)
82+
}
83+
err = sqlstruct.ScanAliased(&address, rows, "a")
84+
if err != nil {
85+
log.Fatal(err)
86+
}
87+
user.HomeAddress = address
88+
}
89+
fmt.Printf("%+v", *user)
90+
// output: "{Id:1 Username:gedi Email:[email protected] Name:Gedas HomeAddress:0xc21001f570}"
91+
fmt.Printf("%+v", *user.HomeAddress)
92+
// output: "{Id:2 City:Vilnius Street:Plento 34}"
8193
*/
8294
package sqlstruct
8395

@@ -97,7 +109,7 @@ import (
97109
// The default mapper converts field names to lower case. If instead you would prefer
98110
// field names converted to snake case, simply assign sqlstruct.ToSnakeCase to the variable:
99111
//
100-
// sqlstruct.NameMapper = sqlstruct.ToSnakeCase
112+
// sqlstruct.NameMapper = sqlstruct.ToSnakeCase
101113
//
102114
// Alternatively for a custom mapping, any func(string) string can be used instead.
103115
var NameMapper func(string) string = strings.ToLower
@@ -145,8 +157,8 @@ func getFieldInfo(typ reflect.Type) fieldInfo {
145157
continue
146158
}
147159

148-
// Handle embedded structs
149-
if f.Anonymous && f.Type.Kind() == reflect.Struct {
160+
// Handle embedded and recurse tagged structs
161+
if (f.Anonymous || strings.EqualFold(tag, ",recurse")) && f.Type.Kind() == reflect.Struct {
150162
for k, v := range getFieldInfo(f.Type) {
151163
finfo[k] = append([]int{i}, v...)
152164
}
@@ -198,7 +210,8 @@ func Columns(s interface{}) string {
198210
// given alias.
199211
//
200212
// For each field in the given struct it will generate a statement like:
201-
// alias.field AS alias_field
213+
//
214+
// alias.field AS alias_field
202215
//
203216
// It is intended to be used in conjunction with the ScanAliased function.
204217
func ColumnsAliased(s interface{}, alias string) string {

sqlstruct_test.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@ type testType2 struct {
2525
FieldSec string `sql:"field_sec"`
2626
}
2727

28+
type testType3 struct {
29+
FieldA string `sql:"field_a"`
30+
EmbeddedType EmbeddedType `sql:",recurse"`
31+
}
32+
2833
// testRows is a mock version of sql.Rows which can only scan strings
2934
type testRows struct {
3035
columns []string
@@ -67,6 +72,16 @@ func TestColumns(t *testing.T) {
6772
}
6873
}
6974

75+
func TestColumnDeep(t *testing.T) {
76+
var v testType3
77+
e := "field_a, field_e"
78+
c := Columns(v)
79+
80+
if c != e {
81+
t.Errorf("expected %q got %q", e, c)
82+
}
83+
}
84+
7085
func TestColumnsAliased(t *testing.T) {
7186
var t1 testType
7287
var t2 testType2

0 commit comments

Comments
 (0)