Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Analysis Tool #482

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open

Analysis Tool #482

wants to merge 4 commits into from

Conversation

yxyOo
Copy link

@yxyOo yxyOo commented Sep 1, 2023

Introduction

Offline analysis of memory requirements and communication information of Megatron-LM GPTModel training under hybrid parallel strategies

Features

Given the GPT model configuration and parallel training configuration, this tool will output the following:

  • Detail the memory requirements for Parameters, Gradients, Optimizer States and Activations at the Transformer granularity level on each GPU.
  • Provide an estimate predicting the least amount of memory a GPU needs to train the GPT model without causing Out-of-Memory (OOM) errors.
  • Describe the communication requirements when implementing Data Parallelism, Pipeline Parallelism and Tensor Parallelism. State how many times each dimension needs to communicate, the amount of data transmitted each time and the members of the communication group, among others.
  • Describe the changes in the size of the Transformer model before and after parallel and how these changes impact GPU utilization.

We randomly selected some parallel configurations and used the "Memory Requirement" output in this tool as the estimated value, and the output of torch.cuda.max_memory_allocated() in Megatron-LM report_memory after training several iterations as the actual value. The parallel configurations in the x-axis of the following figure correspond to the four model parallel configurations in the table below in order.

This can give users insight into whether their planned parallel configuration is trainable, and if it potentially could trigger OOM errors.

图片描述
Model Precision MBS GBS DP PP TP Peak_Memory_Actual Peak_Memory_Estimated Error (%)
Llama2 7B bf16 2 2048 8 1 1 69.1 68.8 0.4
Llama2 7B bf16 4 512 4 1 2 55.7 55.5 0.4
Llama2 7B bf16 2 2048 4 2 1 49.8 49.5 0.6
Llama2 7B bf16 4 128 1 1 8 28.8 28.6 0.7
In this table, MBS refers to micro batch size, GBS refers to global batch size, DP denotes data parallelism size, PP denotes pipeline parallelism size, and TP denotes tensor parallelism size.

Calculation Method Explanation

We analyze the memory requirements of the model parameters, gradients, and optimizer states and the communication behavior of different parallel dimensions based on Megatron(1, 2, and 3)

To estimate the memory requirements for the activation portion, given that Megatron supports FlashAttention and Fusion computations, we have adopted a distinctive approach. This method involves collecting the memory address and size information of the corresponding operations each time the cudaMalloc and cudaFree functions are executed, and then conducting line-by-line analysis of this information to derive a computational formula. To implement this method, we used the torch.cuda.CUDAPluggableAllocator to customize the memory allocator.

We will observe the changes in torch.cuda.max_memory_allocated during the model training process, then summarize these changes in order to estimate peak memory.

Limitations

  • Supported
    • GPTModel
    • Tensor parallelism, Pipeline parallelism, Data parallelism
    • Using --bf16, --fp16, --use-flash-attn, --use-distributed-optimizer, --swiglu
  • To be supported
    • Other Transformer-based models
    • Using --sequence-parallel, --num_layers_per_virtual_pipeline_stage, --recompute-activations
    • Enable --use-flash-attn, --use-distributed-optimizer, --swiglu--bf16

Usage

In the examples directory, we've provided scripts to get pretraining GPT information. Users can generate their scripts by using the following command:

sed 's%torchrun $DISTRIBUTED_ARGS pretrain_gpt.py%python ../get_training_info.py $DISTRIBUTED_ARGS %g' pretrain_gpt_distributed_with_mp.sh > get_pretrain_gpt_distributed_with_mp_info.sh

The function of this command is to replace "torchrun $DISTRIBUTED_ARGS pretrain_gpt.py" with "python ../get_training_info.py $DISTRIBUTED_ARGS" in the "pretrain_gpt_distributed_with_mp.sh" which is your script for launching the training.

Moreover, we've added the following training parameters:

  • --use-flash-attn
  • --use-distributed-optimizer
  • --swiglu
  • --bf16

Example of output

GPUS_PER_NODE=8
NNODES=2

