Skip to content

Conversation

sshonTT
Copy link

@sshonTT sshonTT commented Aug 21, 2025

… PJRT backend

Issue description : #9555

This change introduces the ability to pass custom compile options from Python down to the PJRT backend, allowing users to fine-tune XLA compilation behavior without modifying core code. Key changes:

  • Python API
    • Added custom_compile_options parameter to torch_xla.compile for passing compile-time options as a dict (supports bool, float, int, and str values).
    • Added torch_xla.set_custom_compile_options() utility for setting compile options globally.
    • Added internal binding _XLAC._set_custom_compile_options().
  • C++ Runtime
    • Added SetCustomCompileOptions() virtual method to ComputationClient and implemented it in PjRtComputationClient.
    • PjRtComputationClient now stores custom_compile_options_ and injects them into xla::CompileOptions.env_option_overrides during compilation.
    • Options are stringified before being passed to XLA for compatibility. Motivation:
This enables advanced users to pass through backend-specific tuning flags (e.g., enabling experimental optimizations, toggling partitioning strategies) without hardcoding them, improving flexibility for research and debugging workflows.

@sshonTT sshonTT force-pushed the sshon/custom-compiler-options-upstream branch from b3e8b3e to 2a4b284 Compare August 21, 2025 20:25
@sshonTT
Copy link
Author

sshonTT commented Aug 21, 2025

@qihqi Thank you for your review! I realized there was a lint issue and have fixed it. Would you mind re-running the tests?

@qihqi qihqi enabled auto-merge (squash) August 23, 2025 04:04
@qihqi
Copy link
Collaborator

qihqi commented Aug 25, 2025

hi @sshonTT would you rebase to latest HEAD? the 2 CI issue should be fixed

auto-merge was automatically disabled August 26, 2025 01:02

Head branch was pushed to by a user without write access

@sshonTT sshonTT force-pushed the sshon/custom-compiler-options-upstream branch from 2a4b284 to 2a2eed5 Compare August 26, 2025 01:02
@sshonTT
Copy link
Author

sshonTT commented Aug 26, 2025

hi @sshonTT would you rebase to latest HEAD? the 2 CI issue should be fixed

@qihqi thank you for the info. I had rebased so could you re-run test?

@sshonTT
Copy link
Author

sshonTT commented Aug 29, 2025

hi @qihqi , gentle reminder.

@sshonTT
Copy link
Author

sshonTT commented Sep 2, 2025

Hi @zhanyong-wan — not sure if this is the right tag. I’m waiting on @qihqi’s follow-up (he previously approved). I rebased per his suggestion and want to confirm CI is green. Could you trigger CI for this PR and, if you have time, give it a review? Thanks!

Copy link
Collaborator

@zhanyong-wan zhanyong-wan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@sshonTT sshonTT force-pushed the sshon/custom-compiler-options-upstream branch 3 times, most recently from 2bdd9ed to 7be5b39 Compare September 2, 2025 19:21
Copy link
Collaborator

@zhanyong-wan zhanyong-wan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@@ -136,6 +137,9 @@ def compile(
max_different_graphs (Optional[int]): number of different traced graphs of the given
model/function that we are allowed to have. An error will be raised in case this limit
is exceeded.
custom_compile_options (Optional[dict[str, Any]]): XLA **compiler flag overrides**
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document how None and an empty dict are different?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are doing same behaviour (no-op).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add doc about this.

@@ -215,6 +219,9 @@ def _compile():
torch_xla._XLAC._set_use_eager_mode(saved_eager_mode_status)
torch_xla._XLAC._set_current_graph_name(saved_current_graph_name)

custom_compile_options = custom_compile_options or {}
if len(custom_compile_options) > 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means there's no way for the user to clear the custom compile options?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m exposing torch_xla.set_custom_compile_options for that purpose.
Users can clear the options by calling this API with an empty dict {}.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The behavior of compile() should be consistent with set_custom_compile_options(). If given an empty dict, both should do the same thing. Otherwise the API is very confusing and error-prone.

I suggest:

  • keep set_custom_compile_options() unchanged,
  • treat None as no-op here,
  • treat {} as "clear options" here.

@sshonTT sshonTT force-pushed the sshon/custom-compiler-options-upstream branch from 7be5b39 to 4278f97 Compare September 4, 2025 18:22
Copy link
Collaborator

@zhanyong-wan zhanyong-wan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the future, could you avoid git rebase or git merge during a code review? Otherwise it's hard to track the changes. Thanks!

@@ -215,6 +219,9 @@ def _compile():
torch_xla._XLAC._set_use_eager_mode(saved_eager_mode_status)
torch_xla._XLAC._set_current_graph_name(saved_current_graph_name)

custom_compile_options = custom_compile_options or {}
if len(custom_compile_options) > 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The behavior of compile() should be consistent with set_custom_compile_options(). If given an empty dict, both should do the same thing. Otherwise the API is very confusing and error-prone.

I suggest:

  • keep set_custom_compile_options() unchanged,
  • treat None as no-op here,
  • treat {} as "clear options" here.

they will be stringified before being passed to XLA.
Pass an empty dict `{}` to clear previously set options.
"""
torch_xla._XLAC._set_custom_compile_options(options or {})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or {} is unnecessary here. options cannot be None according to the type annotation.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deleted it.

@sshonTT sshonTT force-pushed the sshon/custom-compiler-options-upstream branch from 4278f97 to f3ad8bd Compare September 5, 2025 15:59
@@ -554,13 +554,18 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(

for (auto& instance : instances) {
xla::CompileOptions compile_options;
for (auto& option : custom_compile_options_) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make the reference const and use structured binding:

const auto& [name, value] : ...

@qihqi qihqi enabled auto-merge (squash) September 9, 2025 01:12
… PJRT backend

This change introduces the ability to pass custom compile options from Python down to the PJRT backend, allowing users to fine-tune XLA compilation behavior without modifying core code.
Key changes:
* Python API
    * Added custom_compile_options parameter to torch_xla.compile for passing compile-time options as a dict (supports bool, float, int, and str values).
    * Added torch_xla.set_custom_compile_options() utility for setting compile options globally.
    * Added internal binding _XLAC._set_custom_compile_options().
* C++ Runtime
    * Added SetCustomCompileOptions() virtual method to ComputationClient and implemented it in PjRtComputationClient.
    * PjRtComputationClient now stores custom_compile_options_ and injects them into xla::CompileOptions.env_option_overrides during compilation.
    * Options are stringified before being passed to XLA for compatibility.
Motivation:
This enables advanced users to pass through backend-specific tuning flags (e.g., enabling experimental optimizations, toggling partitioning strategies) without hardcoding them, improving flexibility for research and debugging workflows.
auto-merge was automatically disabled September 11, 2025 20:10

Head branch was pushed to by a user without write access

@sshonTT sshonTT force-pushed the sshon/custom-compiler-options-upstream branch from f3ad8bd to a27f105 Compare September 11, 2025 20:10
@sshonTT
Copy link
Author

sshonTT commented Sep 11, 2025

hi @qihqi, I see #9634 fixes build failure in ci, so I rebased it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants