Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions model2vec/train/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@


class FinetunableStaticModel(nn.Module):
def __init__(self, *, vectors: torch.Tensor, tokenizer: Tokenizer, out_dim: int = 2, pad_id: int = 0) -> None:
def __init__(
self, *, vectors: torch.Tensor, tokenizer: Tokenizer, out_dim: int = 2, pad_id: int = 0, freeze: bool = False
) -> None:
"""
Initialize a trainable StaticModel from a StaticModel.

:param vectors: The embeddings of the staticmodel.
:param tokenizer: The tokenizer.
:param out_dim: The output dimension of the head.
:param pad_id: The padding id. This is set to 0 in almost all model2vec models
:param freeze: Whether to freeze the embeddings. This should be set to False in most cases.
"""
super().__init__()
self.pad_id = pad_id
Expand All @@ -37,8 +40,8 @@ def __init__(self, *, vectors: torch.Tensor, tokenizer: Tokenizer, out_dim: int
f"Your vectors are {dtype} precision, converting to to torch.float32 to avoid compatibility issues."
)
self.vectors = vectors.float()

self.embeddings = nn.Embedding.from_pretrained(vectors.clone(), freeze=False, padding_idx=pad_id)
self.freeze = freeze
self.embeddings = nn.Embedding.from_pretrained(vectors.clone(), freeze=self.freeze, padding_idx=pad_id)
self.head = self.construct_head()
self.w = self.construct_weights()
self.tokenizer = tokenizer
Expand All @@ -47,7 +50,7 @@ def construct_weights(self) -> nn.Parameter:
"""Construct the weights for the model."""
weights = torch.zeros(len(self.vectors))
weights[self.pad_id] = -10_000
return nn.Parameter(weights)
return nn.Parameter(weights, requires_grad=not self.freeze)

def construct_head(self) -> nn.Sequential:
"""Method should be overridden for various other classes."""
Expand Down Expand Up @@ -118,7 +121,7 @@ def tokenize(self, texts: list[str], max_length: int | None = 512) -> torch.Tens
return pad_sequence(encoded_ids, batch_first=True, padding_value=self.pad_id)

@property
def device(self) -> str:
def device(self) -> torch.device:
"""Get the device of the model."""
return self.embeddings.weight.device

Expand Down Expand Up @@ -157,7 +160,7 @@ def collate_fn(batch: list[tuple[list[list[int]], int]]) -> tuple[torch.Tensor,
"""Collate function."""
texts, targets = zip(*batch)

tensors = [torch.LongTensor(x) for x in texts]
tensors: list[torch.Tensor] = [torch.LongTensor(x) for x in texts]
padded = pad_sequence(tensors, batch_first=True, padding_value=0)

return padded, torch.stack(targets)
Expand Down
23 changes: 16 additions & 7 deletions model2vec/train/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
hidden_dim: int = 512,
out_dim: int = 2,
pad_id: int = 0,
freeze: bool = False,
) -> None:
"""Initialize a standard classifier model."""
self.n_layers = n_layers
Expand All @@ -46,7 +47,7 @@ def __init__(
self.classes_: list[str] = [str(x) for x in range(out_dim)]
# multilabel flag will be set based on the type of `y` passed to fit.
self.multilabel: bool = False
super().__init__(vectors=vectors, out_dim=out_dim, pad_id=pad_id, tokenizer=tokenizer)
super().__init__(vectors=vectors, out_dim=out_dim, pad_id=pad_id, tokenizer=tokenizer, freeze=freeze)

@property
def classes(self) -> np.ndarray:
Expand Down Expand Up @@ -124,7 +125,7 @@ def predict_proba(self, X: list[str], show_progress_bar: bool = False, batch_siz
pred.append(torch.softmax(logits, dim=1).cpu().numpy())
return np.concatenate(pred, axis=0)

def fit(
def fit( # noqa: C901 # Complexity is bad.
self,
X: list[str],
y: LabelType,
Expand Down Expand Up @@ -165,7 +166,7 @@ def fit(
:param device: The device to train on. If this is "auto", the device is chosen automatically.
:param X_val: The texts to be used for validation.
:param y_val: The labels to be used for validation.
:param class_weight: The weight of the classes. If None, all classes are weighted equally. Must
:param class_weight: The weight of the classes. If None, all classes are weighted equally. Must
have the same length as the number of classes.
:return: The fitted model.
:raises ValueError: If either X_val or y_val are provided, but not both.
Expand Down Expand Up @@ -201,7 +202,7 @@ def fit(
base_number = int(min(max(1, (len(train_texts) / 30) // 32), 16))
batch_size = int(base_number * 32)
logger.info("Batch size automatically set to %d.", batch_size)

if class_weight is not None:
if len(class_weight) != len(self.classes_):
raise ValueError("class_weight must have the same length as the number of classes.")
Expand Down Expand Up @@ -300,7 +301,9 @@ def _initialize(self, y: LabelType) -> None:
self.classes_ = classes
self.out_dim = len(self.classes_) # Update output dimension
self.head = self.construct_head()
self.embeddings = nn.Embedding.from_pretrained(self.vectors.clone(), freeze=False, padding_idx=self.pad_id)
self.embeddings = nn.Embedding.from_pretrained(
self.vectors.clone(), freeze=self.freeze, padding_idx=self.pad_id
)
self.w = self.construct_weights()
self.train()

Expand Down Expand Up @@ -383,12 +386,18 @@ def to_pipeline(self) -> StaticModelPipeline:


class _ClassifierLightningModule(pl.LightningModule):
def __init__(self, model: StaticModelForClassification, learning_rate: float, class_weight: torch.Tensor | None = None) -> None:
def __init__(
self, model: StaticModelForClassification, learning_rate: float, class_weight: torch.Tensor | None = None
) -> None:
"""Initialize the LightningModule."""
super().__init__()
self.model = model
self.learning_rate = learning_rate
self.loss_function = nn.CrossEntropyLoss(weight=class_weight) if not model.multilabel else nn.BCEWithLogitsLoss(pos_weight=class_weight)
self.loss_function = (
nn.CrossEntropyLoss(weight=class_weight)
if not model.multilabel
else nn.BCEWithLogitsLoss(pos_weight=class_weight)
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Simple forward pass."""
Expand Down
Loading