Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions helion/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,10 @@ class BreakpointInDeviceLoopRequiresInterpret(BaseError):
message = "breakpoint() inside an `hl.tile` or `hl.grid` loop requires TRITON_INTERPRET=1 or HELION_INTERPRET=1."


class BothInterpretModesActive(BaseError):
message = "Cannot have both TRITON_INTERPRET=1 and HELION_INTERPRET=1 active simultaneously. Please use only one interpret mode."


class UndefinedVariable(BaseError):
message = "{} is not defined."

Expand Down
10 changes: 8 additions & 2 deletions helion/runtime/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,14 @@ def _get_autotune_random_seed() -> int:


def _get_ref_mode() -> RefMode:
interpret = _env_get_bool("HELION_INTERPRET", False)
return RefMode.EAGER if interpret else RefMode.OFF
triton_interpret = os.environ.get("TRITON_INTERPRET") == "1"
helion_interpret = _env_get_bool("HELION_INTERPRET", False)

# Ban having both interpret modes active
if triton_interpret and helion_interpret:
raise exc.BothInterpretModesActive

return RefMode.EAGER if helion_interpret else RefMode.OFF


@dataclasses.dataclass
Expand Down
22 changes: 22 additions & 0 deletions test/test_breakpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ def _run_breakpoint_in_subprocess(
)

env = os.environ.copy()
# Set the env vars in the subprocess environment (needed for the ban check)
env["TRITON_INTERPRET"] = str(triton_interpret)
env["HELION_INTERPRET"] = str(helion_interpret)
result = subprocess.run(
[sys.executable, "-c", script],
env=env,
Expand Down Expand Up @@ -128,6 +131,16 @@ def _run_device_breakpoint_test(
out = bound(x)
torch.testing.assert_close(out, x)

def _run_device_breakpoint_both_interpret_test(
self, triton_interpret: int, helion_interpret: int
) -> None:
"""Test that having both interpret modes active is banned."""
# Environment variables are already set by subprocess, just verify they're both 1
assert triton_interpret == 1 and helion_interpret == 1
# When both are set to 1, creating a kernel should raise an error
with self.assertRaises(exc.BothInterpretModesActive):
self._make_device_breakpoint_kernel()

def test_device_breakpoint_no_interpret(self) -> None:
self._run_breakpoint_in_subprocess(
test_name=self._testMethodName,
Expand All @@ -152,6 +165,15 @@ def test_device_breakpoint_helion_interpret(self) -> None:
helion_interpret=1,
)

def test_device_breakpoint_both_interpret_banned(self) -> None:
"""Test that having both TRITON_INTERPRET and HELION_INTERPRET active is banned."""
self._run_breakpoint_in_subprocess(
test_name=self._testMethodName,
runner_method="_run_device_breakpoint_both_interpret_test",
triton_interpret=1,
helion_interpret=1,
)

def _run_host_breakpoint_test(
self, triton_interpret: int, helion_interpret: int
) -> None:
Expand Down
Loading