Skip to content

Commit ab72f1d

Browse files
committed
add test cases
1 parent a399138 commit ab72f1d

22 files changed

+397
-197
lines changed

sqlx/and.go

Lines changed: 0 additions & 55 deletions
This file was deleted.

sqlx/as.go

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,22 @@
11
package sqlx
22

3-
import "fmt"
3+
import (
4+
"fmt"
5+
)
46

57
type AsClause struct {
6-
depth uint
78
stmt Sqlizer
89
alias string
910
}
1011

1112
func As(stmt Sqlizer, alias string) *AsClause {
12-
return &AsClause{0, stmt, alias}
13+
return &AsClause{stmt, alias}
1314
}
1415

1516
func (self AsClause) Sql() string {
16-
return fmt.Sprintf("%s as \"%s\"", self.stmt.Sql(), self.alias)
17+
return fmt.Sprintf(`%s as "%s"`, self.stmt.Sql(), self.alias)
1718
}
1819

19-
func (self AsClause) SqlPretty() string {
20-
return fmt.Sprintf("%s as \"%s\"", self.stmt.SqlPretty(), self.alias)
21-
}
22-
23-
func (self *AsClause) setDepth(depth uint) {
24-
self.depth = depth
25-
self.stmt.setDepth(depth + 1)
20+
func (self AsClause) SqlPretty(indent string) string {
21+
return fmt.Sprintf(`%s as "%s"`, self.stmt.SqlPretty(indent), self.alias)
2622
}

sqlx/columns.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package sqlx
2+
3+
import "strings"
4+
5+
type Columns []Sqlizer
6+
7+
func (self Columns) Sql() string {
8+
parts := []string{}
9+
10+
for _, column := range self {
11+
parts = append(parts, column.Sql())
12+
}
13+
14+
return strings.Join(parts, ", ")
15+
}
16+
17+
func (self Columns) SqlPretty(indent string) string {
18+
parts := []string{}
19+
20+
for i, column := range self {
21+
lines := strings.Split(column.SqlPretty(indent), "\n")
22+
23+
for _, line := range lines {
24+
parts = append(parts, indent+line)
25+
}
26+
27+
if i < len(self)-1 {
28+
parts[len(parts)-1] += ","
29+
}
30+
}
31+
32+
return strings.Join(parts, "\n")
33+
}

sqlx/or.go

Lines changed: 0 additions & 55 deletions
This file was deleted.

sqlx/raw.go

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,17 @@
11
package sqlx
22

33
type RawStatement struct {
4-
depth uint
5-
stmt string
4+
stmt string
65
}
76

87
func Raw(stmt string) *RawStatement {
9-
return &RawStatement{0, stmt}
8+
return &RawStatement{stmt}
109
}
1110

1211
func (self RawStatement) Sql() string {
1312
return self.stmt
1413
}
1514

16-
func (self RawStatement) SqlPretty() string {
15+
func (self RawStatement) SqlPretty(indent string) string {
1716
return self.stmt
1817
}
19-
20-
func (self *RawStatement) setDepth(depth uint) {
21-
self.depth = depth
22-
}

sqlx/select.go

Lines changed: 68 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ import (
77

88
type SelectStatement struct {
99
depth uint
10-
columns []Sqlizer
10+
columns Columns
1111
from Sqlizer
12-
where Sqlizer
12+
where *WhereClause
1313
}
1414

1515
func Select(columns ...string) *SelectStatement {
@@ -38,7 +38,7 @@ func (self *SelectStatement) ColumnAs(column string, alias string) *SelectStatem
3838
}
3939

4040
func (self *SelectStatement) ColumnSelect(stmt *SelectStatement, alias string) *SelectStatement {
41-
stmt.setDepth(self.depth + 1)
41+
stmt.depth = self.depth + 1
4242
self.columns = append(self.columns, As(stmt, alias))
4343
return self
4444
}
@@ -54,21 +54,54 @@ func (self *SelectStatement) FromSelect(stmt *SelectStatement, alias string) *Se
5454
return self
5555
}
5656

57-
func (self *SelectStatement) Where(where Sqlizer) *SelectStatement {
58-
where.setDepth(self.depth)
59-
self.where = where
57+
func (self *SelectStatement) Where(predicate any) *SelectStatement {
58+
switch v := predicate.(type) {
59+
case string:
60+
self.where = Where(Sql{v})
61+
case Sqlizer:
62+
self.where = Where(Sql{v})
63+
}
64+
6065
return self
6166
}
6267

63-
func (self SelectStatement) Sql() string {
64-
parts := []string{"SELECT"}
65-
columns := []string{}
68+
func (self *SelectStatement) And(predicates ...any) *SelectStatement {
69+
for _, predicate := range predicates {
70+
switch v := predicate.(type) {
71+
case *SelectStatement:
72+
v.depth = self.depth + 1
73+
break
74+
case *WhereClause:
75+
v.depth = self.depth + 1
76+
break
77+
}
78+
79+
self.where.And(predicate)
80+
}
6681

67-
for _, column := range self.columns {
68-
columns = append(columns, column.Sql())
82+
return self
83+
}
84+
85+
func (self *SelectStatement) Or(predicates ...any) *SelectStatement {
86+
for _, predicate := range predicates {
87+
switch v := predicate.(type) {
88+
case *SelectStatement:
89+
v.depth = self.depth + 1
90+
break
91+
case *WhereClause:
92+
v.depth = self.depth + 1
93+
break
94+
}
95+
96+
self.where.Or(predicate)
6997
}
7098

71-
parts = append(parts, strings.Join(columns, ", "))
99+
return self
100+
}
101+
102+
func (self SelectStatement) Sql() string {
103+
parts := []string{"SELECT"}
104+
parts = append(parts, self.columns.Sql())
72105

73106
if self.from != nil {
74107
parts = append(parts, "FROM", self.from.Sql())
@@ -89,35 +122,44 @@ func (self SelectStatement) Sql() string {
89122
return sql
90123
}
91124

92-
func (self SelectStatement) SqlPretty() string {
93-
parts := []string{"SELECT"}
94-
columns := []string{}
125+
func (self SelectStatement) SqlPretty(indent string) string {
126+
parts := []string{}
95127

96-
for _, column := range self.columns {
97-
columns = append(columns, "\t"+column.Sql())
128+
if self.depth > 0 {
129+
parts = append(parts, "(")
98130
}
99131

100-
parts = append(parts, strings.Join(columns, ",\n"))
132+
parts = append(parts, "SELECT")
133+
parts = append(
134+
parts,
135+
strings.Split(self.columns.SqlPretty(indent), "\n")...,
136+
)
101137

102138
if self.from != nil {
103-
parts = append(parts, "FROM "+self.from.Sql())
139+
lines := strings.Split(self.from.SqlPretty(indent), "\n")
140+
parts = append(parts, "FROM "+lines[0])
141+
parts = append(parts, lines[1:]...)
104142
}
105143

106144
if self.where != nil {
107-
parts = append(parts, "WHERE"+self.where.Sql())
145+
lines := strings.Split(self.where.SqlPretty(indent), "\n")
146+
parts = append(parts, "WHERE "+lines[0])
147+
parts = append(parts, lines[1:]...)
148+
}
149+
150+
if self.depth > 0 {
151+
for i := 1; i < len(parts); i++ {
152+
parts[i] = indent + parts[i]
153+
}
154+
155+
parts = append(parts, ")")
108156
}
109157

110158
sql := strings.Join(parts, "\n")
111159

112160
if self.depth == 0 {
113161
sql += ";"
114-
} else {
115-
sql = fmt.Sprintf("(%s)", sql)
116162
}
117163

118164
return sql
119165
}
120-
121-
func (self *SelectStatement) setDepth(depth uint) {
122-
self.depth = depth
123-
}

0 commit comments

Comments
 (0)