Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Control trace cache warnings #7039

Merged
merged 3 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,6 +983,9 @@ def zero_quantized_gradients(self):
def zeropp_loco_param(self):
return self._config.zero_config.zeropp_loco_param

def zero_log_trace_cache_warnings(self):
return self._config.zero_config.log_trace_cache_warnings

def dump_state(self):
return self._config.dump_state

Expand Down Expand Up @@ -1692,6 +1695,7 @@ def _configure_zero_optimizer(self, optimizer):
zero_quantized_weights=self.zero_quantized_weights(),
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(),
zero_module_granularity_threshold=self.zero_module_granularity_threshold(),
log_trace_cache_warnings=self.zero_log_trace_cache_warnings(),
)
else:
log_dist(
Expand Down Expand Up @@ -1740,6 +1744,7 @@ def _configure_zero_optimizer(self, optimizer):
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(),
zero_module_granularity_threshold=self.zero_module_granularity_threshold(),
zeropp_loco_param=self.zeropp_loco_param(),
log_trace_cache_warnings=self.zero_log_trace_cache_warnings(),
)

else:
Expand Down
6 changes: 6 additions & 0 deletions deepspeed/runtime/zero/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
"memory_efficient_linear": [true|false],
"override_module_apply": [true|false],
"zeropp_loco_param": {...},
"log_trace_cache_warnings" : [true|false],
}
}
"""
Expand Down Expand Up @@ -340,6 +341,11 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
Override nn.Module apply function, for Stage 3.
"""

log_trace_cache_warnings: bool = False
"""
Whether to log warnings from trace cache, such as invalidation events.
"""

# Validators
@model_validator(mode="after")
def overlap_comm_valid(self):
Expand Down
6 changes: 5 additions & 1 deletion deepspeed/runtime/zero/parameter_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def __init__(
zero_quantized_weights=False,
zero_quantized_nontrainable_weights=False,
zero_module_granularity_threshold=0,
log_trace_cache_warnings=False,
):

see_memory_usage("DeepSpeedZeRoOffload initialize [begin]", force=True)
Expand All @@ -118,6 +119,7 @@ def __init__(
self.zero_param_parallel_group = zero_param_parallel_group
self.zero_quantized_weights = zero_quantized_weights
self.zero_quantized_nontrainable_weights = zero_quantized_nontrainable_weights
self.log_trace_cache_warnings = log_trace_cache_warnings

if offload_param_config is not None and offload_param_config.device != OffloadDeviceEnum.none:
self.offload_device = offload_param_config.device
Expand Down Expand Up @@ -165,7 +167,9 @@ def __init__(
timers=self.timers,
zero_quantized_weights=self.zero_quantized_weights,
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights,
fast_sharding_for_leaf_module=self.fast_sharding_for_leaf_module)
fast_sharding_for_leaf_module=self.fast_sharding_for_leaf_module,
log_trace_cache_warnings=self.log_trace_cache_warnings,
)

self.forward_hooks = []
self.backward_hooks = []
Expand Down
32 changes: 19 additions & 13 deletions deepspeed/runtime/zero/partitioned_param_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,20 @@ class __ParamInTrace:
param: Parameter
step_id_last_used_at: int

def __init__(self,
prefetch_bucket_sz: int,
max_reuse_distance_in_numel: int,
max_available_parameters_in_numel: int,
allgather_stream: get_accelerator().Stream,
inflight_param_registry: InflightParamRegistry,
prefetch_nvme: bool = False,
timers=None,
zero_quantized_weights=False,
zero_quantized_nontrainable_weights=False,
fast_sharding_for_leaf_module=False) -> None:
def __init__(
self,
prefetch_bucket_sz: int,
max_reuse_distance_in_numel: int,
max_available_parameters_in_numel: int,
allgather_stream: get_accelerator().Stream,
inflight_param_registry: InflightParamRegistry,
prefetch_nvme: bool = False,
timers=None,
zero_quantized_weights=False,
zero_quantized_nontrainable_weights=False,
fast_sharding_for_leaf_module=False,
log_trace_cache_warnings=False,
) -> None:
# mapping of param -> handle for each param that is currently in flight
self.__inflight_param_registry = inflight_param_registry
# keeps track of the number of submodules invoked so far.
Expand Down Expand Up @@ -129,6 +132,9 @@ def __init__(self,
self.__max_ongoing_fetch_events: int = 2
self.__profiler = PartitionedParameterProfiler(timers if ENABLE_PROFILER else None)

# Whether to log trace cache warnings, e.g. invalidation events
self.__log_trace_cache_warnings = log_trace_cache_warnings

# whether to enable fast fetch for the z3 leaf module.
# this will improve fetch speed but will not break down leaf module parameters to alleviate memory pressure.
self.fast_sharding_for_leaf_module = fast_sharding_for_leaf_module
Expand Down Expand Up @@ -177,7 +183,7 @@ def trace_prologue(self, sub_module: Module) -> None:
print_rank_0(
f"Invalidate trace cache @ step {self.__step_id} and module {sub_module.ds_id}: "
f"cache has only {len(self.__submodule_order)} modules",
force=True)
force=self.__log_trace_cache_warnings)
self._invalidate_trace()
return

