diff --git a/README.md b/README.md index 4c1271396..d843158d2 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,7 @@ All notebooks are **beginner friendly**! Add your dataset, click "Run All", and - Click [here](https://github.com/unslothai/unsloth/wiki) for detailed documentation for Unsloth. ## 🦥 Unsloth.ai News +- 📣 NEW! [Gemma-2-2b](https://colab.research.google.com/drive/1weTpKOjBZxZJ5PQ-Ql8i6ptAY2x-FWVA?usp=sharing) now supported! Gemma-2-9b and Gemma-2-27b are alrady supported! - 📣 NEW! [Llama 3.1 8b, 70b](https://colab.research.google.com/drive/1Ys44kVvmeZtnICzWz0xgpRnrIOjZAuxp?usp=sharing) both Base and Instruct now supported - 📣 NEW! [Mistral Nemo-12b](https://colab.research.google.com/drive/17d3U-CAIwzmbDRqbZ9NnpHxCkmXB6LZ0?usp=sharing) both Base and Instruct now supported - 📣 NEW! [Gemma-2-9b](https://colab.research.google.com/drive/1vIrqH5uYDQwsJ4-OO3DErvuv4pBgVwk4?usp=sharing) and Gemma-2-27b now supported diff --git a/pyproject.toml b/pyproject.toml index 6777f7c26..e711325be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -171,7 +171,7 @@ colab-ampere-torch211 = [ "unsloth[cu121onlytorch211]", "packaging", "ninja", - "flash-attn", + "flash-attn>=2.6.3", ] colab-torch220 = [ "unsloth[huggingface]", @@ -184,7 +184,7 @@ colab-ampere-torch220 = [ "unsloth[cu121onlytorch220]", "packaging", "ninja", - "flash-attn", + "flash-attn>=2.6.3", ] colab-new = [ "packaging", @@ -215,7 +215,7 @@ colab-ampere = [ "unsloth[colab-ampere-torch220]", "packaging", "ninja", - "flash-attn", + "flash-attn>=2.6.3", ] cu118-ampere = [ "unsloth[huggingface]", @@ -223,7 +223,7 @@ cu118-ampere = [ "unsloth[cu118only]", "packaging", "ninja", - "flash-attn", + "flash-attn>=2.6.3", ] cu121-ampere = [ "unsloth[huggingface]", @@ -231,7 +231,7 @@ cu121-ampere = [ "unsloth[cu121only]", "packaging", "ninja", - "flash-attn", + "flash-attn>=2.6.3", ] cu118-ampere-torch211 = [ "unsloth[huggingface]", @@ -239,7 +239,7 @@ cu118-ampere-torch211 = [ "unsloth[cu118onlytorch211]", "packaging", "ninja", - "flash-attn", + "flash-attn>=2.6.3", ] cu121-ampere-torch211 = [ "unsloth[huggingface]", @@ -247,7 +247,7 @@ cu121-ampere-torch211 = [ "unsloth[cu121onlytorch211]", "packaging", "ninja", - "flash-attn", + "flash-attn>=2.6.3", ] cu118-ampere-torch220 = [ "unsloth[huggingface]", @@ -255,7 +255,7 @@ cu118-ampere-torch220 = [ "unsloth[cu118onlytorch220]", "packaging", "ninja", - "flash-attn", + "flash-attn>=2.6.3", ] cu121-ampere-torch220 = [ "unsloth[huggingface]", @@ -263,7 +263,7 @@ cu121-ampere-torch220 = [ "unsloth[cu121onlytorch220]", "packaging", "ninja", - "flash-attn", + "flash-attn>=2.6.3", ] cu118-ampere-torch230 = [ "unsloth[huggingface]", @@ -271,7 +271,7 @@ cu118-ampere-torch230 = [ "unsloth[cu118onlytorch230]", "packaging", "ninja", - "flash-attn", + "flash-attn>=2.6.3", ] cu121-ampere-torch230 = [ "unsloth[huggingface]", @@ -279,7 +279,7 @@ cu121-ampere-torch230 = [ "unsloth[cu121onlytorch230]", "packaging", "ninja", - "flash-attn", + "flash-attn>=2.6.3", ] [project.urls] diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 994f97ab7..fe3aa9040 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -21,6 +21,7 @@ "xformers_version", "__version__", "HAS_FLASH_ATTENTION", + "HAS_FLASH_ATTENTION_SOFTCAPPING", "PRE_CHECK", "platform_system", "patch_tokenizer", @@ -140,6 +141,8 @@ def patch_mistral_nemo_config(config): major_version, minor_version = torch.cuda.get_device_capability() SUPPORTS_BFLOAT16 = False +HAS_FLASH_ATTENTION = False +HAS_FLASH_ATTENTION_SOFTCAPPING = False if major_version >= 8: SUPPORTS_BFLOAT16 = True @@ -148,6 +151,17 @@ def patch_mistral_nemo_config(config): try: from flash_attn.flash_attn_interface import flash_attn_cuda HAS_FLASH_ATTENTION = True + + # Also check for softcapping + from flash_attn import __version__ as flash_attn_version + HAS_FLASH_ATTENTION_SOFTCAPPING = Version(flash_attn_version) >= Version("2.6.3") + if not HAS_FLASH_ATTENTION_SOFTCAPPING: + print( + "Unsloth: If you want to finetune Gemma 2, upgrade flash-attn to version 2.6.3 or higher!\n"\ + "Newer versions support faster and less memory usage kernels for Gemma 2's attention softcapping!\n"\ + "To update flash-attn, do the below:\n"\ + '\npip install --no-deps --upgrade "flash-attn>=2.6.3"' + ) except: print( "Unsloth: Your Flash Attention 2 installation seems to be broken?\n"\ diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index 0d21c47b0..1cbaf5b16 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -56,6 +56,8 @@ Gemma2FlashAttention2 = Gemma2Attention pass +if HAS_FLASH_ATTENTION_SOFTCAPPING: + from flash_attn import flash_attn_func # [TODO] We must randomnly use torch.compile? # I checked the gradients and formulas and I'm sure it's correct. @@ -126,8 +128,36 @@ def Gemma2Attention_fast_forward( V = torch.cat([past_key_value[1], V], dim = 2) pass past_key_value = (K, V) if use_cache else None - - A = slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, kv_seq_len) + + # Only enable if the attention_mask is True + has_sliding_window = type(causal_mask) is bool and causal_mask is True + if HAS_FLASH_ATTENTION_SOFTCAPPING and attention_mask is None: + window = (-1, -1) + if has_sliding_window: + sw = getattr(self.config, "sliding_window", None) + sw = kv_seq_len if (sw is None or sw == "null") else sw + window = (-1, -1) if (kv_seq_len <= sw) else (sw, sw) + pass + + # FA uses 1 / sqrt for softmax_scale! + if not hasattr(self, "_flash_attention_softmax_scale"): + self._flash_attention_softmax_scale = 1.0 / (self.config.query_pre_attn_scalar**0.5) + pass + + Q = Q.transpose(1, 2) + K = K.transpose(1, 2) + V = V.transpose(1, 2) + A = flash_attn_func( + Q, K, V, + causal = True, + softcap = self.config.attn_logit_softcapping, + softmax_scale = self._flash_attention_softmax_scale, + window_size = window, + ) + A = A.reshape(bsz, q_len, n_heads*head_dim) + else: + A = slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, kv_seq_len) + pass A = self.apply_o(self, A) return A, None, past_key_value pass @@ -205,6 +235,8 @@ def Gemma2DecoderLayer_fast_forward( from math import sqrt as math_sqrt KV_CACHE_INCREMENT = 256 # KV Cache update size torch_nn_functional_softmax = torch.nn.functional.softmax +torch_matmul = torch.matmul +torch_tanh = torch.tanh def Gemma2Attention_fast_forward_inference( self, @@ -322,13 +354,13 @@ def Gemma2Attention_fast_forward_inference( # if bsz == 1: Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963 # It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows - A = torch.matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len]) + A = torch_matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len]) # if attention_mask is not None: A += attention_mask # Must add attention_mask for batched - A *= self.reciprocal_t; torch.tanh(A, out = A); A *= self.t; # Logit softcapping + A *= self.reciprocal_t; torch_tanh(A, out = A); A *= self.t; # Logit softcapping A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype) - A = torch.matmul(A, Vnn, out = Qn) + A = torch_matmul(A, Vnn, out = Qn) # else: # A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False) # pass @@ -359,19 +391,24 @@ def Gemma2Model_fast_forward_inference( bsz, q_len, hd = hidden_states.shape seq_len = past_key_values[0][0].shape[-2] if bsz != 1: - SWA = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (bsz, q_len), - hidden_states, - seq_len, - sliding_window = self.config.sliding_window, - ) - GA = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (bsz, q_len), - hidden_states, - seq_len, - ) + if HAS_FLASH_ATTENTION_SOFTCAPPING: + SWA = True + GA = False + else: + SWA = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (bsz, q_len), + hidden_states, + seq_len, + sliding_window = self.config.sliding_window, + ) + GA = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (bsz, q_len), + hidden_states, + seq_len, + ) + pass else: SWA = attention_mask GA = attention_mask diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 496a37e7a..b5244ed4e 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -682,23 +682,28 @@ def LlamaModel_fast_forward( # Gemma2 has alternating SWA and global attn if IS_GEMMA2 and not hasattr(self, "SWA_mask"): - n = self.config.max_position_embeddings - # masked_fill is making stuff slower! - # self. GA_mask = create_boolean_mask(n = n, sliding_window = 0) - # self.SWA_mask = create_boolean_mask(n = n, sliding_window = self.config.sliding_window) - from transformers.modeling_attn_mask_utils import AttentionMaskConverter - self.SWA_mask = AttentionMaskConverter( - is_causal = True, - sliding_window = self.config.sliding_window, - )\ - .to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda:0",)\ - .squeeze(0).squeeze(0) - - self.GA_mask = AttentionMaskConverter( - is_causal = True, - )\ - .to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda:0",)\ - .squeeze(0).squeeze(0) + if HAS_FLASH_ATTENTION_SOFTCAPPING: + self.SWA_mask = True + self.GA_mask = False + else: + n = self.config.max_position_embeddings + # masked_fill is making stuff slower! + # self. GA_mask = create_boolean_mask(n = n, sliding_window = 0) + # self.SWA_mask = create_boolean_mask(n = n, sliding_window = self.config.sliding_window) + from transformers.modeling_attn_mask_utils import AttentionMaskConverter + self.SWA_mask = AttentionMaskConverter( + is_causal = True, + sliding_window = self.config.sliding_window, + )\ + .to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda:0",)\ + .squeeze(0).squeeze(0) + + self.GA_mask = AttentionMaskConverter( + is_causal = True, + )\ + .to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda:0",)\ + .squeeze(0).squeeze(0) + pass pass # Go through every layer! diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index f22e81efa..47152d676 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from ._utils import is_bfloat16_supported, HAS_FLASH_ATTENTION, HAS_FLASH_ATTENTION_SOFTCAPPING from .llama import FastLlamaModel, logger from .mistral import FastMistralModel from .qwen2 import FastQwen2Model @@ -42,6 +43,7 @@ def __get_model_name( FLOAT_TO_INT_MAPPER = None, ): + model_name = str(model_name) if not SUPPORTS_FOURBIT and model_name.lower() in INT_TO_FLOAT_MAPPER: model_name = INT_TO_FLOAT_MAPPER[model_name.lower()] logger.warning_once( @@ -232,6 +234,21 @@ def from_pretrained( f'Try `pip install --upgrade "transformers>=4.42.3"`\n'\ f"to obtain the latest transformers build, then restart this session."\ ) + # Also check for softcapping support in flash-attn which is faster! + if is_bfloat16_supported() and not HAS_FLASH_ATTENTION: + print( + "Unsloth: If you want to finetune Gemma 2, install flash-attn to make it faster!\n"\ + "To install flash-attn, do the below:\n"\ + '\npip install --no-deps --upgrade "flash-attn>=2.6.3"' + ) + elif HAS_FLASH_ATTENTION and not HAS_FLASH_ATTENTION_SOFTCAPPING: + print( + "Unsloth: If you want to finetune Gemma 2, upgrade flash-attn to version 2.6.3 or higher!\n"\ + "Newer versions support faster and less memory usage kernels for Gemma 2's attention softcapping!\n"\ + "To update flash-attn, do the below:\n"\ + '\npip install --no-deps --upgrade "flash-attn>=2.6.3"' + ) + dispatch_model = FastGemma2Model elif model_type == "qwen2": dispatch_model = FastQwen2Model diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index 462555f31..57ba67658 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -241,6 +241,14 @@ "unsloth/Mistral-Large-Instruct-2407-bnb-4bit" : ( "mistralai/Mistral-Large-Instruct-2407", ), + "unsloth/gemma-2-2b-bnb-4bit" : ( + "unsloth/gemma-2-2b", + "google/gemma-2-2b", + ), + "unsloth/gemma-2-2b-it-bnb-4bit" : ( + "unsloth/gemma-2-2b-it", + "google/gemma-2-2b-it", + ), } INT_TO_FLOAT_MAPPER = {}