Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch.compile] PyTorch 2.6 and nightly compatibility #12393

Open
wants to merge 24 commits into
base: main
Choose a base branch
from

Conversation

youkaichao
Copy link
Member

@youkaichao youkaichao commented Jan 24, 2025

Manually tested for 2.7.0.dev20250121+cu126

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
@youkaichao youkaichao changed the title [torch.compile] add compiler backend abstraction [torch.compile] PyTorch 2.6 and nightly compatibility Jan 24, 2025
Signed-off-by: youkaichao <[email protected]>
@youkaichao youkaichao marked this pull request as ready for review January 24, 2025 13:46
Copy link
Contributor

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need bump up the torch version in the dep config?

@youkaichao
Copy link
Member Author

Do we need bump up the torch version in the dep config?

You can test it following "Use an existing PyTorch installation" in https://docs.vllm.ai/en/latest/getting_started/installation/gpu/index.html .

Copy link
Contributor

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Internally, we still hit similar issues before this PR. I will dig a bit and see what's going on.

@@ -53,6 +64,9 @@ def add(self, pass_: InductorPass):
assert isinstance(pass_, InductorPass)
self.passes.append(pass_)

def uuid(self):
return self.__getstate__()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it failed on the unpickling of this class...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it still relevant? do we need any change for this function?

@houseroad
Copy link
Contributor

Btw, @chenyang78 is one of our Inductor expert. If @zou3519 or @zhxchen17 can give a pass, that will be great!

@fialhocoelho
Copy link
Contributor

An interesting point is that, to use PT 2.6, we previously had a compatibility issue with xformers, which required PT 2.4. Over the weekend, they updated to v0.0.29.post2, which now supports version 2.6 🚀

@youkaichao youkaichao requested a review from mgoin February 4, 2025 16:45
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm reading through the PR now, but a bit of context would make it easier to follow. Could you say a bit about why the backend and cache changes were needed to support PyTorch >= 2.6?

@youkaichao
Copy link
Member Author

I'm reading through the PR now, but a bit of context would make it easier to follow. Could you say a bit about why the backend and cache changes were needed to support PyTorch >= 2.6?

because pytorch's internal function changed.

Comment on lines +196 to +220
if torch.__version__.startswith("2.5"):
original_load = FxGraphCache.load
original_load_name = "torch._inductor.codecache.FxGraphCache.load"

def hijack_load(*args, **kwargs):
inductor_compiled_graph = original_load(*args, **kwargs)
nonlocal file_path
file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa
return inductor_compiled_graph

hijacked_compile_fx_inner = torch._inductor.compile_fx.compile_fx_inner # noqa
elif torch.__version__ >= "2.6":
# function renamed in 2.6
original_load_name = None

def hijacked_compile_fx_inner(*args, **kwargs):
output = torch._inductor.compile_fx.compile_fx_inner(
*args, **kwargs)
nonlocal hash_str
inductor_compiled_graph = output
if inductor_compiled_graph is not None:
nonlocal file_path
file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa
hash_str = inductor_compiled_graph._fx_graph_cache_key
return output
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e.g. here, @tlrmchlsmth

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants