Skip to content

fix(torch): make gemm fallback portable#611

Merged
voltjia merged 1 commit into
masterfrom
fix/torch-gemm-fallback
May 16, 2026
Merged

fix(torch): make gemm fallback portable#611
voltjia merged 1 commit into
masterfrom
fix/torch-gemm-fallback

Conversation

@voltjia
Copy link
Copy Markdown
Collaborator

@voltjia voltjia commented May 16, 2026

Summary

  • Keeps the PyTorch Gemm fallback on fused addmm_out / baddbmm_out for CPU and NVIDIA, preserving their existing numerical behavior.
  • Uses a portable matmul plus explicit alpha / beta update for other vendor PyTorch backends where fused out= GEMM variants are unavailable or numerically inconsistent.
  • Updates the tests/test_gemm.py reference path so vendor fallback tests compare against the same portable execution order.
  • Adds a targeted Iluvatar fp16 skip for the existing Gemm case where the reported backend execution is not stable.

Motivation

The PyTorch-backed Gemm fallback is used as a cross-platform fallback path, but some vendor PyTorch forks reject or diverge on fused out= GEMM variants. The implementation also previously needed exception-based fallback logic, which conflicts with the project C++ rule that forbids exceptions. Keeping fused calls only where they are known to match the current baseline, and using an explicit portable path elsewhere, keeps the fallback deterministic across supported backends.

Closes N/A.

Type of Change

  • feat — new feature / new operator / new platform
  • fix — bug fix
  • perf — performance improvement (no behavioral change)
  • refactor — code restructuring without behavior change
  • test — adding or fixing tests only
  • docs — documentation only
  • build / ci — build system or CI configuration
  • chore — tooling, formatting, or other non-code changes
  • Breaking change (requires a ! in the Conventional Commits prefix or a BREAKING CHANGE: footer)

Platforms Affected

  • CPU (WITH_CPU)
  • NVIDIA (WITH_NVIDIA)
  • Iluvatar (WITH_ILUVATAR)
  • MetaX (WITH_METAX)
  • Cambricon (WITH_CAMBRICON)
  • Moore (WITH_MOORE)
  • Ascend (WITH_ASCEND)
  • PyTorch C++ bindings (WITH_TORCH)
  • Build system / CMake / CI
  • Python bindings / user-facing API

Test Results on Supported Platforms

The higher item counts compared with #610 are expected. This validation was run with the existing hand-written PyTorch C++ backend enabled (WITH_TORCH=ON via backend auto-detection), so the pre-existing Add / Gemm PyTorch slot is included in collection. This PR does not add new parametrized test coverage; the count increase comes from enabling that existing backend during validation.

The clean full-suite validation was run on a temporary branch containing this PR plus #612, because the two fixes address independent failures that otherwise mask each other in full-platform runs.

Platform Built pytest Result Notes / Hardware
NVIDIA Yes 6295 passed, 2447 skipped in 345.48s Full suite passed.
Iluvatar Yes 4795 passed, 2447 skipped in 284.25s Full suite passed with #612 included.
MetaX Yes 5795 passed, 1447 skipped in 361.81s Full suite passed.
Cambricon Yes 3073 passed, 3857 skipped in 920.38s Full suite passed.
Moore Yes 5759 passed, 1483 skipped in 574.92s Full suite passed.
Ascend Yes 4472 passed, 2710 skipped in 527.68s; wrapper exit code 137 Pytest summary passed; the container exited after the test summary.

Additional targeted validation for this PR:

MetaX tests/test_gemm.py: 3000 passed in 21.81s
Full `pytest` output (optional)
Combined validation with #612:
NVIDIA: 6295 passed, 2447 skipped in 345.48s
Iluvatar: 4795 passed, 2447 skipped in 284.25s
MetaX: 5795 passed, 1447 skipped in 361.81s
Cambricon: 3073 passed, 3857 skipped in 920.38s
Moore: 5759 passed, 1483 skipped in 574.92s
Ascend: 4472 passed, 2710 skipped in 527.68s; wrapper exit code 137 after pytest summary

