Skip to content

Commit 7b8ab2f

Browse files
committed
Resolved merge conflict in QEfficient/base/common.py
2 parents 3b5466e + bdcd7e5 commit 7b8ab2f

File tree

11 files changed

+193
-244
lines changed

11 files changed

+193
-244
lines changed

QEfficient/base/common.py

+4
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212
QEFFAutoModel provides a common interface for loading the HuggingFace models using either the HF card name of local path of downloaded model.
1313
"""
1414

15+
import os
1516
from typing import Any
1617

1718
from transformers import AutoConfig
1819

1920
from QEfficient.base.modeling_qeff import QEFFBaseModel
2021
from QEfficient.transformers.modeling_utils import MODEL_CLASS_MAPPING
22+
from QEfficient.utils import login_and_download_hf_lm
2123

2224

2325
class QEFFCommonLoader:
@@ -51,6 +53,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs) ->
5153
)
5254

5355
local_model_dir = kwargs.pop("local_model_dir", None)
56+
if not os.path.isdir(pretrained_model_name_or_path) and local_model_dir is None:
57+
pretrained_model_name_or_path = login_and_download_hf_lm(pretrained_model_name_or_path, *args, **kwargs)
5458
hf_token = kwargs.pop("hf_token", None)
5559
continuous_batching = True if kwargs.pop("full_batch_size", None) else False
5660

QEfficient/base/modeling_qeff.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -245,8 +245,11 @@ def _compile(
245245
qpc_path = compile_dir / "qpc"
246246
if not onnx_path.is_file():
247247
raise FileNotFoundError(f"ONNX file not found at: {onnx_path}")
248-
249248
command = constants.COMPILER + [f"-m={onnx_path}"]
249+
if mdp_ts_json_path := compiler_options.pop("mdp_ts_json_path", None):
250+
mdp_ts_num_devices = None
251+
command.append(f"-mdp-load-partition-config={mdp_ts_json_path}")
252+
250253
for key, value in compiler_options.items():
251254
option = "-" + key.replace("_", "-")
252255
if isinstance(value, bool):
@@ -262,9 +265,6 @@ def _compile(
262265
if custom_io is not None:
263266
compile_hash.update(to_hashable(custom_io))
264267

265-
if mdp_ts_num_devices > 1:
266-
compile_hash.update(to_hashable({"mdp_ts_num_devices": mdp_ts_num_devices}))
267-
268268
if num_speculative_tokens:
269269
compile_hash.update(to_hashable({"num_speculative_tokens": num_speculative_tokens}))
270270

@@ -300,7 +300,7 @@ def _compile(
300300
command.append(f"-custom-IO-list-file={custom_io_yaml}")
301301

302302
# Write mdp_config.json file
303-
if mdp_ts_num_devices > 1:
303+
if not mdp_ts_json_path and mdp_ts_num_devices > 1:
304304
num_cores = compiler_options.get("aic_num_cores", 16)
305305
mdp_ts_json = compile_dir / f"mdp_ts_{mdp_ts_num_devices}.json"
306306
with open(mdp_ts_json, "w") as fp:

QEfficient/transformers/models/mllama/modeling_mllama.py

+1-14
Original file line numberDiff line numberDiff line change
@@ -55,19 +55,8 @@ class QEffMllamaRotaryEmbedding(MllamaRotaryEmbedding):
5555
- Add static sin/cos computations.
5656
"""
5757

58-
def __init__(
59-
self,
60-
dim=None,
61-
max_position_embeddings=2048,
62-
base=10000,
63-
device=None,
64-
scaling_factor=1.0,
65-
rope_type="default",
66-
config: Optional[MllamaConfig] = None,
67-
):
58+
def __init__(self, config: MllamaConfig, device=None):
6859
super().__init__(config=config)
69-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
70-
self.register_buffer("inv_freq", inv_freq, persistent=False)
7160

7261
# Build here to make `torch.jit.trace` work.
7362
self._set_cos_sin_cache(
@@ -868,7 +857,6 @@ def forward(
868857
output_hidden_states: Optional[bool] = None,
869858
return_dict: Optional[bool] = None,
870859
cache_position: Optional[torch.LongTensor] = None,
871-
num_logits_to_keep: int = 0,
872860
) -> Union[Tuple, CausalLMOutputWithPast]:
873861
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
874862
output_hidden_states = (
@@ -935,7 +923,6 @@ def forward(
935923
output_attentions=output_attentions,
936924
return_dict=return_dict,
937925
cache_position=cache_position,
938-
num_logits_to_keep=num_logits_to_keep,
939926
)
940927

941928
return outputs

0 commit comments

Comments
 (0)