Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,11 @@ torchada supports overriding ATen operators at the C++ level for better performa

**See [docs/custom_musa_ops.md](docs/custom_musa_ops.md) for detailed documentation.**

**Quick start**:
```bash
export TORCHADA_ENABLE_CPP_OPS=1
```
**C++ extensions are automatically loaded on MUSA platform when torchada is imported.**

The C++ extension provides CUDA-compatible APIs including:
- Memory pool functions (`_cuda_beginAllocateCurrentThreadToPool`, `_cuda_endAllocateToPool`, `_cuda_releasePool`)
- These are injected into `torch.cuda.memory` to enable CUDA code using memory pools to work on MUSA

**Adding a new operator override**:

Expand All @@ -202,7 +203,6 @@ export TORCHADA_ENABLE_CPP_OPS=1
3. The extension is JIT-compiled on first use

**Environment variables**:
- `TORCHADA_ENABLE_CPP_OPS=1` - Enable C++ operator overrides
- `TORCHADA_CPP_OPS_VERBOSE=1` - Show compilation output
- `TORCHADA_DEBUG_CPP_OPS=1` - Log operator calls
- `TORCHADA_DISABLE_OP_OVERRIDE_<OP_NAME>=1` - Disable specific operator override
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ See `src/torchada/_mapping.py` for the complete mapping table (380+ mappings).

```
# pyproject.toml or requirements.txt
torchada>=0.1.53
torchada>=0.1.54
```

### Step 2: Conditional Import
Expand Down
2 changes: 1 addition & 1 deletion README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ if torchada.is_gpu_device(device): # 在 CUDA 和 MUSA 上都能工作

```
# pyproject.toml 或 requirements.txt
torchada>=0.1.53
torchada>=0.1.54
```

