You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# Summary
This pull request introduces an fp8 row-scaling kernel as an optional implementation for `scaled_mm`. The kernel selection is based on the scaling tensors of the inputs. For inputs `x` and `y` of shape `[M, K]` and `[K, N]` respectively, the following conditions must be met:
- `x`'s scale should be a 1-dimensional tensor of length `M`.
- `y`'s scale should be a 1-dimensional tensor of length `N`.
It's important to note that this kernel is not called "rowwise, columnwise" scaling because, although the scales for `y` are semantically along its columns, this implementation only supports the TN format. This means the scaling is along the faster-moving dimension, or the "row".
The following two PRs were required to enable local builds:
- [PR pytorch#126185](pytorch#126185)
- [PR pytorch#125523](pytorch#125523)
### Todo
We still do not build our Python wheels with this architecture.
@ptrblck@malfet, should we replace `sm_90` with `sm_90a`?
The NVRTC TMA shadowing feels wrong, but I a not sure the right way to spoof the symbol for this compilation unit:
https://github.com/pytorch/pytorch/pull/125204/files#r1586986954
#### ifdef
I tried to use : `#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12000 && \
defined(__CUDA_ARCH__) && __CUDA_ARCH__ > 900` to gate the building of the kernel. I was having a hell of a time with this.. so I am not really sure the right way to do this
Kernel Credit:
@jwfromm
Pull Request resolved: pytorch#125204
Approved by: https://github.com/lw, https://github.com/malfet
0 commit comments