GPT_ARGS="
    --tensor-model-parallel-size 2 \
    --pipeline-model-parallel-size 2 \
    --num-layers 24 \
    --hidden-size 4096 \
    --num-attention-heads 32 \
    --seq-length 2048 \
    --max-position-embeddings 2048 \
    --micro-batch-size 4 \
    --global-batch-size 512 \
    --lr 0.00015 \
    --train-iters 500000 \
    --lr-decay-iters 320000 \
    --lr-decay-style cosine \
    --min-lr 1.0e-5 \
    --weight-decay 1e-2 \
    --lr-warmup-fraction .01 \
    --clip-grad 1.0 \
    --use-flash-attn \
    --use-distributed-optimizer \
    --swiglu \
    --bf16
"

Assuming there are two nodes, each equipped with eight cards, and training a model according to the above configuration, the following output will be produced.

Full Model without Parallel

Full model information without parallel training enabled.

***Full Model without Parallel***
===========================================================================================================
Layer                                      Param.(shape)           Param.(Mem. MB)  Act.(Mem. MB)        
----------------------------------------------------------------------------------------------------------
GPTModel                                                         
├─TransformerLanguageModel                 
│    └─Embedding                                                    	               	96.0           	
│    │    └─word_embeddings                w=[50432,4096]           	394.0          	  
│    │    └─position_embeddings            w=[2048,4096]            	16.0           	
│    └─ParallelTransformer: X 32(layer_num)                                        	1320.0/layer   	
│    │    └─input_layernorm                w=[4096],b=[4096]        	0.0            	64.0           	
│    │    └─self_attention                                          	               	384.0          	
│    │    |     └─query_key_value          w=[12288,4096],b=[4096]  	96.0           	
│    │    |     └─rearrange                                         	               	192.0          	
│    │    |     └─core_attention_flash                              	               	64.0           	
│    │    |     └─rearrange                                         	               	64.0           	
│    │    |     └─dense                    w=[4096,4096],b=[4096]   	32.0           	64.0           	
│    │    └─post_attention_layernorm       w=[4096],b=[4096]        	0.0            	64.0           	
│    │    └─mlp                                                     	               	744.0          	
│    │    |     └─dense_h_to_4h            w=[21760,4096],b=[21760] 	170.0          	
│    │    |     └─bias_glue                                         	               	680.0          	
│    │    |     └─dense_4h_to_h            w=[4096,10880],b=[4096]  	85.0           	64.0           	
│    │    └─drop_add_fusion                                         	               	96.0           	
-----------------------------------------------------------------------------------------------------------
Amount of Parameters: 6,642,245,632  
Parameters: 12.4GB
Gradients: 24.7GB
Optimizers(Adam) States: 74.2GB
Activations: 44.8GB
Total memory demand: 156.2GB
==============================================================================================================

Cluster Communication Summary

Given the model and parallel configuration, the total communication count and volume for each Pipeline Parallel, Data Parallel, and Tensor Parallel dimension in a single iteration, as well as the total communication count and volume for the entire cluster in the final training iteration.

***Cluster Communication Summary***
==============================
Pipeline Parallelism
│    └─frequency/iteration: 2048
│    └─volume/iteration: 128.0 GB
Data Parallelism
│    └─frequency/iteration: 2
│    └─volume/iteration: 12.8 GB
Tensor Parallelism
│    └─frequency/iteration: 32768
│    └─volume/iteration: 2048.0 GB
All Communication
│    └─frequency/iteration: 34818
│    └─volume/iteration: 2188.8 GB
==============================

Memory demand on each GPU in the cluster

Given the model and parallel configuration, the memory requirements on each GPU in the cluster for training one iteration.

***Memory demand on each GPU in the cluster***
==============================
Amount of Parameters: 1,718,898,688  
Parameters: 3.2GB
Gradients: 6.4GB
Optimizers(Adam) States: 4.8GB
Activations: 25.8GB
Memory Requirement: 40.2GB
==============================

Pipeline Parallel Communication

