From a1dc4e7374ba18d18fac310e14f05a28dd104fc1 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Mon, 19 Jun 2023 14:29:22 +0200 Subject: [PATCH] chore: removed print statement for logging --- spacy_setfit/models.py | 46 ++++++++++++++++++++++++++++++------------ 1 file changed, 33 insertions(+), 13 deletions(-) diff --git a/spacy_setfit/models.py b/spacy_setfit/models.py index 2a1484f..3e8168d 100644 --- a/spacy_setfit/models.py +++ b/spacy_setfit/models.py @@ -1,3 +1,4 @@ +import logging import types from setfit import SetFitModel, SetFitTrainer @@ -7,9 +8,11 @@ from spacy_setfit.schemas import SetFitTrainerArgs +__LOGGER__ = logging.getLogger(__name__) + class SpacySetFit: - def __init__(self, nlp: Language, model: SetFitModel, labels = None): + def __init__(self, nlp: Language, model: SetFitModel, labels=None): self.nlp = nlp self.model = model self.multi_label = self._check_multi_label(model) @@ -27,7 +30,9 @@ def _check_multi_label(model: SetFitModel): return False @staticmethod - def _from_pretrained(pretrained_model_name_or_path: str, setfit_from_pretrained_args: dict = None): + def _from_pretrained( + pretrained_model_name_or_path: str, setfit_from_pretrained_args: dict = None + ): if setfit_from_pretrained_args is None: setfit_from_pretrained_args = {} model = SetFitModel.from_pretrained( @@ -36,32 +41,47 @@ def _from_pretrained(pretrained_model_name_or_path: str, setfit_from_pretrained_ return model @classmethod - def from_pretrained(cls, nlp: Language, pretrained_model_name_or_path: str, setfit_from_pretrained_args: dict = None): - model = cls._from_pretrained(pretrained_model_name_or_path, setfit_from_pretrained_args) + def from_pretrained( + cls, + nlp: Language, + pretrained_model_name_or_path: str, + setfit_from_pretrained_args: dict = None, + ): + model = cls._from_pretrained( + pretrained_model_name_or_path, setfit_from_pretrained_args + ) return cls(nlp, model) @classmethod - def from_trained(cls, nlp: Language, pretrained_model_name_or_path: str, setfit_trainer_args: SetFitTrainerArgs, setfit_from_pretrained_args: dict = None): - - setfit_from_pretrained_args["multi_target_strategy"] = setfit_from_pretrained_args.get("multi_target_strategy") + def from_trained( + cls, + nlp: Language, + pretrained_model_name_or_path: str, + setfit_trainer_args: SetFitTrainerArgs, + setfit_from_pretrained_args: dict = None, + ): + setfit_from_pretrained_args[ + "multi_target_strategy" + ] = setfit_from_pretrained_args.get("multi_target_strategy") if setfit_trainer_args.multi_label: setfit_from_pretrained_args["multi_target_strategy"] = "one-vs-rest" - model = cls._from_pretrained(pretrained_model_name_or_path, setfit_from_pretrained_args) - trainer = SetFitTrainer( - model=model, - **setfit_trainer_args.dict() + model = cls._from_pretrained( + pretrained_model_name_or_path, setfit_from_pretrained_args ) + trainer = SetFitTrainer(model=model, **setfit_trainer_args.dict()) trainer.train() if setfit_trainer_args.eval_dataset: evaluation = trainer.evaluate() - print(evaluation) + __LOGGER__.info(f"evaluation: {evaluation}") return cls(nlp, model, labels=setfit_trainer_args.labels) def _assign_labels(self, doc, prediction: list): if self.id2label: - doc.cats = {self.id2label[idx]: float(score) for idx, score in enumerate(prediction)} + doc.cats = { + self.id2label[idx]: float(score) for idx, score in enumerate(prediction) + } else: doc.cats = {idx: float(score) for idx, score in enumerate(prediction)} return doc