Skip to content

Commit

Permalink
Refactoring LLama Attention and mlp layers (#589)
Browse files Browse the repository at this point in the history
  • Loading branch information
bgoldberg-habana authored Dec 11, 2023
1 parent 4ffb9d7 commit afea217
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 32 deletions.
15 changes: 15 additions & 0 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,19 @@ def setup_device(args):
return torch.device(args.device)


# patching LinearAllreduce to use ScopedLinearAllReduce
def patch_scoped_linear_all_reduce(model):
from deepspeed.module_inject.layers import LinearAllreduce

from optimum.habana.transformers.models.modeling_all_models import ScopedLinearAllReduce

for name, module in model.named_children():
if type(module) is LinearAllreduce:
SL = ScopedLinearAllReduce(mod=module)
setattr(model, name, SL)
patch_scoped_linear_all_reduce(module)


def setup_model(args, model_dtype, model_kwargs, logger):
logger.info("Single-device run.")

Expand Down Expand Up @@ -194,6 +207,8 @@ def setup_distributed_model(args, model_dtype, model_kwargs, logger):

model = deepspeed.init_inference(model, **ds_inference_kwargs)
model = model.module
if model.config.model_type == "llama":
patch_scoped_linear_all_reduce(model)
return model


Expand Down
2 changes: 2 additions & 0 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
GaudiLlamaAttention,
GaudiLlamaDecoderLayer,
GaudiLlamaForCausalLM,
GaudiLlamaMLP,
GaudiLlamaModel,
GaudiMistralForCausalLM,
GaudiMptForCausalLM,
Expand Down Expand Up @@ -240,6 +241,7 @@ def adapt_transformers_to_gaudi():
transformers.models.llama.modeling_llama.LlamaForCausalLM = GaudiLlamaForCausalLM
transformers.models.llama.modeling_llama.LlamaModel = GaudiLlamaModel
transformers.models.llama.modeling_llama.LlamaAttention = GaudiLlamaAttention
transformers.models.llama.modeling_llama.LlamaMLP = GaudiLlamaMLP
transformers.models.llama.modeling_llama.LlamaDecoderLayer = GaudiLlamaDecoderLayer

transformers.models.llama.modeling_llama.LlamaRMSNorm.forward = gaudi_llama_rmsnorm_forward
Expand Down
1 change: 1 addition & 0 deletions optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
GaudiLlamaAttention,
GaudiLlamaDecoderLayer,
GaudiLlamaForCausalLM,
GaudiLlamaMLP,
GaudiLlamaModel,
gaudi_llama_rmsnorm_forward,
)
Expand Down
1 change: 1 addition & 0 deletions optimum/habana/transformers/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
GaudiLlamaAttention,
GaudiLlamaDecoderLayer,
GaudiLlamaForCausalLM,
GaudiLlamaMLP,
GaudiLlamaModel,
gaudi_llama_rmsnorm_forward,
)
151 changes: 119 additions & 32 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@
LlamaAttention,
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaMLP,
LlamaModel,
apply_rotary_pos_emb,
logger,
)

from ..modeling_all_models import ScopedLinearAllReduce


try:
from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE
Expand Down Expand Up @@ -77,10 +80,20 @@ def gaudi_llama_repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tens
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class Matmul(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return torch.matmul(x, y)


class GaudiLlamaAttention(LlamaAttention):
def __init__(self, config: LlamaConfig):
super().__init__(config)

self.matmul_qk = Matmul()
self.matmul_av = Matmul()
self.past_key = None
self.past_value = None
self.inp_seq_len = -1
Expand Down Expand Up @@ -126,7 +139,7 @@ def reorder_kv_cache(self, beam_idx: torch.LongTensor):
self.reorder(self.past_value, beam_idx, seq_length, head_dim)
return (self.past_key.shape, self.past_value.shape)

def forward(
def pre_attn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
Expand All @@ -137,7 +150,7 @@ def forward(
token_idx: Optional[torch.Tensor] = None,
attn_softmax_bf16: Optional[bool] = False,
reuse_cache: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
):
"""
Copied from LlamaAttention.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
The only differences are:
Expand Down Expand Up @@ -209,7 +222,7 @@ def forward(
key_states = gaudi_llama_repeat_kv(key_states, self.num_key_value_groups)
value_states = gaudi_llama_repeat_kv(value_states, self.num_key_value_groups)

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.norm_factor
attn_weights = self.matmul_qk(query_states, key_states.transpose(2, 3)) * self.norm_factor

if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
Expand All @@ -232,7 +245,7 @@ def forward(
query_states.dtype
)

attn_output = torch.matmul(attn_weights, value_states)
attn_output = self.matmul_av(attn_weights, value_states)

if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
Expand All @@ -244,18 +257,57 @@ def forward(

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
else:
attn_output = self.o_proj(attn_output)
attn_output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value

def attention_all_reduce(self, attn_output):
if self.o_proj.__class__ is ScopedLinearAllReduce:
self.o_proj.all_reduce(attn_output)

def post_attn_forward(self, attn_output):
if self.o_proj.__class__ is ScopedLinearAllReduce:
self.o_proj.post_all_reduce(attn_output)
return attn_output


class GaudiLlamaMLP(LlamaMLP):
def pre_mlp_forward(self, x):
if self.config.pretraining_tp > 1:
slice = self.intermediate_size // self.config.pretraining_tp
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
down_proj_slices = self.down_proj.weight.split(slice, dim=1)

gate_proj = torch.cat(
[F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
)
up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)

intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
down_proj = [
F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
]
output = sum(down_proj)
else:
input = self.act_fn(self.gate_proj(x)) * self.up_proj(x)
output = self.down_proj(input)
return output

def mlp_all_reduce(self, x):
if self.down_proj.__class__ is ScopedLinearAllReduce:
self.down_proj.all_reduce(x)

def post_mlp_forward(self, x):
if self.config.pretraining_tp > 1:
return x
if self.down_proj.__class__ is ScopedLinearAllReduce:
return self.down_proj.post_all_reduce(x)
return x


class GaudiLlamaDecoderLayer(LlamaDecoderLayer):
def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8):
Expand Down Expand Up @@ -287,30 +339,23 @@ def forward(
- add new args reuse_cache
"""
residual = hidden_states

hidden_states = self.input_layernorm(hidden_states)

# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
token_idx=token_idx,
attn_softmax_bf16=attn_softmax_bf16,
reuse_cache=reuse_cache,
output_pre_attn, self_attn_weights, present_key_value = self.pre_attn(
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
token_idx,
attn_softmax_bf16,
reuse_cache,
)
hidden_states = residual + hidden_states
self.self_attn.attention_all_reduce(output_pre_attn)
output_post_attn_pre_mlp, residual_mlp = self.post_attn_pre_mlp(output_pre_attn, residual)
self.mlp.mlp_all_reduce(output_post_attn_pre_mlp)
output_post_mlp = self.post_mlp(output_post_attn_pre_mlp, residual_mlp)

# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states

outputs = (hidden_states,)
outputs = (output_post_mlp,)

if output_attentions:
outputs += (self_attn_weights,)
Expand All @@ -319,6 +364,48 @@ def forward(

return outputs

def pre_attn(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
token_idx: Optional[torch.Tensor] = None,
attn_softmax_bf16: Optional[bool] = False,
reuse_cache: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
hidden_states = self.input_layernorm(hidden_states)
output_attn, attn_weights, present_key_value = self.self_attn.pre_attn_forward(
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
token_idx,
attn_softmax_bf16,
reuse_cache,
)
return output_attn, attn_weights, present_key_value

def post_attn_pre_mlp(self, input, residual):
output_post_attn = self.self_attn.post_attn_forward(input)

hidden_states = residual + output_post_attn
residual = hidden_states

hidden_states = self.post_attention_layernorm(hidden_states)

hidden_states = self.mlp.pre_mlp_forward(hidden_states)
return hidden_states, residual

def post_mlp(self, input, residual):
output_post_mlp = self.mlp.post_mlp_forward(input)
output = output_post_mlp + residual
return output


class GaudiLlamaModel(LlamaModel):
def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8):
Expand Down
22 changes: 22 additions & 0 deletions optimum/habana/transformers/models/modeling_all_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,25 @@ def gaudi_conv1d_forward(self, x):
bias = self.bias.view(bias_shape)
x = x + bias
return x


# Splitting DeepSpeed LinearAllReduce to three parts to avoid redundant memory consumption
class ScopedLinearAllReduce(torch.nn.Module):
def __init__(self, mod, *args, **kwargs):
self.__dict__.update(mod.__dict__)

def forward(self, input):
# pre_all_reduce

output = torch.matmul(input, self.weight.transpose(-1, -2))
return output

def all_reduce(self, input):
if self.mp_group is not None:
from deepspeed import comm as dist

dist.inference_all_reduce(input, group=self.mp_group)

def post_all_reduce(self, input):
output = input + self.bias if (self.bias is not None) else input
return output

0 comments on commit afea217

Please sign in to comment.