Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring LLama Attention and mlp layers #589

Merged
merged 3 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,19 @@ def setup_device(args):
return torch.device(args.device)


# patching LinearAllreduce to use ScopedLinearAllReduce
def patch_scoped_linear_all_reduce(model):
from deepspeed.module_inject.layers import LinearAllreduce

from optimum.habana.transformers.models.modeling_all_models import ScopedLinearAllReduce

for name, module in model.named_children():
if type(module) is LinearAllreduce:
SL = ScopedLinearAllReduce(mod=module)
setattr(model, name, SL)
patch_scoped_linear_all_reduce(module)


def setup_model(args, model_dtype, model_kwargs, logger):
logger.info("Single-device run.")

Expand Down Expand Up @@ -194,6 +207,8 @@ def setup_distributed_model(args, model_dtype, model_kwargs, logger):

model = deepspeed.init_inference(model, **ds_inference_kwargs)
model = model.module
if model.config.model_type == "llama":
patch_scoped_linear_all_reduce(model)
return model


Expand Down
2 changes: 2 additions & 0 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
GaudiLlamaAttention,
GaudiLlamaDecoderLayer,
GaudiLlamaForCausalLM,
GaudiLlamaMLP,
GaudiLlamaModel,
GaudiMistralForCausalLM,
GaudiMptForCausalLM,
Expand Down Expand Up @@ -240,6 +241,7 @@ def adapt_transformers_to_gaudi():
transformers.models.llama.modeling_llama.LlamaForCausalLM = GaudiLlamaForCausalLM
transformers.models.llama.modeling_llama.LlamaModel = GaudiLlamaModel
transformers.models.llama.modeling_llama.LlamaAttention = GaudiLlamaAttention
transformers.models.llama.modeling_llama.LlamaMLP = GaudiLlamaMLP
transformers.models.llama.modeling_llama.LlamaDecoderLayer = GaudiLlamaDecoderLayer

transformers.models.llama.modeling_llama.LlamaRMSNorm.forward = gaudi_llama_rmsnorm_forward
Expand Down
1 change: 1 addition & 0 deletions optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
GaudiLlamaAttention,
GaudiLlamaDecoderLayer,
GaudiLlamaForCausalLM,
GaudiLlamaMLP,
GaudiLlamaModel,
gaudi_llama_rmsnorm_forward,
)
Expand Down
1 change: 1 addition & 0 deletions optimum/habana/transformers/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
GaudiLlamaAttention,
GaudiLlamaDecoderLayer,
GaudiLlamaForCausalLM,
GaudiLlamaMLP,
GaudiLlamaModel,
gaudi_llama_rmsnorm_forward,
)
160 changes: 120 additions & 40 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,21 @@

import torch
import torch.nn.functional as F
from torch import nn
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaMLP,
LlamaModel,
apply_rotary_pos_emb,
logger,
)

from ..modeling_all_models import ScopedLinearAllReduce


try:
from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE
Expand Down Expand Up @@ -77,10 +81,20 @@ def gaudi_llama_repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tens
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class Matmul(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return torch.matmul(x, y)


class GaudiLlamaAttention(LlamaAttention):
def __init__(self, config: LlamaConfig):
super().__init__(config)

self.matmul_qk = Matmul()
self.matmul_av = Matmul()
self.past_key = None
self.past_value = None
self.inp_seq_len = -1
Expand Down Expand Up @@ -126,7 +140,7 @@ def reorder_kv_cache(self, beam_idx: torch.LongTensor):
self.reorder(self.past_value, beam_idx, seq_length, head_dim)
return (self.past_key.shape, self.past_value.shape)

def forward(
def pre_attn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
Expand All @@ -137,15 +151,7 @@ def forward(
token_idx: Optional[torch.Tensor] = None,
attn_softmax_bf16: Optional[bool] = False,
reuse_cache: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Copied from LlamaAttention.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
The only differences are:
- add new args token_idx
- optimize KV cache
- add new args attn_softmax_bf16
- add new args reuse_cache
"""
):
bsz, q_len, _ = hidden_states.size()

if self.config.pretraining_tp > 1:
Expand Down Expand Up @@ -209,7 +215,7 @@ def forward(
key_states = gaudi_llama_repeat_kv(key_states, self.num_key_value_groups)
value_states = gaudi_llama_repeat_kv(value_states, self.num_key_value_groups)

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.norm_factor
attn_weights = self.matmul_qk(query_states, key_states.transpose(2, 3)) * self.norm_factor

if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
Expand All @@ -232,7 +238,7 @@ def forward(
query_states.dtype
)

attn_output = torch.matmul(attn_weights, value_states)
attn_output = self.matmul_av(attn_weights, value_states)

if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
Expand All @@ -244,18 +250,57 @@ def forward(

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
else:
attn_output = self.o_proj(attn_output)
attn_output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value

def attention_all_reduce(self, attn_output):
if self.o_proj.__class__ is ScopedLinearAllReduce:
self.o_proj.all_reduce(attn_output)

def post_attn_forward(self, attn_output):
if self.o_proj.__class__ is ScopedLinearAllReduce:
self.o_proj.post_all_reduce(attn_output)
return attn_output


class GaudiLlamaMLP(LlamaMLP):
def pre_mlp_forward(self, x):
if self.config.pretraining_tp > 1:
slice = self.intermediate_size // self.config.pretraining_tp
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
down_proj_slices = self.down_proj.weight.split(slice, dim=1)

gate_proj = torch.cat(
[F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
)
up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)

intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
down_proj = [
F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
]
output = sum(down_proj)
else:
input = self.act_fn(self.gate_proj(x)) * self.up_proj(x)
output = self.down_proj(input)
return output

def mlp_all_reduce(self, x):
if self.down_proj.__class__ is ScopedLinearAllReduce:
self.down_proj.all_reduce(x)

def post_mlp_forward(self, x):
if self.config.pretraining_tp > 1:
return x
if self.down_proj.__class__ is ScopedLinearAllReduce:
return self.down_proj.post_all_reduce(x)
return x


class GaudiLlamaDecoderLayer(LlamaDecoderLayer):
def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8):
Expand Down Expand Up @@ -287,30 +332,23 @@ def forward(
- add new args reuse_cache
"""
residual = hidden_states

hidden_states = self.input_layernorm(hidden_states)

# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
token_idx=token_idx,
attn_softmax_bf16=attn_softmax_bf16,
reuse_cache=reuse_cache,
output_pre_attn, self_attn_weights, present_key_value = self.pre_attn(
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
token_idx,
attn_softmax_bf16,
reuse_cache,
)
hidden_states = residual + hidden_states
self.self_attn.attention_all_reduce(output_pre_attn)
output_post_attn_pre_mlp, residual_mlp = self.post_attn_pre_mlp(output_pre_attn, residual)
self.mlp.mlp_all_reduce(output_post_attn_pre_mlp)
output_post_mlp = self.post_mlp(output_post_attn_pre_mlp, residual_mlp)

# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states

outputs = (hidden_states,)
outputs = (output_post_mlp,)

if output_attentions:
outputs += (self_attn_weights,)
Expand All @@ -319,6 +357,48 @@ def forward(

return outputs

def pre_attn(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
token_idx: Optional[torch.Tensor] = None,
attn_softmax_bf16: Optional[bool] = False,
reuse_cache: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
hidden_states = self.input_layernorm(hidden_states)
output_attn, attn_weights, present_key_value = self.self_attn.pre_attn_forward(
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
token_idx,
attn_softmax_bf16,
reuse_cache,
)
return output_attn, attn_weights, present_key_value

def post_attn_pre_mlp(self, input, residual):
output_post_attn = self.self_attn.post_attn_forward(input)

hidden_states = residual + output_post_attn
residual = hidden_states

hidden_states = self.post_attention_layernorm(hidden_states)

hidden_states = self.mlp.pre_mlp_forward(hidden_states)
return hidden_states, residual

def post_mlp(self, input, residual):
output_post_mlp = self.mlp.post_mlp_forward(input)
output = output_post_mlp + residual
return output


class GaudiLlamaModel(LlamaModel):
def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8):
Expand Down
22 changes: 22 additions & 0 deletions optimum/habana/transformers/models/modeling_all_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,25 @@ def gaudi_conv1d_forward(self, x):
bias = self.bias.view(bias_shape)
x = x + bias
return x


# Splitting Deeppspeed LinearAllReduce to three parts to avoid redundant memory consumption
class ScopedLinearAllReduce(torch.nn.Module):
def __init__(self, mod, *args, **kwargs):
self.__dict__.update(mod.__dict__)

def forward(self, input):
# pre_all_reduce

output = torch.matmul(input, self.weight.transpose(-1, -2))
return output

def all_reduce(self, input):
if self.mp_group is not None:
from deepspeed import comm as dist

dist.inference_all_reduce(input, group=self.mp_group)

def post_all_reduce(self, input):
output = input + self.bias if (self.bias is not None) else input
return output