Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QUESTION] Even training a small model, Megatron-LM need huge memory space, OOM error #525

Closed
jindajia opened this issue Sep 29, 2023 · 4 comments
Labels
stale No activity in 60 days on issue or PR

Comments

@jindajia
Copy link

I'm trying to train a gpt2 large(774M) model on V100-32GB GPU, however even this model is not big, I cant' fit it into a single gpu. It will always show this error, attache with my terminal output.

"torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 786.00 MiB (GPU 0; 31.74 GiB total capacity; 28.82 GiB already allocated; 577.12 MiB free; 29.97 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF"

Below is my training config and model size calculation. Based on my calculation, and especially with the help of @yxyOo from this tool, I figured out this model should only need 23.3GB which is much smaller than 32GB GPU memory. However, I still encounter OOM errors. So, I'm confused why Megatron need so much memory during training.

Besides, I didn't use use-flash-atten method to save memory because I can only access to V100 GPU which is not support it. Is this a reason why memory bigger than theoretical situation?

DISTRIBUTED_ARGS="
    --nproc_per_node $GPUS_PER_NODE \
    --nnodes $NNODES \
    --node_rank $NODE_RANK \
    --master_addr $MASTER_ADDR \
    --master_port $MASTER_PORT
"

GPT_ARGS="
    --tensor-model-parallel-size 1 \
    --pipeline-model-parallel-size 1 \
    --num-layers 36 \
    --hidden-size 1280 \
    --num-attention-heads 20 \
    --seq-length 1024 \
    --max-position-embeddings 1024 \
    --micro-batch-size 4 \
    --global-batch-size 16 \
    --lr 0.00015 \
    --train-iters 500000 \
    --lr-decay-iters 320000 \
    --lr-decay-style cosine \
    --min-lr 1.0e-5 \
    --weight-decay 1e-2 \
    --lr-warmup-fraction .01 \
    --clip-grad 1.0 \
    --fp16 \
    --swiglu \
    --use-distributed-optimizer
"

Model Memory Size

----------------------------------------------------------------------------------------------------------------------------------
***Full Model without Parallel***
===========================================================================================================
Layer                                      Param.(shape)           Param.(Mem. MB)  Act.(Mem. MB)        
----------------------------------------------------------------------------------------------------------
GPTModel                                                         
├─TransformerLanguageModel                 
│    └─Embedding                                                    	               	15.0           	
│    │    └─word_embeddings                w=[50304,1280]           	122.8          	  
│    │    └─position_embeddings            w=[1024,1280]            	2.5            	
│    └─ParallelTransformer: X 36(layer_num)                                        	206.0/layer    	
│    │    └─input_layernorm                w=[1280],b=[1280]        	0.0            	10.0           	
│    │    └─self_attention                                          	               	60.0           	
│    │    |     └─query_key_value          w=[3840,1280],b=[1280]   	9.4            	
│    │    |     └─rearrange                                         	               	30.0           	
│    │    |     └─core_attention_flash                              	               	10.0           	
│    │    |     └─rearrange                                         	               	10.0           	
│    │    |     └─dense                    w=[1280,1280],b=[1280]   	3.1            	10.0           	
│    │    └─post_attention_layernorm       w=[1280],b=[1280]        	0.0            	10.0           	
│    │    └─mlp                                                     	               	116.0          	
│    │    |     └─dense_h_to_4h            w=[6784,1280],b=[6784]   	16.6           	
│    │    |     └─bias_glue                                         	               	106.0          	
│    │    |     └─dense_4h_to_h            w=[1280,3392],b=[1280]   	8.3            	10.0           	
│    │    └─drop_add_fusion                                         	               	15.0           	
-----------------------------------------------------------------------------------------------------------
Amount of Parameters: 771,106,304  
Parameters: 1.4GB
Gradients: 1.4GB
Optimizers(Adam) States: 11.5GB
Activations: 9.0GB
Total memory demand: 23.3GB
==============================================================================================================
@deepakn94
Copy link
Collaborator

What happens if you half the microbatch size? I am wondering if the tool you are using calculates memory estimates for the activations correctly.

Also, unfortunately PyTorch consumes a fair bit of auxiliary memory that is not reflected in any of these calculations.

@yxyOo
Copy link

yxyOo commented Oct 7, 2023

The default output result of this tool is set to "use-flash-atten", which does not match your usage scenario.
Please refer to the "Limitations" section in Analysis Tool.

@deepakn94
Copy link
Collaborator

We also now have a report_theoretical_memory.py script now that should take the same set of arguments as pretrain_gpt.py.

You can use like this:

CUDA_DEVICE_MAX_CONNECTIONS=1 WORLD_SIZE=<WORLD_SIZE> python -u report_theoretical_memory.py ${options}

Copy link

Marking as stale. No activity in 60 days.

@github-actions github-actions bot added the stale No activity in 60 days on issue or PR label Jan 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale No activity in 60 days on issue or PR
Projects
None yet
Development

No branches or pull requests

4 participants