Skip to content

Sema: Fix switch loop OPV cond lowering #24720

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Air/Liveness.zig
Original file line number Diff line number Diff line change
Expand Up @@ -1643,7 +1643,7 @@ fn analyzeInstLoop(
.live_set = data.live_set.move(),
});
defer {
log.debug("[{}] %{f}: popped loop block scop", .{ pass, inst });
log.debug("[{}] %{f}: popped loop block scope", .{ pass, inst });
var scope = data.block_scopes.fetchRemove(inst).?.value;
scope.live_set.deinit(gpa);
}
Expand Down
149 changes: 131 additions & 18 deletions src/Sema.zig
Original file line number Diff line number Diff line change
Expand Up @@ -11255,6 +11255,7 @@ fn zirSwitchBlockErrUnion(sema: *Sema, block: *Block, inst: Zir.Inst.Index) Comp
undefined,
&.{},
&.{},
try sema.typeHasOnePossibleValue(operand_err_set_ty) != null,
);

try sema.air_extra.ensureUnusedCapacity(gpa, @typeInfo(Air.CondBr).@"struct".fields.len +
Expand Down Expand Up @@ -11954,6 +11955,8 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r
defer child_block.instructions.deinit(gpa);
defer merges.deinit(gpa);

const cond_has_opv = try sema.typeHasOnePossibleValue(cond_ty) != null;

if (scalar_cases_len + multi_cases_len == 0 and
special_members_only == null and
!special_generic.is_inline)
Expand All @@ -11969,7 +11972,8 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r
.loop => |l| l.init_cond,
};
if (zcu.backendSupportsFeature(.is_named_enum_value) and block.wantSafety() and
raw_operand_ty.zigTypeTag(zcu) == .@"enum" and !raw_operand_ty.isNonexhaustiveEnum(zcu))
raw_operand_ty.zigTypeTag(zcu) == .@"enum" and !cond_has_opv and
!raw_operand_ty.isNonexhaustiveEnum(zcu))
{
try sema.zirDbgStmt(block, cond_dbg_node_index);
const ok = try block.addUnOp(.is_named_enum_value, init_cond);
Expand Down Expand Up @@ -12147,8 +12151,17 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r
special_members_only_src,
extra_case_vals.items.items,
extra_case_vals.ranges.items,
cond_has_opv,
);

assert(merges.extra_insts.items.len == 0 or operand == .loop);

const simplified = switch (sema.air_instructions.items(.tag)[@intFromEnum(air_switch_ref.toIndex().?)]) {
.loop_switch_br, .switch_br => false,
.loop, .block => true,
else => unreachable,
};

for (merges.extra_insts.items, merges.extra_src_locs.items) |placeholder_inst, dispatch_src| {
var replacement_block = block.makeSubBlock();
defer replacement_block.instructions.deinit(gpa);
Expand All @@ -12169,19 +12182,28 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r

if (zcu.backendSupportsFeature(.is_named_enum_value) and block.wantSafety() and
cond_ty.zigTypeTag(zcu) == .@"enum" and !cond_ty.isNonexhaustiveEnum(zcu) and
!try sema.isComptimeKnown(new_cond))
!try sema.isComptimeKnown(new_cond) and !cond_has_opv)
{
const ok = try replacement_block.addUnOp(.is_named_enum_value, new_cond);
try sema.addSafetyCheck(&replacement_block, src, ok, .corrupt_switch);
}

_ = try replacement_block.addInst(.{
.tag = .switch_dispatch,
.data = .{ .br = .{
.block_inst = air_switch_ref.toIndex().?,
.operand = new_cond,
} },
});
if (simplified) {
_ = try replacement_block.addInst(.{
.tag = .repeat,
.data = .{ .repeat = .{
.loop_inst = air_switch_ref.toIndex().?,
} },
});
} else {
_ = try replacement_block.addInst(.{
.tag = .switch_dispatch,
.data = .{ .br = .{
.block_inst = air_switch_ref.toIndex().?,
.operand = new_cond,
} },
});
}

if (replacement_block.instructions.items.len == 1) {
// Optimization: we don't need a block!
Expand Down Expand Up @@ -12251,6 +12273,7 @@ fn analyzeSwitchRuntimeBlock(
extra_prong_src: LazySrcLoc,
extra_prong_items: []const Air.Inst.Ref,
extra_prong_ranges: []const [2]Air.Inst.Ref,
operand_has_opv: bool,
) CompileError!Air.Inst.Ref {
const pt = sema.pt;
const zcu = pt.zcu;
Expand All @@ -12259,9 +12282,17 @@ fn analyzeSwitchRuntimeBlock(

const block = child_block.parent.?;

const single_prong = operand_has_opv or
(scalar_cases_len + multi_cases_len == 0 and !else_prong.is_inline and extra_prong == null);
var single_prong_payload_index: u32 = undefined;
assert(!(single_prong and multi_cases_len > 0));

const estimated_cases_extra = (scalar_cases_len + multi_cases_len) *
@typeInfo(Air.SwitchBr.Case).@"struct".fields.len + 2;
var cases_extra = try std.ArrayListUnmanaged(u32).initCapacity(gpa, estimated_cases_extra);
var cases_extra: std.ArrayListUnmanaged(u32) = if (single_prong)
.empty
else
try .initCapacity(gpa, estimated_cases_extra);
defer cases_extra.deinit(gpa);

var branch_hints = try std.ArrayListUnmanaged(std.builtin.BranchHint).initCapacity(gpa, scalar_cases_len);
Expand Down Expand Up @@ -12291,7 +12322,7 @@ fn analyzeSwitchRuntimeBlock(
// `item` is already guaranteed to be constant known.

const analyze_body = if (union_originally) blk: {
const unresolved_item_val = sema.resolveConstDefinedValue(block, LazySrcLoc.unneeded, item, undefined) catch unreachable;
const unresolved_item_val = sema.resolveConstDefinedValue(block, .unneeded, item, undefined) catch unreachable;
const item_val = sema.resolveLazyValue(unresolved_item_val) catch unreachable;
const field_ty = maybe_union_ty.unionFieldType(item_val, zcu).?;
break :blk field_ty.zigTypeTag(zcu) != .noreturn;
Expand Down Expand Up @@ -12321,6 +12352,25 @@ fn analyzeSwitchRuntimeBlock(
break :h .none;
};

if (single_prong) {
@branchHint(.unlikely);
assert(operand_has_opv);

term: {
if (case_block.instructions.getLastOrNull()) |last_inst| {
if (sema.isNoReturn(last_inst.toRef())) break :term;
}
_ = try case_block.addNoOp(.unreach);
}
try sema.air_extra.ensureUnusedCapacity(gpa, @typeInfo(Air.Block).@"struct".fields.len +
case_block.instructions.items.len);
single_prong_payload_index = sema.addExtraAssumeCapacity(Air.Block{
.body_len = @intCast(case_block.instructions.items.len),
});
sema.air_extra.appendSliceAssumeCapacity(@ptrCast(case_block.instructions.items));
break;
}

try branch_hints.append(gpa, prong_hint);
try cases_extra.ensureUnusedCapacity(gpa, @typeInfo(Air.SwitchBr.Case).@"struct".fields.len +
1 + // `item`, no ranges
Expand Down Expand Up @@ -12568,7 +12618,39 @@ fn analyzeSwitchRuntimeBlock(
cases_extra.appendSliceAssumeCapacity(@ptrCast(case_block.instructions.items));
}

const else_body: []const Air.Inst.Index = if (else_prong.body.len != 0 or case_block.wantSafety()) else_body: {
const else_body: []const Air.Inst.Index = if (else_prong.body.len != 0 or
(case_block.wantSafety() and !operand_has_opv))
else_body: {
if (single_prong and else_prong.is_inline) {
case_block.instructions.shrinkRetainingCapacity(0);
case_block.error_return_trace_index = child_block.error_return_trace_index;

const operand_opv = (try sema.typeHasOnePossibleValue(operand_ty)).?;
const analyze_body = if (union_originally) blk: {
const field_ty = maybe_union_ty.unionFieldType(operand_opv, zcu).?;
break :blk field_ty.zigTypeTag(zcu) != .noreturn;
} else true;

if (analyze_body) {
_ = try spa.analyzeProngRuntime(
&case_block,
.special,
else_prong.body,
else_prong.capture,
child_block.src(.{ .switch_capture = .{
.switch_node_offset = switch_node_offset,
.case_idx = .special_else,
} }),
&.{.fromValue(operand_opv)},
.fromValue(operand_opv),
else_prong.has_tag_capture,
);
} else {
_ = try case_block.addNoOp(.unreach);
}
break :else_body case_block.instructions.items;
}

var emit_bb = false;
// If this is true we must have a 'true' else prong and not an underscore because
// underscore prongs can never be inlined. We've already checked for this.
Expand Down Expand Up @@ -12802,7 +12884,7 @@ fn analyzeSwitchRuntimeBlock(

if (zcu.backendSupportsFeature(.is_named_enum_value) and
else_prong.body.len != 0 and block.wantSafety() and
operand_ty.zigTypeTag(zcu) == .@"enum" and
operand_ty.zigTypeTag(zcu) == .@"enum" and !operand_has_opv and
(!operand_ty.isNonexhaustiveEnum(zcu) or union_originally))
{
try sema.zirDbgStmt(&case_block, cond_dbg_node_index);
Expand Down Expand Up @@ -12856,14 +12938,47 @@ fn analyzeSwitchRuntimeBlock(
break :h .cold;
};

try branch_hints.append(gpa, else_hint);
if (!single_prong) try branch_hints.append(gpa, else_hint);
break :else_body case_block.instructions.items;
} else else_body: {
try branch_hints.append(gpa, .none);
if (!single_prong) try branch_hints.append(gpa, .none);
break :else_body &.{};
};

assert(branch_hints.items.len == cases_len + 1);
assert(branch_hints.items.len == cases_len + 1 or (single_prong and branch_hints.items.len == 0));

const has_any_continues = spa.operand == .loop and child_block.label.?.merges.extra_insts.items.len > 0;

if (single_prong) {
if (else_body.len > 0) {
assert(scalar_cases_len + multi_cases_len == 0);

const needs_terminator = !sema.isNoReturn(else_body[else_body.len - 1].toRef());
try sema.air_instructions.ensureUnusedCapacity(gpa, @intFromBool(needs_terminator));
try sema.air_extra.ensureUnusedCapacity(gpa, @typeInfo(Air.Block).@"struct".fields.len +
else_body.len + @intFromBool(needs_terminator));
single_prong_payload_index = sema.addExtraAssumeCapacity(Air.Block{
.body_len = @intCast(else_body.len + @intFromBool(needs_terminator)),
});

sema.air_extra.appendSliceAssumeCapacity(@ptrCast(else_body));
if (needs_terminator) {
const terminator_index: Air.Inst.Index = @enumFromInt(sema.air_instructions.len);
sema.air_instructions.appendAssumeCapacity(.{
.tag = .unreach,
.data = .{ .no_op = {} },
});
sema.air_extra.appendAssumeCapacity(@intFromEnum(terminator_index));
}
}
return try child_block.addInst(.{
.tag = if (has_any_continues) .loop else .block,
.data = .{ .ty_pl = .{
.ty = .noreturn_type,
.payload = single_prong_payload_index,
} },
});
}

try sema.air_extra.ensureUnusedCapacity(gpa, @typeInfo(Air.SwitchBr).@"struct".fields.len +
cases_extra.items.len + else_body.len +
Expand Down Expand Up @@ -12892,8 +13007,6 @@ fn analyzeSwitchRuntimeBlock(
sema.air_extra.appendSliceAssumeCapacity(@ptrCast(cases_extra.items));
sema.air_extra.appendSliceAssumeCapacity(@ptrCast(else_body));

const has_any_continues = spa.operand == .loop and child_block.label.?.merges.extra_insts.items.len > 0;

return try child_block.addInst(.{
.tag = if (has_any_continues) .loop_switch_br else .switch_br,
.data = .{ .pl_op = .{
Expand Down
85 changes: 85 additions & 0 deletions test/behavior/switch_loop.zig
Original file line number Diff line number Diff line change
Expand Up @@ -273,3 +273,88 @@ test "switch loop on non-exhaustive enum" {
try S.doTheTest();
try comptime S.doTheTest();
}

test "switch loop on type with opv" {
if (builtin.zig_backend == .stage2_spirv) return error.SkipZigTest;

const S = struct {
const E = enum { opv };
const U = union(E) { opv: u32 };

fn doTheTest() !void {
var x: usize = 0;
label: switch (E.opv) {
.opv => {
x += 1;
if (x == 15) continue :label .opv;
if (x == 10) break :label;
continue :label .opv;
},
}
try expect(x == 10);

label: switch (E.opv) {
else => {
x += 1;
if (x == 25) continue :label .opv;
if (x == 20) break :label;
continue :label .opv;
},
}
try expect(x == 20);

label: switch (E.opv) {
.opv => if (false) continue :label true,
}

const ok = label: switch (U{ .opv = 123 }) {
.opv => |u| {
if (u == 456) break :label true;
continue :label .{ .opv = 456 };
},
};
try expect(ok);
}
};
try S.doTheTest();
try comptime S.doTheTest();
}

test "switch loop with only else prong" {
if (builtin.zig_backend == .stage2_spirv) return error.SkipZigTest;

const S = struct {
const E = enum { a, b, c };
const U = union(E) { a: u32, b: u16, c: u8 };

fn doTheTest() !void {
var x: usize = 0;
label: switch (E.a) {
else => {
x += 1;
if (x == 15) continue :label .b;
if (x == 10) break :label;
continue :label .c;
},
}
try expect(x == 10);

label: switch (E.a) {
else => if (false) continue :label true,
}

const ok = label: switch (U{ .a = 123 }) {
else => |u| {
const y: u32 = switch (u) {
inline else => |y| y,
};
if (y == 456) break :label true;
continue :label .{ .b = 456 };
},
};
try expect(ok);
}
};
try S.doTheTest();
try comptime S.doTheTest();
}