### 步骤 2:条件导入
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/benchmark_history.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"description": "Historical benchmark results for torchada performance tracking",
"results": [
{
"version": "0.1.53",
"version": "0.1.54",
"date": "2026-01-29",
"platform": "MUSA",
"pytorch_version": "2.7.1",
Expand Down
48 changes: 37 additions & 11 deletions docs/custom_musa_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,9 @@ torchada allows you to override ATen operators at the C++ level for the `Private

## Quick Start

### 1. Enable C++ Ops
C++ extensions are automatically loaded on MUSA platform when torchada is imported.

```bash
export TORCHADA_ENABLE_CPP_OPS=1
```

### 2. Write Your Kernel
### 1. Write Your Kernel

Edit `src/torchada/csrc/musa_ops.mu`:

Expand Down Expand Up @@ -69,10 +65,10 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
}
```
### 3. Test Your Kernel
### 2. Test Your Kernel
```bash
TORCHADA_ENABLE_CPP_OPS=1 TORCHADA_DEBUG_CPP_OPS=1 python -c "
TORCHADA_DEBUG_CPP_OPS=1 python -c "
import torch
import torchada
Expand All @@ -87,14 +83,44 @@ print('Result:', y.cpu()[:5])
| File | Purpose |
|------|---------|
| `src/torchada/csrc/ops.h` | Header with utilities (`log_op_call`, `is_override_enabled`) |
| `src/torchada/csrc/ops.cpp` | Python bindings and C++-only operator overrides |
| `src/torchada/csrc/ops.cpp` | Python bindings, C++-only implementations, and CUDA-compatible APIs |
| `src/torchada/csrc/musa_ops.mu` | MUSA kernel implementations |

## Built-in Functions

torchada's C++ extension provides CUDA-compatible implementations of some torch_musa memory management APIs:

### Memory Pool Functions

These functions are automatically injected into `torch.cuda.memory` and allow CUDA code using memory pools to work transparently on MUSA:

- `_cuda_beginAllocateCurrentThreadToPool(device, mempool_id)` - Begin allocating memory from current thread to a memory pool
- `_cuda_endAllocateToPool(device, mempool_id)` - End allocating memory to a memory pool
- `_cuda_releasePool(device, mempool_id)` - Release a memory pool

**Usage example:**
```python
import torchada
import torch

# This works transparently on MUSA - no code changes needed
from torch.cuda.memory import _cuda_beginAllocateCurrentThreadToPool
from torch.cuda.memory import _cuda_endAllocateToPool
from torch.cuda.memory import _cuda_releasePool

# Use the functions as in CUDA code
device = 0
pool_id = torch.cuda.graph_pool_handle()
_cuda_beginAllocateCurrentThreadToPool(device, pool_id)
# ... allocations ...
_cuda_endAllocateToPool(device, pool_id)
_cuda_releasePool(device, pool_id)
```

## Environment Variables

| Variable | Description |
|----------|-------------|
| `TORCHADA_ENABLE_CPP_OPS=1` | Enable C++ operator overrides |
| `TORCHADA_CPP_OPS_VERBOSE=1` | Show compilation output |
| `TORCHADA_DEBUG_CPP_OPS=1` | Log operator calls to stdout |
| `TORCHADA_DISABLE_OP_OVERRIDE_<NAME>=1` | Disable specific operator override |
Expand All @@ -106,7 +132,7 @@ To disable a specific operator override at runtime, set the environment variable

```bash
# Disable the 'neg' operator override, use torch_musa's default instead
TORCHADA_ENABLE_CPP_OPS=1 TORCHADA_DISABLE_OP_OVERRIDE_neg=1 python my_script.py
TORCHADA_DISABLE_OP_OVERRIDE_neg=1 python my_script.py
```

**Important**: The operator name in the environment variable should match the name passed to `is_override_enabled()` in the C++ code. For example, if the code uses `is_override_enabled("neg")`, set `TORCHADA_DISABLE_OP_OVERRIDE_neg=1`.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "torchada"
version = "0.1.53"
version = "0.1.54"
description = "Adapter package for torch_musa to act exactly like PyTorch CUDA"
readme = "README.md"
license = {text = "MIT"}
Expand Down
12 changes: 6 additions & 6 deletions src/torchada/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@
from torch.utils.cpp_extension import CUDAExtension, BuildExtension, CUDA_HOME
"""

__version__ = "0.1.53"
__version__ = "0.1.54"

from . import cuda, utils

# C++ operator overrides are automatically loaded on MUSA platform
from ._cpp_ops import load_cpp_ops
from ._patch import apply_patches, get_original_init_process_group, is_patched
from ._platform import (
Platform,
Expand All @@ -48,17 +51,14 @@
from .triton.autotune.fused_moe import set_default_moe_config_dir
from .utils.cpp_extension import CUDA_HOME

load_cpp_ops()

# Automatically apply patches on import
apply_patches()

# Set default MoE config path for SGL and vLLM
set_default_moe_config_dir()

# Load C++ operator overrides if enabled via TORCHADA_ENABLE_CPP_OPS=1
from ._cpp_ops import load_cpp_ops

load_cpp_ops()


def get_version() -> str:
"""Return the version of torchada."""
Expand Down
16 changes: 4 additions & 12 deletions src/torchada/_cpp_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
This module handles building and loading C++ extensions that can override
ATen operator implementations for the PrivateUse1 (MUSA) dispatch key.

Usage:
# Enable C++ ops by setting environment variable
export TORCHADA_ENABLE_CPP_OPS=1
C++ extensions are automatically loaded on MUSA platform when torchada is imported.

# Then import torchada as usual
import torchada
Usage:
import torchada # C++ extensions are loaded automatically on MUSA

# Or explicitly load
from torchada._cpp_ops import load_cpp_ops
Expand Down Expand Up @@ -75,9 +73,7 @@ def load_cpp_ops(force_reload: bool = False) -> Optional[object]:
"""
Load the C++ operator overrides extension.

The extension is only loaded if:
1. Running on MUSA platform
2. TORCHADA_ENABLE_CPP_OPS=1 environment variable is set
The extension is automatically loaded on MUSA platform.

Args:
force_reload: If True, reload the extension even if already loaded.
Expand All @@ -90,10 +86,6 @@ def load_cpp_ops(force_reload: bool = False) -> Optional[object]:
if _cpp_ops_module is not None and not force_reload:
return _cpp_ops_module

# Check if enabled via environment variable
if os.environ.get("TORCHADA_ENABLE_CPP_OPS") != "1":
return None

# Check if on MUSA platform
from ._platform import is_musa_platform

Expand Down
16 changes: 16 additions & 0 deletions src/torchada/_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import torch

from ._cpp_ops import get_module
from ._platform import is_musa_platform

_patched = False
Expand Down Expand Up @@ -722,6 +723,21 @@ def _patch_torch_cuda_module():
musa_memory_module.MUSAPluggableAllocator
)

# Inject CUDA-compatible memory pool functions from C++ extension
# These functions (_cuda_beginAllocateCurrentThreadToPool, etc.) are
# implemented in torchada's C++ extension to provide CUDA API compatibility
# for torch_musa's memory pool allocator.
cpp_ops_module = get_module()
if cpp_ops_module is not None:
for func_name in [
"_cuda_beginAllocateCurrentThreadToPool",
"_cuda_endAllocateToPool",
"_cuda_releasePool",
]:
func = getattr(cpp_ops_module, func_name, None)
if func is not None:
setattr(musa_memory_module, func_name, func)

# Patch torch.cuda.graph context manager to accept cuda_graph= keyword
# MUSA's graph class uses musa_graph= but CUDA code uses cuda_graph=
_patch_graph_context_manager()
Expand Down
65 changes: 41 additions & 24 deletions src/torchada/csrc/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,34 +12,40 @@

#include "ops.h"

#include "torch_musa/csrc/core/Device.h"
#include "torch_musa/csrc/aten/musa/MUSAContext.h"
#include "torch_musa/csrc/core/MUSAPluggableAllocator.h"
#include <thread>

namespace torchada {

// ============================================================================
// Example: Operator override template (commented out - for reference)
// Memory pool allocation functions (CUDA-compatible API on MUSA)
// ============================================================================
//
// To override an ATen operator, follow this pattern:
//
// static at::Tensor custom_add_impl(
// const at::Tensor& self,
// const at::Tensor& other,
// const at::Scalar& alpha) {
//
// log_op_call("add.Tensor");
//
// // Your custom implementation here
// // IMPORTANT: Avoid calling the same operator to prevent infinite recursion
// // Use in-place operations or lower-level primitives instead
// auto result = at::empty_like(self);
// result.copy_(self);
// result.add_(other, alpha);
// return result;
// }
//
// Then register it:
// TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
// m.impl("add.Tensor", custom_add_impl);
// }

static void _musa_beginAllocateCurrentThreadToPool(
c10::DeviceIndex device,
c10::musa::MempoolId_t mempool_id) {
auto tid = std::this_thread::get_id();

c10::musa::MUSACachingAllocator::beginAllocateToPool(
device, mempool_id, [=](musaStream_t) {
auto current_tid = std::this_thread::get_id();
return current_tid == tid;
});
}

static void _musa_endAllocateToPool(
c10::DeviceIndex device,
c10::musa::MempoolId_t mempool_id) {
c10::musa::MUSACachingAllocator::endAllocateToPool(device, mempool_id);
}

static void _musa_releasePool(
c10::DeviceIndex device,
c10::musa::MempoolId_t mempool_id) {
c10::musa::MUSACachingAllocator::releasePool(device, mempool_id);
}

// ============================================================================
// Utility functions exposed to Python
Expand All @@ -61,6 +67,7 @@ void mark_loaded() {

} // namespace torchada


// ============================================================================
// Python bindings
// ============================================================================
Expand All @@ -74,4 +81,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Get the C++ ops extension version");
m.def("_mark_loaded", &torchada::mark_loaded,
"Mark the extension as loaded (internal use)");

m.def("_cuda_beginAllocateCurrentThreadToPool",
&torchada::_musa_beginAllocateCurrentThreadToPool,
"Begin allocating memory from the current thread to a memory pool");
m.def("_cuda_endAllocateToPool",
&torchada::_musa_endAllocateToPool,
"End allocating memory to a memory pool");
m.def("_cuda_releasePool",
&torchada::_musa_releasePool,
"Release a memory pool");
}
4 changes: 2 additions & 2 deletions src/torchada/triton/runtime/fp8_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Optional
from typing import Optional, Tuple

import torch

Expand Down Expand Up @@ -48,7 +48,7 @@ def scaled_fp8_quant(
scale: Optional[torch.Tensor] = None,
num_token_padding: Optional[int] = None,
use_per_token_if_dynamic: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor]:
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
shape = input.shape
if num_token_padding:
Expand Down
1 change: 0 additions & 1 deletion src/torchada/triton/runtime/fused_moe/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,6 @@ def select_experts(
layer_id: Optional[int] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
) -> StandardTopKOutput:

top_k = topk_config.top_k
use_grouped_topk = topk_config.use_grouped_topk
topk_group = topk_config.topk_group
Expand Down
Loading