Skip to content

Commit ec4961e

Browse files
Add startAtParameterIndex argument to Convert() (#4)
This argument can be used to control which $X argument the generator starts at. This is useful for when you want to add your own arguments to your query as well. Co-authored-by: Koen Bollen <[email protected]>
1 parent de85a91 commit ec4961e

File tree

7 files changed

+104
-32
lines changed

7 files changed

+104
-32
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,14 @@ func main() {
3636

3737
// Convert a filter query to a WHERE clause and values:
3838
input := []byte(`{"title": "Jurassic Park"}`)
39-
where, values, err := converter.Convert(input)
39+
conditions, values, err := converter.Convert(input, 1) // 1 is the starting index for params, $1, $2, ...
4040
if err != nil {
4141
// handle error
4242
}
43-
fmt.Println(where, values) // ("title" = $1), ["Jurassic Park"]
43+
fmt.Println(conditions, values) // ("title" = $1), ["Jurassic Park"]
4444

4545
db, _ := sql.Open("postgres", "...")
46-
db.QueryRow("SELECT * FROM movies WHERE " + where, values...)
46+
db.QueryRow("SELECT * FROM movies WHERE " + conditions, values...)
4747
}
4848
```
4949
(See [examples/](examples/) for more examples)

examples/basic_test.go

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package examples
22

33
import (
4+
"database/sql"
45
"fmt"
56

67
"github.com/poki/mongodb-filter-to-postgres/filter"
@@ -16,7 +17,7 @@ func ExampleNewConverter() {
1617
"$gte": "2020-01-01T00:00:00Z"
1718
}
1819
}`
19-
conditions, values, err := converter.Convert([]byte(mongoFilterQuery))
20+
conditions, values, err := converter.Convert([]byte(mongoFilterQuery), 1)
2021
if err != nil {
2122
// handle error
2223
}
@@ -27,3 +28,37 @@ func ExampleNewConverter() {
2728
// (("created_at" >= $1) AND ("meta"->>'name' = $2))
2829
// []interface {}{"2020-01-01T00:00:00Z", "John"}
2930
}
31+
32+
func ExampleNewConverter_nonIsolatedConditions() {
33+
converter := filter.NewConverter()
34+
35+
mongoFilterQuery := `{
36+
"$or": [
37+
{ "email": "[email protected]" },
38+
{ "name": {"$regex": "^John.*^" },
39+
]
40+
}`
41+
conditions, values, err := converter.Convert([]byte(mongoFilterQuery), 3)
42+
if err != nil {
43+
// handle error
44+
}
45+
46+
query := `
47+
SELECT *
48+
FROM users
49+
WHERE
50+
disabled_at IS NOT NULL
51+
AND role = $1
52+
AND verified_at > $2
53+
AND ` + conditions + `
54+
LIMIT 10
55+
`
56+
57+
role := "user"
58+
verifiedAt := "2020-01-01T00:00:00Z"
59+
values = append([]any{role, verifiedAt}, values...)
60+
61+
db, _ := sql.Open("postgres", "...")
62+
rows := db.QueryRow(query, values...)
63+
_ = rows // actually use rows
64+
}

examples/readme_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ func ExampleNewConverter_readme() {
2424
}
2525
]
2626
}`
27-
conditions, values, err := converter.Convert([]byte(mongoFilterQuery))
27+
conditions, values, err := converter.Convert([]byte(mongoFilterQuery), 1)
2828
if err != nil {
2929
// handle error
3030
panic(err)

filter/converter.go

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,21 @@ func NewConverter(options ...Option) *Converter {
4242
}
4343

4444
// Convert converts a MongoDB filter query into SQL conditions and values.
45-
func (c *Converter) Convert(query []byte) (string, []any, error) {
45+
//
46+
// startAtParameterIndex is the index to start the parameter numbering at.
47+
// Passing X will make the first indexed parameter $X, the second $X+1, and so on.
48+
func (c *Converter) Convert(query []byte, startAtParameterIndex int) (conditions string, values []any, err error) {
49+
if startAtParameterIndex < 1 {
50+
return "", nil, fmt.Errorf("startAtParameterIndex must be greater than 0")
51+
}
52+
4653
var mongoFilter map[string]any
47-
err := json.Unmarshal(query, &mongoFilter)
54+
err = json.Unmarshal(query, &mongoFilter)
4855
if err != nil {
4956
return "", nil, err
5057
}
5158

52-
conditions, values, err := c.convertFilter(mongoFilter, 0)
59+
conditions, values, err = c.convertFilter(mongoFilter, startAtParameterIndex)
5360
if err != nil {
5461
return "", nil, err
5562
}
@@ -126,8 +133,8 @@ func (c *Converter) convertFilter(filter map[string]any, paramIndex int) (string
126133
if !isScalarSlice(v[operator]) {
127134
return "", nil, fmt.Errorf("invalid value for $in operator (must array of primatives): %v", v[operator])
128135
}
129-
paramIndex++
130136
inner = append(inner, fmt.Sprintf("(%s = ANY($%d))", c.columnName(key), paramIndex))
137+
paramIndex++
131138
if c.arrayDriver != nil {
132139
v[operator] = c.arrayDriver(v[operator])
133140
}
@@ -138,8 +145,8 @@ func (c *Converter) convertFilter(filter map[string]any, paramIndex int) (string
138145
if !ok {
139146
return "", nil, fmt.Errorf("unknown operator: %s", operator)
140147
}
141-
paramIndex++
142148
inner = append(inner, fmt.Sprintf("(%s %s $%d)", c.columnName(key), op, paramIndex))
149+
paramIndex++
143150
values = append(values, value)
144151
}
145152
}
@@ -149,8 +156,8 @@ func (c *Converter) convertFilter(filter map[string]any, paramIndex int) (string
149156
}
150157
conditions = append(conditions, innerResult)
151158
default:
152-
paramIndex++
153159
conditions = append(conditions, fmt.Sprintf("(%s = $%d)", c.columnName(key), paramIndex))
160+
paramIndex++
154161
values = append(values, value)
155162
}
156163
}

filter/converter_test.go

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,10 +194,11 @@ func TestConverter_Convert(t *testing.T) {
194194
nil,
195195
},
196196
}
197+
197198
for _, tt := range tests {
198199
t.Run(tt.name, func(t *testing.T) {
199200
c := filter.NewConverter(tt.option)
200-
conditions, values, err := c.Convert([]byte(tt.input))
201+
conditions, values, err := c.Convert([]byte(tt.input), 1)
201202
if err != nil && (tt.err == nil || err.Error() != tt.err.Error()) {
202203
t.Errorf("Converter.Convert() error = %v, wantErr %v", err, tt.err)
203204
return
@@ -215,3 +216,32 @@ func TestConverter_Convert(t *testing.T) {
215216
})
216217
}
217218
}
219+
220+
func TestConverter_Convert_startAtParameterIndex(t *testing.T) {
221+
c := filter.NewConverter()
222+
conditions, values, err := c.Convert([]byte(`{"name": "John", "password": "secret"}`), 10)
223+
if err != nil {
224+
t.Fatal(err)
225+
}
226+
if want := `(("name" = $10) AND ("password" = $11))`; conditions != want {
227+
t.Errorf("Converter.Convert() conditions = %v, want %v", conditions, want)
228+
}
229+
if !reflect.DeepEqual(values, []any{"John", "secret"}) {
230+
t.Errorf("Converter.Convert() values = %v, want %v", values, []any{"John"})
231+
}
232+
233+
_, _, err = c.Convert([]byte(`{"name": "John"}`), 0)
234+
if want := "startAtParameterIndex must be greater than 0"; err == nil || err.Error() != want {
235+
t.Errorf("Converter.Convert(..., 0) error = nil, wantErr %q", want)
236+
}
237+
238+
_, _, err = c.Convert([]byte(`{"name": "John"}`), -123)
239+
if want := "startAtParameterIndex must be greater than 0"; err == nil || err.Error() != want {
240+
t.Errorf("Converter.Convert(..., -123) error = nil, wantErr %q", want)
241+
}
242+
243+
_, _, err = c.Convert([]byte(`{"name": "John"}`), 1234551231231231231)
244+
if err != nil {
245+
t.Errorf("Converter.Convert(..., 1234551231231231231) error = %v, want nil", err)
246+
}
247+
}

fuzz/fuzz_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,15 @@ func FuzzConverter(f *testing.F) {
4040

4141
f.Fuzz(func(t *testing.T, in string) {
4242
c := filter.NewConverter(filter.WithArrayDriver(pq.Array))
43-
where, _, err := c.Convert([]byte(in))
44-
if err == nil && where != "" {
45-
j, err := pg_query.ParseToJSON("SELECT * FROM test WHERE 1 AND " + where)
43+
conditions, _, err := c.Convert([]byte(in), 1)
44+
if err == nil && conditions != "" {
45+
j, err := pg_query.ParseToJSON("SELECT * FROM test WHERE 1 AND " + conditions)
4646
if err != nil {
47-
t.Fatalf("%q %q %v", in, where, err)
47+
t.Fatalf("%q %q %v", in, conditions, err)
4848
}
4949

5050
if strings.Contains(j, "CommentStmt") {
51-
t.Fatal(where, "CommentStmt found")
51+
t.Fatal(conditions, "CommentStmt found")
5252
}
5353
}
5454
})

integration/postgres_test.go

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,15 @@ func TestIntegration_ReadmeExample(t *testing.T) {
6060
]
6161
}`
6262

63-
where, values, err := c.Convert([]byte(in))
63+
conditions, values, err := c.Convert([]byte(in), 1)
6464
if err != nil {
6565
t.Fatal(err)
6666
}
6767

6868
rows, err := db.Query(`
6969
SELECT id
7070
FROM lobbies
71-
WHERE `+where+`;
71+
WHERE `+conditions+`;
7272
`, values...)
7373
if err != nil {
7474
t.Fatal(err)
@@ -124,15 +124,15 @@ func TestIntegration_InAny_PQ(t *testing.T) {
124124
in := `{
125125
"role": { "$in": ["guest", "user"] }
126126
}`
127-
where, values, err := c.Convert([]byte(in))
127+
conditions, values, err := c.Convert([]byte(in), 1)
128128
if err != nil {
129129
t.Fatal(err)
130130
}
131131

132132
rows, err := db.Query(`
133133
SELECT id
134134
FROM users
135-
WHERE `+where+`;
135+
WHERE `+conditions+`;
136136
`, values...)
137137
if err != nil {
138138
t.Fatal(err)
@@ -189,15 +189,15 @@ func TestIntegration_InAny_PGX(t *testing.T) {
189189
in := `{
190190
"role": { "$in": ["guest", "user"] }
191191
}`
192-
where, values, err := c.Convert([]byte(in))
192+
conditions, values, err := c.Convert([]byte(in), 1)
193193
if err != nil {
194194
t.Fatal(err)
195195
}
196196

197197
rows, err := db.Query(ctx, `
198198
SELECT id
199199
FROM users
200-
WHERE `+where+`;
200+
WHERE `+conditions+`;
201201
`, values...)
202202
if err != nil {
203203
t.Fatal(err)
@@ -296,15 +296,15 @@ func TestIntegration_BasicOperators(t *testing.T) {
296296
for _, tt := range tests {
297297
t.Run(tt.name, func(t *testing.T) {
298298
c := filter.NewConverter(filter.WithArrayDriver(pq.Array))
299-
where, values, err := c.Convert([]byte(tt.input))
299+
conditions, values, err := c.Convert([]byte(tt.input), 1)
300300
if err != nil {
301301
t.Fatal(err)
302302
}
303303

304304
rows, err := db.Query(`
305305
SELECT id
306306
FROM players
307-
WHERE `+where+`;
307+
WHERE `+conditions+`;
308308
`, values...)
309309
if err != nil {
310310
if tt.expectedError == nil {
@@ -325,7 +325,7 @@ func TestIntegration_BasicOperators(t *testing.T) {
325325
}
326326

327327
if !reflect.DeepEqual(players, tt.expectedPlayers) {
328-
t.Fatalf("%q expected %v, got %v (where clause used: %q)", tt.input, tt.expectedPlayers, players, where)
328+
t.Fatalf("%q expected %v, got %v (conditions used: %q)", tt.input, tt.expectedPlayers, players, conditions)
329329
}
330330
})
331331
}
@@ -384,15 +384,15 @@ func TestIntegration_NestedJSONB(t *testing.T) {
384384
for _, tt := range tests {
385385
t.Run(tt.name, func(t *testing.T) {
386386
c := filter.NewConverter(filter.WithArrayDriver(pq.Array), filter.WithNestedJSONB("metadata", "name", "level", "class"))
387-
where, values, err := c.Convert([]byte(tt.input))
387+
conditions, values, err := c.Convert([]byte(tt.input), 1)
388388
if err != nil {
389389
t.Fatal(err)
390390
}
391391

392392
rows, err := db.Query(`
393393
SELECT id
394394
FROM players
395-
WHERE `+where+`;
395+
WHERE `+conditions+`;
396396
`, values...)
397397
if err != nil {
398398
t.Fatal(err)
@@ -408,7 +408,7 @@ func TestIntegration_NestedJSONB(t *testing.T) {
408408
}
409409

410410
if !reflect.DeepEqual(players, tt.expectedPlayers) {
411-
t.Fatalf("%q expected %v, got %v (where clause used: %q)", tt.input, tt.expectedPlayers, players, where)
411+
t.Fatalf("%q expected %v, got %v (conditions used: %q)", tt.input, tt.expectedPlayers, players, conditions)
412412
}
413413
})
414414
}
@@ -452,15 +452,15 @@ func TestIntegration_Logic(t *testing.T) {
452452
for _, tt := range tests {
453453
t.Run(tt.name, func(t *testing.T) {
454454
c := filter.NewConverter(filter.WithArrayDriver(pq.Array), filter.WithNestedJSONB("metadata", "name", "level", "class"))
455-
where, values, err := c.Convert([]byte(tt.input))
455+
conditions, values, err := c.Convert([]byte(tt.input), 1)
456456
if err != nil {
457457
t.Fatal(err)
458458
}
459459

460460
rows, err := db.Query(`
461461
SELECT id
462462
FROM players
463-
WHERE `+where+`;
463+
WHERE `+conditions+`;
464464
`, values...)
465465
if err != nil {
466466
t.Fatal(err)
@@ -476,7 +476,7 @@ func TestIntegration_Logic(t *testing.T) {
476476
}
477477

478478
if !reflect.DeepEqual(players, tt.expectedPlayers) {
479-
t.Fatalf("%q expected %v, got %v (where clause used: %q)", tt.input, tt.expectedPlayers, players, where)
479+
t.Fatalf("%q expected %v, got %v (conditions used: %q)", tt.input, tt.expectedPlayers, players, conditions)
480480
}
481481
})
482482
}

0 commit comments

Comments
 (0)