From eefb4b3cea8a3d944631c9ccf7976a440d4c3a4b Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Tue, 18 Feb 2025 14:16:17 -0500 Subject: [PATCH] Control trace cache warnings (#7039) Make trace cache warnings configurable, and disabled by default. Fix #6985, #4081, #5033, #5006, #5662 --------- Signed-off-by: Olatunji Ruwase --- deepspeed/runtime/engine.py | 5 +++ deepspeed/runtime/zero/config.py | 6 ++++ deepspeed/runtime/zero/parameter_offload.py | 6 +++- .../zero/partitioned_param_coordinator.py | 32 +++++++++++-------- deepspeed/runtime/zero/stage3.py | 9 ++++-- docs/_pages/config-json.md | 17 +++++++--- 6 files changed, 54 insertions(+), 21 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 986b68dc1bb15..8575df9d1d5d8 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -983,6 +983,9 @@ def zero_quantized_gradients(self): def zeropp_loco_param(self): return self._config.zero_config.zeropp_loco_param + def zero_log_trace_cache_warnings(self): + return self._config.zero_config.log_trace_cache_warnings + def dump_state(self): return self._config.dump_state @@ -1692,6 +1695,7 @@ def _configure_zero_optimizer(self, optimizer): zero_quantized_weights=self.zero_quantized_weights(), zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(), zero_module_granularity_threshold=self.zero_module_granularity_threshold(), + log_trace_cache_warnings=self.zero_log_trace_cache_warnings(), ) else: log_dist( @@ -1740,6 +1744,7 @@ def _configure_zero_optimizer(self, optimizer): zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(), zero_module_granularity_threshold=self.zero_module_granularity_threshold(), zeropp_loco_param=self.zeropp_loco_param(), + log_trace_cache_warnings=self.zero_log_trace_cache_warnings(), ) else: diff --git a/deepspeed/runtime/zero/config.py b/deepspeed/runtime/zero/config.py index cbc6a15c20576..19ee9b51702e7 100644 --- a/deepspeed/runtime/zero/config.py +++ b/deepspeed/runtime/zero/config.py @@ -45,6 +45,7 @@ "memory_efficient_linear": [true|false], "override_module_apply": [true|false], "zeropp_loco_param": {...}, + "log_trace_cache_warnings" : [true|false], } } """ @@ -340,6 +341,11 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel): Override nn.Module apply function, for Stage 3. """ + log_trace_cache_warnings: bool = False + """ + Whether to log warnings from trace cache, such as invalidation events. + """ + # Validators @model_validator(mode="after") def overlap_comm_valid(self): diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index b5a75fce32fe0..09de21502c272 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -103,6 +103,7 @@ def __init__( zero_quantized_weights=False, zero_quantized_nontrainable_weights=False, zero_module_granularity_threshold=0, + log_trace_cache_warnings=False, ): see_memory_usage("DeepSpeedZeRoOffload initialize [begin]", force=True) @@ -118,6 +119,7 @@ def __init__( self.zero_param_parallel_group = zero_param_parallel_group self.zero_quantized_weights = zero_quantized_weights self.zero_quantized_nontrainable_weights = zero_quantized_nontrainable_weights + self.log_trace_cache_warnings = log_trace_cache_warnings if offload_param_config is not None and offload_param_config.device != OffloadDeviceEnum.none: self.offload_device = offload_param_config.device @@ -165,7 +167,9 @@ def __init__( timers=self.timers, zero_quantized_weights=self.zero_quantized_weights, zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights, - fast_sharding_for_leaf_module=self.fast_sharding_for_leaf_module) + fast_sharding_for_leaf_module=self.fast_sharding_for_leaf_module, + log_trace_cache_warnings=self.log_trace_cache_warnings, + ) self.forward_hooks = [] self.backward_hooks = [] diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py index 415afd8d70268..d5b5db859e31d 100644 --- a/deepspeed/runtime/zero/partitioned_param_coordinator.py +++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py @@ -76,17 +76,20 @@ class __ParamInTrace: param: Parameter step_id_last_used_at: int - def __init__(self, - prefetch_bucket_sz: int, - max_reuse_distance_in_numel: int, - max_available_parameters_in_numel: int, - allgather_stream: get_accelerator().Stream, - inflight_param_registry: InflightParamRegistry, - prefetch_nvme: bool = False, - timers=None, - zero_quantized_weights=False, - zero_quantized_nontrainable_weights=False, - fast_sharding_for_leaf_module=False) -> None: + def __init__( + self, + prefetch_bucket_sz: int, + max_reuse_distance_in_numel: int, + max_available_parameters_in_numel: int, + allgather_stream: get_accelerator().Stream, + inflight_param_registry: InflightParamRegistry, + prefetch_nvme: bool = False, + timers=None, + zero_quantized_weights=False, + zero_quantized_nontrainable_weights=False, + fast_sharding_for_leaf_module=False, + log_trace_cache_warnings=False, + ) -> None: # mapping of param -> handle for each param that is currently in flight self.__inflight_param_registry = inflight_param_registry # keeps track of the number of submodules invoked so far. @@ -129,6 +132,9 @@ def __init__(self, self.__max_ongoing_fetch_events: int = 2 self.__profiler = PartitionedParameterProfiler(timers if ENABLE_PROFILER else None) + # Whether to log trace cache warnings, e.g. invalidation events + self.__log_trace_cache_warnings = log_trace_cache_warnings + # whether to enable fast fetch for the z3 leaf module. # this will improve fetch speed but will not break down leaf module parameters to alleviate memory pressure. self.fast_sharding_for_leaf_module = fast_sharding_for_leaf_module @@ -177,7 +183,7 @@ def trace_prologue(self, sub_module: Module) -> None: print_rank_0( f"Invalidate trace cache @ step {self.__step_id} and module {sub_module.ds_id}: " f"cache has only {len(self.__submodule_order)} modules", - force=True) + force=self.__log_trace_cache_warnings) self._invalidate_trace() return @@ -186,7 +192,7 @@ def trace_prologue(self, sub_module: Module) -> None: print_rank_0( f"Invalidate trace cache @ step {self.__step_id}: " f"expected module {expected_module_id}, but got module {sub_module.ds_id}", - force=True) + force=self.__log_trace_cache_warnings) self._invalidate_trace() @compiler.disable diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 3627d4675a71f..9cc58fdbac010 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -160,6 +160,7 @@ def __init__( zero_quantized_nontrainable_weights=False, zero_module_granularity_threshold=0, zeropp_loco_param=None, + log_trace_cache_warnings=False, ): see_memory_usage("Stage 3 initialize beginning", force=True) @@ -231,7 +232,9 @@ def __init__( zero_param_parallel_group=zero_param_parallel_group, zero_quantized_weights=zero_quantized_weights, zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights, - zero_module_granularity_threshold=zero_module_granularity_threshold) + zero_module_granularity_threshold=zero_module_granularity_threshold, + log_trace_cache_warnings=log_trace_cache_warnings, + ) self.persistent_parameters = self.parameter_offload.persistent_parameters self._configure_offloading(offload_optimizer_config, offload_param_config) @@ -465,6 +468,7 @@ def initialize_ds_offload( zero_quantized_weights, zero_quantized_nontrainable_weights, zero_module_granularity_threshold, + log_trace_cache_warnings, ): return DeepSpeedZeRoOffload(module=module, timers=timers, @@ -481,7 +485,8 @@ def initialize_ds_offload( zero_param_parallel_group=zero_param_parallel_group, zero_quantized_weights=zero_quantized_weights, zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights, - zero_module_granularity_threshold=zero_module_granularity_threshold) + zero_module_granularity_threshold=zero_module_granularity_threshold, + log_trace_cache_warnings=log_trace_cache_warnings) def _get_trainable_parameter_groups(self): param_groups = [] diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md index 51e3bbd6eaaa2..43de95b5210b0 100755 --- a/docs/_pages/config-json.md +++ b/docs/_pages/config-json.md @@ -371,11 +371,12 @@ Enabling and configuring ZeRO memory optimizations "sub_group_size" : 1e12, "elastic_checkpoint" : [true|false], "stage3_gather_16bit_weights_on_model_save": [true|false], - "ignore_unused_parameters": [true|false] - "round_robin_gradients": [true|false] - "zero_hpz_partition_size": 1 - "zero_quantized_weights": [true|false] - "zero_quantized_gradients": [true|false] + "ignore_unused_parameters": [true|false], + "round_robin_gradients": [true|false], + "zero_hpz_partition_size": 1, + "zero_quantized_weights": [true|false], + "zero_quantized_gradients": [true|false], + "log_trace_cache_warnings": [true|false], } ``` @@ -512,6 +513,12 @@ Enabling and configuring ZeRO memory optimizations | ----------------------------------------------------------------------------------------------------------------------------------- | ------- | |Boolean indicating whether to enable communication efficient quantized gradients of ZeRO++. | `False` | +**log_trace_cache_warnings**: [boolean] + +| Description | Default | +| ------------------------------------------------------------------------------------------------------------------- | ------- | +| Log warnings from trace cache optimization of parameter sharding, such as cache invalidation events. | `False` | + ***cpu_offload***: [boolean] **Deprecated:** **cpu_offload** is deprecated and will be removed in future, please use `offload_optimizer` instead.