***Pipeline Parallel Communications***
========================================================================================
GPTModel                                                         
├─TransformerLanguageModel                 
│    └─Embedding                                              
│    │    └─word_embeddings                
│    │    └─position_embeddings            
│    └─Stage0: ParallelTransformerLayer_Index0-15
│    │    └─stage_device_mappings
│    │    │      └─[n0_g0 n0_g1 n0_g2 n0_g3 n0_g4 n0_g5 n0_g6 n0_g7]
│    │    └─each single communication on each gpu
│    │    │    └─shape: [4,2048,4096]              
│    │    │    └─volume: 64.0MB         
│    │    │    └─func: isend, irecv
│    │    │    └─location: between stage in forward and backward process
│    │    └─each iteration communication on each gpu
│    │    │    └─frequency: 128 (num_gradient_accumulation_steps * 4)
│    │    │    └─volume: 8192.0MB       
│    └─Stage1: ParallelTransformerLayer_Index16-31
│    │    └─stage_device_mappings 
│    │    │      └─[n1_g0 n1_g1 n1_g2 n1_g3 n1_g4 n1_g5 n1_g6 n1_g7]

----------------------------------------------------------------------------------------
8 Pipeline Parallel Communication Groups:
│    └─[n0_g0 n1_g0]
│    └─[n0_g1 n1_g1]
│    └─[n0_g2 n1_g2]
│    └─[n0_g3 n1_g3]
│    └─[n0_g4 n1_g4]
│    └─[n0_g5 n1_g5]
│    └─[n0_g6 n1_g6]
│    └─[n0_g7 n1_g7]
All Communication of Cluster in Pipeline Parallelism
│    └─frequency/iteration: 2048
│    └─volume/iteration: 128.0GB
========================================================================================

Data Parallel Communications

***Data Parallel Communications***
========================================================================================
GPTModel                                                         
├─each iteration                
│    └─synchronize_gradient                                         
│    │    └─4 Data Parallel Groups 
│    │    │    └─[n0_g0 n0_g2 n0_g4 n0_g6]
│    │    │    └─[n0_g1 n0_g3 n0_g5 n0_g7]
│    │    │    └─[n1_g0 n1_g2 n1_g4 n1_g6]
│    │    │    └─[n1_g1 n1_g3 n1_g5 n1_g7]
│    │    └─communication 
│    │    │    └─volume: 6.4GB
│    │    │    └─func: reduce_scatter (using DistributedOptimizer) 
│    │    └─frequency/iteration: 1
│    │    └─location: after forward_and_backward_compute * 32 times/iteration 
│    └─gather_model_param (using DistributedOptimizer)                                          
│    │    └─4 Data Parallel Groups 
│    │    │    └─[n0_g0 n0_g2 n0_g4 n0_g6]
│    │    │    └─[n0_g1 n0_g3 n0_g5 n0_g7]
│    │    │    └─[n1_g0 n1_g2 n1_g4 n1_g6]
│    │    │    └─[n1_g1 n1_g3 n1_g5 n1_g7]
│    │    └─communication on each gpu
│    │    │    └─volume: 6.4GB
│    │    │    └─func: all_gather
│    │    └─frequency/iteration: 1
│    │    └─location: after optimizer.iteration
----------------------------------------------------------------------------------------
All Communication of Cluster in Data Parallelism
│    └─frequency/iteration: 2
│    └─volume/iteration: 12.8GB
========================================================================================

Tensor Parallel Communications

***Tensor Parallel Communications***
=================================================================================================================================================================================================================
Layer                                      Param(shape)           Param(Mem. MB)  Activations(Mem. MB)   TP_Fw.(Comm. Shape)  TP_Fw.(Comm. Mem. MB)   TP_Bw.(Comm. Shape)  TP_Bw.(Comm. Mem. MB)   TP(Comm. func)
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
GPTModel                                                         
├─TransformerLanguageModel                 
│    └─Embedding                                                    	               	96.0           	                    
│    │    └─word_embeddings                w=[25216],b=[4096]       	394.0          	               	[4,2048,4096]            	64.0           	                         	               	allreduce                	
│    │    └─position_embeddings            w=[2048],b=[4096]        	16.0           	
│    └─ParallelTransformer: X 16(layer_num)                                        	1320.0/layer   	
│    │    └─input_layernorm                w=[4096],b=[4096]        	0.0            	64.0           	
│    │    └─self_attention                                          	               	384.0          	
│    │    |     └─query_key_value          w=[6144,4096],b=[4096]   	48.0           	               	                         	               	[4,2048,4096]            	64.0           	allreduce                	
│    │    |     └─rearrange                                         	               	96.0           	
│    │    |     └─core_attention_flash                              	               	32.0           	
│    │    |     └─rearrange                                         	               	32.0           	
│    │    |     └─dense                    w=[2048,4096],b=[4096]   	16.0           	64.0           	[4,2048,4096]            	64.0           	                         	               	allreduce                	
│    │    └─post_attention_layernorm       w=[4096],b=[4096]        	0.0            	64.0           	
│    │    └─mlp                                                     	               	744.0          	
│    │    |     └─dense_h_to_4h            w=[10880,4096],b=[10880] 	85.0           	               	                         	               	[4,2048,4096]            	64.0           	allreduce                	
│    │    |     └─bias_glue                                         	               	680.0          	
│    │    |     └─dense_4h_to_h            w=[4096,5440],b=[4096]   	85.0           	64.0           	[4,2048,4096]            	64.0           	                         	               	allreduce                	
│    │    └─drop_add_fusion                                         	               	96.0           	
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
8 Tensor Parallel Communication Groups:
│    └─[n0_g0 n0_g1]
│    └─[n0_g2 n0_g3]
│    └─[n0_g4 n0_g5]
│    └─[n0_g6 n0_g7]
│    └─[n1_g0 n1_g1]
│    └─[n1_g2 n1_g3]
│    └─[n1_g4 n1_g5]
│    └─[n1_g6 n1_g7]
Communication in Tensor Parallel
│    └─each gpu:
│    │    └─each micro_batch:
│    │    │    └─frequency: 64
│    │    │    └─volume: 4.0GB
│    │    │    └─each transformer:
│    │    │    │    └─frequency: 2(forward)+2(backward)=4
│    │    │    │    └─volume: 0.25GB
│    │    └─each iteration:
│    │    │    └─frequency: 2048
│    │    │    └─volume: 128.0GB
│    └─cluster:
│    │    └─each micro_batch:
│    │    │    └─frequency: 1024
│    │    │    └─volume: 64.0GB
│    │    └─each iteration:
│    │    │    └─frequency: 32768
│    │    │    └─volume: 2048.0GB
=======================================================================================================================================================================================================================

@yxyOo yxyOo mentioned this pull request Sep 6, 2023
@deepakn94
Copy link
Collaborator

This looks interesting! How accurate is it?

@jindajia
Copy link

Really awesome!!

@yxyOo
Copy link
Author

yxyOo commented Oct 7, 2023

This looks interesting! How accurate is it?

We randomly selected several parallel configurations and conducted "Memory Requirement" tests on the 7B llama2 model using a single H800 machine with eight cards. The results showed that the error was within 1% for all measurements. All other values output by the tool were theoretical.

@zhipeng93
Copy link

+1 for Really awesome!

Copy link

@zhipeng93 zhipeng93 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. It is really a nice feature. However, the Readme seems inconsistent with the implementation.

I left some comments below. Please take a look.

# Calculation Method Explanation
We analyze the memory requirements of the model parameters, gradients, and optimizer states and the communication behavior of different parallel dimensions based on Megatron([1](https://arxiv.org/pdf/1909.08053.pdf), [2](https://arxiv.org/pdf/2104.04473.pdf), and [3](https://arxiv.org/pdf/2205.05198))

To estimate the memory requirements for the activation portion, given that Megatron supports FlashAttention and Fusion computations, we have adopted a distinctive approach. This method involves collecting the memory address and size information of the corresponding operations each time the cudaMalloc and cudaFree functions are executed, and then conducting line-by-line analysis of this information to derive a computational formula. To implement this method, we used the [torch.cuda.CUDAPluggableAllocator](https://pytorch.org/docs/stable/notes/cuda.html#using-custom-memory-allocators-for-cuda) to customize the memory allocator.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. Could you point out where did you use torch.cuda.CUDAPluggableAllocator to estimate the activation memory? I did not find it.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhipeng93 It is not here. To use it, one must write a shared lib to implement the interface and set it at the beginning the python pytorch program (using ctype to load it).

tp_comm_size = tp_comm_count * s * b * h

dp_comm_count = 0 if d == 1 else 2
dp_comm_size = total_parameters_per_gpu * 4 if args.bf16 else total_parameters_per_gpu * 2

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain why dp_comm_size = total_parameters_per_gpu * 4 for bf16 while *2 for fp16 and fp32?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

7 * h + 4 * h * h / t + 3 * f * h / t + 2 * f) * per_stage_layer_num
total_parameters_per_gpu_formatted = f'{int(total_parameters_per_gpu):,}'

activations = n * (10 * s * h * b +
Copy link

@zhipeng93 zhipeng93 Oct 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain the formula of computating activation memory so that we can understand the intuition behind this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The explanation for this formula can be found in the "Calculation Method Explanation" section

@zhipeng93
Copy link

This looks interesting! How accurate is it?

We randomly selected several parallel configurations and conducted "Memory Requirement" tests on the 7B llama2 model using a single H800 machine with eight cards. The results showed that the error was within 1% for all measurements. All other values output by the tool were theoretical.

Is this verification based on the code base here or that used torch.cuda.CUDAPluggableAllocator? I am wondering whether we really need to execute the code to get the peak memory usage.

@Ethan-yt
Copy link

Ethan-yt commented Oct 18, 2023

Hi @yxyOo:
I have a few questions about total_parameters computing. Since you mentioned your experiments on llama, but I find some inconsistency:

  1. llama doesn't have bias
  2. llama doesn't use shared embeddings
  3. llama doesn't have position_embeddings

Thanks

@V3RGANz
Copy link

V3RGANz commented Oct 27, 2023

Hi, @yxyOo this is a great feature! While my suggestion might seem a bit much, I believe it would be beneficial to use the default argument parser from training. This way, you could simply replace the training executable name with this tool and receive an analysis without additional effort (Or even print it before training if you want).

Moreover, it's particularly handy for LLaMa checkpoints, as most arguments are read directly from the checkpoints (--use-checkpoint-args).

@deepakn94
Copy link
Collaborator

Hi, @yxyOo this is a great feature! While my suggestion might seem a bit much, I believe it would be beneficial to use the default argument parser from training. This way, you could simply replace the training executable name with this tool and receive an analysis without additional effort (Or even print it before training if you want).

Here is a simple training script that computes the "theoretical" memory usage of a model: https://github.com/NVIDIA/Megatron-LM/blob/main/compute_memory_usage.py. It re-uses the existing argument parser so we can easily do precisely what you ask for.

It is under active development and should get better in the coming days.

Copy link

@yiakwy-xpu-ml-framework-team yiakwy-xpu-ml-framework-team left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has a nice representation !

As we have gone though the patch, I believe this theoretical estimation shares common pitfall as the megatron memory report developed by @deepakn94 and use the same approach of deepspeed/onnx flops profiler for dry-run estimation.

The two ranks of SXM (A100 compute ability) for GPT alike model shows the gaps between the estimated and one : 70 (estimated) vs (20 rank0, 50 rank 1):

Model Precision MBS GBS NODES GPU/WORKER DP PP TP Peak_Memory_Actual Peak_Memory_Estimated avg Error (%)
GPT alike (16 layers) bf16 2 2048 2 8 8 2 1 rank#0: 23.8, rank#1: 56.3 70 > 75%

Reproduce this with this estimator:

***Memory demand on each GPU in the cluster***
==============================
Amount of Parameters: 431,403,008  
Parameters: 0.8GB
Gradients: 1.6GB
Optimizers(Adam) States: 0.6GB
Activations: 72.7GB
Memory Requirement: 75.7GB
==============================

GAP analysis

Liveness of tensors

An activation can simulated with an array of liveness tensor :

using LivenessInfo  = map<Key, Val>
/* 
where 

Key : [start_step, end_step]
Val : bytes
*/

An unary accumulated op (+=) or binary op (+) can be defined over this liveness info.

How you define "start_step" and "end_step" is dependent on the compiler. It does not work if two activation are simply added together.

Hence for non-always live (activation), a special algorithm has been explored and to be patented for static graph almost 2 years ago for static graph.

For imperative graph in megatron, since pytorch cached memory allocator will not release memory as soon as the tensor's life end, liveness plays a great role.

This means the peak memory of pytorch must be slower than what flops profiler or this memory estimator, and that megatron memory reporter profile.

I have raised an ticket for this purpose, and hope this is useful for the community.

Ranks

PP is the outer most dimension of GPUs partition groups, DP and TP are inner dimensions. We observed that memory imbalance between ranks. Experts shared GPUs in DP group and more gather/broadcast needed.

Hence you cannot simply divide the total amount of parameters needed for communication to decide which and when a gpu goes out of memory.

----------------------------------------------------------------------------------------------------------
GPTModel
├─TransformerLanguageModel
│ └─Embedding {space()}\t{space(pad=15)}\t{memory_mega_bytes(1.5*s*b*h)}\t
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi I guess many of the parameters guess for weights are hard coded, why don't we make a small named function for that estimation ? (MS deepspeed flops profiler)

@yxyOo
Copy link
Author

yxyOo commented Dec 18, 2023

Hi @yxyOo: I have a few questions about total_parameters computing. Since you mentioned your experiments on llama, but I find some inconsistency:

  1. llama doesn't have bias
  2. llama doesn't use shared embeddings
  3. llama doesn't have position_embeddings

Thanks

Thank you for pointing out that it should be the GPT model.

@yxyOo
Copy link
Author

yxyOo commented Dec 18, 2023

This looks interesting! How accurate is it?

We randomly selected several parallel configurations and conducted "Memory Requirement" tests on the 7B llama2 model using a single H800 machine with eight cards. The results showed that the error was within 1% for all measurements. All other values output by the tool were theoretical.

Is this verification based on the code base here or that used torch.cuda.CUDAPluggableAllocator? I am wondering whether we really need to execute the code to get the peak memory usage.

It is based on the code base. Before training your model, you can use this tool to determine the minimum amount of memory the model will consume.

@yxyOo
Copy link
Author

yxyOo commented Dec 18, 2023

Hi, @yxyOo this is a great feature! While my suggestion might seem a bit much, I believe it would be beneficial to use the default argument parser from training. This way, you could simply replace the training executable name with this tool and receive an analysis without additional effort (Or even print it before training if you want).

Moreover, it's particularly handy for LLaMa checkpoints, as most arguments are read directly from the checkpoints (--use-checkpoint-args).

itiona

Thank you for your suggestion, I did it this way at the time to quickly develop this tool, haha. If needed, I will consider supporting related features in the future.

if args.bf16:
loss_logits_mem = 5 * s * b * v / t if p == 1 else 0
peak_mem = max(
memory_giga_bytes(total_parameters_per_gpu * (1 + 2 + 2 / d + 2)),
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Thanks for this very useful script, but I am struggling to understand where this term (total_parameters_per_gpu * (1 + 2 + 2 / d + 2))) for the peak memory comes from. Would it be possible to get more details?

activations_per_gpu + loss_logits_mem))
gradient = total_parameters
gradient_per_gpu = total_parameters_per_gpu
optimizer = total_parameters * 8
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why optimizer = total_parameters * 6? For each parameter in the model, AdamW keeps track of First Moment Vector and Second Moment Vector.As a result, the CUDA memory requirement for using the AdamW optimizer is approximately 2 times the memory required for the model parameters themselves.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the super fast reply! I meant line 228, about the factor (1 + 2 + 2 / d + 2). The first terms 1 and 2 are params in bfloat16 and gradients in fp32 but I cannot derive 2/d +2. Are those the optimizer states distributed in some way? Where does the second term 2 come from?

Copy link

Marking as stale. No activity in 60 days.

@github-actions github-actions bot added the stale No activity in 60 days on issue or PR label Mar 22, 2024
@yiakwy-xpu-ml-framework-team

Hi @yxyOo: I have a few questions about total_parameters computing. Since you mentioned your experiments on llama, but I find some inconsistency:

  1. llama doesn't have bias
  2. llama doesn't use shared embeddings
  3. llama doesn't have position_embeddings

Thanks

Thank you for pointing out that it should be the GPT model.

Note if flash attention used, memory cost is O(bhsd) not O(bhss*d).

@github-actions github-actions bot removed the stale No activity in 60 days on issue or PR label Mar 23, 2024
Copy link

Marking as stale. No activity in 60 days.

@github-actions github-actions bot added the stale No activity in 60 days on issue or PR label May 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale No activity in 60 days on issue or PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants