Skip to content

Commit

Permalink
Gemma (unslothai#843)
Browse files Browse the repository at this point in the history
* bugs

* Update _utils.py

* flash-attn softcapping

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update gemma2.py

* Update mapper.py

* Update README.md

* Update _utils.py
  • Loading branch information
danielhanchen authored Jul 31, 2024
1 parent 4285d1b commit b85670d
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 46 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ All notebooks are **beginner friendly**! Add your dataset, click "Run All", and
- Click [here](https://github.com/unslothai/unsloth/wiki) for detailed documentation for Unsloth.

## 🦥 Unsloth.ai News
- 📣 NEW! [Gemma-2-2b](https://colab.research.google.com/drive/1weTpKOjBZxZJ5PQ-Ql8i6ptAY2x-FWVA?usp=sharing) now supported! Gemma-2-9b and Gemma-2-27b are alrady supported!
- 📣 NEW! [Llama 3.1 8b, 70b](https://colab.research.google.com/drive/1Ys44kVvmeZtnICzWz0xgpRnrIOjZAuxp?usp=sharing) both Base and Instruct now supported
- 📣 NEW! [Mistral Nemo-12b](https://colab.research.google.com/drive/17d3U-CAIwzmbDRqbZ9NnpHxCkmXB6LZ0?usp=sharing) both Base and Instruct now supported
- 📣 NEW! [Gemma-2-9b](https://colab.research.google.com/drive/1vIrqH5uYDQwsJ4-OO3DErvuv4pBgVwk4?usp=sharing) and Gemma-2-27b now supported
Expand Down
22 changes: 11 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ colab-ampere-torch211 = [
"unsloth[cu121onlytorch211]",
"packaging",
"ninja",
"flash-attn",
"flash-attn>=2.6.3",
]
colab-torch220 = [
"unsloth[huggingface]",
Expand All @@ -184,7 +184,7 @@ colab-ampere-torch220 = [
"unsloth[cu121onlytorch220]",
"packaging",
"ninja",
"flash-attn",
"flash-attn>=2.6.3",
]
colab-new = [
"packaging",
Expand Down Expand Up @@ -215,71 +215,71 @@ colab-ampere = [
"unsloth[colab-ampere-torch220]",
"packaging",
"ninja",
"flash-attn",
"flash-attn>=2.6.3",
]
cu118-ampere = [
"unsloth[huggingface]",
"bitsandbytes",
"unsloth[cu118only]",
"packaging",
"ninja",
"flash-attn",
"flash-attn>=2.6.3",
]
cu121-ampere = [
"unsloth[huggingface]",
"bitsandbytes",
"unsloth[cu121only]",
"packaging",
"ninja",
"flash-attn",
"flash-attn>=2.6.3",
]
cu118-ampere-torch211 = [
"unsloth[huggingface]",
"bitsandbytes",
"unsloth[cu118onlytorch211]",
"packaging",
"ninja",
"flash-attn",
"flash-attn>=2.6.3",
]
cu121-ampere-torch211 = [
"unsloth[huggingface]",
"bitsandbytes",
"unsloth[cu121onlytorch211]",
"packaging",
"ninja",
"flash-attn",
"flash-attn>=2.6.3",
]
cu118-ampere-torch220 = [
"unsloth[huggingface]",
"bitsandbytes",
"unsloth[cu118onlytorch220]",
"packaging",
"ninja",
"flash-attn",
"flash-attn>=2.6.3",
]
cu121-ampere-torch220 = [
"unsloth[huggingface]",
"bitsandbytes",
"unsloth[cu121onlytorch220]",
"packaging",
"ninja",
"flash-attn",
"flash-attn>=2.6.3",
]
cu118-ampere-torch230 = [
"unsloth[huggingface]",
"bitsandbytes",
"unsloth[cu118onlytorch230]",
"packaging",
"ninja",
"flash-attn",
"flash-attn>=2.6.3",
]
cu121-ampere-torch230 = [
"unsloth[huggingface]",
"bitsandbytes",
"unsloth[cu121onlytorch230]",
"packaging",
"ninja",
"flash-attn",
"flash-attn>=2.6.3",
]

[project.urls]
Expand Down
14 changes: 14 additions & 0 deletions unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"xformers_version",
"__version__",
"HAS_FLASH_ATTENTION",
"HAS_FLASH_ATTENTION_SOFTCAPPING",
"PRE_CHECK",
"platform_system",
"patch_tokenizer",
Expand Down Expand Up @@ -140,6 +141,8 @@ def patch_mistral_nemo_config(config):

major_version, minor_version = torch.cuda.get_device_capability()
SUPPORTS_BFLOAT16 = False
HAS_FLASH_ATTENTION = False
HAS_FLASH_ATTENTION_SOFTCAPPING = False

if major_version >= 8:
SUPPORTS_BFLOAT16 = True
Expand All @@ -148,6 +151,17 @@ def patch_mistral_nemo_config(config):
try:
from flash_attn.flash_attn_interface import flash_attn_cuda
HAS_FLASH_ATTENTION = True

# Also check for softcapping
from flash_attn import __version__ as flash_attn_version
HAS_FLASH_ATTENTION_SOFTCAPPING = Version(flash_attn_version) >= Version("2.6.3")
if not HAS_FLASH_ATTENTION_SOFTCAPPING:
print(
"Unsloth: If you want to finetune Gemma 2, upgrade flash-attn to version 2.6.3 or higher!\n"\
"Newer versions support faster and less memory usage kernels for Gemma 2's attention softcapping!\n"\
"To update flash-attn, do the below:\n"\
'\npip install --no-deps --upgrade "flash-attn>=2.6.3"'
)
except:
print(
"Unsloth: Your Flash Attention 2 installation seems to be broken?\n"\
Expand Down
73 changes: 55 additions & 18 deletions unsloth/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@
Gemma2FlashAttention2 = Gemma2Attention
pass

if HAS_FLASH_ATTENTION_SOFTCAPPING:
from flash_attn import flash_attn_func

# [TODO] We must randomnly use torch.compile?
# I checked the gradients and formulas and I'm sure it's correct.
Expand Down Expand Up @@ -126,8 +128,36 @@ def Gemma2Attention_fast_forward(
V = torch.cat([past_key_value[1], V], dim = 2)
pass
past_key_value = (K, V) if use_cache else None

A = slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, kv_seq_len)

# Only enable if the attention_mask is True
has_sliding_window = type(causal_mask) is bool and causal_mask is True
if HAS_FLASH_ATTENTION_SOFTCAPPING and attention_mask is None:
window = (-1, -1)
if has_sliding_window:
sw = getattr(self.config, "sliding_window", None)
sw = kv_seq_len if (sw is None or sw == "null") else sw
window = (-1, -1) if (kv_seq_len <= sw) else (sw, sw)
pass

# FA uses 1 / sqrt for softmax_scale!
if not hasattr(self, "_flash_attention_softmax_scale"):
self._flash_attention_softmax_scale = 1.0 / (self.config.query_pre_attn_scalar**0.5)
pass

Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
A = flash_attn_func(
Q, K, V,
causal = True,
softcap = self.config.attn_logit_softcapping,
softmax_scale = self._flash_attention_softmax_scale,
window_size = window,
)
A = A.reshape(bsz, q_len, n_heads*head_dim)
else:
A = slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, kv_seq_len)
pass
A = self.apply_o(self, A)
return A, None, past_key_value
pass
Expand Down Expand Up @@ -205,6 +235,8 @@ def Gemma2DecoderLayer_fast_forward(
from math import sqrt as math_sqrt
KV_CACHE_INCREMENT = 256 # KV Cache update size
torch_nn_functional_softmax = torch.nn.functional.softmax
torch_matmul = torch.matmul
torch_tanh = torch.tanh

def Gemma2Attention_fast_forward_inference(
self,
Expand Down Expand Up @@ -322,13 +354,13 @@ def Gemma2Attention_fast_forward_inference(
# if bsz == 1:
Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
# It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows
A = torch.matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len])
A = torch_matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len])
# if attention_mask is not None: A += attention_mask # Must add attention_mask for batched

A *= self.reciprocal_t; torch.tanh(A, out = A); A *= self.t; # Logit softcapping
A *= self.reciprocal_t; torch_tanh(A, out = A); A *= self.t; # Logit softcapping

A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype)
A = torch.matmul(A, Vnn, out = Qn)
A = torch_matmul(A, Vnn, out = Qn)
# else:
# A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False)
# pass
Expand Down Expand Up @@ -359,19 +391,24 @@ def Gemma2Model_fast_forward_inference(
bsz, q_len, hd = hidden_states.shape
seq_len = past_key_values[0][0].shape[-2]
if bsz != 1:
SWA = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(bsz, q_len),
hidden_states,
seq_len,
sliding_window = self.config.sliding_window,
)
GA = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(bsz, q_len),
hidden_states,
seq_len,
)
if HAS_FLASH_ATTENTION_SOFTCAPPING:
SWA = True
GA = False
else:
SWA = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(bsz, q_len),
hidden_states,
seq_len,
sliding_window = self.config.sliding_window,
)
GA = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(bsz, q_len),
hidden_states,
seq_len,
)
pass
else:
SWA = attention_mask
GA = attention_mask
Expand Down
39 changes: 22 additions & 17 deletions unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,23 +682,28 @@ def LlamaModel_fast_forward(

# Gemma2 has alternating SWA and global attn
if IS_GEMMA2 and not hasattr(self, "SWA_mask"):
n = self.config.max_position_embeddings
# masked_fill is making stuff slower!
# self. GA_mask = create_boolean_mask(n = n, sliding_window = 0)
# self.SWA_mask = create_boolean_mask(n = n, sliding_window = self.config.sliding_window)
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
self.SWA_mask = AttentionMaskConverter(
is_causal = True,
sliding_window = self.config.sliding_window,
)\
.to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda:0",)\
.squeeze(0).squeeze(0)

self.GA_mask = AttentionMaskConverter(
is_causal = True,
)\
.to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda:0",)\
.squeeze(0).squeeze(0)
if HAS_FLASH_ATTENTION_SOFTCAPPING:
self.SWA_mask = True
self.GA_mask = False
else:
n = self.config.max_position_embeddings
# masked_fill is making stuff slower!
# self. GA_mask = create_boolean_mask(n = n, sliding_window = 0)
# self.SWA_mask = create_boolean_mask(n = n, sliding_window = self.config.sliding_window)
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
self.SWA_mask = AttentionMaskConverter(
is_causal = True,
sliding_window = self.config.sliding_window,
)\
.to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda:0",)\
.squeeze(0).squeeze(0)

self.GA_mask = AttentionMaskConverter(
is_causal = True,
)\
.to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda:0",)\
.squeeze(0).squeeze(0)
pass
pass

# Go through every layer!
Expand Down
17 changes: 17 additions & 0 deletions unsloth/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from ._utils import is_bfloat16_supported, HAS_FLASH_ATTENTION, HAS_FLASH_ATTENTION_SOFTCAPPING
from .llama import FastLlamaModel, logger
from .mistral import FastMistralModel
from .qwen2 import FastQwen2Model
Expand Down Expand Up @@ -42,6 +43,7 @@ def __get_model_name(
FLOAT_TO_INT_MAPPER = None,
):

model_name = str(model_name)
if not SUPPORTS_FOURBIT and model_name.lower() in INT_TO_FLOAT_MAPPER:
model_name = INT_TO_FLOAT_MAPPER[model_name.lower()]
logger.warning_once(
Expand Down Expand Up @@ -232,6 +234,21 @@ def from_pretrained(
f'Try `pip install --upgrade "transformers>=4.42.3"`\n'\
f"to obtain the latest transformers build, then restart this session."\
)
# Also check for softcapping support in flash-attn which is faster!
if is_bfloat16_supported() and not HAS_FLASH_ATTENTION:
print(
"Unsloth: If you want to finetune Gemma 2, install flash-attn to make it faster!\n"\
"To install flash-attn, do the below:\n"\
'\npip install --no-deps --upgrade "flash-attn>=2.6.3"'
)
elif HAS_FLASH_ATTENTION and not HAS_FLASH_ATTENTION_SOFTCAPPING:
print(
"Unsloth: If you want to finetune Gemma 2, upgrade flash-attn to version 2.6.3 or higher!\n"\
"Newer versions support faster and less memory usage kernels for Gemma 2's attention softcapping!\n"\
"To update flash-attn, do the below:\n"\
'\npip install --no-deps --upgrade "flash-attn>=2.6.3"'
)

dispatch_model = FastGemma2Model
elif model_type == "qwen2":
dispatch_model = FastQwen2Model
Expand Down
8 changes: 8 additions & 0 deletions unsloth/models/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,14 @@
"unsloth/Mistral-Large-Instruct-2407-bnb-4bit" : (
"mistralai/Mistral-Large-Instruct-2407",
),
"unsloth/gemma-2-2b-bnb-4bit" : (
"unsloth/gemma-2-2b",
"google/gemma-2-2b",
),
"unsloth/gemma-2-2b-it-bnb-4bit" : (
"unsloth/gemma-2-2b-it",
"google/gemma-2-2b-it",
),
}

INT_TO_FLOAT_MAPPER = {}
Expand Down

0 comments on commit b85670d

Please sign in to comment.