Conversation
|
Many thanks for the contribution. |
|
Thanks for the contribution! Any benchmarking results for the GPT-OSS models? |
shimizust
left a comment
There was a problem hiding this comment.
Thank you for the contribution! What is the discrepancy you find for bf16 convergence tests?
| pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}") | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not is_qwen3_available(), reason="gpt oss module not available") |
There was a problem hiding this comment.
Should be is_gpt_oss_available?
| cross_entropy: bool = False, | ||
| fused_linear_cross_entropy: bool = True, | ||
| rms_norm: bool = True, | ||
| swiglu: bool = True, |
There was a problem hiding this comment.
if swiglu can't be implemented now, let's set to False by default and raise NotImplementedError if set to True
|
@PKUWZP I'll add the benchmark results as soon as the swiglu implementation is complete. |
|
@shimizust During the convergence test, the loss values for the two models running in bf16 diverged significantly at certain steps. This is likely related to the issue discussed here: #742.
|
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> This PR is a follow-up to #830, exposing `accum_dtype` option for monkey patch functions. All bf16 convergence tests related to fused linear cross entropy are also enforced to run with `accum_dtype=torch.float32` for numerical stability. Related: #512, #742, #827, #850 <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence Signed-off-by: Tcc0403 <76503978+Tcc0403@users.noreply.github.com>
|
For future contributors: When patching RMSNorm, there are 4 init args that are easily overlooked.
Take GptOss for example It requires
Create |
|
@Comet0322 Thanks for updating. Do you think we can check in what you have and figure out swiglu after? After @Tcc0403 's accum_dtype changes to the bf16 tests, do they pass now? |
|
@Comet0322 and I have discussed elsewhere. It seems the router topk choice for experts is quite sensitive in bf16, leading the discrepancy in final losses. I suggest just skipping bf16 convergnece test for now. cc @shimizust |

Summary
Add GPT-OSS model support, addressing #848
Completed patching for RoPE, RMSNorm, cross_entropy, and fused_linear_cross_entropy.
Known Issues
Testing Done
FP32 Log
BF16 Log
Env: torch 2.8.0, triton 3.4.0, transformers 4.55.0
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence