Skip to content

Commit

Permalink
fix llama FP8 perf issue, kvcache.update should be used since FP8 pat…
Browse files Browse the repository at this point in the history
…ches KVCache

Signed-off-by: Wang, Yi A <[email protected]>
  • Loading branch information
sywangyi committed Feb 7, 2025
1 parent 27d1495 commit 38e4777
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
11 changes: 7 additions & 4 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,18 +668,21 @@ def pre_attn_forward(
else key_states.dtype,
device=key_states.device,
)
past_key.copy_(key_states)
past_value.copy_(value_states)
# Return list instead of tuple
past_key_value = [past_key, past_value]
key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, key_states.shape[-2])
value_states = self.v_cache.update(
past_key_value[1], value_states, 2, token_idx, value_states.shape[-2]
)

elif (
token_idx is not None
and num_virtual_tokens is not None
and num_virtual_tokens == past_key_value[0].shape[-2]
):
# prefix tuning case. attach past_key_value to generate first token.
key_states = torch.cat((past_key_value[0], key_states), -2)
value_states = torch.cat((past_key_value[1], value_states), -2)
key_states = self.k_cache.update(past_key_value[0], key_states, 2, None, -1)
value_states = self.v_cache.update(past_key_value[1], value_states, 2, None, -1)
past_key_value = (key_states, value_states)
else:
key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len)
Expand Down
3 changes: 0 additions & 3 deletions tests/baselines/llama_7b.json
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,6 @@
"--report_to none",
"--max_steps 100",
"--peft_type prompt_tuning",
"--max_seq_length 64",
"--lr_scheduler_type cosine",
"--warmup_steps 0",
"--weight_decay 0.05",
Expand All @@ -402,7 +401,6 @@
"--report_to none",
"--max_steps 100",
"--peft_type prefix_tuning",
"--max_seq_length 64",
"--lr_scheduler_type cosine",
"--warmup_steps 0",
"--weight_decay 0.05",
Expand All @@ -428,7 +426,6 @@
"--report_to none",
"--max_steps 100",
"--peft_type p_tuning",
"--max_seq_length 64",
"--lr_scheduler_type cosine",
"--warmup_steps 0",
"--weight_decay 0.05",
Expand Down

0 comments on commit 38e4777

Please sign in to comment.