Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -1776,13 +1776,22 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to

class InfNanRemoveLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] that removes all `nan` and `inf` values to avoid the generation method to fail. Note that using
the logits processor should only be used if necessary since it can slow down the generation method.
[`LogitsProcessor`] that removes all `nan` and `inf` values to avoid the generation method to fail. This version
has been extended to sanitize both logits and hidden state output tensors to handle instabilities in very wide
models or ones sharded across many devices.

Note that using the logits processor should only be used if necessary since it can slow down the generation method.

This logits processor has no `generate` example, as there shouldn't be a correct combination of flags that warrants
its use.
its use. However, when dealing with sharded models across many GPUs or models with very wide hidden dimensions that
can produce unstable values, setting `remove_invalid_values=True` in generation config will activate this processor
automatically.
"""

def __init__(self, hidden_states_aware=True):
# Flag to control whether we also want to clean hidden states
self.hidden_states_aware = hidden_states_aware

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# set all nan values to 0.0
Expand Down
33 changes: 33 additions & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1844,6 +1844,39 @@ def _prepare_generation_config(

# Finally, apply any passed kwargs
model_kwargs = generation_config.update(**kwargs)

# Safety: if the model is sharded across multiple devices (hf_device_map/device_map) and we are
# doing sampling, enable `remove_invalid_values` by default to avoid NaN/Inf logits causing CUDA
# asserts during multinomial sampling. Users can still override this by passing the flag explicitly.
try:
is_sharded_map = False
hf_map = getattr(self, "hf_device_map", None)
if hf_map is not None and isinstance(hf_map, dict) and len(set(hf_map.values())) > 1:
# consider sharded if more than one device (excluding "cpu"/"disk")
devices = set(hf_map.values())
gpu_devices = {d for d in devices if d not in {"cpu", "disk"}}
if len(gpu_devices) > 1:
is_sharded_map = True

# also accept legacy `device_map` attribute or accelerate hooks
device_map_attr = getattr(self, "device_map", None)
if not is_sharded_map and device_map_attr is not None:
# device_map can be a dict mapping module->device or other structures; if it's a dict and maps
# to multiple cuda devices, consider it sharded
if isinstance(device_map_attr, dict) and len(set(device_map_attr.values())) > 1:
devices = set(device_map_attr.values())
gpu_devices = {d for d in devices if d not in {"cpu", "disk"}}
if len(gpu_devices) > 1:
is_sharded_map = True

if is_sharded_map and generation_config.do_sample and generation_config.remove_invalid_values is False:
generation_config.remove_invalid_values = True
logger.info(
"Enabling `remove_invalid_values=True` for sharded sampling to avoid NaN/Inf logits during sampling."
)
except Exception:
# never fail generation config preparation due to best-effort safety check
pass
# And keep in model_kwargs variable output controls
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
Expand Down