Skip to content

Commit 9552a6a

Browse files
committed
allow more loops to be converted to iter
1 parent f711a7c commit 9552a6a

File tree

10 files changed

+219
-108
lines changed

10 files changed

+219
-108
lines changed

lints/for_each/src/lib.rs

Lines changed: 14 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,24 @@
33
#![feature(let_chains)]
44

55
extern crate rustc_errors;
6+
extern crate rustc_hash;
67
extern crate rustc_hir;
8+
extern crate rustc_hir_typeck;
9+
extern crate rustc_infer;
10+
extern crate rustc_middle;
711

8-
use clippy_utils::{higher::ForLoop, ty::is_copy};
12+
mod variable_check;
13+
14+
use clippy_utils::higher::ForLoop;
915
use rustc_errors::Applicability;
1016
use rustc_hir::{
11-
def::Res,
1217
intravisit::{walk_expr, Visitor},
13-
Expr, ExprKind, HirId, Node,
18+
Expr, ExprKind,
1419
};
1520
use rustc_lint::{LateContext, LateLintPass, LintContext};
1621
use utils::span_to_snippet_macro;
22+
use variable_check::check_variables;
23+
1724
dylint_linting::declare_late_lint! {
1825
/// ### What it does
1926
/// parallelize iterators using rayon
@@ -48,16 +55,9 @@ impl<'tcx> LateLintPass<'tcx> for ForEach {
4855
let src_map = cx.sess().source_map();
4956

5057
// Make sure we ignore cases that require a try_foreach
51-
let mut validator = Validator {
52-
is_valid: true,
53-
is_arg: true,
54-
arg_variables: vec![],
55-
cx,
56-
};
57-
validator.visit_expr(arg);
58-
validator.is_arg = false;
58+
let mut validator = Validator { is_valid: true };
5959
validator.visit_expr(body);
60-
if !validator.is_valid {
60+
if !validator.is_valid || !check_variables(cx, body) {
6161
return;
6262
}
6363
// Check whether the iter is explicit
@@ -122,38 +122,17 @@ impl Visitor<'_> for IterExplorer {
122122
}
123123
}
124124

125-
struct Validator<'a, 'tcx> {
125+
struct Validator {
126126
is_valid: bool,
127-
cx: &'a LateContext<'tcx>,
128-
is_arg: bool,
129-
arg_variables: Vec<HirId>,
130127
}
131128

132-
impl<'a, 'tcx> Visitor<'_> for Validator<'a, 'tcx> {
129+
impl Visitor<'_> for Validator {
133130
fn visit_expr(&mut self, ex: &Expr) {
134131
match &ex.kind {
135132
ExprKind::Loop(_, _, _, _)
136133
| ExprKind::Closure(_)
137134
| ExprKind::Ret(_)
138135
| ExprKind::Break(_, _) => self.is_valid = false,
139-
ExprKind::Path(ref path) => {
140-
if let Res::Local(hir_id) = self.cx.typeck_results().qpath_res(path, ex.hir_id) {
141-
if let Node::Local(local) = self.cx.tcx.parent_hir_node(hir_id) {
142-
if self.is_arg {
143-
self.arg_variables.push(local.hir_id);
144-
return;
145-
}
146-
147-
if let Some(expr) = local.init
148-
&& !self.arg_variables.contains(&local.hir_id)
149-
{
150-
let ty = self.cx.typeck_results().expr_ty(expr);
151-
self.is_valid &= is_copy(self.cx, ty)
152-
}
153-
}
154-
}
155-
}
156-
157136
_ => walk_expr(self, ex),
158137
}
159138
}

lints/for_each/src/variable_check.rs

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
use clippy_utils::visitors::for_each_expr_with_closures;
2+
use rustc_hash::FxHashSet;
3+
use rustc_hir as hir;
4+
use rustc_hir_typeck::expr_use_visitor as euv;
5+
use rustc_infer::infer::{InferCtxt, TyCtxtInferExt};
6+
use rustc_lint::LateContext;
7+
use rustc_middle::{
8+
hir::map::associated_body,
9+
mir::FakeReadCause,
10+
ty::{self, Ty, UpvarId, UpvarPath},
11+
};
12+
use std::{collections::HashSet, ops::ControlFlow};
13+
14+
pub struct MutablyUsedVariablesCtxt<'tcx> {
15+
all_vars: FxHashSet<Ty<'tcx>>,
16+
prev_bind: Option<hir::HirId>,
17+
/// In async functions, the inner AST is composed of multiple layers until we reach the code
18+
/// defined by the user. Because of that, some variables are marked as mutably borrowed even
19+
/// though they're not. This field lists the `HirId` that should not be considered as mutable
20+
/// use of a variable.
21+
prev_move_to_closure: hir::HirIdSet,
22+
}
23+
24+
// TODO: remove repetation is this two function almost identical
25+
pub fn check_variables<'tcx>(cx: &LateContext<'tcx>, ex: &'tcx hir::Expr) -> bool {
26+
let MutablyUsedVariablesCtxt { all_vars, .. } = {
27+
let body_owner = ex.hir_id.owner.def_id;
28+
29+
let mut ctx = MutablyUsedVariablesCtxt {
30+
all_vars: FxHashSet::default(),
31+
prev_bind: None,
32+
prev_move_to_closure: hir::HirIdSet::default(),
33+
};
34+
let infcx = cx.tcx.infer_ctxt().build();
35+
euv::ExprUseVisitor::new(
36+
&mut ctx,
37+
&infcx,
38+
body_owner,
39+
cx.param_env,
40+
cx.typeck_results(),
41+
)
42+
.walk_expr(ex);
43+
44+
let mut checked_closures = FxHashSet::default();
45+
46+
// We retrieve all the closures declared in the function because they will not be found
47+
// by `euv::Delegate`.
48+
let mut closures: FxHashSet<hir::def_id::LocalDefId> = FxHashSet::default();
49+
for_each_expr_with_closures(cx, ex, |expr| {
50+
if let hir::ExprKind::Closure(closure) = expr.kind {
51+
closures.insert(closure.def_id);
52+
}
53+
ControlFlow::<()>::Continue(())
54+
});
55+
check_closures(&mut ctx, cx, &infcx, &mut checked_closures, closures);
56+
57+
ctx
58+
};
59+
60+
all_vars.is_empty()
61+
}
62+
63+
pub fn check_closures<'tcx, S: ::std::hash::BuildHasher>(
64+
ctx: &mut MutablyUsedVariablesCtxt<'tcx>,
65+
cx: &LateContext<'tcx>,
66+
infcx: &InferCtxt<'tcx>,
67+
checked_closures: &mut HashSet<hir::def_id::LocalDefId, S>,
68+
closures: HashSet<hir::def_id::LocalDefId, S>,
69+
) {
70+
let hir = cx.tcx.hir();
71+
for closure in closures {
72+
if !checked_closures.insert(closure) {
73+
continue;
74+
}
75+
ctx.prev_bind = None;
76+
ctx.prev_move_to_closure.clear();
77+
if let Some(body) = cx
78+
.tcx
79+
.opt_hir_node_by_def_id(closure)
80+
.and_then(associated_body)
81+
.map(|(_, body_id)| hir.body(body_id))
82+
{
83+
euv::ExprUseVisitor::new(ctx, infcx, closure, cx.param_env, cx.typeck_results())
84+
.consume_body(body);
85+
}
86+
}
87+
}
88+
89+
impl<'tcx> euv::Delegate<'tcx> for MutablyUsedVariablesCtxt<'tcx> {
90+
#[allow(clippy::if_same_then_else)]
91+
fn consume(&mut self, cmt: &euv::PlaceWithHirId<'tcx>, _: hir::HirId) {
92+
if let euv::Place {
93+
base:
94+
euv::PlaceBase::Local(_)
95+
| euv::PlaceBase::Upvar(UpvarId {
96+
var_path: UpvarPath { hir_id: _ },
97+
..
98+
}),
99+
base_ty,
100+
..
101+
} = &cmt.place
102+
{
103+
self.all_vars.insert(*base_ty);
104+
}
105+
}
106+
107+
#[allow(clippy::if_same_then_else)]
108+
fn borrow(&mut self, _: &euv::PlaceWithHirId<'tcx>, _: hir::HirId, _: ty::BorrowKind) {}
109+
110+
fn mutate(&mut self, _: &euv::PlaceWithHirId<'tcx>, _id: hir::HirId) {}
111+
fn copy(&mut self, _: &euv::PlaceWithHirId<'tcx>, _: hir::HirId) {}
112+
fn fake_read(
113+
&mut self,
114+
_: &rustc_hir_typeck::expr_use_visitor::PlaceWithHirId<'tcx>,
115+
_: FakeReadCause,
116+
_id: hir::HirId,
117+
) {
118+
}
119+
120+
fn bind(&mut self, _: &euv::PlaceWithHirId<'tcx>, id: hir::HirId) {
121+
self.prev_bind = Some(id);
122+
}
123+
}

lints/for_each/ui/main.fixed

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@ impl MyBuilder {
2020
}
2121
}
2222

23+
struct LocalQueue {}
24+
25+
impl LocalQueue {
26+
fn new() -> Self {
27+
Self {}
28+
}
29+
}
30+
2331
// no
2432
fn build_request_builder() {
2533
let headers = vec![("Key1", "Value1"), ("Key2", "Value2")];
@@ -73,6 +81,7 @@ fn nested_loop() {
7381
}
7482
}
7583

84+
// for_each
7685
fn get_upload_file_total_size() -> u64 {
7786
let some_num = vec![0; 10];
7887
let mut file_total_size = 0;
@@ -95,4 +104,14 @@ fn return_loop() {
95104
}
96105
}
97106

107+
// for_each
108+
fn local_into_iter() {
109+
let thread_num = 10;
110+
let mut locals = vec![];
111+
112+
(0..thread_num).into_iter().for_each(|_| {
113+
locals.push(LocalQueue::new());
114+
});
115+
}
116+
98117
// TODO: double capture

lints/for_each/ui/main.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@ impl MyBuilder {
2020
}
2121
}
2222

23+
struct LocalQueue {}
24+
25+
impl LocalQueue {
26+
fn new() -> Self {
27+
Self {}
28+
}
29+
}
30+
2331
// no
2432
fn build_request_builder() {
2533
let headers = vec![("Key1", "Value1"), ("Key2", "Value2")];
@@ -73,6 +81,7 @@ fn nested_loop() {
7381
}
7482
}
7583

84+
// for_each
7685
fn get_upload_file_total_size() -> u64 {
7786
let some_num = vec![0; 10];
7887
let mut file_total_size = 0;
@@ -95,4 +104,14 @@ fn return_loop() {
95104
}
96105
}
97106

107+
// for_each
108+
fn local_into_iter() {
109+
let thread_num = 10;
110+
let mut locals = vec![];
111+
112+
for _ in 0..thread_num {
113+
locals.push(LocalQueue::new());
114+
}
115+
}
116+
98117
// TODO: double capture

lints/for_each/ui/main.stderr

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
warning: use a for_each to enable iterator refinement
2-
--> $DIR/main.rs:35:5
2+
--> $DIR/main.rs:43:5
33
|
44
LL | / for x in 1..=100 {
55
LL | | println!("{x}");
@@ -15,7 +15,7 @@ LL + });
1515
|
1616

1717
warning: use a for_each to enable iterator refinement
18-
--> $DIR/main.rs:44:5
18+
--> $DIR/main.rs:52:5
1919
|
2020
LL | / for a in vec_a {
2121
LL | | if a == 1 {
@@ -36,7 +36,7 @@ LL + });
3636
|
3737

3838
warning: use a for_each to enable iterator refinement
39-
--> $DIR/main.rs:70:9
39+
--> $DIR/main.rs:78:9
4040
|
4141
LL | / for b in &vec_b {
4242
LL | | dbg!(a, b);
@@ -51,7 +51,7 @@ LL + });
5151
|
5252

5353
warning: use a for_each to enable iterator refinement
54-
--> $DIR/main.rs:79:5
54+
--> $DIR/main.rs:88:5
5555
|
5656
LL | / for _ in 0..some_num.len() {
5757
LL | | let (_, upload_size) = (true, 99);
@@ -67,5 +67,20 @@ LL + file_total_size += upload_size;
6767
LL + });
6868
|
6969

70-
warning: 4 warnings emitted
70+
warning: use a for_each to enable iterator refinement
71+
--> $DIR/main.rs:112:5
72+
|
73+
LL | / for _ in 0..thread_num {
74+
LL | | locals.push(LocalQueue::new());
75+
LL | | }
76+
| |_____^
77+
|
78+
help: try using `for_each` on the iterator
79+
|
80+
LL ~ (0..thread_num).into_iter().for_each(|_| {
81+
LL + locals.push(LocalQueue::new());
82+
LL + });
83+
|
84+
85+
warning: 5 warnings emitted
7186

lints/map/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ rustc_private = true
2525
rlib = ["dylint_linting/constituent"]
2626

2727
[[example]]
28-
name = "fold_main"
28+
name = "map_main"
2929
path = "ui/main.rs"
3030

3131
[lints.clippy]

utils/src/constants.rs renamed to lints/par_iter/src/constants.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
pub const TRAIT_PATHS: &[&[&str]] = &[
1+
pub(crate) const TRAIT_PATHS: &[&[&str]] = &[
22
&["rayon", "iter", "IntoParallelIterator"],
33
&["rayon", "iter", "ParallelIterator"],
44
&["rayon", "iter", "IndexedParallelIterator"],

lints/par_iter/src/lib.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,16 @@
44

55
extern crate rustc_data_structures;
66
extern crate rustc_errors;
7+
extern crate rustc_hash;
78
extern crate rustc_hir;
8-
9+
extern crate rustc_hir_typeck;
10+
extern crate rustc_infer;
911
extern crate rustc_middle;
1012
extern crate rustc_span;
13+
extern crate rustc_trait_selection;
14+
15+
mod constants;
16+
mod variable_check;
1117

1218
use clippy_utils::{get_parent_expr, get_trait_def_id};
1319
use rustc_data_structures::fx::FxHashSet;
@@ -17,7 +23,7 @@ use rustc_hir::{self as hir};
1723
use rustc_lint::{LateContext, LateLintPass, LintContext};
1824
use rustc_middle::ty::{self, Ty};
1925
use rustc_span::sym;
20-
use utils::variable_check::{
26+
use variable_check::{
2127
check_implements_par_iter, check_trait_impl, check_variables, generate_suggestion,
2228
is_type_valid,
2329
};

0 commit comments

Comments
 (0)