Skip to content

Commit

Permalink
chore: removed print statement for logging
Browse files Browse the repository at this point in the history
  • Loading branch information
davidberenstein1957 committed Jun 19, 2023
1 parent 3353db7 commit a1dc4e7
Showing 1 changed file with 33 additions and 13 deletions.
46 changes: 33 additions & 13 deletions spacy_setfit/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import types

from setfit import SetFitModel, SetFitTrainer
Expand All @@ -7,9 +8,11 @@

from spacy_setfit.schemas import SetFitTrainerArgs

__LOGGER__ = logging.getLogger(__name__)


class SpacySetFit:
def __init__(self, nlp: Language, model: SetFitModel, labels = None):
def __init__(self, nlp: Language, model: SetFitModel, labels=None):
self.nlp = nlp
self.model = model
self.multi_label = self._check_multi_label(model)
Expand All @@ -27,7 +30,9 @@ def _check_multi_label(model: SetFitModel):
return False

@staticmethod
def _from_pretrained(pretrained_model_name_or_path: str, setfit_from_pretrained_args: dict = None):
def _from_pretrained(
pretrained_model_name_or_path: str, setfit_from_pretrained_args: dict = None
):
if setfit_from_pretrained_args is None:
setfit_from_pretrained_args = {}
model = SetFitModel.from_pretrained(
Expand All @@ -36,32 +41,47 @@ def _from_pretrained(pretrained_model_name_or_path: str, setfit_from_pretrained_
return model

@classmethod
def from_pretrained(cls, nlp: Language, pretrained_model_name_or_path: str, setfit_from_pretrained_args: dict = None):
model = cls._from_pretrained(pretrained_model_name_or_path, setfit_from_pretrained_args)
def from_pretrained(
cls,
nlp: Language,
pretrained_model_name_or_path: str,
setfit_from_pretrained_args: dict = None,
):
model = cls._from_pretrained(
pretrained_model_name_or_path, setfit_from_pretrained_args
)
return cls(nlp, model)

@classmethod
def from_trained(cls, nlp: Language, pretrained_model_name_or_path: str, setfit_trainer_args: SetFitTrainerArgs, setfit_from_pretrained_args: dict = None):

setfit_from_pretrained_args["multi_target_strategy"] = setfit_from_pretrained_args.get("multi_target_strategy")
def from_trained(
cls,
nlp: Language,
pretrained_model_name_or_path: str,
setfit_trainer_args: SetFitTrainerArgs,
setfit_from_pretrained_args: dict = None,
):
setfit_from_pretrained_args[
"multi_target_strategy"
] = setfit_from_pretrained_args.get("multi_target_strategy")
if setfit_trainer_args.multi_label:
setfit_from_pretrained_args["multi_target_strategy"] = "one-vs-rest"

model = cls._from_pretrained(pretrained_model_name_or_path, setfit_from_pretrained_args)
trainer = SetFitTrainer(
model=model,
**setfit_trainer_args.dict()
model = cls._from_pretrained(
pretrained_model_name_or_path, setfit_from_pretrained_args
)
trainer = SetFitTrainer(model=model, **setfit_trainer_args.dict())
trainer.train()
if setfit_trainer_args.eval_dataset:
evaluation = trainer.evaluate()
print(evaluation)
__LOGGER__.info(f"evaluation: {evaluation}")

return cls(nlp, model, labels=setfit_trainer_args.labels)

def _assign_labels(self, doc, prediction: list):
if self.id2label:
doc.cats = {self.id2label[idx]: float(score) for idx, score in enumerate(prediction)}
doc.cats = {
self.id2label[idx]: float(score) for idx, score in enumerate(prediction)
}
else:
doc.cats = {idx: float(score) for idx, score in enumerate(prediction)}
return doc
Expand Down

0 comments on commit a1dc4e7

Please sign in to comment.