Skip to content

Commit 094328c

Browse files
authored
Merge pull request #26 from kyrias/catch-compile-time-binding-duplicates-and-cycles
Catch compile-time binding duplicates and cycles
2 parents fd9e21b + 70d98c5 commit 094328c

File tree

2 files changed

+154
-0
lines changed

2 files changed

+154
-0
lines changed

core/src/analyze.rs

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::collections::HashSet;
2+
13
use syn::spanned::Spanned;
24

35
use crate::parse::ParsedConditionalQueryAs;
@@ -13,6 +15,16 @@ pub enum AnalyzeError {
1315
values: usize,
1416
values_span: proc_macro2::Span,
1517
},
18+
#[error("found two compile-time bindings with the same binding: {first}")]
19+
DuplicatedCompileTimeBindingsFound {
20+
first: proc_macro2::Ident,
21+
second: proc_macro2::Ident,
22+
},
23+
#[error("found cycle in compile-time bindings: {path}")]
24+
CompileTimeBindingCycleDetected {
25+
root_ident: proc_macro2::Ident,
26+
path: String,
27+
},
1628
}
1729

1830
/// This represents the finished second step in the processing pipeline.
@@ -45,12 +57,26 @@ pub(crate) fn analyze(
4557
) -> Result<AnalyzedConditionalQueryAs, AnalyzeError> {
4658
let mut compile_time_bindings = Vec::new();
4759

60+
let mut known_binding_names = HashSet::new();
61+
4862
for (names, match_expr) in parsed.compile_time_bindings {
4963
let binding_names_span = names.span();
5064
// Convert the OneOrPunctuated enum in a list of `Ident`s.
5165
// `One(T)` will be converted into a Vec with a single entry.
5266
let binding_names: Vec<_> = names.into_iter().collect();
5367

68+
// Find duplicate compile-time bindings.
69+
for name in &binding_names {
70+
let Some(first) = known_binding_names.get(name) else {
71+
known_binding_names.insert(name.clone());
72+
continue;
73+
};
74+
return Err(AnalyzeError::DuplicatedCompileTimeBindingsFound {
75+
first: first.clone(),
76+
second: name.clone(),
77+
});
78+
}
79+
5480
let mut bindings = Vec::new();
5581
for arm in match_expr.arms {
5682
let arm_span = arm.body.span();
@@ -108,13 +134,87 @@ pub(crate) fn analyze(
108134
});
109135
}
110136

137+
compile_time_bindings::validate_compile_time_bindings(&compile_time_bindings)?;
138+
111139
Ok(AnalyzedConditionalQueryAs {
112140
output_type: parsed.output_type,
113141
query_string: parsed.query_string,
114142
compile_time_bindings,
115143
})
116144
}
117145

146+
mod compile_time_bindings {
147+
use std::collections::{HashMap, HashSet};
148+
149+
use super::{AnalyzeError, CompileTimeBinding};
150+
151+
pub(super) fn validate_compile_time_bindings(
152+
compile_time_bindings: &[CompileTimeBinding],
153+
) -> Result<(), AnalyzeError> {
154+
let mut bindings = HashMap::new();
155+
156+
for (_, binding_values) in compile_time_bindings
157+
.iter()
158+
.flat_map(|bindings| &bindings.arms)
159+
{
160+
for (binding, value) in binding_values {
161+
let name = binding.to_string();
162+
163+
let (_, references) = bindings
164+
.entry(name)
165+
.or_insert_with(|| (binding, HashSet::new()));
166+
fill_references(references, &value.value());
167+
}
168+
}
169+
170+
for (name, (ident, _)) in &bindings {
171+
validate_references(&bindings, ident, &[], name)?;
172+
}
173+
174+
Ok(())
175+
}
176+
177+
fn fill_references(references: &mut HashSet<String>, mut fragment: &str) {
178+
while let Some(start_idx) = fragment.find("{#") {
179+
fragment = &fragment[start_idx + 2..];
180+
if let Some(end_idx) = fragment.find("}") {
181+
references.insert(fragment[..end_idx].to_string());
182+
fragment = &fragment[end_idx + 1..];
183+
} else {
184+
break;
185+
}
186+
}
187+
}
188+
189+
fn validate_references(
190+
bindings: &HashMap<String, (&syn::Ident, HashSet<String>)>,
191+
root_ident: &syn::Ident,
192+
path: &[&str],
193+
name: &str,
194+
) -> Result<(), AnalyzeError> {
195+
let mut path = path.to_vec();
196+
path.push(name);
197+
198+
if path.iter().filter(|component| **component == name).count() > 1 {
199+
return Err(AnalyzeError::CompileTimeBindingCycleDetected {
200+
root_ident: root_ident.clone(),
201+
path: path.join(" -> "),
202+
});
203+
}
204+
205+
let Some((_, references)) = bindings.get(name) else {
206+
// This error is caught and handled in all contexts in the expand stage.
207+
return Ok(());
208+
};
209+
210+
for reference in references {
211+
validate_references(bindings, root_ident, &path, reference)?;
212+
}
213+
214+
Ok(())
215+
}
216+
}
217+
118218
#[cfg(test)]
119219
mod tests {
120220
use quote::ToTokens;
@@ -209,4 +309,50 @@ mod tests {
209309
}
210310
}
211311
}
312+
313+
#[test]
314+
fn duplicate_compile_time_bindings() {
315+
let parsed = syn::parse_str::<ParsedConditionalQueryAs>(
316+
r##"
317+
SomeType,
318+
r#"{#a}"#,
319+
#a = match _ {
320+
_ => "1",
321+
},
322+
#a = match _ {
323+
_ => "2",
324+
},
325+
"##,
326+
)
327+
.unwrap();
328+
let analyzed = analyze(parsed.clone()).unwrap_err();
329+
330+
assert!(matches!(
331+
analyzed,
332+
AnalyzeError::DuplicatedCompileTimeBindingsFound { .. }
333+
));
334+
}
335+
336+
#[test]
337+
fn compile_time_binding_cycle_detected() {
338+
let parsed = syn::parse_str::<ParsedConditionalQueryAs>(
339+
r##"
340+
SomeType,
341+
r#"{#a}"#,
342+
#a = match _ {
343+
_ => "{#b}",
344+
},
345+
#b = match _ {
346+
_ => "{#a}",
347+
},
348+
"##,
349+
)
350+
.unwrap();
351+
let analyzed = analyze(parsed.clone()).unwrap_err();
352+
353+
assert!(matches!(
354+
analyzed,
355+
AnalyzeError::CompileTimeBindingCycleDetected { .. }
356+
));
357+
}
212358
}

macros/src/lib.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,14 @@ pub fn conditional_query_as(input: proc_macro::TokenStream) -> proc_macro::Token
4141
names = names_span => "number of names: {}", names;
4242
values = values_span => "number of values: {}", values;
4343
),
44+
AnalyzeError::DuplicatedCompileTimeBindingsFound { first: _, second } => {
45+
abort!(second.span(), "found duplicate compile-time binding")
46+
}
47+
AnalyzeError::CompileTimeBindingCycleDetected { root_ident, path } => abort!(
48+
root_ident.span(),
49+
"detected compile-time binding cycle: {}",
50+
path
51+
),
4452
},
4553
Err(Error::ExpandError(err)) => match err {
4654
// TODO: Make this span point at the binding reference. Requires https://github.com/rust-lang/rust/issues/54725

0 commit comments

Comments
 (0)