-
Notifications
You must be signed in to change notification settings - Fork 564
feat: add support for custom compile options in torch_xla.compile and… #9575
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
feat: add support for custom compile options in torch_xla.compile and… #9575
Conversation
b3e8b3e
to
2a4b284
Compare
@qihqi Thank you for your review! I realized there was a lint issue and have fixed it. Would you mind re-running the tests? |
hi @sshonTT would you rebase to latest HEAD? the 2 CI issue should be fixed |
Head branch was pushed to by a user without write access
2a4b284
to
2a2eed5
Compare
hi @qihqi , gentle reminder. |
Hi @zhanyong-wan — not sure if this is the right tag. I’m waiting on @qihqi’s follow-up (he previously approved). I rebased per his suggestion and want to confirm CI is green. Could you trigger CI for this PR and, if you have time, give it a review? Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
2bdd9ed
to
7be5b39
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
torch_xla/torch_xla.py
Outdated
@@ -136,6 +137,9 @@ def compile( | |||
max_different_graphs (Optional[int]): number of different traced graphs of the given | |||
model/function that we are allowed to have. An error will be raised in case this limit | |||
is exceeded. | |||
custom_compile_options (Optional[dict[str, Any]]): XLA **compiler flag overrides** |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Document how None and an empty dict are different?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are doing same behaviour (no-op).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add doc about this.
torch_xla/torch_xla.py
Outdated
@@ -215,6 +219,9 @@ def _compile(): | |||
torch_xla._XLAC._set_use_eager_mode(saved_eager_mode_status) | |||
torch_xla._XLAC._set_current_graph_name(saved_current_graph_name) | |||
|
|||
custom_compile_options = custom_compile_options or {} | |||
if len(custom_compile_options) > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This means there's no way for the user to clear the custom compile options?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I’m exposing torch_xla.set_custom_compile_options for that purpose.
Users can clear the options by calling this API with an empty dict {}.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The behavior of compile() should be consistent with set_custom_compile_options(). If given an empty dict, both should do the same thing. Otherwise the API is very confusing and error-prone.
I suggest:
- keep set_custom_compile_options() unchanged,
- treat None as no-op here,
- treat {} as "clear options" here.
7be5b39
to
4278f97
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the future, could you avoid git rebase or git merge during a code review? Otherwise it's hard to track the changes. Thanks!
torch_xla/torch_xla.py
Outdated
@@ -215,6 +219,9 @@ def _compile(): | |||
torch_xla._XLAC._set_use_eager_mode(saved_eager_mode_status) | |||
torch_xla._XLAC._set_current_graph_name(saved_current_graph_name) | |||
|
|||
custom_compile_options = custom_compile_options or {} | |||
if len(custom_compile_options) > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The behavior of compile() should be consistent with set_custom_compile_options(). If given an empty dict, both should do the same thing. Otherwise the API is very confusing and error-prone.
I suggest:
- keep set_custom_compile_options() unchanged,
- treat None as no-op here,
- treat {} as "clear options" here.
torch_xla/torch_xla.py
Outdated
they will be stringified before being passed to XLA. | ||
Pass an empty dict `{}` to clear previously set options. | ||
""" | ||
torch_xla._XLAC._set_custom_compile_options(options or {}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or {}
is unnecessary here. options
cannot be None according to the type annotation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
deleted it.
4278f97
to
f3ad8bd
Compare
@@ -554,13 +554,18 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile( | |||
|
|||
for (auto& instance : instances) { | |||
xla::CompileOptions compile_options; | |||
for (auto& option : custom_compile_options_) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make the reference const and use structured binding:
const auto& [name, value] : ...
… PJRT backend This change introduces the ability to pass custom compile options from Python down to the PJRT backend, allowing users to fine-tune XLA compilation behavior without modifying core code. Key changes: * Python API * Added custom_compile_options parameter to torch_xla.compile for passing compile-time options as a dict (supports bool, float, int, and str values). * Added torch_xla.set_custom_compile_options() utility for setting compile options globally. * Added internal binding _XLAC._set_custom_compile_options(). * C++ Runtime * Added SetCustomCompileOptions() virtual method to ComputationClient and implemented it in PjRtComputationClient. * PjRtComputationClient now stores custom_compile_options_ and injects them into xla::CompileOptions.env_option_overrides during compilation. * Options are stringified before being passed to XLA for compatibility. Motivation: This enables advanced users to pass through backend-specific tuning flags (e.g., enabling experimental optimizations, toggling partitioning strategies) without hardcoding them, improving flexibility for research and debugging workflows.
Head branch was pushed to by a user without write access
f3ad8bd
to
a27f105
Compare
… PJRT backend
Issue description : #9555
This change introduces the ability to pass custom compile options from Python down to the PJRT backend, allowing users to fine-tune XLA compilation behavior without modifying core code. Key changes: