Skip to content

Commit 09775dc

Browse files
committed
sqlite: reject invalid qualified refs in correlated subqueries
Signed-off-by: Amirhossein Akhlaghpour <[email protected]>
1 parent a0c474f commit 09775dc

File tree

3 files changed

+33
-67
lines changed

3 files changed

+33
-67
lines changed

internal/compiler/analyze.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,8 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool)
182182
}
183183

184184
if c.conf.Engine == config.EngineSQLite {
185-
if err := validateSQLiteQualifiedColumnRefs(raw.Stmt); err != nil {
186-
return nil, check(err)
185+
if err := check(validate.ValidateSQLiteQualifiedColumnRefs(raw.Stmt)); err != nil {
186+
return nil, err
187187
}
188188
}
189189

internal/engine/sqlite/analyzer/analyze.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"github.com/sqlc-dev/sqlc/internal/sql/catalog"
1818
"github.com/sqlc-dev/sqlc/internal/sql/named"
1919
"github.com/sqlc-dev/sqlc/internal/sql/sqlerr"
20+
"github.com/sqlc-dev/sqlc/internal/sql/validate"
2021
)
2122

2223
type Analyzer struct {
@@ -76,6 +77,16 @@ func (a *Analyzer) Analyze(ctx context.Context, n ast.Node, query string, migrat
7677
}
7778
}
7879

80+
// SQLite-specific validation
81+
toValidate := n
82+
if raw, ok := n.(*ast.RawStmt); ok && raw != nil && raw.Stmt != nil {
83+
toValidate = raw.Stmt
84+
}
85+
86+
if err := validate.ValidateSQLiteQualifiedColumnRefs(toValidate); err != nil {
87+
return nil, err
88+
}
89+
7990
// Prepare the statement to get column and parameter information
8091
stmt, _, err := a.conn.Prepare(query)
8192
if err != nil {

internal/compiler/validator.go renamed to internal/sql/validate/sqlite_qualified_refs.go

Lines changed: 20 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package compiler
1+
package validate
22

33
import (
44
"fmt"
@@ -8,47 +8,28 @@ import (
88
"github.com/sqlc-dev/sqlc/internal/sql/sqlerr"
99
)
1010

11-
/*
12-
This file implements SQLite-specific validation for qualified column references.
13-
14-
Problem:
15-
SQLite allows invalid correlated references to pass parsing, e.g.
16-
17-
SELECT *
18-
FROM locations l
19-
WHERE EXISTS (
20-
SELECT 1
21-
FROM projects p
22-
WHERE p.id = location.project_id -- invalid: "location" not in scope
23-
)
24-
25-
SQLite itself errors at runtime, but sqlc historically accepted this query.
26-
This validator rejects such queries during compilation, matching SQLite behavior.
27-
*/
11+
// ValidateSQLiteQualifiedColumnRefs validates that qualified column references
12+
// only use visible tables/aliases in the current or outer SELECT scopes.
13+
func ValidateSQLiteQualifiedColumnRefs(root ast.Node) error {
14+
return validateNodeSQLite(root, nil)
15+
}
2816

29-
// scope represents the set of visible table names and aliases at a SELECT level.
30-
// parent enables correlated subqueries to see outer query tables.
3117
type scope struct {
3218
parent *scope
3319
names map[string]struct{}
3420
}
3521

3622
func newScope(parent *scope) *scope {
37-
return &scope{
38-
parent: parent,
39-
names: map[string]struct{}{},
40-
}
23+
return &scope{parent: parent, names: map[string]struct{}{}}
4124
}
4225

43-
// add registers a visible table name or alias.
4426
func (s *scope) add(name string) {
4527
if name == "" {
4628
return
4729
}
4830
s.names[name] = struct{}{}
4931
}
5032

51-
// has checks whether a name is visible in this scope or any parent scope.
5233
func (s *scope) has(name string) bool {
5334
for cur := s; cur != nil; cur = cur.parent {
5435
if _, ok := cur.names[name]; ok {
@@ -58,13 +39,19 @@ func (s *scope) has(name string) bool {
5839
return false
5940
}
6041

61-
// qualifierFromColumnRef extracts the table/alias portion of a qualified ref.
62-
//
63-
// Examples:
64-
//
65-
// a.b -> "a"
66-
// s.a.b -> "a" (schema.table.column)
67-
// b -> "" (unqualified)
42+
func stringSlice(list *ast.List) []string {
43+
if list == nil {
44+
return nil
45+
}
46+
out := make([]string, 0, len(list.Items))
47+
for _, it := range list.Items {
48+
if s, ok := it.(*ast.String); ok {
49+
out = append(out, s.Str)
50+
}
51+
}
52+
return out
53+
}
54+
6855
func qualifierFromColumnRef(ref *ast.ColumnRef) (string, bool) {
6956
if ref == nil || ref.Fields == nil {
7057
return "", false
@@ -80,7 +67,6 @@ func qualifierFromColumnRef(ref *ast.ColumnRef) (string, bool) {
8067
}
8168
}
8269

83-
// addFromItemToScope records tables and aliases introduced by FROM/JOIN items.
8470
func addFromItemToScope(sc *scope, n ast.Node) {
8571
switch t := n.(type) {
8672
case *ast.RangeVar:
@@ -90,59 +76,40 @@ func addFromItemToScope(sc *scope, n ast.Node) {
9076
if t.Alias != nil && t.Alias.Aliasname != nil {
9177
sc.add(*t.Alias.Aliasname)
9278
}
93-
9479
case *ast.JoinExpr:
9580
addFromItemToScope(sc, t.Larg)
9681
addFromItemToScope(sc, t.Rarg)
97-
9882
case *ast.RangeSubselect:
9983
if t.Alias != nil && t.Alias.Aliasname != nil {
10084
sc.add(*t.Alias.Aliasname)
10185
}
102-
10386
case *ast.RangeFunction:
10487
if t.Alias != nil && t.Alias.Aliasname != nil {
10588
sc.add(*t.Alias.Aliasname)
10689
}
10790
}
10891
}
10992

110-
// validateSQLiteQualifiedColumnRefs is the public entry point.
111-
// It validates that qualified column references only use visible tables/aliases.
112-
func validateSQLiteQualifiedColumnRefs(root ast.Node) error {
113-
return validateNodeSQLite(root, nil)
114-
}
115-
116-
// validateNodeSQLite validates a SELECT node and establishes a new scope.
117-
// Nested SELECTs receive the current scope as their parent.
11893
func validateNodeSQLite(node ast.Node, parent *scope) error {
11994
switch n := node.(type) {
12095
case *ast.SelectStmt:
12196
sc := newScope(parent)
122-
123-
// Collect visible names from FROM clause
12497
if n.FromClause != nil {
12598
for _, item := range n.FromClause.Items {
12699
addFromItemToScope(sc, item)
127100
}
128101
}
129-
130-
// Walk this SELECT subtree with the new scope
131102
return walkSQLite(n, sc)
132-
133103
default:
134-
// Only SELECTs introduce scopes; other nodes are irrelevant here.
135104
return nil
136105
}
137106
}
138107

139-
// walkSQLite recursively walks an AST node, validating ColumnRef qualifiers.
140108
func walkSQLite(node ast.Node, sc *scope) error {
141109
if node == nil {
142110
return nil
143111
}
144112

145-
// Pre-order validation: check ColumnRef immediately
146113
if ref, ok := node.(*ast.ColumnRef); ok {
147114
if qual, ok := qualifierFromColumnRef(ref); ok && !sc.has(qual) {
148115
return &sqlerr.Error{
@@ -153,27 +120,22 @@ func walkSQLite(node ast.Node, sc *scope) error {
153120
}
154121
}
155122

156-
// Explicit handling of subquery boundaries
157123
switch n := node.(type) {
158124
case *ast.SubLink:
159125
if n.Subselect != nil {
160126
return validateNodeSQLite(n.Subselect, sc)
161127
}
162128
return nil
163-
164129
case *ast.RangeSubselect:
165130
if n.Subquery != nil {
166131
return validateNodeSQLite(n.Subquery, sc)
167132
}
168133
return nil
169134
}
170135

171-
// Generic recursion for all other node types
172136
return walkSQLiteReflect(node, sc)
173137
}
174138

175-
// walkSQLiteReflect traverses AST nodes via reflection.
176-
// This avoids dependency on astutils.Walk, whose signature varies.
177139
func walkSQLiteReflect(node ast.Node, sc *scope) error {
178140
v := reflect.ValueOf(node)
179141
if v.Kind() == reflect.Pointer {
@@ -188,25 +150,21 @@ func walkSQLiteReflect(node ast.Node, sc *scope) error {
188150

189151
t := v.Type()
190152
for i := 0; i < v.NumField(); i++ {
191-
// Skip unexported fields
192153
if t.Field(i).PkgPath != "" {
193154
continue
194155
}
195-
196156
f := v.Field(i)
197157
if !f.IsValid() {
198158
continue
199159
}
200160

201-
// Dereference pointers
202161
for f.Kind() == reflect.Pointer {
203162
if f.IsNil() {
204163
goto next
205164
}
206165
f = f.Elem()
207166
}
208167

209-
// Handle ast.List
210168
if f.Type() == reflect.TypeOf(ast.List{}) {
211169
list := f.Addr().Interface().(*ast.List)
212170
for _, n := range list.Items {
@@ -217,7 +175,6 @@ func walkSQLiteReflect(node ast.Node, sc *scope) error {
217175
continue
218176
}
219177

220-
// Handle *ast.List
221178
if f.CanAddr() {
222179
if pl, ok := f.Addr().Interface().(**ast.List); ok && *pl != nil {
223180
for _, n := range (*pl).Items {
@@ -229,7 +186,6 @@ func walkSQLiteReflect(node ast.Node, sc *scope) error {
229186
}
230187
}
231188

232-
// Handle single ast.Node
233189
if f.CanInterface() {
234190
if n, ok := f.Interface().(ast.Node); ok {
235191
if err := walkSQLite(n, sc); err != nil {
@@ -239,7 +195,6 @@ func walkSQLiteReflect(node ast.Node, sc *scope) error {
239195
}
240196
}
241197

242-
// Handle slices of ast.Node
243198
if f.Kind() == reflect.Slice {
244199
for j := 0; j < f.Len(); j++ {
245200
elem := f.Index(j)

0 commit comments

Comments
 (0)