Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ if __name__ == "__main__":
### Quick Installation

> [!NOTE]
> **Requirements**: Python 3.10+, PyTorch 2.0+ (ROCm version), ROCm 6.3.1+ HIP runtime, Triton, and setuptools>=61
> **Requirements**: Python 3.10+, PyTorch 2.0+ (ROCm version), ROCm 6.3.1+ HIP runtime, and Triton

For a quick installation directly from the repository:

Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@
"triton.language",
"numpy",
"iris._distributed_helpers",
"iris.hip",
"iris.backend",
]

# Napoleon settings for Google/NumPy docstring parsing
Expand Down
2 changes: 1 addition & 1 deletion examples/07_gemm_all_scatter/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def run_experiment():
json_writer.display()

if args["trace_tiles"] and rank == 0:
gpu_freq = iris.hip.get_wall_clock_rate(rank) * 1e-3
gpu_freq = iris.backend.get_wall_clock_rate(rank) * 1e-3
algo_string = "all_scatter"
filename = f"gemm_tiles_{algo_string}_trace_rank{rank}.json"
timestamps.to_json(filename, gpu_freq)
Expand Down
2 changes: 1 addition & 1 deletion examples/07_gemm_all_scatter/matmul_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class matmul(torch.autograd.Function):
_registers = None
_spills = None

_num_xcds = iris.hip.get_num_xcc()
_num_xcds = iris.backend.get_num_xcc()

@staticmethod
def set_debug(debug: bool):
Expand Down
2 changes: 1 addition & 1 deletion examples/08_gemm_atomics_all_reduce/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def run_experiment():
json_writer.display()

if args["trace_tiles"] and rank == 0:
gpu_freq = iris.hip.get_wall_clock_rate(rank) * 1e-3
gpu_freq = iris.backend.get_wall_clock_rate(rank) * 1e-3
filename = f"gemm_all_reduce_tiles_trace_rank{rank}.json"
timestamps.to_json(filename, gpu_freq)

Expand Down
2 changes: 1 addition & 1 deletion examples/08_gemm_atomics_all_reduce/matmul_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
class matmul(torch.autograd.Function):
_debug = True

_num_xcds = iris.hip.get_num_xcc()
_num_xcds = iris.backend.get_num_xcc()

@staticmethod
def set_debug(debug: bool):
Expand Down
2 changes: 1 addition & 1 deletion examples/09_gemm_one_shot_all_reduce/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def run_experiment():
json_writer.display()

if args["trace_tiles"] and rank == 0:
gpu_freq = iris.hip.get_wall_clock_rate(rank) * 1e-3
gpu_freq = iris.backend.get_wall_clock_rate(rank) * 1e-3
filename = f"gemm_all_reduce_tiles_trace_rank{rank}.json"
timestamps.to_json(filename, gpu_freq)

Expand Down
2 changes: 1 addition & 1 deletion examples/09_gemm_one_shot_all_reduce/matmul_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
class matmul(torch.autograd.Function):
_debug = True

_num_xcds = iris.hip.get_num_xcc()
_num_xcds = iris.backend.get_num_xcc()

@staticmethod
def set_debug(debug: bool):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def run_experiment():
json_writer.display()

if args["trace_tiles"] and rank == 0:
gpu_freq = iris.hip.get_wall_clock_rate(rank) * 1e-3
gpu_freq = iris.backend.get_wall_clock_rate(rank) * 1e-3
algo_string = "all_scatter"
filename = f"gemm_tiles_{algo_string}_trace_rank{rank}.json"
timestamps.to_json(filename, gpu_freq)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class matmul(torch.autograd.Function):
_registers = None
_spills = None

_num_xcds = iris.hip.get_num_xcc()
_num_xcds = iris.backend.get_num_xcc()

@staticmethod
def set_debug(debug: bool):
Expand Down
4 changes: 2 additions & 2 deletions examples/11_gemm_all_scatter_producer_consumer/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict):

bias = None

num_xcds = iris.hip.get_num_xcc()
num_xcds = iris.backend.get_num_xcc()

gemm_stream = torch.cuda.Stream()
comm_stream = torch.cuda.Stream()
Expand Down Expand Up @@ -275,7 +275,7 @@ def run_experiment():
json_writer.display()

if args["trace_tiles"] and rank == 0:
gpu_freq = iris.hip.get_wall_clock_rate(rank) * 1e-3
gpu_freq = iris.backend.get_wall_clock_rate(rank) * 1e-3
algo_string = "all_scatter"
filename = f"gemm_tiles_{algo_string}_trace_rank{rank}.json"
timestamps.to_json(filename, gpu_freq)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class matmul(torch.autograd.Function):
_registers = None
_spills = None

_num_xcds = iris.hip.get_num_xcc()
_num_xcds = iris.backend.get_num_xcc()

@staticmethod
def set_debug(debug: bool):
Expand Down
4 changes: 2 additions & 2 deletions examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict):

bias = None

num_xcds = iris.hip.get_num_xcc()
num_xcds = iris.backend.get_num_xcc()

# This is one after another.
main_stream = torch.cuda.Stream()
Expand Down Expand Up @@ -271,7 +271,7 @@ def run_experiment():
json_writer.display()

if args["trace_tiles"] and rank == 0:
gpu_freq = iris.hip.get_wall_clock_rate(rank) * 1e-3
gpu_freq = iris.backend.get_wall_clock_rate(rank) * 1e-3
algo_string = "all_scatter"
filename = f"gemm_tiles_{algo_string}_trace_rank{rank}.json"
timestamps.to_json(filename, gpu_freq)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class matmul(torch.autograd.Function):
_debug = False
_registers = None
_spills = None
_num_xcds = iris.hip.get_num_xcc()
_num_xcds = iris.backend.get_num_xcc()

@staticmethod
def set_debug(debug: bool):
Expand Down
4 changes: 2 additions & 2 deletions iris/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
do_bench,
)

from . import hip
from . import backend

# Import logging functionality
from .logging import (
Expand Down Expand Up @@ -77,7 +77,7 @@
"atomic_min",
"atomic_max",
"do_bench",
"hip",
"backend",
"set_logger_level",
"logger",
"DEBUG",
Expand Down
46 changes: 46 additions & 0 deletions iris/backend/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Backend Interface

The backend provides two ways to work with GPU platforms:

For portable code, use the unified API that works across all backends:

```python
import iris.backend as backend

num_gpus = backend.count_devices()
backend.set_device(0)
ptr = backend.malloc(size)
```

For platform-specific code, import directly:

```python
from iris.backend import hip
ptr = hip.malloc_fine_grained(size) # HIP-only: cache-coherent shared memory

from iris.backend import cuda
ptr = cuda.malloc_managed(size) # CUDA-only: unified memory with page migration
```

## Implementing a New Backend

Backends must implement these functions:

```python
def set_device(gpu_id: int) -> None
def get_device_id() -> int
def count_devices() -> int
def get_cu_count(device_id: int | None = None) -> int
def get_wall_clock_rate(device_id: int) -> int
def get_arch_string(device_id: int | None = None) -> str
def get_num_xcc(device_id: int | None = None) -> int
def get_runtime_version() -> tuple[int, int] # (major, minor)

def get_ipc_handle(ptr: int | ctypes.c_void_p, rank: int) -> Any
def open_ipc_handle(ipc_handle_data: np.ndarray, rank: int) -> int

def malloc(size: int) -> ctypes.c_void_p
def free(ptr: int | ctypes.c_void_p) -> None
```

See `cuda.py` and `hip.py` for reference implementations.
94 changes: 94 additions & 0 deletions iris/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# SPDX-License-Identifier: MIT

"""
Backend auto-detection for Iris.

Automatically detects and loads the appropriate GPU backend (CUDA or HIP)
based on what's available on the system. It tries CUDA first, then falls back to HIP.

The backend can be forced by setting the IRIS_BACKEND environment variable to 'cuda' or 'hip'.
"""

import ctypes
import os
from types import ModuleType
from typing import NamedTuple


class _Backend(NamedTuple):
name: str
module: ModuleType

def get_name(self): # for re-export
return self.name


def _detect_and_load() -> _Backend:
"""Detect available GPU runtime and load corresponding backend."""

def library_exists(lib_path):
try:
ctypes.cdll.LoadLibrary(lib_path)
return True
except OSError:
return False

def backend_allowed(name):
"""Backend is allowed based on IRIS_BACKEND env var."""
return not forced or forced == name

forced = os.getenv("IRIS_BACKEND", "").lower()
if forced and forced not in ("cuda", "hip"):
raise ValueError(f"Invalid IRIS_BACKEND='{forced}'. Must be 'cuda' or 'hip'.")

if library_exists("libcudart.so") and backend_allowed("cuda"):
from . import cuda

return _Backend("cuda", cuda)

if library_exists("libamdhip64.so") and backend_allowed("hip"):
from . import hip

return _Backend("hip", hip)

forced_msg = f"IRIS_BACKEND={forced} but {forced.upper()} runtime not found. " if forced else ""
raise RuntimeError(
f"No GPU backend available. {forced_msg}"
"Iris requires either CUDA or HIP runtime. "
"Please install CUDA (NVIDIA) or ROCm (AMD) to use Iris."
)


_backend = _detect_and_load() # Load backend at import time
# Re-export backend funcs
backend_name = _backend.get_name
set_device = _backend.module.set_device
get_cu_count = _backend.module.get_cu_count
count_devices = _backend.module.count_devices
get_ipc_handle = _backend.module.get_ipc_handle
open_ipc_handle = _backend.module.open_ipc_handle
get_wall_clock_rate = _backend.module.get_wall_clock_rate
get_device_id = _backend.module.get_device_id
get_arch_string = _backend.module.get_arch_string
get_num_xcc = _backend.module.get_num_xcc
malloc = _backend.module.malloc
free = _backend.module.free
get_runtime_version = _backend.module.get_runtime_version


__all__ = [
"backend_name",
"set_device",
"get_cu_count",
"count_devices",
"get_ipc_handle",
"open_ipc_handle",
"get_wall_clock_rate",
"get_device_id",
"get_arch_string",
"get_num_xcc",
"malloc_fine_grained",
"malloc",
"free",
"get_runtime_version",
]
Loading