1+ from typing import Union
2+ from sqlalchemy .engine .url import URL as URL
13import timm
24import torch
35import torch .utils .data
46import torchvision
57
68from trapdata import constants , logger
9+ from trapdata .common .schemas import FilePath
710from trapdata .db .models .detections import save_classified_objects
811from trapdata .db .models .queue import DetectedObjectQueue , UnclassifiedObjectQueue
912
1013from .base import InferenceBaseClass , imagenet_normalization
1114
15+ import numpy as np
16+ import os
17+ from trapdata .ml .utils import get_or_download_file
18+
1219
1320class ClassificationIterableDatabaseDataset (torch .utils .data .IterableDataset ):
1421 def __init__ (self , queue , image_transforms , batch_size = 4 ):
@@ -318,6 +325,24 @@ def save_results(self, object_ids, batch_output, *args, **kwargs):
318325 save_classified_objects (self .db_path , object_ids , classified_objects_data )
319326
320327
328+
329+ # class SpeciesClassifierWithOOD(SpeciesClassifier):
330+ # def save_results(self, object_ids, batch_output, *args, **kwargs):
331+ # # Here we are saving the specific taxon labels
332+ # classified_objects_data = [
333+ # {
334+ # "specific_label": label,
335+ # "specific_label_score": score,
336+ # "model_name": self.name,
337+ # "in_queue": True, # Put back in queue for the feature extractor & tracking
338+ # }
339+ # for label, score in batch_output
340+ # ]
341+ # save_classified_objects(self.db_path, object_ids, classified_objects_data)
342+
343+
344+
345+
321346class QuebecVermontMothSpeciesClassifierMixedResolution (
322347 SpeciesClassifier , Resnet50ClassifierLowRes
323348):
@@ -456,7 +481,7 @@ class QuebecVermontMothSpeciesClassifier2024(SpeciesClassifier, Resnet50TimmClas
456481 )
457482 weights_path = (
458483 "https://object-arbutus.cloud.computecanada.ca/ami-models/moths/classification/"
459- "quebec -vermont_resnet50_baseline_20240417_950de764.pth"
484+ "= -vermont_resnet50_baseline_20240417_950de764.pth"
460485 )
461486 labels_path = (
462487 "https://object-arbutus.cloud.computecanada.ca/ami-models/moths/classification/"
@@ -505,12 +530,12 @@ class PanamaMothSpeciesClassifier2024(SpeciesClassifier, Resnet50TimmClassifier)
505530 )
506531
507532
508- class PanamaMothSpeciesClassifier2025 (SpeciesClassifier , Resnet50TimmClassifier ):
533+ class PanamaPlusWithOODClassifier2025 (SpeciesClassifier , Resnet50TimmClassifier ):
509534 input_size = 128
510535 normalization = imagenet_normalization
511536 lookup_gbif_names = False
512537
513- name = "Panama Species Classifier - Mar 2025"
538+ name = "Panama Plus Species Classifier with OOD detection - Mar 2025"
514539 description = (
515540 "Trained on March 13th, 2025 for 2360 species. "
516541 "https://wandb.ai/moth-ai/panama_classifier/runs/81f5ssv9/overview"
@@ -523,5 +548,21 @@ class PanamaMothSpeciesClassifier2025(SpeciesClassifier, Resnet50TimmClassifier)
523548
524549 labels_path = (
525550 "https://object-arbutus.cloud.computecanada.ca/ami-models/moths/classification/"
526- "panama_plus_category_map-with_names.json"
551+ "panama_plus_category_map-with_names.json"
527552 )
553+
554+ training_csv_path = "https://object-arbutus.cloud.computecanada.ca/ami-models/moths/classification/panama_plus_train.csv"
555+
556+ def save_results (self , object_ids , batch_output , * args , ** kwargs ):
557+ # Here we are saving the specific taxon labels
558+ classified_objects_data = [
559+ {
560+ "specific_label" : label ,
561+ "specific_label_score" : score ,
562+ "model_name" : self .name ,
563+ "in_queue" : True , # Put back in queue for the feature extractor & tracking
564+ }
565+ for label , score in batch_output
566+ ]
567+ save_classified_objects (self .db_path , object_ids , classified_objects_data )
568+
0 commit comments