Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 220 additions & 0 deletions keras/src/distribution/tensor_parallel/autoconfig.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
from typing import Sequence

from keras.src.distribution.tensor_parallel.config import ConfigKeras
from keras.src.distribution.tensor_parallel.state_action_keras import SplitKeras


def analyze_dense_layer_directly(layer, module, prefix: str) -> str:
from keras.src import layers

"""Analyzes a Dense layer to classify it for tensor parallelism sharding.

This function inspects the layer's weight shapes to determine if it's an
"up-projection" (expanding feature dimensions), a "down-projection"
(contracting feature dimensions), or a generic layer. This classification
helps in deciding whether to apply column-wise or row-wise parallelism.

Args:
layer: The keras.layers.Dense instance to analyze.
module: The parent Keras model containing the layer.
prefix: The hierarchical name prefix for the layer.

Returns:
A string indicating the layer's classification: 'up_projection',
'down_projection', or 'generic_dense'.
"""
if not isinstance(layer, layers.Dense):
return "generic_dense"

input_dim = None
output_dim = None

if hasattr(layer, "kernel"):
kernel_shape = layer.kernel.shape
if len(kernel_shape) == 2:
input_dim = kernel_shape[0]
output_dim = kernel_shape[1]
else:
if hasattr(layer, "units"):
output_dim = layer.units

if (
hasattr(layer, "input_shape")
and layer.input_shape
and len(layer.input_shape) > 1
):
input_dim = layer.input_shape[-1]

if not input_dim or not output_dim:
return "generic_dense"

expansion_threshold = 1.5
is_expansion = output_dim > input_dim * expansion_threshold
is_contraction = input_dim > output_dim * expansion_threshold

if is_expansion:
return "up_projection"
elif is_contraction:
return "down_projection"
else:
return "generic_dense"


def _traverse_and_shard_layer(
current_layer,
module,
world_size: int,
state_rules: dict,
output_rules: dict,
processed_layers: set,
prefix: str = "",
):
from keras.src import layers

"""Traverses a layer and its sub-layers to apply sharding rules.

This function navigates through the model's layer hierarchy. For each
layer, it identifies its type and applies appropriate sharding logic,
populating the `state_rules` and `output_rules` dictionaries.

Args:
current_layer: The current keras.Layer object to be processed.
module: The top-level Keras Model, used for context analysis.
world_size: The total number of devices for sharding.
state_rules: The dictionary of state sharding rules to populate.
output_rules: The dictionary of output sharding rules to populate.
processed_layers: A set of layer IDs that have already been processed
to avoid redundant computation and infinite loops.
prefix: The hierarchical name prefix from parent layers, used to
construct the full unique name for the current layer.
"""
if id(current_layer) in processed_layers:
return
processed_layers.add(id(current_layer))
Comment on lines +103 to +105
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per my comment below about not needing a recursion, this is not needed


name = current_layer.name
full_name = f"{prefix}.{name}" if prefix else name
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because you will never really recurse, the prefix won't work.


if isinstance(current_layer, layers.Dense):
mlp_type = analyze_dense_layer_directly(
current_layer, module, full_name
)

if mlp_type == "down_projection":
state_rules[f"^{full_name}.kernel$"] = SplitKeras(
world_size, 0, "row"
)
output_rules[f"^{full_name}$"] = {0: "allreduce"}

else:
state_rules[f"^{full_name}.kernel$"] = SplitKeras(
world_size, 1, "column"
)
if current_layer.use_bias:
state_rules[f"^{full_name}.bias$"] = SplitKeras(
world_size, 0, "column"
)
output_rules[f"^{full_name}$"] = {0: "no_comm"}
return

elif isinstance(current_layer, layers.EinsumDense):
is_row_parallel = False
if "->" in current_layer.equation:
equation_parts = current_layer.equation.split("->")
if len(equation_parts) == 2:
input_spec = equation_parts[0].split(",")[0].strip()
output_spec = equation_parts[1].strip()
if (
input_spec
and output_spec
and len(output_spec) < len(input_spec)
):
is_row_parallel = True

if is_row_parallel:
state_rules[f"^{full_name}.kernel$"] = SplitKeras(
world_size, 0, "row"
)
output_rules[f"^{full_name}$"] = {0: "allreduce"}
else:
state_rules[f"^{full_name}.kernel$"] = SplitKeras(
world_size, 1, "column"
)
if (
hasattr(current_layer, "bias")
and current_layer.bias is not None
):
state_rules[f"^{full_name}.bias$"] = SplitKeras(
world_size, 0, "column"
)
output_rules[f"^{full_name}$"] = {0: "no_comm"}
return

elif isinstance(current_layer, layers.Embedding):
weight_name = (
"embeddings" if hasattr(current_layer, "embeddings") else None
)
if weight_name:
state_rules[f"^{full_name}\.{weight_name}$"] = SplitKeras(
world_size, 1, "column"
)
output_rules[f"^{full_name}$"] = {0: "no_comm"}
return

elif isinstance(
current_layer,
Comment on lines +177 to +178
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about other layer types?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function is set up to only worry about the biggest layers in the model (Dense, Embedding, etc.). These are the only ones big enough to cause memory problems and need splitting (sharding).

We skip the smaller layers for a few reasons:

Normalization Layers (like LayerNormalization): Their weights are small. We leave them alone so we don't slow things down with extra communication.

Layers with No Weights (like Dropout, Activation): They don't have anything to split. They just use the sharded data that comes from the layer before them.

(
layers.LayerNormalization,
layers.BatchNormalization,
layers.GroupNormalization,
),
):
return
else:
if hasattr(current_layer, "layers"):
for sub_layer in current_layer.layers:
_traverse_and_shard_layer(
sub_layer,
module,
world_size,
state_rules,
output_rules,
processed_layers,
full_name,
)


def get_default_config_keras(module, device_ids: Sequence[str]) -> ConfigKeras:
"""Generates a smart, recursive sharding configuration for a Keras model.

This function traverses the layers of a given Keras model and applies a
set of heuristics to automatically determine how each layer's weights
and outputs should be sharded for tensor parallelism. It uses a helper
function to perform the recursive traversal.

Args:
module: The Keras Model to generate a sharding configuration for.
device_ids: A sequence of device identifiers, used to determine the
world size (number of devices) for sharding.

Returns:
A ConfigKeras object containing the generated 'state_rules' (for model
parameters) and 'output_rules' (for layer outputs).
"""
world_size = len(device_ids)
state_rules = {}
output_rules = {}
processed_layers = set()

for layer in module.layers:
_traverse_and_shard_layer(
current_layer=layer,
module=module,
world_size=world_size,
state_rules=state_rules,
output_rules=output_rules,
processed_layers=processed_layers,
prefix="",
)

return ConfigKeras(state_rules=state_rules, output_rules=output_rules)
151 changes: 151 additions & 0 deletions keras/src/distribution/tensor_parallel/autoconfig_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import os

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2"

from keras import Input
from keras import Model
from keras import layers
from keras.src import testing
from keras.src.backend.distributed import backend_resolver
from keras.src.distribution.tensor_parallel.autoconfig import (
analyze_dense_layer_directly,
)
from keras.src.distribution.tensor_parallel.autoconfig import (
get_default_config_keras,
)
from keras.src.distribution.tensor_parallel.state_action_keras import SplitKeras


class TestAutoConfigKeras(testing.TestCase):
def setUp(self):
"""Set up the test case and common variables."""
super().setUp()
backend = backend_resolver.get_distributed_backend()
device_info = backend.get_device_info()
self.world_size = device_info["device_count"]
self.device_ids = [f"device:{i}" for i in range(self.world_size)]

self.assertGreater(
self.world_size, 1, "Distribution tests require more than 1 device."
)

def _assert_split_keras_equal(self, rule1, rule2):
"""
Helper to compare two SplitKeras objects by their attributes.
"""
self.assertIsInstance(rule1, SplitKeras)
self.assertIsInstance(rule2, SplitKeras)
self.assertDictEqual(vars(rule1), vars(rule2))

def _assert_rules_equal(self, actual_rules, expected_rules):
"""Helper to compare two dictionaries of sharding rules."""
self.assertSetEqual(
set(actual_rules.keys()), set(expected_rules.keys())
)
for key in expected_rules:
actual_val = actual_rules[key]
expected_val = expected_rules[key]
if isinstance(expected_val, SplitKeras):
self._assert_split_keras_equal(actual_val, expected_val)
else:
self.assertEqual(actual_val, expected_val)

def test_analyze_dense_layer(self):
"""Tests the direct analysis and classification of Dense layers."""
up_proj_layer = layers.Dense(32)
up_proj_layer.build(input_shape=(None, 16))
self.assertEqual(
analyze_dense_layer_directly(up_proj_layer, None, ""),
"up_projection",
)

down_proj_layer = layers.Dense(16)
down_proj_layer.build(input_shape=(None, 32))
self.assertEqual(
analyze_dense_layer_directly(down_proj_layer, None, ""),
"down_projection",
)

def test_simple_mlp_sharding(self):
"""Tests a simple MLP with up and down projection layers."""
inputs = Input(shape=(64,))
x = layers.Dense(256, name="up_projection_layer", use_bias=True)(inputs)
outputs = layers.Dense(
64, name="down_projection_layer", use_bias=False
)(x)
model = Model(inputs=inputs, outputs=outputs, name="simple_mlp")

config = get_default_config_keras(model, self.device_ids)

expected_state_rules = {
r"^up_projection_layer.kernel$": SplitKeras(
self.world_size, 1, "column"
),
r"^up_projection_layer.bias$": SplitKeras(
self.world_size, 0, "column"
),
r"^down_projection_layer.kernel$": SplitKeras(
self.world_size, 0, "row"
),
}
expected_output_rules = {
r"^up_projection_layer$": {0: "no_comm"},
r"^down_projection_layer$": {0: "allreduce"},
}

self._assert_rules_equal(config.state_rules, expected_state_rules)
self._assert_rules_equal(config.output_rules, expected_output_rules)

def test_embedding_sharding(self):
"""Tests an Embedding layer."""
inputs = Input(shape=(10,), dtype="int32")
outputs = layers.Embedding(
input_dim=1000, output_dim=128, name="token_embedding"
)(inputs)
model = Model(inputs=inputs, outputs=outputs, name="embed_model")

config = get_default_config_keras(model, self.device_ids)

expected_state_rules = {
r"^token_embedding\.embeddings$": SplitKeras(
self.world_size, 1, "column"
)
}
expected_output_rules = {r"^token_embedding$": {0: "no_comm"}}

self._assert_rules_equal(config.state_rules, expected_state_rules)
self._assert_rules_equal(config.output_rules, expected_output_rules)

def test_nested_model_sharding(self):
"""Tests that the traversal logic correctly handles nested models."""
inner_inputs = Input(shape=(32,))
inner_outputs = layers.Dense(128, name="inner_dense")(inner_inputs)
inner_model = Model(
inputs=inner_inputs, outputs=inner_outputs, name="inner_block"
)

outer_inputs = Input(shape=(32,))
x = inner_model(outer_inputs)
outer_outputs = layers.Dense(32, name="outer_dense")(x)
outer_model = Model(
inputs=outer_inputs, outputs=outer_outputs, name="outer_model"
)

config = get_default_config_keras(outer_model, self.device_ids)

expected_state_rules = {
r"^inner_block.inner_dense.kernel$": SplitKeras(
self.world_size, 1, "column"
),
r"^inner_block.inner_dense.bias$": SplitKeras(
self.world_size, 0, "column"
),
r"^outer_dense.kernel$": SplitKeras(self.world_size, 0, "row"),
}
expected_output_rules = {
r"^inner_block.inner_dense$": {0: "no_comm"},
r"^outer_dense$": {0: "allreduce"},
}

self._assert_rules_equal(config.state_rules, expected_state_rules)
self._assert_rules_equal(config.output_rules, expected_output_rules)
Loading
Loading