@@ -20,6 +20,11 @@ pub enum AnalyzeError {
2020 first : proc_macro2:: Ident ,
2121 second : proc_macro2:: Ident ,
2222 } ,
23+ #[ error( "found cycle in compile-time bindings: {path}" ) ]
24+ CompileTimeBindingCycleDetected {
25+ root_ident : proc_macro2:: Ident ,
26+ path : String ,
27+ } ,
2328}
2429
2530/// This represents the finished second step in the processing pipeline.
@@ -129,13 +134,87 @@ pub(crate) fn analyze(
129134 } ) ;
130135 }
131136
137+ compile_time_bindings:: validate_compile_time_bindings ( & compile_time_bindings) ?;
138+
132139 Ok ( AnalyzedConditionalQueryAs {
133140 output_type : parsed. output_type ,
134141 query_string : parsed. query_string ,
135142 compile_time_bindings,
136143 } )
137144}
138145
146+ mod compile_time_bindings {
147+ use std:: collections:: { HashMap , HashSet } ;
148+
149+ use super :: { AnalyzeError , CompileTimeBinding } ;
150+
151+ pub ( super ) fn validate_compile_time_bindings (
152+ compile_time_bindings : & [ CompileTimeBinding ] ,
153+ ) -> Result < ( ) , AnalyzeError > {
154+ let mut bindings = HashMap :: new ( ) ;
155+
156+ for ( _, binding_values) in compile_time_bindings
157+ . iter ( )
158+ . flat_map ( |bindings| & bindings. arms )
159+ {
160+ for ( binding, value) in binding_values {
161+ let name = binding. to_string ( ) ;
162+
163+ let ( _, references) = bindings
164+ . entry ( name)
165+ . or_insert_with ( || ( binding, HashSet :: new ( ) ) ) ;
166+ fill_references ( references, & value. value ( ) ) ;
167+ }
168+ }
169+
170+ for ( name, ( ident, _) ) in & bindings {
171+ validate_references ( & bindings, ident, & [ ] , name) ?;
172+ }
173+
174+ Ok ( ( ) )
175+ }
176+
177+ fn fill_references ( references : & mut HashSet < String > , mut fragment : & str ) {
178+ while let Some ( start_idx) = fragment. find ( "{#" ) {
179+ fragment = & fragment[ start_idx + 2 ..] ;
180+ if let Some ( end_idx) = fragment. find ( "}" ) {
181+ references. insert ( fragment[ ..end_idx] . to_string ( ) ) ;
182+ fragment = & fragment[ end_idx + 1 ..] ;
183+ } else {
184+ break ;
185+ }
186+ }
187+ }
188+
189+ fn validate_references (
190+ bindings : & HashMap < String , ( & syn:: Ident , HashSet < String > ) > ,
191+ root_ident : & syn:: Ident ,
192+ path : & [ & str ] ,
193+ name : & str ,
194+ ) -> Result < ( ) , AnalyzeError > {
195+ let mut path = path. to_vec ( ) ;
196+ path. push ( name) ;
197+
198+ if path. iter ( ) . filter ( |component| * * component == name) . count ( ) > 1 {
199+ return Err ( AnalyzeError :: CompileTimeBindingCycleDetected {
200+ root_ident : root_ident. clone ( ) ,
201+ path : path. join ( " -> " ) ,
202+ } ) ;
203+ }
204+
205+ let Some ( ( _, references) ) = bindings. get ( name) else {
206+ // This error is caught and handled in all contexts in the expand stage.
207+ return Ok ( ( ) ) ;
208+ } ;
209+
210+ for reference in references {
211+ validate_references ( bindings, root_ident, & path, reference) ?;
212+ }
213+
214+ Ok ( ( ) )
215+ }
216+ }
217+
139218#[ cfg( test) ]
140219mod tests {
141220 use quote:: ToTokens ;
@@ -253,4 +332,27 @@ mod tests {
253332 AnalyzeError :: DuplicatedCompileTimeBindingsFound { .. }
254333 ) ) ;
255334 }
335+
336+ #[ test]
337+ fn compile_time_binding_cycle_detected ( ) {
338+ let parsed = syn:: parse_str :: < ParsedConditionalQueryAs > (
339+ r##"
340+ SomeType,
341+ r#"{#a}"#,
342+ #a = match _ {
343+ _ => "{#b}",
344+ },
345+ #b = match _ {
346+ _ => "{#a}",
347+ },
348+ "## ,
349+ )
350+ . unwrap ( ) ;
351+ let analyzed = analyze ( parsed. clone ( ) ) . unwrap_err ( ) ;
352+
353+ assert ! ( matches!(
354+ analyzed,
355+ AnalyzeError :: CompileTimeBindingCycleDetected { .. }
356+ ) ) ;
357+ }
256358}
0 commit comments