Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 93eb02e

Browse files
committedOct 15, 2024·
make enum size not depend on the order of variants
1 parent 9322d18 commit 93eb02e

File tree

8 files changed

+315
-140
lines changed

8 files changed

+315
-140
lines changed
 

‎compiler/rustc_abi/src/layout.rs

+12-4
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ impl<Cx: HasDataLayout> LayoutCalculator<Cx> {
196196
pub fn layout_of_struct_or_enum<
197197
'a,
198198
FieldIdx: Idx,
199-
VariantIdx: Idx,
199+
VariantIdx: Idx + PartialOrd,
200200
F: Deref<Target = &'a LayoutS<FieldIdx, VariantIdx>> + fmt::Debug + Copy,
201201
>(
202202
&self,
@@ -468,7 +468,7 @@ impl<Cx: HasDataLayout> LayoutCalculator<Cx> {
468468
fn layout_of_enum<
469469
'a,
470470
FieldIdx: Idx,
471-
VariantIdx: Idx,
471+
VariantIdx: Idx + PartialOrd,
472472
F: Deref<Target = &'a LayoutS<FieldIdx, VariantIdx>> + fmt::Debug + Copy,
473473
>(
474474
&self,
@@ -528,8 +528,16 @@ impl<Cx: HasDataLayout> LayoutCalculator<Cx> {
528528
let niche_variants = all_indices.clone().find(|v| needs_disc(*v)).unwrap()
529529
..=all_indices.rev().find(|v| needs_disc(*v)).unwrap();
530530

531-
let count =
532-
(niche_variants.end().index() as u128 - niche_variants.start().index() as u128) + 1;
531+
let count = {
532+
let niche_variants_len = (niche_variants.end().index() as u128
533+
- niche_variants.start().index() as u128)
534+
+ 1;
535+
if niche_variants.contains(&largest_variant_index) {
536+
niche_variants_len - 1
537+
} else {
538+
niche_variants_len
539+
}
540+
};
533541

534542
// Use the largest niche in the largest variant.
535543
let niche = variant_layouts[largest_variant_index].largest_niche?;

‎compiler/rustc_abi/src/lib.rs

+50-6
Original file line numberDiff line numberDiff line change
@@ -1498,15 +1498,59 @@ pub enum TagEncoding<VariantIdx: Idx> {
14981498
Direct,
14991499

15001500
/// Niche (values invalid for a type) encoding the discriminant:
1501-
/// Discriminant and variant index coincide.
1501+
/// Discriminant and variant index doesn't always coincide.
1502+
///
15021503
/// The variant `untagged_variant` contains a niche at an arbitrary
15031504
/// offset (field `tag_field` of the enum), which for a variant with
1504-
/// discriminant `d` is set to
1505-
/// `(d - niche_variants.start).wrapping_add(niche_start)`.
1505+
/// discriminant `d` is set to `d.wrapping_add(niche_start)`.
15061506
///
1507-
/// For example, `Option<(usize, &T)>` is represented such that
1508-
/// `None` has a null pointer for the second tuple field, and
1509-
/// `Some` is the identity function (with a non-null reference).
1507+
/// As for how to compute the discriminant, we have an optimization here that we allocate discriminant
1508+
/// value starting from the variant after the `untagged_variant` when the `untagged_variant` is
1509+
/// contained in `niche_variants`' range. Thus the `untagged_variant` won't be allocated with a
1510+
/// unneeded discriminant. Motivation for this is issue #117238.
1511+
/// For example,
1512+
/// ```rust
1513+
/// enum {
1514+
/// A, // 1
1515+
/// B, // 2
1516+
/// C(bool), // untagged_variant, no discriminant
1517+
/// D, // has a discriminant of 0
1518+
/// }
1519+
/// ```
1520+
/// The algorithm is as follows:
1521+
/// ```rust
1522+
/// // We ignore leading and trailing variants that don't need discriminants.
1523+
/// adjusted_len = niche_variants.end - niche_variants.start + 1
1524+
/// adjusted_index = variant_index - niche_variants.start
1525+
/// d = if niche_variants.contains(untagged_variant) {
1526+
/// adjusted_untagged_index = untagged_variant - niche_variants.start
1527+
/// (adjusted_index + adjusted_len - adjusted_untagged_index) % adjusted_len - 1
1528+
/// } else {
1529+
/// adjusted_index
1530+
/// }
1531+
/// tag_value = d.wrapping_add(niche_start)
1532+
/// ```
1533+
/// To load variant index from tag value:
1534+
/// ```rust
1535+
/// adjusted_len = niche_variants.end - niche_variants.start + 1
1536+
/// d = tag_value.wrapping_sub(niche_start)
1537+
/// variant_index = if niche_variants.contains(untagged_variant) {
1538+
/// if d < adjusted_len - 1 {
1539+
/// adjusted_untagged_index = untagged_variant - niche_variants.start
1540+
/// (d + 1 + adjusted_untagged_index) % adjusted_len + niche_variants.start
1541+
/// } else {
1542+
/// // When the discriminant is larger than the number of variants having
1543+
/// // discriminant, we know it represents the untagged_variant.
1544+
/// untagged_variant
1545+
/// }
1546+
/// } else {
1547+
/// if d < adjusted_len {
1548+
/// d + niche_variants.start
1549+
/// } else {
1550+
/// untagged_variant
1551+
/// }
1552+
/// }
1553+
/// ```
15101554
Niche {
15111555
untagged_variant: VariantIdx,
15121556
niche_variants: RangeInclusive<VariantIdx>,

‎compiler/rustc_codegen_cranelift/src/discriminant.rs

+92-63
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,20 @@ pub(crate) fn codegen_set_discriminant<'tcx>(
5252
variants: _,
5353
} => {
5454
if variant_index != untagged_variant {
55+
let discr_len = niche_variants.end().index() - niche_variants.start().index() + 1;
56+
let adj_idx = variant_index.index() - niche_variants.start().index();
57+
5558
let niche = place.place_field(fx, FieldIdx::new(tag_field));
5659
let niche_type = fx.clif_type(niche.layout().ty).unwrap();
57-
let niche_value = variant_index.as_u32() - niche_variants.start().as_u32();
58-
let niche_value = (niche_value as u128).wrapping_add(niche_start);
60+
61+
let discr = if niche_variants.contains(&untagged_variant) {
62+
let adj_untagged_idx =
63+
untagged_variant.index() - niche_variants.start().index();
64+
(adj_idx + discr_len - adj_untagged_idx) % discr_len - 1
65+
} else {
66+
adj_idx
67+
};
68+
let niche_value = (discr as u128).wrapping_add(niche_start);
5969
let niche_value = match niche_type {
6070
types::I128 => {
6171
let lsb = fx.bcx.ins().iconst(types::I64, niche_value as u64 as i64);
@@ -131,72 +141,91 @@ pub(crate) fn codegen_get_discriminant<'tcx>(
131141
dest.write_cvalue(fx, res);
132142
}
133143
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start } => {
134-
let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32();
135-
136-
// We have a subrange `niche_start..=niche_end` inside `range`.
137-
// If the value of the tag is inside this subrange, it's a
138-
// "niche value", an increment of the discriminant. Otherwise it
139-
// indicates the untagged variant.
140-
// A general algorithm to extract the discriminant from the tag
141-
// is:
142-
// relative_tag = tag - niche_start
143-
// is_niche = relative_tag <= (ule) relative_max
144-
// discr = if is_niche {
145-
// cast(relative_tag) + niche_variants.start()
146-
// } else {
147-
// untagged_variant
148-
// }
149-
// However, we will likely be able to emit simpler code.
150-
151-
let (is_niche, tagged_discr, delta) = if relative_max == 0 {
152-
// Best case scenario: only one tagged variant. This will
153-
// likely become just a comparison and a jump.
154-
// The algorithm is:
155-
// is_niche = tag == niche_start
156-
// discr = if is_niche {
157-
// niche_start
158-
// } else {
159-
// untagged_variant
160-
// }
144+
// See the algorithm explanation in the definition of `TagEncoding::Niche`.
145+
let discr_len = niche_variants.end().index() - niche_variants.start().index() + 1;
146+
147+
let niche_start_value = match fx.bcx.func.dfg.value_type(tag) {
148+
types::I128 => {
149+
let lsb = fx.bcx.ins().iconst(types::I64, niche_start as u64 as i64);
150+
let msb = fx.bcx.ins().iconst(types::I64, (niche_start >> 64) as u64 as i64);
151+
fx.bcx.ins().iconcat(lsb, msb)
152+
}
153+
ty => fx.bcx.ins().iconst(ty, niche_start as i64),
154+
};
155+
156+
let (is_niche, tagged_discr) = if discr_len == 1 {
157+
// Special case where we only have a single tagged variant.
158+
// The untagged variant can't be contained in niche_variant's range in this case.
159+
// Thus the discriminant of the only tagged variant is 0 and its variant index
160+
// is the start of niche_variants.
161161
let is_niche = codegen_icmp_imm(fx, IntCC::Equal, tag, niche_start as i128);
162162
let tagged_discr =
163163
fx.bcx.ins().iconst(cast_to, niche_variants.start().as_u32() as i64);
164-
(is_niche, tagged_discr, 0)
164+
(is_niche, tagged_discr)
165165
} else {
166-
// The special cases don't apply, so we'll have to go with
167-
// the general algorithm.
168-
let niche_start = match fx.bcx.func.dfg.value_type(tag) {
169-
types::I128 => {
170-
let lsb = fx.bcx.ins().iconst(types::I64, niche_start as u64 as i64);
171-
let msb =
172-
fx.bcx.ins().iconst(types::I64, (niche_start >> 64) as u64 as i64);
173-
fx.bcx.ins().iconcat(lsb, msb)
174-
}
175-
ty => fx.bcx.ins().iconst(ty, niche_start as i64),
176-
};
177-
let relative_discr = fx.bcx.ins().isub(tag, niche_start);
178-
let cast_tag = clif_intcast(fx, relative_discr, cast_to, false);
179-
let is_niche = crate::common::codegen_icmp_imm(
180-
fx,
181-
IntCC::UnsignedLessThanOrEqual,
182-
relative_discr,
183-
i128::from(relative_max),
184-
);
185-
(is_niche, cast_tag, niche_variants.start().as_u32() as u128)
186-
};
166+
// General case.
167+
let discr = fx.bcx.ins().isub(tag, niche_start_value);
168+
let tagged_discr = clif_intcast(fx, discr, cast_to, false);
169+
if niche_variants.contains(&untagged_variant) {
170+
let is_niche = crate::common::codegen_icmp_imm(
171+
fx,
172+
IntCC::UnsignedLessThan,
173+
discr,
174+
(discr_len - 1) as i128,
175+
);
176+
let adj_untagged_idx =
177+
untagged_variant.index() - niche_variants.start().index();
178+
let untagged_delta = 1 + adj_untagged_idx;
179+
let untagged_delta = match cast_to {
180+
types::I128 => {
181+
let lsb = fx.bcx.ins().iconst(types::I64, untagged_delta as i64);
182+
let msb = fx.bcx.ins().iconst(types::I64, 0);
183+
fx.bcx.ins().iconcat(lsb, msb)
184+
}
185+
ty => fx.bcx.ins().iconst(ty, untagged_delta as i64),
186+
};
187+
let tagged_discr = fx.bcx.ins().iadd(tagged_discr, untagged_delta);
187188

188-
let tagged_discr = if delta == 0 {
189-
tagged_discr
190-
} else {
191-
let delta = match cast_to {
192-
types::I128 => {
193-
let lsb = fx.bcx.ins().iconst(types::I64, delta as u64 as i64);
194-
let msb = fx.bcx.ins().iconst(types::I64, (delta >> 64) as u64 as i64);
195-
fx.bcx.ins().iconcat(lsb, msb)
196-
}
197-
ty => fx.bcx.ins().iconst(ty, delta as i64),
198-
};
199-
fx.bcx.ins().iadd(tagged_discr, delta)
189+
let discr_len = match cast_to {
190+
types::I128 => {
191+
let lsb = fx.bcx.ins().iconst(types::I64, discr_len as i64);
192+
let msb = fx.bcx.ins().iconst(types::I64, 0);
193+
fx.bcx.ins().iconcat(lsb, msb)
194+
}
195+
ty => fx.bcx.ins().iconst(ty, discr_len as i64),
196+
};
197+
let tagged_discr = fx.bcx.ins().urem(tagged_discr, discr_len);
198+
199+
let niche_variants_start = niche_variants.start().index();
200+
let niche_variants_start = match cast_to {
201+
types::I128 => {
202+
let lsb = fx.bcx.ins().iconst(types::I64, niche_variants_start as i64);
203+
let msb = fx.bcx.ins().iconst(types::I64, 0);
204+
fx.bcx.ins().iconcat(lsb, msb)
205+
}
206+
ty => fx.bcx.ins().iconst(ty, niche_variants_start as i64),
207+
};
208+
let tagged_discr = fx.bcx.ins().iadd(tagged_discr, niche_variants_start);
209+
(is_niche, tagged_discr)
210+
} else {
211+
let is_niche = crate::common::codegen_icmp_imm(
212+
fx,
213+
IntCC::UnsignedLessThan,
214+
discr,
215+
(discr_len - 1) as i128,
216+
);
217+
let niche_variants_start = niche_variants.start().index();
218+
let niche_variants_start = match cast_to {
219+
types::I128 => {
220+
let lsb = fx.bcx.ins().iconst(types::I64, niche_variants_start as i64);
221+
let msb = fx.bcx.ins().iconst(types::I64, 0);
222+
fx.bcx.ins().iconcat(lsb, msb)
223+
}
224+
ty => fx.bcx.ins().iconst(ty, niche_variants_start as i64),
225+
};
226+
let tagged_discr = fx.bcx.ins().iadd(tagged_discr, niche_variants_start);
227+
(is_niche, tagged_discr)
228+
}
200229
};
201230

202231
let untagged_variant = if cast_to == types::I128 {

‎compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/mod.rs

+15-3
Original file line numberDiff line numberDiff line change
@@ -391,9 +391,21 @@ fn compute_discriminant_value<'ll, 'tcx>(
391391

392392
DiscrResult::Range(min, max)
393393
} else {
394-
let value = (variant_index.as_u32() as u128)
395-
.wrapping_sub(niche_variants.start().as_u32() as u128)
396-
.wrapping_add(niche_start);
394+
let discr_len = niche_variants.end().as_u32() as u128
395+
- niche_variants.start().as_u32() as u128
396+
+ 1;
397+
// FIXME: Why do we even return discriminant for absent variants?
398+
let adj_idx = (variant_index.as_u32() as u128)
399+
.wrapping_sub(niche_variants.start().as_u32() as u128);
400+
401+
let discr = if niche_variants.contains(&untagged_variant) {
402+
let adj_untagged_idx =
403+
(untagged_variant.as_u32() - niche_variants.start().as_u32()) as u128;
404+
(adj_idx + discr_len - adj_untagged_idx) % discr_len - 1
405+
} else {
406+
adj_idx
407+
};
408+
let value = discr.wrapping_add(niche_start);
397409
let value = tag.size(cx).truncate(value);
398410
DiscrResult::Value(value)
399411
}

‎compiler/rustc_codegen_ssa/src/mir/place.rs

+55-46
Original file line numberDiff line numberDiff line change
@@ -287,54 +287,53 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
287287
_ => (tag_imm, bx.cx().immediate_backend_type(tag_op.layout)),
288288
};
289289

290-
let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32();
291-
292-
// We have a subrange `niche_start..=niche_end` inside `range`.
293-
// If the value of the tag is inside this subrange, it's a
294-
// "niche value", an increment of the discriminant. Otherwise it
295-
// indicates the untagged variant.
296-
// A general algorithm to extract the discriminant from the tag
297-
// is:
298-
// relative_tag = tag - niche_start
299-
// is_niche = relative_tag <= (ule) relative_max
300-
// discr = if is_niche {
301-
// cast(relative_tag) + niche_variants.start()
302-
// } else {
303-
// untagged_variant
304-
// }
305-
// However, we will likely be able to emit simpler code.
306-
let (is_niche, tagged_discr, delta) = if relative_max == 0 {
307-
// Best case scenario: only one tagged variant. This will
308-
// likely become just a comparison and a jump.
309-
// The algorithm is:
310-
// is_niche = tag == niche_start
311-
// discr = if is_niche {
312-
// niche_start
313-
// } else {
314-
// untagged_variant
315-
// }
316-
let niche_start = bx.cx().const_uint_big(tag_llty, niche_start);
290+
// See the algorithm explanation in the definition of `TagEncoding::Niche`.
291+
let discr_len = niche_variants.end().index() - niche_variants.start().index() + 1;
292+
let niche_start = bx.cx().const_uint_big(tag_llty, niche_start);
293+
let (is_niche, tagged_discr) = if discr_len == 1 {
294+
// Special case where we only have a single tagged variant.
295+
// The untagged variant can't be contained in niche_variant's range in this case.
296+
// Thus the discriminant of the only tagged variant is 0 and its variant index
297+
// is the start of niche_variants.
317298
let is_niche = bx.icmp(IntPredicate::IntEQ, tag, niche_start);
318299
let tagged_discr =
319300
bx.cx().const_uint(cast_to, niche_variants.start().as_u32() as u64);
320-
(is_niche, tagged_discr, 0)
301+
(is_niche, tagged_discr)
321302
} else {
322-
// The special cases don't apply, so we'll have to go with
323-
// the general algorithm.
324-
let relative_discr = bx.sub(tag, bx.cx().const_uint_big(tag_llty, niche_start));
325-
let cast_tag = bx.intcast(relative_discr, cast_to, false);
326-
let is_niche = bx.icmp(
327-
IntPredicate::IntULE,
328-
relative_discr,
329-
bx.cx().const_uint(tag_llty, relative_max as u64),
330-
);
331-
(is_niche, cast_tag, niche_variants.start().as_u32() as u128)
332-
};
333-
334-
let tagged_discr = if delta == 0 {
335-
tagged_discr
336-
} else {
337-
bx.add(tagged_discr, bx.cx().const_uint_big(cast_to, delta))
303+
// General case.
304+
let discr = bx.sub(tag, niche_start);
305+
let tagged_discr = bx.intcast(discr, cast_to, false);
306+
if niche_variants.contains(&untagged_variant) {
307+
let is_niche = bx.icmp(
308+
IntPredicate::IntULT,
309+
discr,
310+
bx.cx().const_uint(tag_llty, (discr_len - 1) as u64),
311+
);
312+
let adj_untagged_idx =
313+
untagged_variant.index() - niche_variants.start().index();
314+
let tagged_discr = bx.add(
315+
tagged_discr,
316+
bx.cx().const_uint_big(cast_to, (1 + adj_untagged_idx) as u128),
317+
);
318+
let tagged_discr = bx
319+
.urem(tagged_discr, bx.cx().const_uint_big(cast_to, discr_len as u128));
320+
let tagged_discr = bx.add(
321+
tagged_discr,
322+
bx.cx().const_uint_big(cast_to, niche_variants.start().index() as u128),
323+
);
324+
(is_niche, tagged_discr)
325+
} else {
326+
let is_niche = bx.icmp(
327+
IntPredicate::IntULT,
328+
discr,
329+
bx.cx().const_uint(tag_llty, discr_len as u64),
330+
);
331+
let tagged_discr = bx.add(
332+
tagged_discr,
333+
bx.cx().const_uint_big(cast_to, niche_variants.start().index() as u128),
334+
);
335+
(is_niche, tagged_discr)
336+
}
338337
};
339338

340339
let discr = bx.select(
@@ -384,10 +383,20 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
384383
..
385384
} => {
386385
if variant_index != untagged_variant {
386+
let discr_len =
387+
niche_variants.end().index() - niche_variants.start().index() + 1;
388+
let adj_idx = variant_index.index() - niche_variants.start().index();
389+
387390
let niche = self.project_field(bx, tag_field);
388391
let niche_llty = bx.cx().immediate_backend_type(niche.layout);
389-
let niche_value = variant_index.as_u32() - niche_variants.start().as_u32();
390-
let niche_value = (niche_value as u128).wrapping_add(niche_start);
392+
let discr = if niche_variants.contains(&untagged_variant) {
393+
let adj_untagged_idx =
394+
untagged_variant.index() - niche_variants.start().index();
395+
(adj_idx + discr_len - adj_untagged_idx) % discr_len - 1
396+
} else {
397+
adj_idx
398+
};
399+
let niche_value = (discr as u128).wrapping_add(niche_start);
391400
// FIXME(eddyb): check the actual primitive type here.
392401
let niche_llval = if niche_value == 0 {
393402
// HACK(eddyb): using `c_null` as it works on all types.

‎compiler/rustc_const_eval/src/interpret/discriminant.rs

+57-17
Original file line numberDiff line numberDiff line change
@@ -166,31 +166,59 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
166166
untagged_variant
167167
}
168168
Ok(tag_bits) => {
169+
// See the algorithm explanation in the definition of `TagEncoding::Niche`.
170+
let discr_len = (variants_end - variants_start)
171+
.checked_add(1)
172+
.expect("the number of niche variants fits into u32");
173+
169174
let tag_bits = tag_bits.to_bits(tag_layout.size);
170175
// We need to use machine arithmetic to get the relative variant idx:
171-
// variant_index_relative = tag_val - niche_start_val
172176
let tag_val = ImmTy::from_uint(tag_bits, tag_layout);
173177
let niche_start_val = ImmTy::from_uint(niche_start, tag_layout);
174178
let variant_index_relative_val =
175179
self.binary_op(mir::BinOp::Sub, &tag_val, &niche_start_val)?;
176180
let variant_index_relative =
177181
variant_index_relative_val.to_scalar().to_bits(tag_val.layout.size)?;
178182
// Check if this is in the range that indicates an actual discriminant.
179-
if variant_index_relative <= u128::from(variants_end - variants_start) {
180-
let variant_index_relative = u32::try_from(variant_index_relative)
181-
.expect("we checked that this fits into a u32");
182-
// Then computing the absolute variant idx should not overflow any more.
183-
let variant_index = VariantIdx::from_u32(
184-
variants_start
185-
.checked_add(variant_index_relative)
186-
.expect("overflow computing absolute variant idx"),
187-
);
188-
let variants =
189-
ty.ty_adt_def().expect("tagged layout for non adt").variants();
190-
assert!(variant_index < variants.next_index());
191-
variant_index
183+
if niche_variants.contains(&untagged_variant) {
184+
if variant_index_relative < u128::from(discr_len) {
185+
let adj_untagged_idx = untagged_variant.as_u32() - variants_start;
186+
let variant_index_relative = u32::try_from(variant_index_relative)
187+
.expect("we checked that this fits into a u32");
188+
let variant_index_to_modulo = variant_index_relative
189+
.checked_add(1)
190+
.expect("overflow computing absolute variant idx")
191+
.checked_add(adj_untagged_idx)
192+
.expect("overflow computing absolute variant idx");
193+
let variant_index = VariantIdx::from_u32(
194+
variants_start
195+
.checked_add(variant_index_to_modulo % discr_len)
196+
.expect("overflow computing absolute variant idx"),
197+
);
198+
let variants =
199+
ty.ty_adt_def().expect("tagged layout for non adt").variants();
200+
assert!(variant_index < variants.next_index());
201+
variant_index
202+
} else {
203+
untagged_variant
204+
}
192205
} else {
193-
untagged_variant
206+
if variant_index_relative < u128::from(discr_len) {
207+
let variant_index_relative = u32::try_from(variant_index_relative)
208+
.expect("we checked that this fits into a u32");
209+
// Then computing the absolute variant idx should not overflow any more.
210+
let variant_index = VariantIdx::from_u32(
211+
variants_start
212+
.checked_add(variant_index_relative)
213+
.expect("overflow computing absolute variant idx"),
214+
);
215+
let variants =
216+
ty.ty_adt_def().expect("tagged layout for non adt").variants();
217+
assert!(variant_index < variants.next_index());
218+
variant_index
219+
} else {
220+
untagged_variant
221+
}
194222
}
195223
}
196224
};
@@ -286,13 +314,25 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
286314
..
287315
} => {
288316
assert!(variant_index != untagged_variant);
317+
let discr_len = (niche_variants.end().as_u32() - niche_variants.start().as_u32())
318+
.checked_add(1)
319+
.expect("the number of niche variants fits into u32");
289320
let variants_start = niche_variants.start().as_u32();
290-
let variant_index_relative = variant_index
321+
let adj_idx = variant_index
291322
.as_u32()
292323
.checked_sub(variants_start)
293324
.expect("overflow computing relative variant idx");
325+
326+
let variant_index_relative = if niche_variants.contains(&untagged_variant) {
327+
let adj_untagged_idx = untagged_variant.as_u32() - variants_start;
328+
let adj_idx_to_modulo = adj_idx
329+
.checked_add(discr_len - adj_untagged_idx)
330+
.expect("overflow computing relative variant idx");
331+
adj_idx_to_modulo % discr_len - 1
332+
} else {
333+
adj_idx
334+
};
294335
// We need to use machine arithmetic when taking into account `niche_start`:
295-
// tag_val = variant_index_relative + niche_start_val
296336
let tag_layout = self.layout_of(tag_layout.primitive().to_int_ty(*self.tcx))?;
297337
let niche_start_val = ImmTy::from_uint(niche_start, tag_layout);
298338
let variant_index_relative_val =

‎src/tools/rust-analyzer/crates/hir-ty/src/layout.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ pub use self::{
3434
mod adt;
3535
mod target;
3636

37-
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
37+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd)]
3838
pub struct RustcEnumVariantIdx(pub usize);
3939

4040
impl rustc_index::Idx for RustcEnumVariantIdx {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
//@ run-pass
2+
3+
#![feature(rustc_attrs)]
4+
#![allow(internal_features)]
5+
#![allow(dead_code)]
6+
7+
#[rustc_layout_scalar_valid_range_start(2)]
8+
struct U8WithTwoNiches(u8);
9+
10+
// 1 bytes.
11+
enum Order1 {
12+
A(U8WithTwoNiches),
13+
B,
14+
C,
15+
}
16+
17+
enum Order2 {
18+
A,
19+
B(U8WithTwoNiches),
20+
C,
21+
}
22+
23+
enum Order3 {
24+
A,
25+
B,
26+
C(U8WithTwoNiches),
27+
}
28+
29+
fn main() {
30+
assert_eq!(std::mem::size_of::<Order1>(), 1);
31+
assert_eq!(std::mem::size_of::<Order2>(), 1);
32+
assert_eq!(std::mem::size_of::<Order3>(), 1);
33+
}

0 commit comments

Comments
 (0)
Please sign in to comment.