1+ use std:: collections:: HashSet ;
2+
13use syn:: spanned:: Spanned ;
24
35use crate :: parse:: ParsedConditionalQueryAs ;
@@ -13,6 +15,16 @@ pub enum AnalyzeError {
1315 values : usize ,
1416 values_span : proc_macro2:: Span ,
1517 } ,
18+ #[ error( "found two compile-time bindings with the same binding: {first}" ) ]
19+ DuplicatedCompileTimeBindingsFound {
20+ first : proc_macro2:: Ident ,
21+ second : proc_macro2:: Ident ,
22+ } ,
23+ #[ error( "found cycle in compile-time bindings: {path}" ) ]
24+ CompileTimeBindingCycleDetected {
25+ root_ident : proc_macro2:: Ident ,
26+ path : String ,
27+ } ,
1628}
1729
1830/// This represents the finished second step in the processing pipeline.
@@ -45,12 +57,26 @@ pub(crate) fn analyze(
4557) -> Result < AnalyzedConditionalQueryAs , AnalyzeError > {
4658 let mut compile_time_bindings = Vec :: new ( ) ;
4759
60+ let mut known_binding_names = HashSet :: new ( ) ;
61+
4862 for ( names, match_expr) in parsed. compile_time_bindings {
4963 let binding_names_span = names. span ( ) ;
5064 // Convert the OneOrPunctuated enum in a list of `Ident`s.
5165 // `One(T)` will be converted into a Vec with a single entry.
5266 let binding_names: Vec < _ > = names. into_iter ( ) . collect ( ) ;
5367
68+ // Find duplicate compile-time bindings.
69+ for name in & binding_names {
70+ let Some ( first) = known_binding_names. get ( name) else {
71+ known_binding_names. insert ( name. clone ( ) ) ;
72+ continue ;
73+ } ;
74+ return Err ( AnalyzeError :: DuplicatedCompileTimeBindingsFound {
75+ first : first. clone ( ) ,
76+ second : name. clone ( ) ,
77+ } ) ;
78+ }
79+
5480 let mut bindings = Vec :: new ( ) ;
5581 for arm in match_expr. arms {
5682 let arm_span = arm. body . span ( ) ;
@@ -108,13 +134,87 @@ pub(crate) fn analyze(
108134 } ) ;
109135 }
110136
137+ compile_time_bindings:: validate_compile_time_bindings ( & compile_time_bindings) ?;
138+
111139 Ok ( AnalyzedConditionalQueryAs {
112140 output_type : parsed. output_type ,
113141 query_string : parsed. query_string ,
114142 compile_time_bindings,
115143 } )
116144}
117145
146+ mod compile_time_bindings {
147+ use std:: collections:: { HashMap , HashSet } ;
148+
149+ use super :: { AnalyzeError , CompileTimeBinding } ;
150+
151+ pub ( super ) fn validate_compile_time_bindings (
152+ compile_time_bindings : & [ CompileTimeBinding ] ,
153+ ) -> Result < ( ) , AnalyzeError > {
154+ let mut bindings = HashMap :: new ( ) ;
155+
156+ for ( _, binding_values) in compile_time_bindings
157+ . iter ( )
158+ . flat_map ( |bindings| & bindings. arms )
159+ {
160+ for ( binding, value) in binding_values {
161+ let name = binding. to_string ( ) ;
162+
163+ let ( _, references) = bindings
164+ . entry ( name)
165+ . or_insert_with ( || ( binding, HashSet :: new ( ) ) ) ;
166+ fill_references ( references, & value. value ( ) ) ;
167+ }
168+ }
169+
170+ for ( name, ( ident, _) ) in & bindings {
171+ validate_references ( & bindings, ident, & [ ] , name) ?;
172+ }
173+
174+ Ok ( ( ) )
175+ }
176+
177+ fn fill_references ( references : & mut HashSet < String > , mut fragment : & str ) {
178+ while let Some ( start_idx) = fragment. find ( "{#" ) {
179+ fragment = & fragment[ start_idx + 2 ..] ;
180+ if let Some ( end_idx) = fragment. find ( "}" ) {
181+ references. insert ( fragment[ ..end_idx] . to_string ( ) ) ;
182+ fragment = & fragment[ end_idx + 1 ..] ;
183+ } else {
184+ break ;
185+ }
186+ }
187+ }
188+
189+ fn validate_references (
190+ bindings : & HashMap < String , ( & syn:: Ident , HashSet < String > ) > ,
191+ root_ident : & syn:: Ident ,
192+ path : & [ & str ] ,
193+ name : & str ,
194+ ) -> Result < ( ) , AnalyzeError > {
195+ let mut path = path. to_vec ( ) ;
196+ path. push ( name) ;
197+
198+ if path. iter ( ) . filter ( |component| * * component == name) . count ( ) > 1 {
199+ return Err ( AnalyzeError :: CompileTimeBindingCycleDetected {
200+ root_ident : root_ident. clone ( ) ,
201+ path : path. join ( " -> " ) ,
202+ } ) ;
203+ }
204+
205+ let Some ( ( _, references) ) = bindings. get ( name) else {
206+ // This error is caught and handled in all contexts in the expand stage.
207+ return Ok ( ( ) ) ;
208+ } ;
209+
210+ for reference in references {
211+ validate_references ( bindings, root_ident, & path, reference) ?;
212+ }
213+
214+ Ok ( ( ) )
215+ }
216+ }
217+
118218#[ cfg( test) ]
119219mod tests {
120220 use quote:: ToTokens ;
@@ -209,4 +309,50 @@ mod tests {
209309 }
210310 }
211311 }
312+
313+ #[ test]
314+ fn duplicate_compile_time_bindings ( ) {
315+ let parsed = syn:: parse_str :: < ParsedConditionalQueryAs > (
316+ r##"
317+ SomeType,
318+ r#"{#a}"#,
319+ #a = match _ {
320+ _ => "1",
321+ },
322+ #a = match _ {
323+ _ => "2",
324+ },
325+ "## ,
326+ )
327+ . unwrap ( ) ;
328+ let analyzed = analyze ( parsed. clone ( ) ) . unwrap_err ( ) ;
329+
330+ assert ! ( matches!(
331+ analyzed,
332+ AnalyzeError :: DuplicatedCompileTimeBindingsFound { .. }
333+ ) ) ;
334+ }
335+
336+ #[ test]
337+ fn compile_time_binding_cycle_detected ( ) {
338+ let parsed = syn:: parse_str :: < ParsedConditionalQueryAs > (
339+ r##"
340+ SomeType,
341+ r#"{#a}"#,
342+ #a = match _ {
343+ _ => "{#b}",
344+ },
345+ #b = match _ {
346+ _ => "{#a}",
347+ },
348+ "## ,
349+ )
350+ . unwrap ( ) ;
351+ let analyzed = analyze ( parsed. clone ( ) ) . unwrap_err ( ) ;
352+
353+ assert ! ( matches!(
354+ analyzed,
355+ AnalyzeError :: CompileTimeBindingCycleDetected { .. }
356+ ) ) ;
357+ }
212358}
0 commit comments