Skip to content

make enum size not depend on the order of variants #131684

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions compiler/rustc_abi/src/layout.rs
Original file line number Diff line number Diff line change
@@ -196,7 +196,7 @@ impl<Cx: HasDataLayout> LayoutCalculator<Cx> {
pub fn layout_of_struct_or_enum<
'a,
FieldIdx: Idx,
VariantIdx: Idx,
VariantIdx: Idx + PartialOrd,
F: Deref<Target = &'a LayoutS<FieldIdx, VariantIdx>> + fmt::Debug + Copy,
>(
&self,
@@ -468,7 +468,7 @@ impl<Cx: HasDataLayout> LayoutCalculator<Cx> {
fn layout_of_enum<
'a,
FieldIdx: Idx,
VariantIdx: Idx,
VariantIdx: Idx + PartialOrd,
F: Deref<Target = &'a LayoutS<FieldIdx, VariantIdx>> + fmt::Debug + Copy,
>(
&self,
@@ -528,8 +528,16 @@ impl<Cx: HasDataLayout> LayoutCalculator<Cx> {
let niche_variants = all_indices.clone().find(|v| needs_disc(*v)).unwrap()
..=all_indices.rev().find(|v| needs_disc(*v)).unwrap();

let count =
(niche_variants.end().index() as u128 - niche_variants.start().index() as u128) + 1;
let count = {
let niche_variants_len = (niche_variants.end().index() as u128
- niche_variants.start().index() as u128)
+ 1;
if niche_variants.contains(&largest_variant_index) {
niche_variants_len - 1
} else {
niche_variants_len
}
};

// Use the largest niche in the largest variant.
let niche = variant_layouts[largest_variant_index].largest_niche?;
56 changes: 50 additions & 6 deletions compiler/rustc_abi/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1498,15 +1498,59 @@ pub enum TagEncoding<VariantIdx: Idx> {
Direct,

/// Niche (values invalid for a type) encoding the discriminant:
/// Discriminant and variant index coincide.
/// Discriminant and variant index doesn't always coincide.
///
/// The variant `untagged_variant` contains a niche at an arbitrary
/// offset (field `tag_field` of the enum), which for a variant with
/// discriminant `d` is set to
/// `(d - niche_variants.start).wrapping_add(niche_start)`.
/// discriminant `d` is set to `d.wrapping_add(niche_start)`.
///
/// For example, `Option<(usize, &T)>` is represented such that
/// `None` has a null pointer for the second tuple field, and
/// `Some` is the identity function (with a non-null reference).
/// As for how to compute the discriminant, we have an optimization here that we allocate discriminant
/// value starting from the variant after the `untagged_variant` when the `untagged_variant` is
/// contained in `niche_variants`' range. Thus the `untagged_variant` won't be allocated with a
/// unneeded discriminant. Motivation for this is issue #117238.
/// For example,
/// ```
/// enum SomeEnum {
/// A, // 1
/// B, // 2
/// C(bool), // untagged_variant, no discriminant
/// D, // has a discriminant of 0
/// }
Comment on lines +1513 to +1518
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add the tag/payload values here too, not just the discriminants?

Comment on lines +1513 to +1518
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, the given enum already has a size of 1. I think it's the edge-case in #117238 where we need the wraparound. If there's enough room the layout logic already shifts up the entire tag range, including a gap where the untagged variant would otherwise be. It's only when the tag range is full that that becomes a problem.

Well, maybe there also are cases with range-restricted payloads other than bool. Something using rustc_layout_scalar_valid_range_start and rustc_layout_scalar_valid_range_end might also benefit from wraparound.

But generally those should be exceptions, not the rule. So we shouldn't take a more expensive code-path for the normal case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we can still wrap around when the tag range is full to avoid the off-by-1 bug. But taking a step back, a clippy warning might be a better approach for this rare case than complicating codegen?

/// ```
/// The algorithm is as follows:
/// ```rust,ignore (pseudo-code)
/// // We ignore leading and trailing variants that don't need discriminants.
/// adjusted_len = niche_variants.end - niche_variants.start + 1
/// adjusted_index = variant_index - niche_variants.start
/// d = if niche_variants.contains(untagged_variant) {
/// adjusted_untagged_index = untagged_variant - niche_variants.start
/// (adjusted_index + adjusted_len - adjusted_untagged_index) % adjusted_len - 1
/// } else {
/// adjusted_index
/// }
/// tag_value = d.wrapping_add(niche_start)
/// ```
/// To load variant index from tag value:
/// ```rust,ignore (pseudo-code)
/// adjusted_len = niche_variants.end - niche_variants.start + 1
/// d = tag_value.wrapping_sub(niche_start)
/// variant_index = if niche_variants.contains(untagged_variant) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code mixes compile-time and runtime behavior. I think that's unnecessarily confusing. It could be split into separate sections (one for interior untagged variants, one without) which each have their own code samples.

/// if d < adjusted_len - 1 {
/// adjusted_untagged_index = untagged_variant - niche_variants.start
/// (d + 1 + adjusted_untagged_index) % adjusted_len + niche_variants.start
/// } else {
/// // When the discriminant is larger than the number of variants having
/// // discriminant, we know it represents the untagged_variant.
/// untagged_variant
/// }
/// } else {
/// if d < adjusted_len {
/// d + niche_variants.start
/// } else {
/// untagged_variant
/// }
/// }
/// ```
Niche {
untagged_variant: VariantIdx,
niche_variants: RangeInclusive<VariantIdx>,
155 changes: 92 additions & 63 deletions compiler/rustc_codegen_cranelift/src/discriminant.rs
Original file line number Diff line number Diff line change
@@ -52,10 +52,20 @@ pub(crate) fn codegen_set_discriminant<'tcx>(
variants: _,
} => {
if variant_index != untagged_variant {
let discr_len = niche_variants.end().index() - niche_variants.start().index() + 1;
let adj_idx = variant_index.index() - niche_variants.start().index();

let niche = place.place_field(fx, FieldIdx::new(tag_field));
let niche_type = fx.clif_type(niche.layout().ty).unwrap();
let niche_value = variant_index.as_u32() - niche_variants.start().as_u32();
let niche_value = (niche_value as u128).wrapping_add(niche_start);

let discr = if niche_variants.contains(&untagged_variant) {
let adj_untagged_idx =
untagged_variant.index() - niche_variants.start().index();
(adj_idx + discr_len - adj_untagged_idx) % discr_len - 1
} else {
adj_idx
};
let niche_value = (discr as u128).wrapping_add(niche_start);
let niche_value = match niche_type {
types::I128 => {
let lsb = fx.bcx.ins().iconst(types::I64, niche_value as u64 as i64);
@@ -131,72 +141,91 @@ pub(crate) fn codegen_get_discriminant<'tcx>(
dest.write_cvalue(fx, res);
}
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start } => {
let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32();

// We have a subrange `niche_start..=niche_end` inside `range`.
// If the value of the tag is inside this subrange, it's a
// "niche value", an increment of the discriminant. Otherwise it
// indicates the untagged variant.
// A general algorithm to extract the discriminant from the tag
// is:
// relative_tag = tag - niche_start
// is_niche = relative_tag <= (ule) relative_max
// discr = if is_niche {
// cast(relative_tag) + niche_variants.start()
// } else {
// untagged_variant
// }
// However, we will likely be able to emit simpler code.

let (is_niche, tagged_discr, delta) = if relative_max == 0 {
// Best case scenario: only one tagged variant. This will
// likely become just a comparison and a jump.
// The algorithm is:
// is_niche = tag == niche_start
// discr = if is_niche {
// niche_start
// } else {
// untagged_variant
// }
// See the algorithm explanation in the definition of `TagEncoding::Niche`.
let discr_len = niche_variants.end().index() - niche_variants.start().index() + 1;

let niche_start_value = match fx.bcx.func.dfg.value_type(tag) {
types::I128 => {
let lsb = fx.bcx.ins().iconst(types::I64, niche_start as u64 as i64);
let msb = fx.bcx.ins().iconst(types::I64, (niche_start >> 64) as u64 as i64);
fx.bcx.ins().iconcat(lsb, msb)
}
ty => fx.bcx.ins().iconst(ty, niche_start as i64),
};

let (is_niche, tagged_discr) = if discr_len == 1 {
// Special case where we only have a single tagged variant.
// The untagged variant can't be contained in niche_variant's range in this case.
// Thus the discriminant of the only tagged variant is 0 and its variant index
// is the start of niche_variants.
let is_niche = codegen_icmp_imm(fx, IntCC::Equal, tag, niche_start as i128);
let tagged_discr =
fx.bcx.ins().iconst(cast_to, niche_variants.start().as_u32() as i64);
(is_niche, tagged_discr, 0)
(is_niche, tagged_discr)
} else {
// The special cases don't apply, so we'll have to go with
// the general algorithm.
let niche_start = match fx.bcx.func.dfg.value_type(tag) {
types::I128 => {
let lsb = fx.bcx.ins().iconst(types::I64, niche_start as u64 as i64);
let msb =
fx.bcx.ins().iconst(types::I64, (niche_start >> 64) as u64 as i64);
fx.bcx.ins().iconcat(lsb, msb)
}
ty => fx.bcx.ins().iconst(ty, niche_start as i64),
};
let relative_discr = fx.bcx.ins().isub(tag, niche_start);
let cast_tag = clif_intcast(fx, relative_discr, cast_to, false);
let is_niche = crate::common::codegen_icmp_imm(
fx,
IntCC::UnsignedLessThanOrEqual,
relative_discr,
i128::from(relative_max),
);
(is_niche, cast_tag, niche_variants.start().as_u32() as u128)
};
// General case.
let discr = fx.bcx.ins().isub(tag, niche_start_value);
let tagged_discr = clif_intcast(fx, discr, cast_to, false);
if niche_variants.contains(&untagged_variant) {
let is_niche = crate::common::codegen_icmp_imm(
fx,
IntCC::UnsignedLessThan,
discr,
(discr_len - 1) as i128,
);
let adj_untagged_idx =
untagged_variant.index() - niche_variants.start().index();
let untagged_delta = 1 + adj_untagged_idx;
let untagged_delta = match cast_to {
types::I128 => {
let lsb = fx.bcx.ins().iconst(types::I64, untagged_delta as i64);
let msb = fx.bcx.ins().iconst(types::I64, 0);
fx.bcx.ins().iconcat(lsb, msb)
}
ty => fx.bcx.ins().iconst(ty, untagged_delta as i64),
};
let tagged_discr = fx.bcx.ins().iadd(tagged_discr, untagged_delta);

let tagged_discr = if delta == 0 {
tagged_discr
} else {
let delta = match cast_to {
types::I128 => {
let lsb = fx.bcx.ins().iconst(types::I64, delta as u64 as i64);
let msb = fx.bcx.ins().iconst(types::I64, (delta >> 64) as u64 as i64);
fx.bcx.ins().iconcat(lsb, msb)
}
ty => fx.bcx.ins().iconst(ty, delta as i64),
};
fx.bcx.ins().iadd(tagged_discr, delta)
let discr_len = match cast_to {
types::I128 => {
let lsb = fx.bcx.ins().iconst(types::I64, discr_len as i64);
let msb = fx.bcx.ins().iconst(types::I64, 0);
fx.bcx.ins().iconcat(lsb, msb)
}
ty => fx.bcx.ins().iconst(ty, discr_len as i64),
};
let tagged_discr = fx.bcx.ins().urem(tagged_discr, discr_len);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

discr_len is not a power of two, right? This doesn't look good for performance.

In the motivating issue (#117238) the full discriminant range is essentially packed 2 values for the bool, 254 values for the remaining variants. In that case we can just use wrapping arithmetic of the right type size without any remainder.

So for smaller cases like the A-D case above, can we rearrange the discriminants so that they also fit the natural tag wrap-around?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't realize modulo operation has such bad performance. Now the foundation of this PR is mostly gone.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also from this comment it sounds like it should be possible to fix this issue without changing codegen at all. IIUC, we don't need to change anything about TagEncoding::Niche, we just need to be more clever in how we construct enums with that encoding. Or did I misunderstand?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I basically implemented what that comment suggested, like allocating discriminant starting right after the untagged variant and wrapping around. This changes how we convert variant_index <--> tag, thus codegen.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Codegen needs to compute the tag from the discriminant. The mapping between variant index and tag only exists during compilation and can be adjusted without codegen impact. Currently, enums with Niche encoding make the tag equal to the variant index, but it doesn't have to be that way.

Anyway maybe @the8472 could clarify what they meant in that comment.


let niche_variants_start = niche_variants.start().index();
let niche_variants_start = match cast_to {
types::I128 => {
let lsb = fx.bcx.ins().iconst(types::I64, niche_variants_start as i64);
let msb = fx.bcx.ins().iconst(types::I64, 0);
fx.bcx.ins().iconcat(lsb, msb)
}
ty => fx.bcx.ins().iconst(ty, niche_variants_start as i64),
};
let tagged_discr = fx.bcx.ins().iadd(tagged_discr, niche_variants_start);
(is_niche, tagged_discr)
} else {
let is_niche = crate::common::codegen_icmp_imm(
fx,
IntCC::UnsignedLessThan,
discr,
(discr_len - 1) as i128,
);
let niche_variants_start = niche_variants.start().index();
let niche_variants_start = match cast_to {
types::I128 => {
let lsb = fx.bcx.ins().iconst(types::I64, niche_variants_start as i64);
let msb = fx.bcx.ins().iconst(types::I64, 0);
fx.bcx.ins().iconcat(lsb, msb)
}
ty => fx.bcx.ins().iconst(ty, niche_variants_start as i64),
};
let tagged_discr = fx.bcx.ins().iadd(tagged_discr, niche_variants_start);
(is_niche, tagged_discr)
}
};

let untagged_variant = if cast_to == types::I128 {
18 changes: 15 additions & 3 deletions compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/mod.rs
Original file line number Diff line number Diff line change
@@ -391,9 +391,21 @@ fn compute_discriminant_value<'ll, 'tcx>(

DiscrResult::Range(min, max)
} else {
let value = (variant_index.as_u32() as u128)
.wrapping_sub(niche_variants.start().as_u32() as u128)
.wrapping_add(niche_start);
let discr_len = niche_variants.end().as_u32() as u128
- niche_variants.start().as_u32() as u128
+ 1;
// FIXME: Why do we even return discriminant for absent variants?
let adj_idx = (variant_index.as_u32() as u128)
.wrapping_sub(niche_variants.start().as_u32() as u128);

let discr = if niche_variants.contains(&untagged_variant) {
let adj_untagged_idx =
(untagged_variant.as_u32() - niche_variants.start().as_u32()) as u128;
(adj_idx + discr_len - adj_untagged_idx) % discr_len - 1
} else {
adj_idx
};
let value = discr.wrapping_add(niche_start);
let value = tag.size(cx).truncate(value);
DiscrResult::Value(value)
}
Loading