[float8] add auto_filter_for_recipe to float8 #2410
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Part of pytorch/torchtitan#1207
Problem
filter_fqns
for float8 model conversion are fine for the fp8 tensorwise recipe, but bad for the float8 rowwise recipe.Solution
This has been a footgun for various users as well (including Poolside), so I created an "auto filter" (#2410) which automatically filters Linears for a given float8 recipe, by checking for the following criteria:
filter_fqns
I integrated a PoC into torchtitan and the auto filter improved fp8 rowwise perf both local Llama3 8b run and Llama3 70b MAST run, compared to the default filter_fn we have now.
It prevents users from hitting this common footgun, while also preserving the flexibility to define their model-specific fqns.
Results
See pytorch/torchtitan#1207 for Llama3 70b results, TL;DR is filtering wk and wv improves TPS ~10% for vanilla TP and ~15% for async TP.