Skip to content

perf: Merge done_cache and active_cache in ObligationForest #67892

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 58 additions & 67 deletions src/librustc_data_structures/obligation_forest/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,12 @@

use crate::fx::{FxHashMap, FxHashSet};

use std::cell::{Cell, RefCell};
use std::cell::Cell;
use std::collections::hash_map::Entry;
use std::fmt::Debug;
use std::hash;
use std::marker::PhantomData;
use std::mem;

mod graphviz;

Expand Down Expand Up @@ -127,26 +128,24 @@ struct ObligationTreeId(usize);
type ObligationTreeIdGenerator =
::std::iter::Map<::std::ops::RangeFrom<usize>, fn(usize) -> ObligationTreeId>;

/// `usize` indices are used here and throughout this module, rather than
/// `rustc_index::newtype_index!` indices, because this code is hot enough
/// that the `u32`-to-`usize` conversions that would be required are
/// significant, and space considerations are not important.
type NodeIndex = usize;

pub struct ObligationForest<O: ForestObligation> {
/// The list of obligations. In between calls to `process_obligations`,
/// this list only contains nodes in the `Pending` or `Waiting` state.
///
/// `usize` indices are used here and throughout this module, rather than
/// `rustc_index::newtype_index!` indices, because this code is hot enough
/// that the `u32`-to-`usize` conversions that would be required are
/// significant, and space considerations are not important.
nodes: Vec<Node<O>>,

/// A cache of predicates that have been successfully completed.
done_cache: FxHashSet<O::Predicate>,

/// A cache of the nodes in `nodes`, indexed by predicate. Unfortunately,
/// its contents are not guaranteed to match those of `nodes`. See the
/// comments in `process_obligation` for details.
active_cache: FxHashMap<O::Predicate, usize>,
active_cache: FxHashMap<O::Predicate, Option<NodeIndex>>,

/// A vector reused in compress(), to avoid allocating new vectors.
node_rewrites: RefCell<Vec<usize>>,
node_rewrites: Vec<NodeIndex>,

obligation_tree_id_generator: ObligationTreeIdGenerator,

Expand All @@ -167,12 +166,12 @@ struct Node<O> {

/// Obligations that depend on this obligation for their completion. They
/// must all be in a non-pending state.
dependents: Vec<usize>,
dependents: Vec<NodeIndex>,

/// If true, dependents[0] points to a "parent" node, which requires
/// special treatment upon error but is otherwise treated the same.
/// (It would be more idiomatic to store the parent node in a separate
/// `Option<usize>` field, but that slows down the common case of
/// `Option<NodeIndex>` field, but that slows down the common case of
/// iterating over the parent and other descendants together.)
has_parent: bool,

Expand All @@ -181,7 +180,11 @@ struct Node<O> {
}

impl<O> Node<O> {
fn new(parent: Option<usize>, obligation: O, obligation_tree_id: ObligationTreeId) -> Node<O> {
fn new(
parent: Option<NodeIndex>,
obligation: O,
obligation_tree_id: ObligationTreeId,
) -> Node<O> {
Node {
obligation,
state: Cell::new(NodeState::Pending),
Expand Down Expand Up @@ -283,9 +286,8 @@ impl<O: ForestObligation> ObligationForest<O> {
pub fn new() -> ObligationForest<O> {
ObligationForest {
nodes: vec![],
done_cache: Default::default(),
active_cache: Default::default(),
node_rewrites: RefCell::new(vec![]),
node_rewrites: vec![],
obligation_tree_id_generator: (0..).map(ObligationTreeId),
error_cache: Default::default(),
}
Expand All @@ -304,14 +306,18 @@ impl<O: ForestObligation> ObligationForest<O> {
}

// Returns Err(()) if we already know this obligation failed.
fn register_obligation_at(&mut self, obligation: O, parent: Option<usize>) -> Result<(), ()> {
if self.done_cache.contains(obligation.as_predicate()) {
return Ok(());
}

fn register_obligation_at(
&mut self,
obligation: O,
parent: Option<NodeIndex>,
) -> Result<(), ()> {
match self.active_cache.entry(obligation.as_predicate().clone()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This now clones the predicate even when it is done. Maybe that is the problem?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looked into it the predicate type, it is always a ty::Predicate (or &str in tests) which is copyable and only takes up 32 bytes which isn't tiny but shouldn't really affect things...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ugh, ok I think it is because active_cache now gets really big, so the retain becomes really expensive as it fills up with None values...

Entry::Occupied(o) => {
let node = &mut self.nodes[*o.get()];
let index = match o.get() {
Some(index) => *index,
None => return Ok(()), // Done!
};
let node = &mut self.nodes[index];
if let Some(parent_index) = parent {
// If the node is already in `active_cache`, it has already
// had its chance to be marked with a parent. So if it's
Expand Down Expand Up @@ -340,7 +346,7 @@ impl<O: ForestObligation> ObligationForest<O> {
Err(())
} else {
let new_index = self.nodes.len();
v.insert(new_index);
v.insert(Some(new_index));
self.nodes.push(Node::new(parent, obligation, obligation_tree_id));
Ok(())
}
Expand Down Expand Up @@ -375,7 +381,7 @@ impl<O: ForestObligation> ObligationForest<O> {
.collect()
}

fn insert_into_error_cache(&mut self, index: usize) {
fn insert_into_error_cache(&mut self, index: NodeIndex) {
let node = &self.nodes[index];
self.error_cache
.entry(node.obligation_tree_id)
Expand Down Expand Up @@ -465,8 +471,8 @@ impl<O: ForestObligation> ObligationForest<O> {

/// Returns a vector of obligations for `p` and all of its
/// ancestors, putting them into the error state in the process.
fn error_at(&self, mut index: usize) -> Vec<O> {
let mut error_stack: Vec<usize> = vec![];
fn error_at(&self, mut index: NodeIndex) -> Vec<O> {
let mut error_stack: Vec<NodeIndex> = vec![];
let mut trace = vec![];

loop {
Expand Down Expand Up @@ -558,8 +564,12 @@ impl<O: ForestObligation> ObligationForest<O> {
debug_assert!(stack.is_empty());
}

fn find_cycles_from_node<P>(&self, stack: &mut Vec<usize>, processor: &mut P, index: usize)
where
fn find_cycles_from_node<P>(
&self,
stack: &mut Vec<NodeIndex>,
processor: &mut P,
index: NodeIndex,
) where
P: ObligationProcessor<Obligation = O>,
{
let node = &self.nodes[index];
Expand All @@ -576,7 +586,7 @@ impl<O: ForestObligation> ObligationForest<O> {
Some(rpos) => {
// Cycle detected.
processor.process_backedge(
stack[rpos..].iter().map(GetObligation(&self.nodes)),
stack[rpos..].iter().map(|i| &self.nodes[*i].obligation),
PhantomData,
);
}
Expand All @@ -590,7 +600,8 @@ impl<O: ForestObligation> ObligationForest<O> {
#[inline(never)]
fn compress(&mut self, do_completed: DoCompleted) -> Option<Vec<O>> {
let orig_nodes_len = self.nodes.len();
let mut node_rewrites: Vec<_> = self.node_rewrites.replace(vec![]);
let remove_node_marker = orig_nodes_len + 1;
let mut node_rewrites: Vec<_> = mem::take(&mut self.node_rewrites);
debug_assert!(node_rewrites.is_empty());
node_rewrites.extend(0..orig_nodes_len);
let mut dead_nodes = 0;
Expand All @@ -613,17 +624,6 @@ impl<O: ForestObligation> ObligationForest<O> {
}
}
NodeState::Done => {
// This lookup can fail because the contents of
// `self.active_cache` are not guaranteed to match those of
// `self.nodes`. See the comment in `process_obligation`
// for more details.
if let Some((predicate, _)) =
self.active_cache.remove_entry(node.obligation.as_predicate())
{
self.done_cache.insert(predicate);
} else {
self.done_cache.insert(node.obligation.as_predicate().clone());
}
if do_completed == DoCompleted::Yes {
// Extract the success stories.
removed_done_obligations.push(node.obligation.clone());
Expand All @@ -637,7 +637,7 @@ impl<O: ForestObligation> ObligationForest<O> {
// check against.
self.active_cache.remove(node.obligation.as_predicate());
self.insert_into_error_cache(index);
node_rewrites[index] = orig_nodes_len;
node_rewrites[index] = remove_node_marker;
dead_nodes += 1;
}
NodeState::Success => unreachable!(),
Expand All @@ -650,14 +650,15 @@ impl<O: ForestObligation> ObligationForest<O> {
self.apply_rewrites(&node_rewrites);
}

node_rewrites.truncate(0);
self.node_rewrites.replace(node_rewrites);
node_rewrites.clear();
self.node_rewrites = node_rewrites;

if do_completed == DoCompleted::Yes { Some(removed_done_obligations) } else { None }
}

fn apply_rewrites(&mut self, node_rewrites: &[usize]) {
fn apply_rewrites(&mut self, node_rewrites: &[NodeIndex]) {
let orig_nodes_len = node_rewrites.len();
let remove_node_marker = orig_nodes_len + 1;

for node in &mut self.nodes {
let mut i = 0;
Expand All @@ -678,31 +679,21 @@ impl<O: ForestObligation> ObligationForest<O> {

// This updating of `self.active_cache` is necessary because the
// removal of nodes within `compress` can fail. See above.
self.active_cache.retain(|_predicate, index| {
let new_index = node_rewrites[*index];
if new_index >= orig_nodes_len {
false
self.active_cache.retain(|_predicate, opt_index| {
if let Some(index) = opt_index {
let new_index = node_rewrites[*index];
if new_index == orig_nodes_len {
*opt_index = None;
true
} else if new_index == remove_node_marker {
false
} else {
*index = new_index;
true
}
} else {
*index = new_index;
true
}
});
}
}

// I need a Clone closure.
#[derive(Clone)]
struct GetObligation<'a, O>(&'a [Node<O>]);

impl<'a, 'b, O> FnOnce<(&'b usize,)> for GetObligation<'a, O> {
type Output = &'a O;
extern "rust-call" fn call_once(self, args: (&'b usize,)) -> &'a O {
&self.0[*args.0].obligation
}
}

impl<'a, 'b, O> FnMut<(&'b usize,)> for GetObligation<'a, O> {
extern "rust-call" fn call_mut(&mut self, args: (&'b usize,)) -> &'a O {
&self.0[*args.0].obligation
}
}