Skip to content

Commit 14d5c49

Browse files
committed
remove redudant code
1 parent adbee68 commit 14d5c49

File tree

2 files changed

+41
-169
lines changed

2 files changed

+41
-169
lines changed

src/llmcompressor/transformers/utils/helpers.py

+39-167
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import os
77
from pathlib import Path
8-
from typing import TYPE_CHECKING, Optional, Tuple, Union
8+
from typing import TYPE_CHECKING, Optional, Union
99

1010
import requests
1111
from huggingface_hub import HUGGINGFACE_CO_URL_HOME, hf_hub_download
@@ -77,211 +77,83 @@ def is_model_quantized_from_path(path: str) -> bool:
7777
return False
7878

7979

80-
def resolve_recipe(
81-
model_path: Union[str, Path],
82-
recipe: Union[str, Path, None] = None,
83-
) -> Union[str, None]:
84-
"""
85-
Resolve the recipe to apply to the model.
86-
:param recipe: the recipe to apply to the model.
87-
88-
It can be one of the following:
89-
- None
90-
This means that we are not either not applying
91-
any recipe and allowing the model to potentially
92-
infer the appropriate pre-existing recipe
93-
from the model_path
94-
- a path to the recipe file
95-
This can be a string or Path object pointing
96-
to a recipe file. If the specified recipe file
97-
is different from the potential pre-existing
98-
recipe for that model (stored in the model_path),
99-
the function will raise an warning
100-
- name of the recipe file (e.g. "recipe.yaml")
101-
Recipe file name specific is assumed to be stored
102-
in the model_path
103-
- a string containing the recipe
104-
Needs to adhere to the SparseML recipe format
105-
106-
:param model_path: the path to the model to load.
107-
It can be one of the following:
108-
- a path to the model directory
109-
- a path to the model file
110-
- Hugging face model id
111-
112-
:return: the resolved recipe
113-
"""
114-
115-
if recipe is None:
116-
return infer_recipe_from_model_path(model_path)
117-
118-
elif os.path.isfile(recipe):
119-
# recipe is a path to a recipe file
120-
return resolve_recipe_file(recipe, model_path)
121-
122-
elif os.path.isfile(os.path.join(model_path, recipe)):
123-
# recipe is a name of a recipe file
124-
recipe = os.path.join(model_path, recipe)
125-
return resolve_recipe_file(recipe, model_path)
126-
127-
elif isinstance(recipe, str):
128-
# recipe is a string containing the recipe
129-
logger.debug(
130-
"Applying the recipe string directly to the model, without "
131-
"checking for a potential existing recipe in the model_path."
132-
)
133-
return recipe
134-
135-
logger.info(
136-
"No recipe requested and no default recipe "
137-
f"found in {model_path}. Skipping recipe resolution."
138-
)
139-
return None
140-
141-
14280
def infer_recipe_from_model_path(model_path: Union[str, Path]) -> Optional[str]:
14381
"""
14482
Infer the recipe from the model_path.
145-
:param model_path: the path to the model to load.
146-
It can be one of the following:
83+
84+
:param model_path: The path to the model to load. It can be one of the following:
14785
- a path to the model directory
14886
- a path to the model file
149-
- Hugging face model id
150-
:return the path to the recipe file if found, None otherwise
87+
- Hugging face model ID
88+
:return: The path to the recipe file if found, None otherwise.
15189
"""
15290
model_path = model_path.as_posix() if isinstance(model_path, Path) else model_path
15391

15492
if os.path.isdir(model_path) or os.path.isfile(model_path):
155-
# model_path is a local path to the model directory or model file
156-
# attempting to find the recipe in the model_directory
93+
# Model path is a local path to the model directory or file
15794
model_path = (
15895
os.path.dirname(model_path) if os.path.isfile(model_path) else model_path
15996
)
16097
recipe = os.path.join(model_path, RECIPE_FILE_NAME)
98+
16199
if os.path.isfile(recipe):
162100
logger.info(f"Found recipe in the model_path: {recipe}")
163101
return recipe
164102
logger.debug(f"No recipe found in the model_path: {model_path}")
165103
return None
166104

167-
recipe = recipe_from_huggingface_model_id(model_path)[0]
105+
# If the model path is a Hugging Face model ID
106+
recipe = recipe_from_huggingface_model_id(model_path)
168107

169108
if recipe is None:
170109
logger.info("Failed to infer the recipe from the model_path")
110+
171111
return recipe
172112

173113

174114
def recipe_from_huggingface_model_id(
175115
model_path: str, recipe_file_name: str = RECIPE_FILE_NAME
176-
) -> Tuple[Optional[str], bool]:
116+
) -> Optional[str]:
177117
"""
178-
Attempts to download the recipe from the huggingface model id.
118+
Attempts to download the recipe from the Hugging Face model ID.
179119
180-
:param model_path: Assumed to be the huggingface model id.
181-
If it is not, this function will return None.
120+
:param model_path: Assumed to be the Hugging Face model ID.
182121
:param recipe_file_name: The name of the recipe file to download.
183-
Defaults to recipe_file_name.
184-
:return: tuple:
185-
- the path to the recipe file if found, None otherwise
186-
- True if model_path is a valid huggingface model id, False otherwise
122+
Defaults to RECIPE_FILE_NAME.
123+
:return: A tuple:
124+
- The path to the recipe file if found, None otherwise.
125+
- True if model_path is a valid Hugging Face model ID, False otherwise.
187126
"""
188-
model_id = os.path.join(HUGGINGFACE_CO_URL_HOME, model_path)
189-
request = requests.get(model_id)
190-
if not request.status_code == 200:
127+
model_id_url = os.path.join(HUGGINGFACE_CO_URL_HOME, model_path)
128+
request = requests.get(model_id_url)
129+
130+
if request.status_code != 200:
191131
logger.debug(
192-
"model_path is not a valid huggingface model id. "
193-
"Skipping recipe resolution."
132+
(
133+
"model_path is not a valid Hugging Face model ID. ",
134+
"Skipping recipe resolution.",
135+
)
194136
)
195-
return None, False
137+
return None
196138

197139
logger.info(
198-
"model_path is a huggingface model id. "
199-
"Attempting to download recipe from "
200-
f"{HUGGINGFACE_CO_URL_HOME}"
140+
(
141+
"model_path is a Hugging Face model ID. ",
142+
f"Attempting to download recipe from {HUGGINGFACE_CO_URL_HOME}",
143+
)
201144
)
145+
202146
try:
203147
recipe = hf_hub_download(repo_id=model_path, filename=recipe_file_name)
204-
logger.info(f"Found recipe: {recipe_file_name} for model id: {model_path}.")
148+
logger.info(f"Found recipe: {recipe_file_name} for model ID: {model_path}.")
205149
except Exception as e:
206-
logger.info(
207-
f"Unable to to find recipe {recipe_file_name} "
208-
f"for model id: {model_path}: {e}. "
209-
"Skipping recipe resolution."
210-
)
211-
recipe = None
212-
return recipe, True
213-
214-
215-
def resolve_recipe_file(
216-
requested_recipe: Union[str, Path], model_path: Union[str, Path]
217-
) -> Union[str, Path, None]:
218-
"""
219-
Given the requested recipe and the model_path, return the path to the recipe file.
220-
221-
:param requested_recipe. Is a full path to the recipe file
222-
:param model_path: the path to the model to load.
223-
It can be one of the following:
224-
- a path to the model directory
225-
- a path to the model file
226-
- Hugging face model id
227-
:return the path to the recipe file if found, None otherwise
228-
"""
229-
# preprocess arguments so that they are all strings
230-
requested_recipe = (
231-
requested_recipe.as_posix()
232-
if isinstance(requested_recipe, Path)
233-
else requested_recipe
234-
)
235-
model_path = model_path.as_posix() if isinstance(model_path, Path) else model_path
236-
model_path = (
237-
os.path.dirname(model_path) if os.path.isfile(model_path) else model_path
238-
)
239-
240-
if not os.path.isdir(model_path):
241-
default_recipe, model_exists = recipe_from_huggingface_model_id(model_path)
242-
if not model_exists:
243-
raise ValueError(f"Unrecognized model_path: {model_path}")
244-
245-
if not default_recipe == requested_recipe and default_recipe is not None:
246-
logger.warning(
247-
f"Attempting to apply recipe: {requested_recipe} "
248-
f"to the model at: {model_path}, "
249-
f"but the model already has a recipe: {default_recipe}. "
250-
f"Using {requested_recipe} instead."
150+
logger.error(
151+
(
152+
f"Unable to find recipe {recipe_file_name} "
153+
f"for model ID: {model_path}: {e}."
154+
"Skipping recipe resolution."
251155
)
252-
return requested_recipe
253-
254-
# pathway for model_path that is a directory
255-
default_recipe = os.path.join(model_path, RECIPE_FILE_NAME)
256-
default_recipe_exists = os.path.isfile(default_recipe)
257-
default_and_request_recipes_identical = os.path.samefile(
258-
default_recipe, requested_recipe
259-
)
260-
261-
if (
262-
default_recipe_exists
263-
and requested_recipe
264-
and not default_and_request_recipes_identical
265-
):
266-
logger.warning(
267-
f"Attempting to apply recipe: {requested_recipe} "
268-
f"to the model located in {model_path}, "
269-
f"but the model already has a recipe stored as {default_recipe}. "
270-
f"Using {requested_recipe} instead."
271-
)
272-
273-
elif not default_recipe_exists and requested_recipe:
274-
logger.warning(
275-
f"Attempting to apply {requested_recipe} "
276-
f"to the model located in {model_path}."
277-
"However, it is expected that the model "
278-
f"has its target recipe stored as {default_recipe}."
279-
"Applying any recipe before the target recipe may "
280-
"result in unexpected behavior."
281-
f"Applying {requested_recipe} nevertheless."
282156
)
157+
recipe = None
283158

284-
elif default_recipe_exists:
285-
logger.info(f"Using the default recipe: {requested_recipe}")
286-
287-
return requested_recipe
159+
return recipe

tests/llmcompressor/transformers/obcq/test_consecutive_runs.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from transformers import AutoModelForCausalLM
99
from transformers.utils.quantization_config import CompressedTensorsConfig
1010

11-
from llmcompressor.transformers.utils.helpers import resolve_recipe
11+
from llmcompressor.transformers.utils.helpers import infer_recipe_from_model_path
1212
from tests.testing_utils import parse_params, requires_gpu
1313

1414
CONFIGS_DIRECTORY = "tests/llmcompressor/transformers/obcq/obcq_configs/consec_runs"
@@ -60,7 +60,7 @@ def _test_consecutive_runs(
6060
self.assertEqual(len(stages), 1)
6161
session.reset()
6262

63-
recipe = resolve_recipe(recipe=self.first_recipe, model_path=self.output_first)
63+
recipe = infer_recipe_from_model_path(model_path=self.output_first)
6464
if recipe:
6565
initialize_recipe(model=first_model, recipe_path=recipe)
6666

0 commit comments

Comments
 (0)