Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove workarounds to have causal_mask in uint8 for GPT2, GPT-J and CodeGen #592

Merged
merged 2 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,8 @@ def adapt_transformers_to_gaudi():
transformers.models.bart.modeling_bart.BartForConditionalGeneration.prepare_inputs_for_generation = (
gaudi_BartForConditionalGeneration_prepare_inputs_for_generation
)

# Optimization for codegen generation on Gaudi
# The bias in the CodeGenAttention layer is a Boolean
# Since HCCL cannot handle this dtype, we revert it back to uint8
transformers.models.codegen.modeling_codegen.CodeGenAttention = GaudiCodeGenAttention
transformers.models.codegen.modeling_codegen.CodeGenForCausalLM = GaudiCodeGenForCausalLM
transformers.models.codegen.modeling_codegen.CodeGenModel.forward = gaudi_codegen_model_forward
Expand All @@ -194,8 +193,7 @@ def adapt_transformers_to_gaudi():
# AlbertModel.forward does not rely on get_extended_attention_mask so it also needs to be replaced
transformers.models.albert.modeling_albert.AlbertModel.forward = gaudi_albert_forward

# From Transformers 4.27, the bias in the GPT2Attention layer is a Boolean
# Since HCCL cannot handle this dtype, we revert it back to uint8 (same behaviour as Transformers <= 4.26)
# Optimization for GPT2 on Gaudi
transformers.models.gpt2.modeling_gpt2.GPT2Attention = GaudiGPT2Attention
transformers.models.gpt2.modeling_gpt2.GPT2Model.forward = gaudi_gpt2_forward
transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel = GaudiGPT2LMHeadModel
Expand All @@ -216,8 +214,6 @@ def adapt_transformers_to_gaudi():
transformers.models.opt.modeling_opt.OPTLearnedPositionalEmbedding = GaudiOPTLearnedPositionalEmbedding

# Optimization for GPTJ on Gaudi
# The bias in the GPTJAttention layer is a Boolean
# Since HCCL cannot handle this dtype, we revert it back to uint8 (same behaviour as Transformers <= 4.26)
transformers.models.gptj.modeling_gptj.GPTJAttention = GaudiGPTJAttention
transformers.models.gptj.modeling_gptj.GPTJForCausalLM = GaudiGPTJForCausalLM
transformers.models.gptj.modeling_gptj.GPTJBlock.forward = gaudi_gptj_block_forward
Expand Down
95 changes: 2 additions & 93 deletions optimum/habana/transformers/models/codegen/modeling_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,108 +2,17 @@

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.codegen.modeling_codegen import (
CodeGenAttention,
CodeGenForCausalLM,
apply_rotary_pos_emb,
create_sinusoidal_positions,
logger,
)


class GaudiCodeGenAttention(nn.Module):
def __init__(self, config):
super().__init__()

max_positions = config.max_position_embeddings
self.register_buffer(
"causal_mask",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
1, 1, max_positions, max_positions
),
persistent=False,
)

self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)

self.embed_dim = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_attention_heads
if self.head_dim * self.num_attention_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and"
f" `num_attention_heads`: {self.num_attention_heads})."
)
self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())
self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3, bias=False)

self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.rotary_dim = config.rotary_dim
pos_embd_dim = self.rotary_dim or self.embed_dim
self.embed_positions = create_sinusoidal_positions(max_positions, pos_embd_dim)

def _split_heads(self, x, n_head, dim_head, mp_num):
reshaped = x.reshape(x.shape[:-1] + (n_head // mp_num, dim_head))
reshaped = reshaped.reshape(x.shape[:-2] + (-1,) + reshaped.shape[-1:])
return reshaped

def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into n_ctx
"""
if len(tensor.shape) == 5:
tensor = tensor.permute(0, 1, 3, 2, 4).contiguous()
elif len(tensor.shape) == 4:
tensor = tensor.permute(0, 2, 1, 3).contiguous()
else:
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
return tensor.view(new_shape)

def _attn(
self,
query,
key,
value,
attention_mask=None,
head_mask=None,
):
# compute causal mask from causal mask buffer
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length].bool()

# Keep the attention weights computation in fp32 to avoid overflow issues
query = query.to(torch.float32)
key = key.to(torch.float32)

attn_weights = torch.matmul(query, key.transpose(-1, -2))

attn_weights = attn_weights / self.scale_attn
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)

if attention_mask is not None:
# Apply the attention mask
attn_weights = attn_weights + attention_mask

attn_weights = nn.Softmax(dim=-1)(attn_weights)
attn_weights = attn_weights.to(value.dtype)
attn_weights = self.attn_dropout(attn_weights)

# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask

attn_output = torch.matmul(attn_weights, value)

return attn_output, attn_weights

class GaudiCodeGenAttention(CodeGenAttention):
def forward(
self,
hidden_states: Optional[torch.FloatTensor],
Expand Down
85 changes: 4 additions & 81 deletions optimum/habana/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,77 +3,16 @@
import torch
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel, logger
from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2LMHeadModel, logger


class GaudiGPT2Attention(torch.nn.Module):
class GaudiGPT2Attention(GPT2Attention):
"""
Copied from GPT2Attention: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
The only differences are:
- `self.bias` is a torch.uint8 and not a torch.bool
- it is casted to bool before being used in torch.where
- optimize KV cache
"""

def __init__(self, config, is_cross_attention=False, layer_idx=None):
super().__init__()

max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
1, 1, max_positions, max_positions
),
persistent=False,
)
self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)

self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.split_size = self.embed_dim
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)

self.scale_attn_weights = config.scale_attn_weights
self.is_cross_attention = is_cross_attention

# Layer-wise attention scaling, reordering, and upcasting
self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
self.layer_idx = layer_idx
self.reorder_and_upcast_attn = config.reorder_and_upcast_attn

if self.is_cross_attention:
self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
else:
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
self.c_proj = Conv1D(self.embed_dim, self.embed_dim)

self.attn_dropout = torch.nn.Dropout(config.attn_pdrop)
self.resid_dropout = torch.nn.Dropout(config.resid_pdrop)

self.pruned_heads = set()

def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])

# Prune conv1d layers
self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)

# Update hyper params
self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
self.num_heads = self.num_heads - len(heads)
self.pruned_heads = self.pruned_heads.union(heads)

def _attn(self, query, key, value, attention_mask=None, head_mask=None):
key = key.contiguous()
value = value.contiguous()
Expand All @@ -91,7 +30,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
Expand Down Expand Up @@ -141,7 +80,7 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea
if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
Expand All @@ -168,22 +107,6 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea

return attn_output, attn_weights

def _split_heads(self, tensor, num_heads, attn_head_size):
"""
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(new_shape)
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)

def _merge_heads(self, tensor, num_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden_size
"""
tensor = tensor.permute(0, 2, 1, 3).contiguous()
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
return tensor.view(new_shape)

def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
Expand Down
64 changes: 3 additions & 61 deletions optimum/habana/transformers/models/gptj/modeling_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,73 +5,15 @@
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.gptj.modeling_gptj import (
GPTJAttention,
GPTJForCausalLM,
apply_rotary_pos_emb,
create_sinusoidal_positions,
logger,
)


class GaudiGPTJAttention(nn.Module):
def __init__(self, config):
super().__init__()

max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
1, 1, max_positions, max_positions
),
)
self.register_buffer("masked_bias", torch.tensor(-1e9))

self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)

self.embed_dim = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_attention_heads
if self.head_dim * self.num_attention_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and"
f" `num_attention_heads`: {self.num_attention_heads})."
)
self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())

self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.rotary_dim = config.rotary_dim

def _split_heads(self, tensor, num_attention_heads, attn_head_size, rotary):
"""
Splits hidden dim into attn_head_size and num_attention_heads
"""
new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
tensor = tensor.view(new_shape)
if rotary:
return tensor
if len(tensor.shape) == 5:
return tensor.permute(0, 1, 3, 2, 4) # (batch, blocks, head, block_length, head_features)
elif len(tensor.shape) == 4:
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
else:
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")

def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden dim
"""
if len(tensor.shape) == 5:
tensor = tensor.permute(0, 1, 3, 2, 4).contiguous()
elif len(tensor.shape) == 4:
tensor = tensor.permute(0, 2, 1, 3).contiguous()
else:
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
return tensor.view(new_shape)

class GaudiGPTJAttention(GPTJAttention):
def _attn(
self,
query,
Expand All @@ -82,7 +24,7 @@ def _attn(
):
# compute causal mask from causal mask buffer
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]

query = query.contiguous()
key = key.contiguous()
Expand Down