Skip to content

Commit 0a4ce53

Browse files
authored
use inference-runner corresponding to reference model (#346)
1 parent 1c5164f commit 0a4ce53

File tree

5 files changed

+22
-22
lines changed

5 files changed

+22
-22
lines changed

fast_llm/engine/multi_stage/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
if typing.TYPE_CHECKING:
3333
from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM
34+
from fast_llm.engine.inference.runner import InferenceRunner
3435
from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel
3536

3637
logger = logging.getLogger(__name__)
@@ -241,6 +242,10 @@ def get_checkpoint_handler_class(cls, format: type[CheckpointFormat] | str) -> t
241242
def get_model_class(cls) -> type["FastLLMModel"]:
242243
raise NotImplementedError
243244

245+
@classmethod
246+
def get_inference_runner_class(cls) -> type["InferenceRunner"]:
247+
raise NotImplementedError
248+
244249
@classmethod
245250
def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceBaseModelForCausalLM"]:
246251
raise NotImplementedError

fast_llm/engine/training/config.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from fast_llm.utils import Assert
3333

3434
if typing.TYPE_CHECKING:
35-
from fast_llm.engine.inference.runner import InferenceRunner
3635
from fast_llm.engine.training.trainer import Trainer, TrainingEvaluator
3736

3837

@@ -403,10 +402,6 @@ def _setup(self):
403402
def get_trainer_class(cls) -> type["Trainer"]:
404403
raise NotImplementedError
405404

406-
@classmethod
407-
def get_inference_runner_class(cls) -> type["InferenceRunner"]:
408-
raise NotImplementedError
409-
410405
def _get_runnable(self) -> typing.Callable[[], None]:
411406
from fast_llm.engine.distributed.distributed import Distributed
412407

fast_llm/engine/training/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def __init__(self, config: TrainerConfig):
142142
self._reference_models = {}
143143
for name, reference_config in self._config.reference_models.items():
144144
log_main_rank(f"Creating `{name} reference model...")
145-
self._reference_models[name] = self._config.get_inference_runner_class()(
145+
self._reference_models[name] = reference_config.model.get_inference_runner_class()(
146146
reference_config.model.get_model_class()(reference_config.model)
147147
)
148148
self._multi_stage.base_model.add_reference_model(name, self._reference_models[name])

fast_llm/models/gpt/config.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,12 @@ def get_model_class(cls) -> type["GPTModel"]:
171171

172172
return GPTModel
173173

174+
@classmethod
175+
def get_inference_runner_class(cls) -> type["GPTInferenceRunner"]:
176+
from fast_llm.models.gpt.model import GPTInferenceRunner
177+
178+
return GPTInferenceRunner
179+
174180
@classmethod
175181
def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceGPTModelForCausalLM"]:
176182
from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM
@@ -254,9 +260,3 @@ def get_trainer_class(cls) -> type["GPTTrainer"]:
254260
from fast_llm.models.gpt.trainer import GPTTrainer
255261

256262
return GPTTrainer
257-
258-
@classmethod
259-
def get_inference_runner_class(cls) -> type["GPTInferenceRunner"]:
260-
from fast_llm.models.gpt.model import GPTInferenceRunner
261-
262-
return GPTInferenceRunner

fast_llm/models/ssm/config.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,16 @@ def get_model_class(cls) -> type["HybridSSMModel"]:
173173

174174
return HybridSSMModel
175175

176+
@classmethod
177+
def get_inference_runner_class(cls) -> type["HybridSSMInferenceRunner"]:
178+
from fast_llm.models.ssm.model import HybridSSMInferenceRunner
179+
180+
logger.warning(
181+
"HybridSSMInferenceRunner only supports training-style forward pass. Use generate with cache disabled."
182+
)
183+
184+
return HybridSSMInferenceRunner
185+
176186
@classmethod
177187
def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceHybridSSMModelForCausalLM"]:
178188
from fast_llm.models.ssm.huggingface import HuggingfaceHybridSSMModelForCausalLM
@@ -227,13 +237,3 @@ def _validate(self) -> None:
227237
Assert.none(reference_model.model.base_model.cross_entropy_splits)
228238
Assert.eq(reference_model.model.base_model.parallel_embeddings, self.model.base_model.parallel_embeddings)
229239
Assert.geq(reference_model.model.base_model.prediction_heads, self.model.base_model.prediction_heads)
230-
231-
@classmethod
232-
def get_inference_runner_class(cls) -> type["HybridSSMInferenceRunner"]:
233-
from fast_llm.models.ssm.model import HybridSSMInferenceRunner
234-
235-
logger.warning(
236-
"HybridSSMInferenceRunner only supports training-style forward pass. Use generate with cache disabled."
237-
)
238-
239-
return HybridSSMInferenceRunner

0 commit comments

Comments
 (0)