From a0d14d2998ca40adb30a866a880e8af4a217f8d4 Mon Sep 17 00:00:00 2001 From: Chetan Kumar Verma <39086835+ckvermaAI@users.noreply.github.com> Date: Fri, 7 Feb 2025 16:18:11 +0530 Subject: [PATCH] Accuracy fix for llama3.1-70B in eager/torch.compile mode (#1746) Co-authored-by: Vivek Goel --- optimum/habana/transformers/models/llama/modeling_llama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 130490d5c9..77659b15d1 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -136,7 +136,8 @@ def __init__( def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + # Use torch.int32 to avoid loss due to low precision with BF16 (refer to SW-215204) + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int32) freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation