Skip to content

Commit 8993e79

Browse files
committed
feat: Enhance type inference scoring and parameter reference handling
This commit introduces a new function scoreParamRefForTypeInference to evaluate parameter references based on their context for type inference. The scoring system prioritizes contexts like explicit type casts and comparison operations, which are excellent for type inference, while assigning lower scores to contexts like IS NULL checks. Additionally, the uniqueParamRefs function is refactored to group parameter references by their number and select the reference with the best type inference context. This ensures more accurate type inference by leveraging the newly introduced scoring system. Unnamed parameters are also handled more effectively for non-dollar parameter styles.
1 parent b76563b commit 8993e79

1 file changed

Lines changed: 125 additions & 7 deletions

File tree

internal/compiler/parse.go

Lines changed: 125 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -165,29 +165,147 @@ func rangeVars(root ast.Node) []*ast.RangeVar {
165165
return vars
166166
}
167167

168+
// scoreParamRefForTypeInference scores a parameter reference based on how good
169+
// its context is for type inference. Higher scores indicate better contexts.
170+
func scoreParamRefForTypeInference(ref paramRef) int {
171+
if ref.parent == nil {
172+
return 0 // No context
173+
}
174+
175+
switch parent := ref.parent.(type) {
176+
case *ast.TypeCast:
177+
// Explicit type cast - excellent for type inference
178+
return 100
179+
180+
case *ast.A_Expr:
181+
// Expression context - quality depends on the operator
182+
if parent.Name != nil && len(parent.Name.Items) > 0 {
183+
if nameStr, ok := parent.Name.Items[0].(*ast.String); ok {
184+
switch nameStr.Str {
185+
case "=", "==", "!=", "<>", "<", "<=", ">", ">=":
186+
// Comparison operations - very good for type inference
187+
return 100
188+
case "+", "-", "*", "/", "%":
189+
// Mathematical operations - good for type inference
190+
return 90
191+
case "||":
192+
// String concatenation - good for type inference
193+
return 90
194+
case "~~", "!~~", "~~*", "!~~*":
195+
// LIKE operations - good for type inference
196+
return 90
197+
case "IS", "IS NOT":
198+
// IS NULL/IS NOT NULL - poor for type inference
199+
return 0
200+
default:
201+
return 50
202+
}
203+
}
204+
}
205+
return 50 // Default for A_Expr without clear operator
206+
207+
case *ast.BoolExpr:
208+
// Boolean expressions
209+
switch parent.Boolop {
210+
case ast.BoolExprTypeAnd, ast.BoolExprTypeOr:
211+
// Logical operations - still useful but lower priority
212+
return 60
213+
case ast.BoolExprTypeIsNull, ast.BoolExprTypeIsNotNull:
214+
// IS NULL/IS NOT NULL - poor for type inference
215+
return 20
216+
case ast.BoolExprTypeNot:
217+
// NOT operations - moderate for type inference
218+
return 50
219+
default:
220+
return 40
221+
}
222+
223+
case *ast.BetweenExpr:
224+
// BETWEEN expressions - good for type inference
225+
return 75
226+
227+
case *ast.FuncCall:
228+
// Function call context - depends on function, generally moderate
229+
// sqlc.narg() and similar functions have poor type inference context
230+
if parent.Funcname != nil && len(parent.Funcname.Items) > 0 {
231+
if nameStr, ok := parent.Funcname.Items[0].(*ast.String); ok {
232+
if nameStr.Str == "sqlc.narg" || nameStr.Str == "sqlc.arg" {
233+
// sqlc parameter functions in isolation - poor for type inference
234+
return 30
235+
}
236+
}
237+
}
238+
return 40
239+
240+
case *ast.ResTarget:
241+
// SELECT target or similar - can be good for type inference
242+
return 60
243+
244+
case *ast.In:
245+
// IN expression - good for type inference
246+
return 70
247+
248+
case *limitCount, *limitOffset:
249+
// LIMIT/OFFSET - known to be integer, good for type inference
250+
return 90
251+
252+
default:
253+
// Unknown context - assign low score
254+
return 10
255+
}
256+
}
257+
168258
func uniqueParamRefs(in []paramRef, dollar bool) []paramRef {
169-
m := make(map[int]bool, len(in))
170-
o := make([]paramRef, 0, len(in))
259+
// Group parameter references by their number
260+
paramGroups := make(map[int][]paramRef)
171261
for _, v := range in {
172-
if !m[v.ref.Number] {
173-
m[v.ref.Number] = true
174-
if v.ref.Number != 0 {
175-
o = append(o, v)
262+
if v.ref.Number != 0 {
263+
paramGroups[v.ref.Number] = append(paramGroups[v.ref.Number], v)
264+
}
265+
}
266+
267+
// For each parameter number, select the reference with the best type inference context
268+
o := make([]paramRef, 0, len(paramGroups))
269+
for _, refs := range paramGroups {
270+
if len(refs) == 1 {
271+
// Only one reference, use it
272+
o = append(o, refs[0])
273+
} else {
274+
// Multiple references, select the one with the highest score
275+
bestRef := refs[0]
276+
bestScore := scoreParamRefForTypeInference(refs[0])
277+
278+
for _, ref := range refs[1:] {
279+
score := scoreParamRefForTypeInference(ref)
280+
if score > bestScore {
281+
bestScore = score
282+
bestRef = ref
283+
}
176284
}
285+
o = append(o, bestRef)
177286
}
178287
}
288+
289+
// Handle unnamed parameters (number == 0) for non-dollar parameter styles
179290
if !dollar {
180291
start := 1
292+
usedNumbers := make(map[int]bool)
293+
for _, v := range o {
294+
usedNumbers[v.ref.Number] = true
295+
}
296+
181297
for _, v := range in {
182298
if v.ref.Number == 0 {
183-
for m[start] {
299+
for usedNumbers[start] {
184300
start++
185301
}
186302
v.ref.Number = start
303+
usedNumbers[start] = true
187304
o = append(o, v)
188305
}
189306
}
190307
}
308+
191309
return o
192310
}
193311

0 commit comments

Comments
 (0)