Expand All @@ -186,7 +192,7 @@ def trace_prologue(self, sub_module: Module) -> None:
print_rank_0(
f"Invalidate trace cache @ step {self.__step_id}: "
f"expected module {expected_module_id}, but got module {sub_module.ds_id}",
force=True)
force=self.__log_trace_cache_warnings)
self._invalidate_trace()

@compiler.disable
Expand Down
9 changes: 7 additions & 2 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def __init__(
zero_quantized_nontrainable_weights=False,
zero_module_granularity_threshold=0,
zeropp_loco_param=None,
log_trace_cache_warnings=False,
):
see_memory_usage("Stage 3 initialize beginning", force=True)

Expand Down Expand Up @@ -231,7 +232,9 @@ def __init__(
zero_param_parallel_group=zero_param_parallel_group,
zero_quantized_weights=zero_quantized_weights,
zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights,
zero_module_granularity_threshold=zero_module_granularity_threshold)
zero_module_granularity_threshold=zero_module_granularity_threshold,
log_trace_cache_warnings=log_trace_cache_warnings,
)

self.persistent_parameters = self.parameter_offload.persistent_parameters
self._configure_offloading(offload_optimizer_config, offload_param_config)
Expand Down Expand Up @@ -465,6 +468,7 @@ def initialize_ds_offload(
zero_quantized_weights,
zero_quantized_nontrainable_weights,
zero_module_granularity_threshold,
log_trace_cache_warnings,
):
return DeepSpeedZeRoOffload(module=module,
timers=timers,
Expand All @@ -481,7 +485,8 @@ def initialize_ds_offload(
zero_param_parallel_group=zero_param_parallel_group,
zero_quantized_weights=zero_quantized_weights,
zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights,
zero_module_granularity_threshold=zero_module_granularity_threshold)
zero_module_granularity_threshold=zero_module_granularity_threshold,
log_trace_cache_warnings=log_trace_cache_warnings)

def _get_trainable_parameter_groups(self):
param_groups = []
Expand Down
17 changes: 12 additions & 5 deletions docs/_pages/config-json.md
Original file line number Diff line number Diff line change
Expand Up @@ -371,11 +371,12 @@ Enabling and configuring ZeRO memory optimizations
"sub_group_size" : 1e12,
"elastic_checkpoint" : [true|false],
"stage3_gather_16bit_weights_on_model_save": [true|false],
"ignore_unused_parameters": [true|false]
"round_robin_gradients": [true|false]
"zero_hpz_partition_size": 1
"zero_quantized_weights": [true|false]
"zero_quantized_gradients": [true|false]
"ignore_unused_parameters": [true|false],
"round_robin_gradients": [true|false],
"zero_hpz_partition_size": 1,
"zero_quantized_weights": [true|false],
"zero_quantized_gradients": [true|false],
"log_trace_cache_warnings": [true|false],
}
```

Expand Down Expand Up @@ -512,6 +513,12 @@ Enabling and configuring ZeRO memory optimizations
| ----------------------------------------------------------------------------------------------------------------------------------- | ------- |
|Boolean indicating whether to enable communication efficient quantized gradients of ZeRO++. | `False` |

<i>**log_trace_cache_warnings**</i>: [boolean]

| Description | Default |
| ------------------------------------------------------------------------------------------------------------------- | ------- |
| Log warnings from trace cache optimization of parameter sharding, such as cache invalidation events. | `False` |

***cpu_offload***: [boolean]

**Deprecated:** **cpu_offload** is deprecated and will be removed in future, please use `offload_optimizer` instead.
Expand Down