Description
I am writing a simple script to run FSDP2 (fully_shard
) on the pythia-1b
model available on HuggingFace. I am currently running the model on 1 node with 2 devices. I was following the meta-device initialisation from the FSDP2 docs. However, I think there is something wrong with my implementation since the peak memory usage with FSDP is same as without FSDP (~ 1GB). Further, I get an OOM on my device when I try with pythia-2.8b
model. Following is a snippet on how I am initialising the model on a meta device using HuggingFace APIs:
model_name = "EleutherAI/pythia-14m"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
config = AutoConfig.from_pretrained(model_name)
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config)
for module in model.modules():
if isinstance(module, GPTNeoXLayer):
fully_shard(module)
model = fully_shard(model, reshard_after_forward=True)
model = load_checkpoint_and_dispatch(
model, path_to_safe_tensors
)
This is not very straightforward since the shards expect DTensors
when the weights are being loaded via load_checkpoint_and_dispatch
. I am looking for some suggestions on what would be a good way to make FSDP2 work with HuggingFace models. I dont think accelerate supports FSDP2 yet.