Skip to content

Commit 18b132e

Browse files
committed
feat: add ood score
1 parent 71545ca commit 18b132e

File tree

11 files changed

+178
-21
lines changed

11 files changed

+178
-21
lines changed

trapdata/api/api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
MothClassifierGlobal,
1919
MothClassifierPanama,
2020
MothClassifierPanama2024,
21+
MothClassifierPanamaPlus2025,
2122
MothClassifierQuebecVermont,
2223
MothClassifierTuringAnguilla,
2324
MothClassifierTuringCostaRica,
@@ -39,6 +40,7 @@
3940

4041

4142
CLASSIFIER_CHOICES = {
43+
"panama_plus_moths_2025": MothClassifierPanamaPlus2025,
4244
"panama_moths_2023": MothClassifierPanama,
4345
"panama_moths_2024": MothClassifierPanama2024,
4446
"quebec_vermont_moths_2023": MothClassifierQuebecVermont,

trapdata/api/models/classification.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
TuringAnguillaSpeciesClassifier,
1616
TuringCostaRicaSpeciesClassifier,
1717
UKDenmarkMothSpeciesClassifier2024,
18+
PanamaPlusWithOODClassifier2025,
1819
)
1920

2021
from ..datasets import ClassificationImageDataset
@@ -25,6 +26,7 @@
2526
SourceImage,
2627
)
2728
from .base import APIInferenceBaseClass
29+
from trapdata.ml.models.base import ClassifierResult
2830

2931

3032
class APIMothClassifier(
@@ -188,3 +190,47 @@ class MothClassifierTuringAnguilla(APIMothClassifier, TuringAnguillaSpeciesClass
188190

189191
class MothClassifierGlobal(APIMothClassifier, GlobalMothSpeciesClassifier):
190192
pass
193+
194+
195+
class MothClassifierPanamaPlus2025(APIMothClassifier, PanamaPlusWithOODClassifier2025):
196+
def post_process_batch(self, logits: torch.Tensor):
197+
"""
198+
Return the labels, softmax/calibrated scores, and the original logits for
199+
each image in the batch.
200+
201+
Almost like the base class method, but we need to return the logits as well.
202+
"""
203+
predictions = torch.nn.functional.softmax(logits, dim=1)
204+
predictions = predictions.cpu().numpy()
205+
206+
ood_scores = None
207+
if self.class_prior:
208+
_, ood_scores = torch.max(predictions - self.class_prior, dim=-1)
209+
else:
210+
_, ood_scores = torch.max(predictions, dim=-1)
211+
212+
batch_results = []
213+
for softmax_scores in predictions:
214+
# Get all class indices and their corresponding scores
215+
class_indices = np.arange(len(softmax_scores))
216+
labels = [self.category_map[i] for i in class_indices]
217+
218+
print("labels type", type(labels))
219+
print("logits type", type(logits))
220+
print("label type", type(softmax_scores))
221+
print("label type", type(ood_scores))
222+
223+
exit()
224+
225+
# TODO: Change batch_results
226+
result = ClassifierResult(
227+
labels=labels,
228+
logits=logits,
229+
softmax_scores=softmax_scores,
230+
ood_scores=ood_scores,
231+
)
232+
batch_results.append(result)
233+
234+
logger.debug(f"Post-processing result batch: {batch_results}")
235+
236+
return batch_results

trapdata/common/logs.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,16 @@
22

33
import structlog
44

5+
# structlog.configure(
6+
# wrapper_class=structlog.make_filtering_bound_logger(logging.INFO),
7+
# )
8+
59
structlog.configure(
6-
wrapper_class=structlog.make_filtering_bound_logger(logging.INFO),
10+
wrapper_class=structlog.make_filtering_bound_logger(logging.CRITICAL),
711
)
812

9-
1013
logger = structlog.get_logger()
14+
logging.disable(logging.CRITICAL)
1115

1216
# import logging
1317
# from rich.logging import RichHandler

trapdata/db/models/detections.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class DetectionListItem(BaseModel):
2727
model_name: Optional[str]
2828
in_queue: bool
2929
notes: Optional[str]
30+
ood_score: Optional[str]
3031

3132
# PyDantic complains because we have an attribute called `model_name`
3233
model_config = ConfigDict(protected_namespaces=[]) # type:ignore
@@ -43,6 +44,7 @@ class DetectionDetail(DetectionListItem):
4344
timestamp: Optional[str]
4445
bbox_center: Optional[tuple[int, int]]
4546
area_pixels: Optional[int]
47+
ood_score: Optional[float]
4648

4749

4850
class DetectedObject(db.Base):
@@ -76,6 +78,7 @@ class DetectedObject(db.Base):
7678
sequence_previous_id = sa.Column(sa.Integer)
7779
sequence_previous_cost = sa.Column(sa.Float)
7880
cnn_features = sa.Column(sa.JSON)
81+
ood_score = sa.Column(sa.Float)
7982

8083
# @TODO add updated & created timestamps to all db models
8184

@@ -288,6 +291,7 @@ def report_data(self) -> DetectionDetail:
288291
last_detected=self.last_detected,
289292
notes=self.notes,
290293
in_queue=self.in_queue,
294+
ood_score=self.ood_score
291295
)
292296

293297
def report_data_simple(self):
@@ -510,7 +514,9 @@ def get_species_for_image(db_path, image_id):
510514
def num_species_for_event(
511515
db_path, monitoring_session, classification_threshold: float = 0.6
512516
) -> int:
513-
query = sa.select(sa.func.count(DetectedObject.specific_label.distinct()),).where(
517+
query = sa.select(
518+
sa.func.count(DetectedObject.specific_label.distinct()),
519+
).where(
514520
(DetectedObject.specific_label_score >= classification_threshold)
515521
& (DetectedObject.monitoring_session == monitoring_session)
516522
)
@@ -522,7 +528,9 @@ def num_species_for_event(
522528
def num_occurrences_for_event(
523529
db_path, monitoring_session, classification_threshold: float = 0.6
524530
) -> int:
525-
query = sa.select(sa.func.count(DetectedObject.sequence_id.distinct()),).where(
531+
query = sa.select(
532+
sa.func.count(DetectedObject.sequence_id.distinct()),
533+
).where(
526534
(DetectedObject.specific_label_score >= classification_threshold)
527535
& (DetectedObject.monitoring_session == monitoring_session)
528536
)

trapdata/ml/models/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ def get_default_model(choices: EnumMeta) -> str:
2929
)
3030
DEFAULT_OBJECT_DETECTOR = get_default_model(ObjectDetectorChoice)
3131

32-
3332
binary_classifiers = {Model.name: Model for Model in BinaryClassifier.__subclasses__()}
3433
BinaryClassifierChoice = ModelChoiceEnum(
3534
"BinaryClassifierChoice",

trapdata/ml/models/base.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import pandas as pd
23
from typing import Union
34

45
import numpy as np
@@ -14,6 +15,8 @@
1415
from trapdata.db.models.queue import QueueManager
1516
from trapdata.ml.utils import StopWatch, get_device, get_or_download_file
1617

18+
from dataclasses import dataclass
19+
1720

1821
class BatchEmptyException(Exception):
1922
pass
@@ -88,6 +91,7 @@ class InferenceBaseClass:
8891
queue: QueueManager
8992
dataset: torch.utils.data.Dataset
9093
dataloader: torch.utils.data.DataLoader
94+
training_csv_path: str | None = None
9195

9296
def __init__(
9397
self,
@@ -105,6 +109,7 @@ def __init__(
105109

106110
self.device = self.device or get_device()
107111
self.category_map = self.get_labels(self.labels_path)
112+
self.class_prior = self.get_class_prior(self.training_csv_path)
108113
self.num_classes = self.num_classes or len(self.category_map)
109114
self.weights = self.get_weights(self.weights_path)
110115
self.transforms = self.get_transforms()
@@ -183,6 +188,25 @@ def fetch_gbif_ids(labels):
183188
else:
184189
return {}
185190

191+
def get_class_prior(self, training_csv_path):
192+
if training_csv_path:
193+
local_path = get_or_download_file(
194+
training_csv_path,
195+
self.user_data_path or torch.hub.get_dir(),
196+
prefix="models",
197+
)
198+
df_train = pd.read_csv(local_path)
199+
categories = sorted(list(df_train["speciesKey"].unique()))
200+
categories_map = {categ: id for id, categ in enumerate(categories)}
201+
df_train["label"] = df_train["speciesKey"].map(categories_map)
202+
cls_idx = df_train["label"].astype(int).values
203+
num_classes = df_train["label"].nunique()
204+
cls_num = np.bincount(cls_idx, minlength=num_classes)
205+
targets = cls_num / cls_num.sum()
206+
return targets
207+
else:
208+
return None
209+
186210
def get_model(self) -> torch.nn.Module:
187211
"""
188212
This method must be implemented by a subclass.
@@ -330,3 +354,12 @@ def run(self):
330354
logger.info(f"{self.name} Batch -- Done")
331355

332356
logger.info(f"{self.name} -- Done")
357+
358+
359+
@dataclass
360+
class ClassifierResult:
361+
# TODO: add types
362+
labels = None
363+
logits = None
364+
softmax_scores = None
365+
ood_scores = None

trapdata/ml/models/classification.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
1+
from typing import Union
2+
from sqlalchemy.engine.url import URL as URL
13
import timm
24
import torch
35
import torch.utils.data
46
import torchvision
57

68
from trapdata import constants, logger
9+
from trapdata.common.schemas import FilePath
710
from trapdata.db.models.detections import save_classified_objects
811
from trapdata.db.models.queue import DetectedObjectQueue, UnclassifiedObjectQueue
912

1013
from .base import InferenceBaseClass, imagenet_normalization
1114

15+
import numpy as np
16+
import os
17+
from trapdata.ml.utils import get_or_download_file
18+
1219

1320
class ClassificationIterableDatabaseDataset(torch.utils.data.IterableDataset):
1421
def __init__(self, queue, image_transforms, batch_size=4):
@@ -318,6 +325,24 @@ def save_results(self, object_ids, batch_output, *args, **kwargs):
318325
save_classified_objects(self.db_path, object_ids, classified_objects_data)
319326

320327

328+
329+
# class SpeciesClassifierWithOOD(SpeciesClassifier):
330+
# def save_results(self, object_ids, batch_output, *args, **kwargs):
331+
# # Here we are saving the specific taxon labels
332+
# classified_objects_data = [
333+
# {
334+
# "specific_label": label,
335+
# "specific_label_score": score,
336+
# "model_name": self.name,
337+
# "in_queue": True, # Put back in queue for the feature extractor & tracking
338+
# }
339+
# for label, score in batch_output
340+
# ]
341+
# save_classified_objects(self.db_path, object_ids, classified_objects_data)
342+
343+
344+
345+
321346
class QuebecVermontMothSpeciesClassifierMixedResolution(
322347
SpeciesClassifier, Resnet50ClassifierLowRes
323348
):
@@ -456,7 +481,7 @@ class QuebecVermontMothSpeciesClassifier2024(SpeciesClassifier, Resnet50TimmClas
456481
)
457482
weights_path = (
458483
"https://object-arbutus.cloud.computecanada.ca/ami-models/moths/classification/"
459-
"quebec-vermont_resnet50_baseline_20240417_950de764.pth"
484+
"=-vermont_resnet50_baseline_20240417_950de764.pth"
460485
)
461486
labels_path = (
462487
"https://object-arbutus.cloud.computecanada.ca/ami-models/moths/classification/"
@@ -505,12 +530,12 @@ class PanamaMothSpeciesClassifier2024(SpeciesClassifier, Resnet50TimmClassifier)
505530
)
506531

507532

508-
class PanamaMothSpeciesClassifier2025(SpeciesClassifier, Resnet50TimmClassifier):
533+
class PanamaPlusWithOODClassifier2025(SpeciesClassifier, Resnet50TimmClassifier):
509534
input_size = 128
510535
normalization = imagenet_normalization
511536
lookup_gbif_names = False
512537

513-
name = "Panama Species Classifier - Mar 2025"
538+
name = "Panama Plus Species Classifier with OOD detection - Mar 2025"
514539
description = (
515540
"Trained on March 13th, 2025 for 2360 species. "
516541
"https://wandb.ai/moth-ai/panama_classifier/runs/81f5ssv9/overview"
@@ -523,5 +548,21 @@ class PanamaMothSpeciesClassifier2025(SpeciesClassifier, Resnet50TimmClassifier)
523548

524549
labels_path = (
525550
"https://object-arbutus.cloud.computecanada.ca/ami-models/moths/classification/"
526-
"panama_plus_category_map-with_names.json"
551+
"panama_plus_category_map-with_names.json"
527552
)
553+
554+
training_csv_path = "https://object-arbutus.cloud.computecanada.ca/ami-models/moths/classification/panama_plus_train.csv"
555+
556+
def save_results(self, object_ids, batch_output, *args, **kwargs):
557+
# Here we are saving the specific taxon labels
558+
classified_objects_data = [
559+
{
560+
"specific_label": label,
561+
"specific_label_score": score,
562+
"model_name": self.name,
563+
"in_queue": True, # Put back in queue for the feature extractor & tracking
564+
}
565+
for label, score in batch_output
566+
]
567+
save_classified_objects(self.db_path, object_ids, classified_objects_data)
568+

trapdata/ml/models/localization.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,14 @@ def get_model(self):
214214
state_dict = checkpoint.get("model_state_dict") or checkpoint
215215
model.load_state_dict(state_dict)
216216
model = model.to(self.device)
217+
218+
# Get the state dictionary
219+
state_dict = model.state_dict()
220+
221+
# Print the shape of each tensor in the state_dict
222+
for name, param in state_dict.items():
223+
print(f"{name}: {param.shape}")
224+
217225
model.eval()
218226
self.model = model
219227
return self.model
@@ -267,6 +275,14 @@ def get_model(self):
267275
checkpoint = torch.load(self.weights, map_location=self.device)
268276
state_dict = checkpoint.get("model_state_dict") or checkpoint
269277
model.load_state_dict(state_dict)
278+
279+
# Get the state dictionary
280+
state_dict = model.state_dict()
281+
282+
# Print the shape of each tensor in the state_dict
283+
for name, param in state_dict.items():
284+
print(f"{name}: {param.shape}")
285+
270286
model = model.to(self.device)
271287
model.eval()
272288
self.model = model

trapdata/ml/models/tracking.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
PanamaMothSpeciesClassifierMixedResolution2023,
2424
QuebecVermontMothSpeciesClassifierMixedResolution,
2525
UKDenmarkMothSpeciesClassifierMixedResolution,
26+
PanamaPlusWithOODClassifier2025
2627
)
2728
from trapdata.ml.utils import get_device
2829

@@ -501,6 +502,12 @@ class PanamaFeatureExtractor(
501502
):
502503
name = "Features from Panama species model"
503504

505+
class PanamaPlusFeatureExtractor(
506+
FeatureExtractor,
507+
PanamaPlusWithOODClassifier2025
508+
):
509+
name = "Features from Panama Plus species model"
510+
504511

505512
def clear_sequences(monitoring_session: MonitoringSession, session: orm.Session):
506513
logger.info(f"Clearing existing sequences for {monitoring_session.day}")

trapdata/ml/pipeline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def start_pipeline(
2323
num_workers=settings.num_workers,
2424
single=single,
2525
)
26+
2627
if object_detector.queue.queue_count() > 0:
2728
object_detector.run()
2829
logger.info("Localization complete")

0 commit comments

Comments
 (0)