Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions gpt_oss/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,34 @@ def init_distributed() -> torch.device:
# Initialize distributed inference
world_size = int(os.environ.get("WORLD_SIZE", 1))
rank = int(os.environ.get("RANK", 0))
xpu_available = hasattr(torch, "xpu") and torch.xpu.is_available()

if xpu_available:
backend = "xccl"
device_type = "xpu"
else:
backend = "nccl"
device_type = "cuda"

if world_size > 1:
dist.init_process_group(
backend="nccl", init_method="env://", world_size=world_size, rank=rank
backend=backend, init_method="env://", world_size=world_size, rank=rank
)
torch.cuda.set_device(rank)
device = torch.device(f"cuda:{rank}")

# Warm up NCCL to avoid first-time latency
if xpu_available:
torch.xpu.set_device(rank)
else:
torch.cuda.set_device(rank)
device = torch.device(f"{device_type}:{rank}")

# Warm up backend to avoid first-time latency
if world_size > 1:
x = torch.ones(1, device=device)
dist.all_reduce(x)
torch.cuda.synchronize(device)
if xpu_available:
torch.xpu.synchronize(device)
else:
torch.cuda.synchronize(device)

suppress_output(rank)
return device
20 changes: 13 additions & 7 deletions gpt_oss/triton/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,10 @@ def _attn_fwd(
q = Q.load([off_z, off_h, start_m * BLOCK_M, 0]).reshape([BLOCK_M, HEAD_DIM])

if BANDWIDTH:
lo, hi = tl.maximum(start_q, start_q + start_m * BLOCK_M - BANDWIDTH), start_q + (start_m + 1) * BLOCK_M
lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M - BANDWIDTH + 1), start_q + (start_m + 1) * BLOCK_M
else:
lo, hi = start_q, start_q + (start_m + 1) * BLOCK_M
lo, hi = 0, start_q + (start_m + 1) * BLOCK_M
hi = tl.minimum(N_KV_CTX, hi)

for start_n in range(lo, hi, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
Expand Down Expand Up @@ -216,12 +217,17 @@ def test_eq(batch_size, num_queries, num_keys, num_key_value_heads, num_key_valu
if num_queries > num_keys:
pytest.skip("too many queries")

q = torch.randn(batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim).bfloat16().cuda()
k = torch.randn(batch_size, num_keys, num_key_value_heads, head_dim).bfloat16().cuda()
v = torch.randn(batch_size, num_keys, num_key_value_heads, head_dim).bfloat16().cuda()
sinks = torch.randn(num_key_value_heads * num_key_value_groups).bfloat16().cuda()
if torch.xpu.is_available():
device = "xpu"
else:
device = "cuda"

q = torch.randn(batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim).bfloat16().to(device)
k = torch.randn(batch_size, num_keys, num_key_value_heads, head_dim).bfloat16().to(device)
v = torch.randn(batch_size, num_keys, num_key_value_heads, head_dim).bfloat16().to(device)
sinks = torch.randn(num_key_value_heads * num_key_value_groups).bfloat16().to(device)

start_q = torch.tensor([start_q], dtype=torch.int32).cuda()
start_q = torch.tensor([start_q], dtype=torch.int32).to(device)

o1 = attention(q, k, v, sinks, sm_scale, sliding_window, start_q)
o2 = attention_ref(q, k, v, sinks, sm_scale, sliding_window, start_q)
Expand Down
26 changes: 20 additions & 6 deletions gpt_oss/triton/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,10 @@ def from_checkpoint(
checkpoint = Checkpoint(path, device)

for name, param in model.named_parameters():
torch.cuda.empty_cache()
if device.type == "xpu":
torch.xpu.empty_cache()
else:
torch.cuda.empty_cache()
loaded_tensor = checkpoint.get(name)

if "mlp1" in name:
Expand All @@ -463,7 +466,10 @@ def from_checkpoint(
param.data.copy_(loaded_tensor)

# NOTE: Required to avoid OOM errors
torch.cuda.empty_cache()
if device.type == "xpu":
torch.xpu.empty_cache()
else:
torch.cuda.empty_cache()
return model


Expand All @@ -476,10 +482,15 @@ def __init__(self, checkpoint: str, context: int, device: torch.device):
self.input_token = torch.zeros(1, dtype=torch.int32, device=self.device)
# warmup
self.model(self.input_token[None, :], caches=self.caches)
# capture for sampling
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph):

if self.device.type == "xpu":
self.graph = None
self.logits = self.model(self.input_token[None, :], caches=self.caches)[0]
else:
# capture for sampling
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph):
self.logits = self.model(self.input_token[None, :], caches=self.caches)[0]

@torch.inference_mode()
def generate(self,
Expand All @@ -497,7 +508,10 @@ def generate(self,
num_generated_tokens = 0
while max_tokens == 0 or num_generated_tokens < max_tokens:
self.input_token[0] = predicted_token
self.graph.replay()
if self.graph is not None:
self.graph.replay()
else:
self.logits = self.model(self.input_token[None, :], caches=self.caches)[0]
if temperature == 0.0:
predicted_token = torch.argmax(self.logits[-1, :], dim=-1).item()
else:
Expand Down
5 changes: 3 additions & 2 deletions gpt_oss/triton/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
from triton_kernels.numerics import InFlexData
from triton_kernels.routing import routing
from triton_kernels.tensor import convert_layout
from triton_kernels.tensor_details.layout import StridedLayout, HopperMXScaleLayout, HopperMXValueLayout
from triton_kernels.tensor_details.layout import StridedLayout, HopperMXScaleLayout, HopperMXValueLayout, make_default_matmul_mxfp4_w_layout
from triton_kernels.tensor import wrap_torch_tensor, FP4


def quantize_mx4(w):
w, w_scale = downcast_to_mxfp(w.to(torch.bfloat16), torch.uint8, axis=1)
w = convert_layout(wrap_torch_tensor(w, dtype=FP4), HopperMXValueLayout, mx_axis=1)
w_layout_cls, w_layout_kwargs = make_default_matmul_mxfp4_w_layout(mx_axis=1)
w = convert_layout(wrap_torch_tensor(w, dtype=FP4), w_layout_cls, **w_layout_kwargs)
w_scale = convert_layout(wrap_torch_tensor(w_scale), StridedLayout)
return w, w_scale

Expand Down