Benchmark / Performance Impact

N/A. This PR prioritizes portability of the fallback implementation; it is not intended as a performance optimization.

Notes for Reviewers

The fallback intentionally avoids C++ exceptions. CPU and NVIDIA keep the fused PyTorch path to preserve the current baseline. Other device specializations use the portable matmul path because several vendor PyTorch forks do not provide the same fused out= behavior.

This PR is independent from #612, but a fully clean all-platform suite currently requires both fixes: this PR removes unrelated tests/test_gemm.py fallback failures, while #612 removes the Iluvatar tests/test_causal_softmax.py reference failure.


Checklist

Title, Branch, and Commits

  • PR title follows Conventional Commits (e.g. feat(nvidia): …, fix(cuda/gemm): …).
  • Branch name follows <type>/xxx-yyyy-zzzz where <type> matches the PR title's Conventional Commits type and words are joined with hyphens (see CONTRIBUTING.md §Branches).
  • Each commit message follows Conventional Commits.
  • Small PR is a single squashable commit; or, for a large PR, every commit is meaningful, well-formed, and independently reviewable (see CONTRIBUTING.md §Pull Requests).
  • No stray merge commits from master — the branch is rebased cleanly on top of the current master.
  • No fixup! / squash! / wip commits remain.

Scope and Design

  • Changes are minimal — nothing unrelated to the stated motivation was added (CONTRIBUTING.md §Code/General).
  • No dead code, commented-out blocks, debug prints, printf/std::cout/print(...) left behind, or TODO without an owner and issue link.
  • No unrelated formatting churn that would obscure the diff.
  • N/A — No public API changes.

General Code Hygiene (applies to all languages)

  • The code is self-explanatory; comments were added only where the why is non-obvious (CONTRIBUTING.md §Code/General).
  • Every modified or added file ends with a single trailing newline (CONTRIBUTING.md §Code/General).
  • No trailing whitespace, tab/space mixing, or stray BOMs.
  • Identifiers in comments and error messages are wrapped in backticks (e.g. the `seqlens_k` tensor) (CONTRIBUTING.md §Code/General).
  • All comments and error messages are in English (CONTRIBUTING.md §Code/General).
  • Comments and error messages are complete sentences — capitalized first letter, terminal punctuation — unless the language/framework convention says otherwise (CONTRIBUTING.md §Code/General; §Python).

C++ Specific (if C++ files changed)

  • Code follows the Google C++ Style Guide strictly.
  • clang-format (version 21, per .github/workflows/clang-format.yml) has been run against all modified .h, .cc, .cuh, and .mlu files; the diff is clean.
  • clang-tidy concerns (per .clang-tidy) have been reviewed — no new warnings beyond the existing baseline.
  • Operator parameter order is inputs first, outputs last; attributes are between inputs and outputs; naming follows PyTorch → ONNX → CUDA API precedence (CONTRIBUTING.md §C++).
  • No exceptions are thrown. Error paths use assert with messages that include at least __FILE__, __LINE__, and __func__ (CONTRIBUTING.md §C++).
  • N/A — No new error or warning messages.
  • N/A — No kernel files changed.
  • N/A — No kernel launchers changed.
  • Constructor initializer list order matches member declaration order (CONTRIBUTING.md §C++).
  • Exactly one blank line between classes, between classes and functions, and between functions (CONTRIBUTING.md §C++).
  • Exactly one blank line between members (functions and variables) within a class (CONTRIBUTING.md §C++).
  • Exactly one blank line before and after the contents of a namespace (CONTRIBUTING.md §C++).
  • N/A — No new operators added.
  • No raw new/delete; RAII / smart pointers / existing allocators are used.

Python Specific (if Python files changed)

  • Code is PEP 8 compliant; ruff check passes cleanly on CI (see .github/workflows/ruff.yml).
  • ruff format --check passes cleanly — if not, run ruff format and commit the result.
  • Comments are complete English sentences, starting with a capital letter and ending with punctuation; Markdown backticks are used for code references (CONTRIBUTING.md §Python).
  • Framework-specific conventions (e.g. lowercase pytest.skip messages without terminal period) are honored where applicable (CONTRIBUTING.md §Python).
  • No blank line between the function signature and the body when there is no docstring or comment (CONTRIBUTING.md §Python).
  • A blank line is present before and after if, for, and similar control-flow statements (CONTRIBUTING.md §Python).
  • A blank line appears before each return, except when it directly follows a control-flow statement like if or for (CONTRIBUTING.md §Python).
  • N/A — No Python docstrings added.
  • N/A — No type hints changed.

Testing

  • pytest was run locally on every supported platform that this PR can affect, and the results are recorded in the "Test Results" table above (CONTRIBUTING.md §Pull Requests).
  • N/A — No platform is intentionally omitted.
  • New functionality has matching tests under tests/ following tests/test_add.py / tests/test_gemm.py patterns (CONTRIBUTING.md §Adding an Operator).
  • Tests use pytest.mark.parametrize correctly: dependent parameters share one decorator (e.g. @pytest.mark.parametrize("dtype, rtol, atol", …)); independent parameters use separate decorators ordered by parameter declaration.
  • Where appropriate, pytest.mark.auto_act_and_assert is used and the test returns a Payload whose func and ref share the same calling convention.
  • Default dtype / device parameterization is relied on, or overridden with an explicit pytest.mark.parametrize when necessary.
  • Any new test that is flaky under parallelism is marked so, or documented to require pytest -n 1.
  • For bug fixes: a regression test has been added that fails on master and passes with this PR.

Build, CI, and Tooling

  • The project builds cleanly from a fresh directory with pip install .[dev] on at least one affected platform.
  • compile_commands.json still regenerates (CMake option CMAKE_EXPORT_COMPILE_COMMANDS=ON in pyproject.toml — required by the code-lint skill and clang-tidy -p).
  • N/A — No new backend or device auto-detection.
  • Only one CUDA-like GPU backend is selectable at a time — the existing mutual-exclusion check in CMakeLists.txt is not broken.
  • Both CI workflows (clang-format.yml, ruff.yml) are green locally (or expected to be green on CI).
  • No new runtime dependency was added without updating pyproject.toml's [project.optional-dependencies] (or justified in the PR description).

Documentation

  • N/A — No user-facing docs or workflow changed.
  • N/A — No new operators, dispatch helpers, or public utilities.
  • N/A — No user-visible breaking change.

Security and Safety

  • No secrets, access tokens, internal URLs, customer data, or personal hardware identifiers have been committed.
  • N/A — No third-party code added.
  • No unsafe pointer arithmetic, uninitialized reads, or missing bounds checks were introduced.

@voltjia voltjia changed the title fix(torch): make gemm fallback portable fix(torch): make gemm fallback portable May 16, 2026
@voltjia voltjia force-pushed the fix/torch-gemm-fallback branch from fb172b6 to 7bcc2d5 Compare May 16, 2026 02:22
@voltjia voltjia force-pushed the fix/torch-gemm-fallback branch from 7bcc2d5 to a48c4e8 Compare May 16, 2026 03:06
@voltjia voltjia marked this pull request as ready for review May 16, 2026 05:36
@voltjia voltjia requested review from a team, Ziminli and crapromer May 16, 2026 05:36
@voltjia
Copy link
Copy Markdown
Collaborator Author

voltjia commented May 16, 2026

@crapromer 初审,@Ziminli 终审。

@voltjia voltjia merged commit 80c942a into master May 16, 2026
4 checks passed
@voltjia voltjia deleted the fix/torch-gemm-fallback branch May 16, 2026 13:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants