Skip to content

Commit 70d98c5

Browse files
committed
Detect compile-time binding cycles
Signed-off-by: Johannes Löthberg <[email protected]>
1 parent a02f5b9 commit 70d98c5

File tree

2 files changed

+107
-0
lines changed

2 files changed

+107
-0
lines changed

core/src/analyze.rs

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@ pub enum AnalyzeError {
2020
first: proc_macro2::Ident,
2121
second: proc_macro2::Ident,
2222
},
23+
#[error("found cycle in compile-time bindings: {path}")]
24+
CompileTimeBindingCycleDetected {
25+
root_ident: proc_macro2::Ident,
26+
path: String,
27+
},
2328
}
2429

2530
/// This represents the finished second step in the processing pipeline.
@@ -129,13 +134,87 @@ pub(crate) fn analyze(
129134
});
130135
}
131136

137+
compile_time_bindings::validate_compile_time_bindings(&compile_time_bindings)?;
138+
132139
Ok(AnalyzedConditionalQueryAs {
133140
output_type: parsed.output_type,
134141
query_string: parsed.query_string,
135142
compile_time_bindings,
136143
})
137144
}
138145

146+
mod compile_time_bindings {
147+
use std::collections::{HashMap, HashSet};
148+
149+
use super::{AnalyzeError, CompileTimeBinding};
150+
151+
pub(super) fn validate_compile_time_bindings(
152+
compile_time_bindings: &[CompileTimeBinding],
153+
) -> Result<(), AnalyzeError> {
154+
let mut bindings = HashMap::new();
155+
156+
for (_, binding_values) in compile_time_bindings
157+
.iter()
158+
.flat_map(|bindings| &bindings.arms)
159+
{
160+
for (binding, value) in binding_values {
161+
let name = binding.to_string();
162+
163+
let (_, references) = bindings
164+
.entry(name)
165+
.or_insert_with(|| (binding, HashSet::new()));
166+
fill_references(references, &value.value());
167+
}
168+
}
169+
170+
for (name, (ident, _)) in &bindings {
171+
validate_references(&bindings, ident, &[], name)?;
172+
}
173+
174+
Ok(())
175+
}
176+
177+
fn fill_references(references: &mut HashSet<String>, mut fragment: &str) {
178+
while let Some(start_idx) = fragment.find("{#") {
179+
fragment = &fragment[start_idx + 2..];
180+
if let Some(end_idx) = fragment.find("}") {
181+
references.insert(fragment[..end_idx].to_string());
182+
fragment = &fragment[end_idx + 1..];
183+
} else {
184+
break;
185+
}
186+
}
187+
}
188+
189+
fn validate_references(
190+
bindings: &HashMap<String, (&syn::Ident, HashSet<String>)>,
191+
root_ident: &syn::Ident,
192+
path: &[&str],
193+
name: &str,
194+
) -> Result<(), AnalyzeError> {
195+
let mut path = path.to_vec();
196+
path.push(name);
197+
198+
if path.iter().filter(|component| **component == name).count() > 1 {
199+
return Err(AnalyzeError::CompileTimeBindingCycleDetected {
200+
root_ident: root_ident.clone(),
201+
path: path.join(" -> "),
202+
});
203+
}
204+
205+
let Some((_, references)) = bindings.get(name) else {
206+
// This error is caught and handled in all contexts in the expand stage.
207+
return Ok(());
208+
};
209+
210+
for reference in references {
211+
validate_references(bindings, root_ident, &path, reference)?;
212+
}
213+
214+
Ok(())
215+
}
216+
}
217+
139218
#[cfg(test)]
140219
mod tests {
141220
use quote::ToTokens;
@@ -253,4 +332,27 @@ mod tests {
253332
AnalyzeError::DuplicatedCompileTimeBindingsFound { .. }
254333
));
255334
}
335+
336+
#[test]
337+
fn compile_time_binding_cycle_detected() {
338+
let parsed = syn::parse_str::<ParsedConditionalQueryAs>(
339+
r##"
340+
SomeType,
341+
r#"{#a}"#,
342+
#a = match _ {
343+
_ => "{#b}",
344+
},
345+
#b = match _ {
346+
_ => "{#a}",
347+
},
348+
"##,
349+
)
350+
.unwrap();
351+
let analyzed = analyze(parsed.clone()).unwrap_err();
352+
353+
assert!(matches!(
354+
analyzed,
355+
AnalyzeError::CompileTimeBindingCycleDetected { .. }
356+
));
357+
}
256358
}

macros/src/lib.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ pub fn conditional_query_as(input: proc_macro::TokenStream) -> proc_macro::Token
4444
AnalyzeError::DuplicatedCompileTimeBindingsFound { first: _, second } => {
4545
abort!(second.span(), "found duplicate compile-time binding")
4646
}
47+
AnalyzeError::CompileTimeBindingCycleDetected { root_ident, path } => abort!(
48+
root_ident.span(),
49+
"detected compile-time binding cycle: {}",
50+
path
51+
),
4752
},
4853
Err(Error::ExpandError(err)) => match err {
4954
// TODO: Make this span point at the binding reference. Requires https://github.com/rust-lang/rust/issues/54725

0 commit comments

Comments
 (0)