Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions shardy/dialect/mpmd/ir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,11 @@ inline constexpr StringRef kIsSdyPartitioned = "mpmd.is_sdy_partitioned";
inline constexpr StringRef kIsGspmdPartitioned = "mpmd.is_gspmd_partitioned";

// The suffix of the mesh name for a CPU mesh.
// LINT.IfChange
constexpr StringRef kCpuMeshSuffix = "/cpu";
// LINT.ThenChange(
// https://github.com/openxla/shardy/blob/main/shardy/integrations/python/jax/mpmd/types.py
// )

// Memory kind attributes.
// Attr on func args and results to indicate whether the value lives on host or
Expand Down
46 changes: 27 additions & 19 deletions shardy/integrations/python/jax/mpmd/jaxlib/mpmd_program.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,29 +171,35 @@ PartitioningResult MpmdProgram::ApplyPartitioning(PartitioningPhase phases) {

func::FuncOp main_func = GetMainFunction(module);
SetTopology(named_meshes, main_func);
SetArgDonationAttributes(main_func, donate_argnums);

// It is not necessary to do this
// validation after the export pipeline because here we're only checking that
// the attributes set on the main func are consistent with the received donate
// args.
VerifyOnlyDonatedArgsHaveDonationAttributes(main_func, donate_argnums);
if (phases & PartitioningPhase::kImport) {
SetArgDonationAttributes(main_func, donate_argnums);

SDY_LOG(INFO) << "Importing function named " << func_name
<< " for MPMD partitioning.";
// It is not necessary to do this validation after the export pipeline
// because here we're only checking that the attributes set on the main func
// are consistent with the received donate args.
VerifyOnlyDonatedArgsHaveDonationAttributes(main_func, donate_argnums);

Import(module);
SDY_LOG(INFO) << "Importing function named " << func_name
<< " for MPMD partitioning.";

SDY_LOG(INFO) << "Optimizing function named " << func_name
<< " for pipeline parallelism.";
Optimize(module);
Import(module);
}

if (phases & PartitioningPhase::kOptimize) {
SDY_LOG(INFO) << "Optimizing function named " << func_name
<< " for pipeline parallelism.";
Optimize(module);
}

SDY_LOG(INFO) << "Applying SDY propagation to function named " << func_name
<< ".";
PropagateSharding(module);
if (phases & PartitioningPhase::kPartition) {
SDY_LOG(INFO) << "Applying SDY propagation to function named " << func_name
<< ".";
PropagateSharding(module);

SDY_LOG(INFO) << "Exporting MPMD function named " << func_name << ".";
Export(module);
SDY_LOG(INFO) << "Exporting MPMD function named " << func_name << ".";
Export(module);
}

return PartitioningResult(module);
}
Expand All @@ -208,7 +214,8 @@ void MpmdProgram::Import(ModuleOp module) {
ConvertMeshVectorToMap(input_meshes)};
import_options.outputIndexToMeshAssignment = {
ConvertMeshVectorToMap(output_meshes)};
import_options.mergeAfterScheduling = options.mpmd_merge_after_scheduling;
import_options.mergeAfterScheduling =
options.mpmd_merge_inferred_after_scheduling;
import_options.absorbInferredFragmentsOnEntryPointFunction =
options.mpmd_absorb_inferred_fragments_on_entry_point_function;
import_options.cloneInferredFragments =
Expand All @@ -229,7 +236,8 @@ void MpmdProgram::Optimize(ModuleOp module) {

OptimizeOptions optimize_options;
optimize_options.fragmentMergeRules = llvm::to_vector(fragment_merge_rules);
optimize_options.mergeAfterScheduling = options.mpmd_merge_after_scheduling;
optimize_options.mergeAfterScheduling =
options.mpmd_merge_inferred_after_scheduling;
optimize_options.applyFragmentRemat = options.mpmd_fragment_remat;
optimize_options.mergeRematFragments = options.mpmd_merge_remat_fragments;
optimize_options.absorbInferredFragmentsOnEntryPointFunction =
Expand Down
3 changes: 2 additions & 1 deletion shardy/integrations/python/jax/mpmd/jaxlib/mpmd_program.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ struct PartitioningOptions {
bool mpmd_absorb_inferred_fragments_on_entry_point_function = false;
bool mpmd_copy_constant_creation_from_producer_to_consumer = false;
bool mpmd_apply_merge_transfers_pass = false;
bool mpmd_merge_after_scheduling = false;
bool mpmd_merge_inferred_after_scheduling = false;
};

PartitioningOptions ParsePartitioningOptions(
Expand All @@ -122,6 +122,7 @@ struct MpmdProgram {
const std::vector<std::optional<std::string>>& output_meshes;
const std::vector<int64_t>& donate_argnums;
const mlir::mpmd::FragmentMergeRules& fragment_merge_rules;
const mlir::mpmd::FragmentScheduleRules& fragment_schedule_rules;

// Runs the PartIR MPMD partitioning passes on the MPMD program.
//
Expand Down
125 changes: 106 additions & 19 deletions shardy/integrations/python/jax/mpmd/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import typing_extensions

from shardy.integrations.python.jax.mpmd import ops
from shardy.integrations.python.jax.mpmd import pipeline
from shardy.integrations.python.jax.mpmd import stages
from shardy.integrations.python.jax.mpmd import types
from shardy.integrations.python.jax.mpmd import utils
Expand All @@ -38,7 +39,7 @@

@dataclasses.dataclass(frozen=True)
class _MpmdPartitioningArgs:
"""Arguments for mpmd_py.apply_mpmd_partitioning.
"""Arguments for jaxlib_mpmd.apply_mpmd_partitioning.

This is essentially a processed version of a MpmdConfig dataclass, but in a
format that is more convenient for the C++ function. Note that users should
Expand Down Expand Up @@ -72,8 +73,15 @@ class _MpmdPartitioningArgs:
tpu_topology_args_proto: See `types.MpmdConfig.tpu_info`. This is required
for TPUs when using GSPMD partitioning.
partitioning_options: See `types.MpmdConfig.partitioning_options`.
fragment_merge_rules: See `types.MpmdConfig.fragment_merge_rules`.
fragment_schedule_rules: See `types.MpmdConfig.fragment_schedule_rules`.
fragment_merge_rules: A sequence of fragment merge rules. Each merge rule
contains a sequence of fragment metadata objects that should be merged
into a single fragment, together with metadata for the resulting fragment.
These rules are generated from the `pipeline_schedule` in the
`MpmdConfig`.
fragment_schedule_rules: A sequence of fragment schedule rules. Each
schedule rule contains a sequence of fragment metadata objects in the
order that they should be scheduled. These rules are generated from the
`pipeline_schedule` in the `MpmdConfig`.
"""

func_name: str
Expand All @@ -83,8 +91,12 @@ class _MpmdPartitioningArgs:
output_meshes: Sequence[str | None]
donate_argnums: Sequence[int]
partitioning_options: types.PartitioningOptions | None
fragment_merge_rules: Sequence[jaxlib_mpmd.FragmentMergeRule]
fragment_schedule_rules: Sequence[jaxlib_mpmd.FragmentScheduleRule]
fragment_merge_rules: Sequence[jaxlib_mpmd.FragmentMergeRule] = (
dataclasses.field(default_factory=list)
)
fragment_schedule_rules: Sequence[jaxlib_mpmd.FragmentScheduleRule] = (
dataclasses.field(default_factory=list)
)


@dataclasses.dataclass(frozen=True)
Expand All @@ -99,6 +111,14 @@ class MpmdLoweredArgs:
flat_input_mesh_assignment: Sequence[str] | None = None


def _get_fragment_info(mlir_module: mlir.ir.Module) -> list[types.FragmentInfo]:
"""Returns the fragment info for the given MLIR module."""
return [
types.convert_pybind_fragment_info_to_types(info)
for info in jaxlib_mpmd.get_fragment_info(mlir_module)
]


def _apply_partitioning(
mlir_module: mlir.ir.Module,
partitioning_args: _MpmdPartitioningArgs,
Expand All @@ -115,8 +135,7 @@ def _apply_partitioning(
donate_argnums=partitioning_args.donate_argnums,
partitioning_options=partitioning_args.partitioning_options,
fragment_merge_rules=partitioning_args.fragment_merge_rules,
# TODO: b/424385447 - Reenable fragment_schedule_rules once
# we update jaxlib.
fragment_schedule_rules=partitioning_args.fragment_schedule_rules,
phases=phases,
)

Expand Down Expand Up @@ -281,11 +300,6 @@ def _shaped_abstractify(x):
if arg_info.donated
]

assert not self._mpmd_config.fragment_merge_rules
fragment_merge_rules = []
assert not self._mpmd_config.fragment_schedule_rules
fragment_schedule_rules = []

partitioning_args = _MpmdPartitioningArgs(
func_name=func_name,
named_meshes=topology_shape,
Expand All @@ -294,8 +308,8 @@ def _shaped_abstractify(x):
output_meshes=flat_output_mesh_assignment,
donate_argnums=donate_argnums,
partitioning_options=self._mpmd_config.partitioning_options,
fragment_merge_rules=fragment_merge_rules,
fragment_schedule_rules=fragment_schedule_rules,
# Rules will be generated in _import_and_generate_rules, as long as a
# PipelineSchedule has been passed into MpmdConfig
)
lowered_args = MpmdLoweredArgs(
stablehlo_mlir_module=stablehlo_mlir_module,
Expand All @@ -307,6 +321,72 @@ def _shaped_abstractify(x):
)
return mlir_module, partitioning_args, lowered_args

def _import_and_generate_rules(
self,
mlir_module: mlir.ir.Module,
partitioning_args: _MpmdPartitioningArgs,
) -> tuple[jaxlib_mpmd.PartitioningResult, _MpmdPartitioningArgs]:
if self._mpmd_config.pipeline_schedule is None:
raise ValueError('Pipeline schedule is not defined')

# Validate and merge partitioning options with options required by the
# pipeline schedule
validated_options = types.validate_and_merge_partitioning_options(
pipeline_required_options=self._mpmd_config.pipeline_schedule.required_mpmd_options,
user_provided_options=partitioning_args.partitioning_options,
)
partitioning_args_with_pipeline_options = dataclasses.replace(
partitioning_args,
partitioning_options=validated_options,
)

imported_result = _apply_partitioning(
mlir_module,
partitioning_args_with_pipeline_options,
jaxlib_mpmd.PartitioningPhase.IMPORT,
)
context = types.PipelineContext(
num_meshes=len(types.get_schedulable_meshes(self._mpmd_config.topology))
)
schedule_rules, merge_rules = pipeline.build_rules_from_pipeline(
_get_fragment_info(imported_result.mpmd_module),
self._mpmd_config.pipeline_schedule,
context,
)

types.validate_fragment_schedule_rules(schedule_rules)
types.validate_fragment_merge_rules(merge_rules)
# Populate the partitioning args with the generated rules
partitioning_args_with_rules = dataclasses.replace(
partitioning_args_with_pipeline_options,
fragment_schedule_rules=types.convert_fragment_schedule_rules_to_pybind(
schedule_rules
),
fragment_merge_rules=types.convert_fragment_merge_rules_to_pybind(
merge_rules
),
)

return imported_result.mpmd_module, partitioning_args_with_rules

def _partition_with_pipeline_schedule(
self,
mlir_module: mlir.ir.Module,
partitioning_args: _MpmdPartitioningArgs,
) -> jaxlib_mpmd.PartitioningResult:

imported_module, partitioning_args_with_rules = (
self._import_and_generate_rules(mlir_module, partitioning_args)
)
partitioning_result = _apply_partitioning(
imported_module,
partitioning_args_with_rules,
jaxlib_mpmd.PartitioningPhase.OPTIMIZE
| jaxlib_mpmd.PartitioningPhase.PARTITION,
)

return partitioning_result

@typing_extensions.override
def lower(
self,
Expand All @@ -327,9 +407,14 @@ def lower(
self._prepare_partitioning_args(_private_parameters)
)

partitioning_result = _apply_partitioning(
mlir_module, partitioning_args, jaxlib_mpmd.PartitioningPhase.ALL
)
if self._mpmd_config.pipeline_schedule:
partitioning_result = self._partition_with_pipeline_schedule(
mlir_module, partitioning_args
)
else:
partitioning_result = _apply_partitioning(
mlir_module, partitioning_args, jaxlib_mpmd.PartitioningPhase.ALL
)
ifrt_ir_module = jaxlib_mpmd.clone_mlir_module(
partitioning_result.mpmd_module
)
Expand Down Expand Up @@ -514,9 +599,11 @@ def __init__(
"""Initializes an MpmdGspmdWrapped object."""

if override_func_name:

@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)

wrapper.__name__ = override_func_name
self.func = wrapper
else:
Expand Down Expand Up @@ -627,8 +714,8 @@ def jit(
out_shardings: See `jax.jit`.
donate_argnums: See `jax.jit`.
keep_unused: See `jax.jit`.
override_func_name: If provided, the function name will be overridden to
the provided value.
override_func_name: If provided, the function name will be overridden to the
provided value.

Returns:
An MpmdGspmdWrapped object.
Expand Down
Loading
Loading