Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
47fd13e
Reimplement average_checkpoints which uses average_state_dicts now
BloodAxe Feb 6, 2025
e67ffa5
convert_2d_to_3d
BloodAxe Feb 6, 2025
28a3791
convert_2d_to_3d
BloodAxe Feb 6, 2025
07d4100
Added missing import
BloodAxe Feb 6, 2025
513c6b2
A transfer_weights method now returns TransferWeightsOuptut which can…
BloodAxe Mar 1, 2025
b427f76
Merge remote-tracking branch 'origin/develop' into develop
BloodAxe Mar 1, 2025
47a1273
vstack_header now can wrap text into multiple lines
BloodAxe Mar 1, 2025
531822b
Replace deprecated torch.cuda.amp.autocast with torch.amp.autocast
BloodAxe Mar 1, 2025
56e2ef0
Bump version
BloodAxe Mar 1, 2025
4b4ad80
Bump black & reformat
BloodAxe Mar 1, 2025
710ef17
Remove activation replacement
BloodAxe Mar 1, 2025
3c21031
Fix flake8
BloodAxe Mar 1, 2025
cb46d42
Fix wrong argument name
BloodAxe Jun 9, 2025
070f2d5
master_node_first
BloodAxe Jun 22, 2025
316360f
Add keepdim parameter to geometric_mean and harmonic_mean functions
BloodAxe Jun 22, 2025
06d4de5
Add contextmanager import to distributed.py
BloodAxe Jun 22, 2025
ec97142
Rename wait_for_the_master to master_node_first in distributed.py
BloodAxe Jun 22, 2025
89b1dd4
Add device_id parameter to init_process_group for CUDA device configu…
BloodAxe Sep 12, 2025
3336e88
Merge remote-tracking branch 'origin/develop' into develop
BloodAxe Sep 12, 2025
71ea86b
change the use of np.ndarray to np.array
liangosc Oct 9, 2025
91dc0fd
Merge pull request #109 from liangosc/ndarray-should-be-array
BloodAxe Oct 9, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytorch_toolbelt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from __future__ import absolute_import

__version__ = "0.8.0"
__version__ = "0.8.1"
131 changes: 85 additions & 46 deletions pytorch_toolbelt/inference/ensembling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,17 @@

import torch
from torch import nn, Tensor
from typing import List, Union, Iterable, Optional, Dict, Tuple

__all__ = ["ApplySoftmaxTo", "ApplySigmoidTo", "Ensembler", "PickModelOutput", "SelectByIndex", "average_checkpoints"]
from typing import List, Union, Iterable, Optional, Dict, Tuple, Mapping

__all__ = [
"ApplySoftmaxTo",
"ApplySigmoidTo",
"Ensembler",
"PickModelOutput",
"SelectByIndex",
"average_checkpoints",
"average_state_dicts",
]

from pytorch_toolbelt.inference.tta import _deaugment_averaging

Expand Down Expand Up @@ -163,53 +171,84 @@ def forward(self, outputs: Dict[str, Tensor]) -> Tensor:
return outputs[self.target_key]


def average_checkpoints(inputs: List[str]) -> collections.OrderedDict:
"""Loads checkpoints from inputs and returns a model with averaged weights. Original implementation taken from:
https://github.com/pytorch/fairseq/blob/a48f235636557b8d3bc4922a6fa90f3a0fa57955/scripts/average_checkpoints.py#L16
def average_state_dicts(state_dicts: List[Mapping[str, Tensor]]) -> Mapping[str, Tensor]:
"""
Averages multiple 'state_dict'

"""

keys = state_dicts[0].keys()
final_state_dict = collections.OrderedDict()

for key in keys:
# Collect the values (tensors) for this key from all checkpoints
values = [sd[key] for sd in state_dicts]

# Check the dtype of the first value (assuming all dtypes match)
first_val = values[0]

if not all(v.shape == first_val.shape for v in values):
raise ValueError(f"Tensor shapes for key '{key}' are not consistent across checkpoints.")

if first_val.dtype == torch.bool:
# For bool, ensure all are identical
for val in values[1:]:
if not torch.equal(val, first_val):
raise ValueError(f"Boolean values for key '{key}' differ between checkpoints.")
final_state_dict[key] = first_val # Use the first if all identical

elif torch.is_floating_point(first_val):
# Average float values
stacked = torch.stack(values, dim=0)
target_dtype = stacked.dtype
accum_dtype = torch.promote_types(target_dtype, torch.float32) # Upcast to float32 if needed
averaged = stacked.to(accum_dtype).mean(dim=0).to(target_dtype)
final_state_dict[key] = averaged

elif first_val.dtype in {
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
torch.uint16,
torch.uint32,
torch.uint64,
}:
# Average integer values (using integer division)
stacked = torch.stack(values, dim=0)
summed = stacked.sum(dim=0, dtype=torch.int64)
averaged = summed // len(values)
final_state_dict[key] = averaged.to(first_val.dtype)

else:
# If you have other special dtypes to handle, add logic here
# or simply copy the first value if that is your intended behavior.
raise TypeError(f"Unsupported dtype '{first_val.dtype}' encountered for key '{key}'.")

return final_state_dict


def average_checkpoints(inputs: List[str], key=None, map_location="cpu", weights_only=True) -> collections.OrderedDict:
"""Loads checkpoints from inputs and returns a model with averaged weights.

Args:
inputs (List[str]): An iterable of string paths of checkpoints to load from.
key (str): An optional key to select a sub-dictionary from the checkpoint.
map_location (str): A string describing how to remap storage locations when loading the model.
weights_only (bool): If True, will only load the weights of the model.

Returns:
A dict of string keys mapping to various values. The 'model' key
from the returned dict should correspond to an OrderedDict mapping
string parameter names to torch Tensors.
"""
params_dict = collections.OrderedDict()
params_keys = None
new_state = None
num_models = len(inputs)
for fpath in inputs:
with open(fpath, "rb") as f:
state = torch.load(
f,
map_location="cpu",
)
# Copies over the settings from the first checkpoint
if new_state is None:
new_state = state
model_params = state["model_state_dict"]
model_params_keys = list(model_params.keys())
if params_keys is None:
params_keys = model_params_keys
elif params_keys != model_params_keys:
raise KeyError(
"For checkpoint {}, expected list of params: {}, "
"but found: {}".format(f, params_keys, model_params_keys)
)
for k in params_keys:
p = model_params[k]
if isinstance(p, torch.HalfTensor):
p = p.float()
if k not in params_dict:
params_dict[k] = p.clone()
# NOTE: clone() is needed in case of p is a shared parameter
else:
params_dict[k] += p
averaged_params = collections.OrderedDict()
for k, v in params_dict.items():
averaged_params[k] = v
if averaged_params[k].is_floating_point():
averaged_params[k].div_(num_models)
else:
averaged_params[k] //= num_models
new_state["model_state_dict"] = averaged_params
return new_state
state_dicts = [torch.load(path, map_location="cpu", weights_only=weights_only) for path in inputs]
if key is not None:
state_dicts = [sd[key] for sd in state_dicts]

avg_state_dict = average_state_dicts(state_dicts)
if key is not None:
avg_state_dict = {key: avg_state_dict}

return avg_state_dict
8 changes: 4 additions & 4 deletions pytorch_toolbelt/inference/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def unpad_xyxy_bboxes(bboxes_tensor: torch.Tensor, pad, dim=-1):
return bboxes_tensor - pad


def geometric_mean(x: Tensor, dim: int) -> Tensor:
def geometric_mean(x: Tensor, dim: int, keepdim=False) -> Tensor:
"""
Compute geometric mean along given dimension.
This implementation assume values are in range (0...1) (Probabilities)
Expand All @@ -258,10 +258,10 @@ def geometric_mean(x: Tensor, dim: int) -> Tensor:
Returns:
Tensor
"""
return x.log().mean(dim=dim).exp()
return x.log().mean(dim=dim, keepdim=keepdim).exp()


def harmonic_mean(x: Tensor, dim: int, eps: float = 1e-6) -> Tensor:
def harmonic_mean(x: Tensor, dim: int, eps: float = 1e-6, keepdim=False) -> Tensor:
"""
Compute harmonic mean along given dimension.

Expand All @@ -273,7 +273,7 @@ def harmonic_mean(x: Tensor, dim: int, eps: float = 1e-6) -> Tensor:
Tensor
"""
x = torch.reciprocal(x.clamp_min(eps))
x = torch.mean(x, dim=dim)
x = torch.mean(x, dim=dim, keepdim=keepdim)
x = torch.reciprocal(x.clamp_min(eps))
return x

Expand Down
88 changes: 44 additions & 44 deletions pytorch_toolbelt/losses/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
]


@torch.cuda.amp.autocast(False)
def focal_loss_with_logits(
output: torch.Tensor,
target: torch.Tensor,
Expand Down Expand Up @@ -58,51 +57,52 @@ def focal_loss_with_logits(
output = output.float()
target = target.float()

if activation == "sigmoid":
p = torch.sigmoid(output)
else:
p = torch.softmax(output, dim=softmax_dim)

ce_loss = F.binary_cross_entropy_with_logits(output, target, reduction="none")
pt = p * target + (1 - p) * (1 - target)

# compute the loss
if reduced_threshold is None:
focal_term = (1.0 - pt).pow(gamma)
else:
focal_term = ((1.0 - pt) / (1 - reduced_threshold)).pow(
gamma
) # the focal term continuity breaks when reduced_threshold not equal to 0.5. At pt equal to reduced_threshold, the value of piecewise function of focal term should be 1 from both sides .
focal_term = torch.masked_fill(focal_term, pt < reduced_threshold, 1)

loss = focal_term * ce_loss

if alpha is not None:
loss *= alpha * target + (1 - alpha) * (1 - target)

if class_weights is not None:
# class_weights is of shape [C]
# Loss is of shape [B,C ...]
# Reshape class_weights to [1, C, ...]
class_weights = class_weights.view(1, -1, *(1 for _ in range(loss.dim() - 2)))
loss *= class_weights
with torch.amp.autocast(device_type=output.device.type, enabled=False):
if activation == "sigmoid":
p = torch.sigmoid(output)
else:
p = torch.softmax(output, dim=softmax_dim)

ce_loss = F.binary_cross_entropy_with_logits(output, target, reduction="none")
pt = p * target + (1 - p) * (1 - target)

# compute the loss
if reduced_threshold is None:
focal_term = (1.0 - pt).pow(gamma)
else:
focal_term = ((1.0 - pt) / (1 - reduced_threshold)).pow(
gamma
) # the focal term continuity breaks when reduced_threshold not equal to 0.5. At pt equal to reduced_threshold, the value of piecewise function of focal term should be 1 from both sides .
focal_term = torch.masked_fill(focal_term, pt < reduced_threshold, 1)

loss = focal_term * ce_loss

if alpha is not None:
loss *= alpha * target + (1 - alpha) * (1 - target)

if class_weights is not None:
# class_weights is of shape [C]
# Loss is of shape [B,C ...]
# Reshape class_weights to [1, C, ...]
class_weights = class_weights.view(1, -1, *(1 for _ in range(loss.dim() - 2)))
loss *= class_weights

if ignore_index is not None:
ignore_mask = target.eq(ignore_index)
loss = torch.masked_fill(loss, ignore_mask, 0)
if normalized:
focal_term = torch.masked_fill(focal_term, ignore_mask, 0)

if ignore_index is not None:
ignore_mask = target.eq(ignore_index)
loss = torch.masked_fill(loss, ignore_mask, 0)
if normalized:
focal_term = torch.masked_fill(focal_term, ignore_mask, 0)

if normalized:
norm_factor = focal_term.sum(dtype=torch.float32).clamp_min(eps)
loss /= norm_factor

if reduction == "mean":
loss = loss.mean()
if reduction == "sum":
loss = loss.sum()
if reduction == "batchwise_mean":
loss = loss.sum(dim=0)
norm_factor = focal_term.sum(dtype=torch.float32).clamp_min(eps)
loss /= norm_factor

if reduction == "mean":
loss = loss.mean()
if reduction == "sum":
loss = loss.sum()
if reduction == "batchwise_mean":
loss = loss.sum(dim=0)

return loss

Expand Down
22 changes: 11 additions & 11 deletions pytorch_toolbelt/losses/quality_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def __init__(self, beta: float = 2, reduction="mean"):
self.beta = beta
self.reduction = reduction

@torch.cuda.amp.autocast(False)
def forward(self, predictions: Tensor, targets: Tensor) -> Tensor:
"""
Compute quality focal loss
Expand All @@ -32,15 +31,16 @@ def forward(self, predictions: Tensor, targets: Tensor) -> Tensor:
predictions = predictions.float()
targets = targets.float()

bce = torch.nn.functional.binary_cross_entropy_with_logits(predictions, targets, reduction="none")
focal_term = torch.nn.functional.l1_loss(predictions.sigmoid(), targets, reduction="none").pow_(self.beta)
loss = focal_term * bce

if self.reduction == "mean":
return loss.mean()
if self.reduction == "sum":
return loss.sum()
if self.reduction == "normalized":
return loss.sum() / focal_term.sum()
with torch.amp.autocast(device_type=predictions.device.type, enabled=False):
bce = torch.nn.functional.binary_cross_entropy_with_logits(predictions, targets, reduction="none")
focal_term = torch.nn.functional.l1_loss(predictions.sigmoid(), targets, reduction="none").pow_(self.beta)
loss = focal_term * bce

if self.reduction == "mean":
return loss.mean()
if self.reduction == "sum":
return loss.sum()
if self.reduction == "normalized":
return loss.sum() / focal_term.sum()

return loss
14 changes: 4 additions & 10 deletions pytorch_toolbelt/modules/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import torch.jit
from torch import nn, Tensor

__all__ = [
"FeatureMapsSpecification",
Expand All @@ -13,10 +14,6 @@
"AbstractEncoder",
]

from torch import nn, Tensor

from pytorch_toolbelt.utils import pytorch_toolbelt_deprecated


@dataclasses.dataclass
class FeatureMapsSpecification:
Expand Down Expand Up @@ -61,8 +58,7 @@ class HasInputFeaturesSpecification(Protocol):
"""

@torch.jit.unused
def get_input_spec(self) -> FeatureMapsSpecification:
...
def get_input_spec(self) -> FeatureMapsSpecification: ...


class HasOutputFeaturesSpecification(Protocol):
Expand All @@ -71,8 +67,7 @@ class HasOutputFeaturesSpecification(Protocol):
"""

@torch.jit.unused
def get_output_spec(self) -> FeatureMapsSpecification:
...
def get_output_spec(self) -> FeatureMapsSpecification: ...


class AbstractEncoder(nn.Module, HasOutputFeaturesSpecification):
Expand Down Expand Up @@ -108,8 +103,7 @@ def __init__(self, input_spec: FeatureMapsSpecification):
@abstractmethod
def forward(
self, feature_maps: List[Tensor], output_size: Union[Tuple[int, int], torch.Size, None] = None
) -> Union[Tensor, Tuple[Tensor, ...], List[Tensor], Mapping[str, Tensor]]:
...
) -> Union[Tensor, Tuple[Tensor, ...], List[Tensor], Mapping[str, Tensor]]: ...

@torch.jit.unused
def apply_to_final_layer(self, func: Callable[[nn.Module], None]):
Expand Down
Loading
Loading