Skip to content

Commit

Permalink
Revert "feat: incorporate fallback allocation into infer_auto_device_…
Browse files Browse the repository at this point in the history
…map"

This reverts commit d607bfb.
  • Loading branch information
Nech-C committed Oct 6, 2024
1 parent d607bfb commit f040302
Showing 1 changed file with 0 additions and 39 deletions.
39 changes: 0 additions & 39 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1277,7 +1277,6 @@ def infer_auto_device_map(
verbose: bool = False,
clean_result: bool = True,
offload_buffers: bool = False,
fallback_allocation: bool = False,
):
"""
Compute a device map for a given model giving priority to GPUs, then offload on CPU and finally offload to disk,
Expand Down Expand Up @@ -1352,7 +1351,6 @@ def infer_auto_device_map(
device_memory_used = {}
device_buffer_sizes = {}
device_minimum_assignment_memory = {}
fallback_attempted = False

# Direct submodules and parameters
modules_to_treat = (
Expand Down Expand Up @@ -1424,27 +1422,8 @@ def infer_auto_device_map(
# -> no split, we go to the next device
if verbose:
print("This module cannot be split, going to the next device.")

if fallback_allocation and devices[current_device] in main_devices and \
current_memory_used == 0 and not fallback_attempted:
fallback_module_name, fallback_module, remaining_modules = fallback_allocate(
modules_to_treat,
module_sizes,
current_max_size - current_memory_used,
no_split_module_classes,
tied_parameters,
)

# use the next iteration to put the fallback module on the next device to avoid code duplication
if fallback_module is not None:
modules_to_treat = [(fallback_module_name, fallback_module)]\
+ [(name, module)]\
+ remaining_modules
continue

if current_memory_used == 0:
device_minimum_assignment_memory[device] = module_size + current_memory_reserved

device_memory_used[device] = current_memory_used + current_memory_reserved
current_device += 1
modules_to_treat = [(name, module)] + modules_to_treat
Expand Down Expand Up @@ -1543,24 +1522,6 @@ def infer_auto_device_map(
# If the tied module is not split, we go to the next device
if verbose:
print("None of the tied module can be split, going to the next device.")

if fallback_allocation and devices[current_device] in main_devices and \
current_memory_used == 0 and not fallback_attempted:
fallback_module_name, fallback_module, remaining_modules = fallback_allocate(
modules_to_treat,
module_sizes,
current_max_size - current_memory_used,
no_split_module_classes,
tied_parameters,
)

# use the next iteration to put the fallback module on the next device to avoid code duplication
if fallback_module is not None:
modules_to_treat = [(fallback_module_name, fallback_module)] \
+ [(name, module)] \
+ remaining_modules
continue

if current_memory_used == 0:
device_minimum_assignment_memory[device] = module_size_with_ties + current_memory_reserved

Expand Down

0 comments on commit f040302

Please sign in to comment.