@@ -209,8 +209,9 @@ def LlamaAttention_fast_forward_inference(
209
209
210
210
# Attention
211
211
if bsz == 1 :
212
+ Qn *= self .scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
213
+ # It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows
212
214
A = torch .matmul (Qn , Knn .transpose (2 , 3 ), out = self .attention [:,:,:,:cached_len ])
213
- A *= self .scalar
214
215
# if attention_mask is not None: A += attention_mask # Must add attention_mask for batched
215
216
A [:] = torch_nn_functional_softmax (A , dim = - 1 , dtype = torch .float32 )#.to(A.dtype)
216
217
A = torch .matmul (A , Vnn , out = Qn )
@@ -791,7 +792,7 @@ def _CausalLM_fast_forward(
791
792
* args , ** kwargs ,
792
793
) -> Union [Tuple , CausalLMOutputWithPast ]:
793
794
794
- if past_key_values is not None and self . config . model_type != "qwen2" :
795
+ if past_key_values is not None :
795
796
outputs = fast_forward_inference (
796
797
self ,
797
798
input_ids ,
@@ -1195,7 +1196,13 @@ def from_pretrained(
1195
1196
f"\\ / Total batch size = {total_train_batch_size:,} | Total steps = {max_steps:,}\\ n"\\
1196
1197
f' "-____-" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}'
1197
1198
logger.warning(debug_info)
1198
- import gc
1199
+ import subprocess, re, gc
1200
+ output = subprocess.check_output(
1201
+ 'nvidia-smi --query-gpu=memory.used --format=csv', shell = True)
1202
+ output = re.findall(rb'([\\ d]{1,})[\\ s]{1,}M', output)
1203
+ output = sum(int(x.decode('utf-8'))/1024 > 4 for x in output)
1204
+ if output > 1: raise RuntimeError(
1205
+ 'Error: More than 1 GPUs have a lot of VRAM usage. Please obtain a commercial license.')
1199
1206
for _ in range(3):
1200
1207
gc.collect()
1201
1208
torch.cuda.empty_cache()"""
@@ -1206,12 +1213,12 @@ def from_pretrained(
1206
1213
1207
1214
debug_info = """n_total_devices = total_train_batch_size // \\
1208
1215
args.gradient_accumulation_steps // self._train_batch_size
1209
- if n_total_devices > 2 :
1216
+ if n_total_devices > 1 :
1210
1217
logger.warning_once(
1211
- "Our OSS was designed for people with few GPU resources to level the playing field.\\ n"
1212
- "The OSS Apache 2 license only supports one GPU - please obtain a commercial license.\\ n"
1213
- "We're a 2 person team, so we still have to fund our development costs - thanks!\\ n"
1214
- "If you don't, please consider at least sponsoring us through Ko-fi! Appreciate it!",
1218
+ "* Our OSS was designed for people with few GPU resources to level the playing field.\\ n"
1219
+ "* The OSS Apache 2 license only supports one GPU - please obtain a commercial license.\\ n"
1220
+ "* We're a 2 person team, so we still have to fund our development costs - thanks!\\ n"
1221
+ "* If you don't, please consider at least sponsoring us through Ko-fi! Appreciate it!",
1215
1222
)
1216
1223
debug_info ="""
1217
1224
debug_info = debug_info .split ('\n ' )
@@ -1236,17 +1243,17 @@ def from_pretrained(
1236
1243
bsz = self._train_batch_size
1237
1244
total_batches = bsz * ga * args.world_size
1238
1245
n_total_devices = total_batches // ga // bsz
1239
- if n_total_devices > 2 :
1246
+ if n_total_devices > 1 :
1240
1247
logger.warning_once(
1241
- "Our OSS was designed for people with few GPU resources to level the playing field.\\ n"
1242
- "The OSS Apache 2 license only supports one GPU - please obtain a commercial license.\\ n"
1243
- "We're a 2 person team, so we still have to fund our development costs - thanks!\\ n"
1244
- "If you don't, please consider at least sponsoring us through Ko-fi! Appreciate it!",
1248
+ "* Our OSS was designed for people with few GPU resources to level the playing field.\\ n"
1249
+ "* The OSS Apache 2 license only supports one GPU - please obtain a commercial license.\\ n"
1250
+ "* We're a 2 person team, so we still have to fund our development costs - thanks!\\ n"
1251
+ "* If you don't, please consider at least sponsoring us through Ko-fi! Appreciate it!",
1245
1252
)
1246
- divisor = n_total_devices / 2
1253
+ divisor = n_total_devices / 1
1247
1254
bsz = self._train_batch_size = max(int(bsz / divisor), 1)
1248
- if total_batches // ga // bsz > 2 :
1249
- divisor = n_total_devices / 2
1255
+ if total_batches // ga // bsz > 1 :
1256
+ divisor = n_total_devices / 1
1250
1257
ga = args.gradient_accumulation_steps = max(int(ga / divisor), 1)"""
1251
1258
check_batches = check_batches .split ('\n ' )
1252
1259
check_batches = "\n " .join ([check_batches [0 ]] + [front_spaces + x [8 :] for x in check_batches [1 :]])
@@ -1830,10 +1837,10 @@ def patch_peft_model(
1830
1837
1831
1838
@staticmethod
1832
1839
def for_inference (model ):
1833
- if model .config .model_type == "qwen2" :
1834
- FastLlamaModel .for_training (model )
1835
- return
1836
- pass
1840
+ # if model.config.model_type == "qwen2":
1841
+ # FastLlamaModel.for_training(model)
1842
+ # return
1843
+ # pass
1837
1844
1838
1845
internal_model = model
1839
1846
internal_model .gradient_checkpointing = False
0 commit comments