-
Couldn't load subscription status.
- Fork 19.6k
Add Autoconfig, Coordinated_Optimizer and Sharding keras implementations for Tensor Parallel Autosharding #21707
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 6 commits
dd3181e
bcae2f6
439643b
36edcb9
b7862d9
e8b51f7
3383dec
5824c66
9cf5c7f
996a154
31994da
8124b08
3a4af33
ec0009a
50b9c85
2483ba0
c3be844
9fcc4e7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,220 @@ | ||
| from typing import Sequence | ||
|
|
||
| from keras.src.distribution.tensor_parallel.config import ConfigKeras | ||
| from keras.src.distribution.tensor_parallel.state_action_keras import SplitKeras | ||
|
|
||
|
|
||
| def analyze_dense_layer_directly(layer, module, prefix: str) -> str: | ||
| from keras.src import layers | ||
|
|
||
| """Analyzes a Dense layer to classify it for tensor parallelism sharding. | ||
|
|
||
| This function inspects the layer's weight shapes to determine if it's an | ||
| "up-projection" (expanding feature dimensions), a "down-projection" | ||
| (contracting feature dimensions), or a generic layer. This classification | ||
| helps in deciding whether to apply column-wise or row-wise parallelism. | ||
|
|
||
| Args: | ||
| layer: The keras.layers.Dense instance to analyze. | ||
| module: The parent Keras model containing the layer. | ||
| prefix: The hierarchical name prefix for the layer. | ||
|
|
||
| Returns: | ||
| A string indicating the layer's classification: 'up_projection', | ||
| 'down_projection', or 'generic_dense'. | ||
| """ | ||
| if not isinstance(layer, layers.Dense): | ||
| return "generic_dense" | ||
buildwithsuhana marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| input_dim = None | ||
| output_dim = None | ||
|
|
||
| if hasattr(layer, "kernel"): | ||
| kernel_shape = layer.kernel.shape | ||
| if len(kernel_shape) == 2: | ||
| input_dim = kernel_shape[0] | ||
| output_dim = kernel_shape[1] | ||
| else: | ||
| if hasattr(layer, "units"): | ||
| output_dim = layer.units | ||
|
|
||
| if ( | ||
| hasattr(layer, "input_shape") | ||
| and layer.input_shape | ||
| and len(layer.input_shape) > 1 | ||
| ): | ||
| input_dim = layer.input_shape[-1] | ||
|
|
||
| if not input_dim or not output_dim: | ||
| return "generic_dense" | ||
|
|
||
| expansion_threshold = 1.5 | ||
| is_expansion = output_dim > input_dim * expansion_threshold | ||
| is_contraction = input_dim > output_dim * expansion_threshold | ||
|
|
||
| if is_expansion: | ||
| return "up_projection" | ||
| elif is_contraction: | ||
| return "down_projection" | ||
| else: | ||
| return "generic_dense" | ||
|
|
||
|
|
||
| def _traverse_and_shard_layer( | ||
| current_layer, | ||
| module, | ||
| world_size: int, | ||
| state_rules: dict, | ||
| output_rules: dict, | ||
| processed_layers: set, | ||
| prefix: str = "", | ||
| ): | ||
| from keras.src import layers | ||
|
|
||
| """Traverses a layer and its sub-layers to apply sharding rules. | ||
|
|
||
| This function navigates through the model's layer hierarchy. For each | ||
| layer, it identifies its type and applies appropriate sharding logic, | ||
| populating the `state_rules` and `output_rules` dictionaries. | ||
|
|
||
| Args: | ||
| current_layer: The current keras.Layer object to be processed. | ||
| module: The top-level Keras Model, used for context analysis. | ||
| world_size: The total number of devices for sharding. | ||
| state_rules: The dictionary of state sharding rules to populate. | ||
| output_rules: The dictionary of output sharding rules to populate. | ||
| processed_layers: A set of layer IDs that have already been processed | ||
| to avoid redundant computation and infinite loops. | ||
| prefix: The hierarchical name prefix from parent layers, used to | ||
| construct the full unique name for the current layer. | ||
| """ | ||
| if id(current_layer) in processed_layers: | ||
| return | ||
| processed_layers.add(id(current_layer)) | ||
|
Comment on lines
+103
to
+105
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Per my comment below about not needing a recursion, this is not needed |
||
|
|
||
| name = current_layer.name | ||
| full_name = f"{prefix}.{name}" if prefix else name | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because you will never really recurse, the prefix won't work. |
||
|
|
||
| if isinstance(current_layer, layers.Dense): | ||
| mlp_type = analyze_dense_layer_directly( | ||
| current_layer, module, full_name | ||
| ) | ||
|
|
||
| if mlp_type == "down_projection": | ||
| state_rules[f"^{full_name}.kernel$"] = SplitKeras( | ||
| world_size, 0, "row" | ||
| ) | ||
| output_rules[f"^{full_name}$"] = {0: "allreduce"} | ||
|
|
||
| else: | ||
| state_rules[f"^{full_name}.kernel$"] = SplitKeras( | ||
| world_size, 1, "column" | ||
| ) | ||
| if current_layer.use_bias: | ||
| state_rules[f"^{full_name}.bias$"] = SplitKeras( | ||
| world_size, 0, "column" | ||
| ) | ||
| output_rules[f"^{full_name}$"] = {0: "no_comm"} | ||
| return | ||
|
|
||
| elif isinstance(current_layer, layers.EinsumDense): | ||
| is_row_parallel = False | ||
| if "->" in current_layer.equation: | ||
| equation_parts = current_layer.equation.split("->") | ||
| if len(equation_parts) == 2: | ||
| input_spec = equation_parts[0].split(",")[0].strip() | ||
| output_spec = equation_parts[1].strip() | ||
| if ( | ||
| input_spec | ||
| and output_spec | ||
| and len(output_spec) < len(input_spec) | ||
| ): | ||
| is_row_parallel = True | ||
|
|
||
| if is_row_parallel: | ||
| state_rules[f"^{full_name}.kernel$"] = SplitKeras( | ||
| world_size, 0, "row" | ||
| ) | ||
| output_rules[f"^{full_name}$"] = {0: "allreduce"} | ||
| else: | ||
| state_rules[f"^{full_name}.kernel$"] = SplitKeras( | ||
| world_size, 1, "column" | ||
| ) | ||
| if ( | ||
| hasattr(current_layer, "bias") | ||
| and current_layer.bias is not None | ||
| ): | ||
| state_rules[f"^{full_name}.bias$"] = SplitKeras( | ||
| world_size, 0, "column" | ||
| ) | ||
| output_rules[f"^{full_name}$"] = {0: "no_comm"} | ||
| return | ||
|
|
||
| elif isinstance(current_layer, layers.Embedding): | ||
| weight_name = ( | ||
| "embeddings" if hasattr(current_layer, "embeddings") else None | ||
| ) | ||
| if weight_name: | ||
| state_rules[f"^{full_name}\.{weight_name}$"] = SplitKeras( | ||
| world_size, 1, "column" | ||
| ) | ||
| output_rules[f"^{full_name}$"] = {0: "no_comm"} | ||
| return | ||
|
|
||
| elif isinstance( | ||
| current_layer, | ||
|
Comment on lines
+177
to
+178
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about other layer types? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The function is set up to only worry about the biggest layers in the model (Dense, Embedding, etc.). These are the only ones big enough to cause memory problems and need splitting (sharding). We skip the smaller layers for a few reasons: Normalization Layers (like LayerNormalization): Their weights are small. We leave them alone so we don't slow things down with extra communication. Layers with No Weights (like Dropout, Activation): They don't have anything to split. They just use the sharded data that comes from the layer before them. |
||
| ( | ||
| layers.LayerNormalization, | ||
| layers.BatchNormalization, | ||
| layers.GroupNormalization, | ||
| ), | ||
| ): | ||
| return | ||
| else: | ||
| if hasattr(current_layer, "layers"): | ||
| for sub_layer in current_layer.layers: | ||
| _traverse_and_shard_layer( | ||
| sub_layer, | ||
| module, | ||
| world_size, | ||
| state_rules, | ||
| output_rules, | ||
| processed_layers, | ||
| full_name, | ||
| ) | ||
|
|
||
|
|
||
| def get_default_config_keras(module, device_ids: Sequence[str]) -> ConfigKeras: | ||
| """Generates a smart, recursive sharding configuration for a Keras model. | ||
|
|
||
| This function traverses the layers of a given Keras model and applies a | ||
| set of heuristics to automatically determine how each layer's weights | ||
| and outputs should be sharded for tensor parallelism. It uses a helper | ||
| function to perform the recursive traversal. | ||
|
|
||
| Args: | ||
| module: The Keras Model to generate a sharding configuration for. | ||
| device_ids: A sequence of device identifiers, used to determine the | ||
| world size (number of devices) for sharding. | ||
|
|
||
| Returns: | ||
| A ConfigKeras object containing the generated 'state_rules' (for model | ||
| parameters) and 'output_rules' (for layer outputs). | ||
| """ | ||
| world_size = len(device_ids) | ||
buildwithsuhana marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| state_rules = {} | ||
| output_rules = {} | ||
| processed_layers = set() | ||
|
|
||
| for layer in module.layers: | ||
| _traverse_and_shard_layer( | ||
| current_layer=layer, | ||
| module=module, | ||
| world_size=world_size, | ||
| state_rules=state_rules, | ||
| output_rules=output_rules, | ||
| processed_layers=processed_layers, | ||
| prefix="", | ||
| ) | ||
|
|
||
| return ConfigKeras(state_rules=state_rules, output_rules=output_rules) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,151 @@ | ||
| import os | ||
|
|
||
| os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2" | ||
|
|
||
| from keras import Input | ||
| from keras import Model | ||
| from keras import layers | ||
| from keras.src import testing | ||
| from keras.src.backend.distributed import backend_resolver | ||
| from keras.src.distribution.tensor_parallel.autoconfig import ( | ||
| analyze_dense_layer_directly, | ||
| ) | ||
| from keras.src.distribution.tensor_parallel.autoconfig import ( | ||
| get_default_config_keras, | ||
| ) | ||
| from keras.src.distribution.tensor_parallel.state_action_keras import SplitKeras | ||
|
|
||
|
|
||
| class TestAutoConfigKeras(testing.TestCase): | ||
| def setUp(self): | ||
| """Set up the test case and common variables.""" | ||
| super().setUp() | ||
| backend = backend_resolver.get_distributed_backend() | ||
| device_info = backend.get_device_info() | ||
| self.world_size = device_info["device_count"] | ||
| self.device_ids = [f"device:{i}" for i in range(self.world_size)] | ||
|
|
||
| self.assertGreater( | ||
| self.world_size, 1, "Distribution tests require more than 1 device." | ||
| ) | ||
|
|
||
| def _assert_split_keras_equal(self, rule1, rule2): | ||
| """ | ||
| Helper to compare two SplitKeras objects by their attributes. | ||
| """ | ||
| self.assertIsInstance(rule1, SplitKeras) | ||
| self.assertIsInstance(rule2, SplitKeras) | ||
| self.assertDictEqual(vars(rule1), vars(rule2)) | ||
|
|
||
| def _assert_rules_equal(self, actual_rules, expected_rules): | ||
| """Helper to compare two dictionaries of sharding rules.""" | ||
| self.assertSetEqual( | ||
| set(actual_rules.keys()), set(expected_rules.keys()) | ||
| ) | ||
| for key in expected_rules: | ||
| actual_val = actual_rules[key] | ||
| expected_val = expected_rules[key] | ||
| if isinstance(expected_val, SplitKeras): | ||
| self._assert_split_keras_equal(actual_val, expected_val) | ||
| else: | ||
| self.assertEqual(actual_val, expected_val) | ||
|
|
||
| def test_analyze_dense_layer(self): | ||
| """Tests the direct analysis and classification of Dense layers.""" | ||
| up_proj_layer = layers.Dense(32) | ||
| up_proj_layer.build(input_shape=(None, 16)) | ||
| self.assertEqual( | ||
| analyze_dense_layer_directly(up_proj_layer, None, ""), | ||
| "up_projection", | ||
| ) | ||
|
|
||
| down_proj_layer = layers.Dense(16) | ||
| down_proj_layer.build(input_shape=(None, 32)) | ||
| self.assertEqual( | ||
| analyze_dense_layer_directly(down_proj_layer, None, ""), | ||
| "down_projection", | ||
| ) | ||
|
|
||
| def test_simple_mlp_sharding(self): | ||
| """Tests a simple MLP with up and down projection layers.""" | ||
| inputs = Input(shape=(64,)) | ||
| x = layers.Dense(256, name="up_projection_layer", use_bias=True)(inputs) | ||
| outputs = layers.Dense( | ||
| 64, name="down_projection_layer", use_bias=False | ||
| )(x) | ||
| model = Model(inputs=inputs, outputs=outputs, name="simple_mlp") | ||
|
|
||
| config = get_default_config_keras(model, self.device_ids) | ||
|
|
||
| expected_state_rules = { | ||
| r"^up_projection_layer.kernel$": SplitKeras( | ||
| self.world_size, 1, "column" | ||
| ), | ||
| r"^up_projection_layer.bias$": SplitKeras( | ||
| self.world_size, 0, "column" | ||
| ), | ||
| r"^down_projection_layer.kernel$": SplitKeras( | ||
| self.world_size, 0, "row" | ||
| ), | ||
| } | ||
| expected_output_rules = { | ||
| r"^up_projection_layer$": {0: "no_comm"}, | ||
| r"^down_projection_layer$": {0: "allreduce"}, | ||
| } | ||
|
|
||
| self._assert_rules_equal(config.state_rules, expected_state_rules) | ||
| self._assert_rules_equal(config.output_rules, expected_output_rules) | ||
|
|
||
| def test_embedding_sharding(self): | ||
| """Tests an Embedding layer.""" | ||
| inputs = Input(shape=(10,), dtype="int32") | ||
| outputs = layers.Embedding( | ||
| input_dim=1000, output_dim=128, name="token_embedding" | ||
| )(inputs) | ||
| model = Model(inputs=inputs, outputs=outputs, name="embed_model") | ||
|
|
||
| config = get_default_config_keras(model, self.device_ids) | ||
|
|
||
| expected_state_rules = { | ||
| r"^token_embedding\.embeddings$": SplitKeras( | ||
| self.world_size, 1, "column" | ||
| ) | ||
| } | ||
| expected_output_rules = {r"^token_embedding$": {0: "no_comm"}} | ||
|
|
||
| self._assert_rules_equal(config.state_rules, expected_state_rules) | ||
| self._assert_rules_equal(config.output_rules, expected_output_rules) | ||
|
|
||
| def test_nested_model_sharding(self): | ||
| """Tests that the traversal logic correctly handles nested models.""" | ||
| inner_inputs = Input(shape=(32,)) | ||
| inner_outputs = layers.Dense(128, name="inner_dense")(inner_inputs) | ||
| inner_model = Model( | ||
| inputs=inner_inputs, outputs=inner_outputs, name="inner_block" | ||
| ) | ||
|
|
||
| outer_inputs = Input(shape=(32,)) | ||
| x = inner_model(outer_inputs) | ||
| outer_outputs = layers.Dense(32, name="outer_dense")(x) | ||
| outer_model = Model( | ||
| inputs=outer_inputs, outputs=outer_outputs, name="outer_model" | ||
| ) | ||
|
|
||
| config = get_default_config_keras(outer_model, self.device_ids) | ||
|
|
||
| expected_state_rules = { | ||
| r"^inner_block.inner_dense.kernel$": SplitKeras( | ||
| self.world_size, 1, "column" | ||
| ), | ||
| r"^inner_block.inner_dense.bias$": SplitKeras( | ||
| self.world_size, 0, "column" | ||
| ), | ||
| r"^outer_dense.kernel$": SplitKeras(self.world_size, 0, "row"), | ||
| } | ||
| expected_output_rules = { | ||
| r"^inner_block.inner_dense$": {0: "no_comm"}, | ||
| r"^outer_dense$": {0: "allreduce"}, | ||
| } | ||
|
|
||
| self._assert_rules_equal(config.state_rules, expected_state_rules) | ||
| self._assert_rules_equal(config.output_rules, expected_output_rules) |
Uh oh!
There was an error while loading. Please reload this page.