Skip to content

Commit f3ad8bd

Browse files
committed
feat: add support for custom compile options in torch_xla.compile and PJRT backend
This change introduces the ability to pass custom compile options from Python down to the PJRT backend, allowing users to fine-tune XLA compilation behavior without modifying core code. Key changes: * Python API * Added custom_compile_options parameter to torch_xla.compile for passing compile-time options as a dict (supports bool, float, int, and str values). * Added torch_xla.set_custom_compile_options() utility for setting compile options globally. * Added internal binding _XLAC._set_custom_compile_options(). * C++ Runtime * Added SetCustomCompileOptions() virtual method to ComputationClient and implemented it in PjRtComputationClient. * PjRtComputationClient now stores custom_compile_options_ and injects them into xla::CompileOptions.env_option_overrides during compilation. * Options are stringified before being passed to XLA for compatibility. Motivation:
This enables advanced users to pass through backend-specific tuning flags (e.g., enabling experimental optimizations, toggling partitioning strategies) without hardcoding them, improving flexibility for research and debugging workflows.
1 parent 89f929b commit f3ad8bd

File tree

6 files changed

+61
-0
lines changed

6 files changed

+61
-0
lines changed

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3247,6 +3247,17 @@ void InitXlaModuleBindings(py::module m) {
32473247
XLA_ERROR() << "Could not get the buffer pointer for XLATensor "
32483248
"without a data handle or an IR.";
32493249
})
3250+
.def("_set_custom_compile_options",
3251+
[](const py::dict& compile_options) {
3252+
std::unordered_map<std::string, std::string> options;
3253+
for (const auto& item : compile_options) {
3254+
// Keys must be strings; values are stringified.
3255+
const std::string key = py::str(item.first);
3256+
options[key] = py::str(item.second);
3257+
}
3258+
runtime::GetComputationClientOrDie()->SetCustomCompileOptions(
3259+
options);
3260+
})
32503261
.def(
32513262
// from an XLA tensor to a PyCapsule.
32523263
// When consuming the PyCapsule, we should synchronize

torch_xla/csrc/runtime/computation_client.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,14 @@ class ComputationClient {
446446
// after the last ':' character of the device string.
447447
static int64_t GetDeviceOrdinal(const std::string& device);
448448

449+
// Sets XLA compile option overrides used by the backend compiler.
450+
// - The map keys are XLA compiler flag names (env option override keys).
451+
// - The values are stringified flag values.
452+
// - Calling this method **overwrites** any previously set options.
453+
// (Pass an empty map to clear.)
454+
virtual void SetCustomCompileOptions(
455+
const std::unordered_map<std::string, std::string>& options) = 0;
456+
449457
protected:
450458
static constexpr auto spmd_device_str = "SPMD:0";
451459

torch_xla/csrc/runtime/ifrt_computation_client.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,11 @@ class IfrtComputationClient : public ComputationClient {
172172
XLA_ERROR() << __FUNCTION__ << " not implemented";
173173
}
174174

175+
void SetCustomCompileOptions(
176+
const std::unordered_map<std::string, std::string>& options) override {
177+
XLA_ERROR() << __FUNCTION__ << " not implemented";
178+
}
179+
175180
// Creates a new instance of IfrtComputationClient and initializes it.
176181
static absl::StatusOr<absl_nonnull std::unique_ptr<IfrtComputationClient>>
177182
Create();

torch_xla/csrc/runtime/pjrt_computation_client.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,13 +554,18 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
554554

555555
for (auto& instance : instances) {
556556
xla::CompileOptions compile_options;
557+
for (auto& option : custom_compile_options_) {
558+
compile_options.env_option_overrides.push_back(
559+
{option.first, option.second});
560+
}
557561
if (enable_cm_in_mp) {
558562
compile_options.executable_build_options.set_use_spmd_partitioning(true);
559563
compile_options.env_option_overrides.push_back(
560564
{"xla_tpu_decompose_all_gather_einsum", true});
561565
compile_options.env_option_overrides.push_back(
562566
{"xla_tpu_decompose_einsum_reduce_scatter", true});
563567
}
568+
564569
if (instance.is_sharded) {
565570
// TODO(yeounoh) multi-host, multi-slice configurations
566571
compile_options.executable_build_options.set_use_spmd_partitioning(true);
@@ -1052,5 +1057,13 @@ void PjRtComputationClient::OnReadyCallback(
10521057
[callback](absl::Status unused) { callback(); });
10531058
}
10541059

1060+
void PjRtComputationClient::SetCustomCompileOptions(
1061+
const std::unordered_map<std::string, std::string>& options) {
1062+
custom_compile_options_.clear();
1063+
for (const auto& [key, value] : options) {
1064+
custom_compile_options_[key] = value;
1065+
}
1066+
}
1067+
10551068
} // namespace runtime
10561069
} // namespace torch_xla

torch_xla/csrc/runtime/pjrt_computation_client.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,10 @@ class PjRtComputationClient : public ComputationClient {
165165
void OnReadyCallback(DataPtr data,
166166
const std::function<void()>& callback) override;
167167

168+
// See base class for semantics. This call overwrites previously set options.
169+
void SetCustomCompileOptions(
170+
const std::unordered_map<std::string, std::string>& options) override;
171+
168172
// Creates a new instance of PjRtComputationClient and initializes it.
169173
static absl::StatusOr<absl_nonnull std::unique_ptr<PjRtComputationClient>>
170174
Create();
@@ -197,6 +201,7 @@ class PjRtComputationClient : public ComputationClient {
197201
// If not nullptr, invoke this instead of the actual XLA compilation. Used
198202
// only for testing.
199203
std::function<absl::Status()> fake_xla_compile_ = nullptr;
204+
std::unordered_map<std::string, std::string> custom_compile_options_;
200205

201206
xla::PjRtDevice* StringToPjRtDevice(const std::string& device);
202207

torch_xla/torch_xla.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def compile(
116116
full_graph: Optional[bool] = False,
117117
name: Optional[str] = None,
118118
max_different_graphs: Optional[int] = None,
119+
custom_compile_options: Optional[dict[str, Any]] = None,
119120
):
120121
"""
121122
Optimizes given model/function using torch_xla's LazyTensor tracing mode.
@@ -136,6 +137,11 @@ def compile(
136137
max_different_graphs (Optional[int]): number of different traced graphs of the given
137138
model/function that we are allowed to have. An error will be raised in case this limit
138139
is exceeded.
140+
custom_compile_options (Optional[dict[str, Any]]): XLA compiler flag overrides.
141+
Keys are XLA compiler flag names (forwarded to xla::CompileOptions.env_option_overrides),
142+
and values may be bool, int, float, or str (internally stringified).
143+
- {} (empty dict): clear previously set options.
144+
- None (default): do not change previously set options (no-op).
139145
140146
Example::
141147
@@ -215,6 +221,8 @@ def _compile():
215221
torch_xla._XLAC._set_use_eager_mode(saved_eager_mode_status)
216222
torch_xla._XLAC._set_current_graph_name(saved_current_graph_name)
217223

224+
if custom_compile_options is not None:
225+
torch_xla._XLAC._set_custom_compile_options(custom_compile_options)
218226
return _compile() if f is None else _compile()(f)
219227

220228

@@ -264,3 +272,14 @@ def launch(
264272
fn(xu.getenv_as(xenv.LOCAL_RANK, int), *args)
265273
else:
266274
xmp.spawn(fn, args=args, nprocs=nprocs, start_method=start_method)
275+
276+
277+
def set_custom_compile_options(options: dict[str, Any]) -> None:
278+
"""Set XLA **compiler flag overrides** (env option overrides) for compilation.
279+
280+
Args:
281+
options: Dict mapping XLA flag names to values. Values may be bool/float/int/str;
282+
they will be stringified before being passed to XLA.
283+
Pass an empty dict `{}` to clear previously set options.
284+
"""
285+
torch_xla._XLAC._set_custom_compile_options(options)

0 commit comments

Comments
 (0)