Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 34 additions & 7 deletions trapdata/db/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import time
from typing import Generator

import alembic
import alembic.command
import sqlalchemy as sa
import sqlalchemy.exc
from alembic import command as alembic
from alembic.config import Config
from alembic.script import ScriptDirectory
from rich import print
from sqlalchemy import orm

Expand Down Expand Up @@ -74,19 +76,44 @@ def create_db(db_path: DatabaseURL) -> None:

Base.metadata.create_all(db, checkfirst=True)
alembic_cfg = get_alembic_config(db_path)
alembic.stamp(alembic_cfg, "head")
alembic.command.stamp(alembic_cfg, "head")


def migrate(db_path: DatabaseURL) -> None:
"""
Run database migrations.

# @TODO See this post for a more complete implementation
# https://pawamoy.github.io/posts/testing-fastapi-ormar-alembic-apps/
Run database migrations with better error handling and verification.
"""
logger.debug("Running any database migrations if necessary")
alembic_cfg = get_alembic_config(db_path)
alembic.upgrade(alembic_cfg, "head")

try:
# Check current state first
current_head = alembic.command.current(alembic_cfg)
script_dir = ScriptDirectory.from_config(alembic_cfg)
target_head = script_dir.get_current_head()

if current_head != target_head:
logger.info(f"Upgrading from {current_head} to {target_head}")
alembic.command.upgrade(alembic_cfg, "head")
logger.info("Migration completed successfully")
else:
logger.debug("Database already at target revision")

# Verify the migration actually worked
logger.debug("Verifying database schema consistency")
alembic.command.check(alembic_cfg)
logger.debug("Database schema verification passed")

except Exception as e:
logger.error(f"Migration failed: {e}")
# Check if we're in an inconsistent state
try:
alembic.command.check(alembic_cfg)
logger.warning("Migration failed but database schema appears consistent")
except Exception as check_error:
logger.error(f"Database is in inconsistent state: {check_error}")
logger.error("Manual intervention may be required to fix migration state")
raise


def get_db(db_path, create=False, update=False):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""New column for saving logits to detections

Revision ID: 68f8b8fe793a
Revises: 1544478c3031
Create Date: 2025-08-07 15:50:24.447765

"""
import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "68f8b8fe793a"
down_revision = "1544478c3031"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("detections", sa.Column("logits", sa.JSON(), nullable=True))
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("detections", "logits")
# ### end Alembic commands ###
8 changes: 6 additions & 2 deletions trapdata/db/models/detections.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ class DetectionListItem(BaseModel):
area_pixels: Optional[float]
last_detected: Optional[datetime.datetime]
label: Optional[str]
score: Optional[int]
score: Optional[float]
# logits: Optional[list[float]]
model_name: Optional[str]
in_queue: bool
notes: Optional[str]
Expand All @@ -41,8 +42,9 @@ class DetectionDetail(DetectionListItem):
sequence_cost: Optional[float]
source_image_path: Optional[pathlib.Path]
timestamp: Optional[str]
bbox_center: Optional[tuple[int, int]]
bbox_center: Optional[tuple[float, float]]
area_pixels: Optional[int]
logits: Optional[list[float]]


class DetectedObject(db.Base):
Expand Down Expand Up @@ -75,6 +77,7 @@ class DetectedObject(db.Base):
sequence_frame = sa.Column(sa.Integer)
sequence_previous_id = sa.Column(sa.Integer)
sequence_previous_cost = sa.Column(sa.Float)
logits = sa.Column(sa.JSON)
cnn_features = sa.Column(sa.JSON)

# @TODO add updated & created timestamps to all db models
Expand Down Expand Up @@ -288,6 +291,7 @@ def report_data(self) -> DetectionDetail:
last_detected=self.last_detected,
notes=self.notes,
in_queue=self.in_queue,
logits=self.logits,
)

def report_data_simple(self):
Expand Down
35 changes: 25 additions & 10 deletions trapdata/ml/models/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,17 +184,18 @@ def get_transforms(self):
]
)

def post_process_batch(self, output):
def post_process_batch(self, output: torch.Tensor) -> list[tuple[str, float, list]]:
predictions = torch.nn.functional.softmax(output, dim=1)
predictions = predictions.cpu().numpy()

categories = predictions.argmax(axis=1)
labels = [self.category_map[cat] for cat in categories]
scores = predictions.max(axis=1).astype(float)

result = list(zip(labels, scores))
logger.debug(f"Post-processing result batch: {result}")
return result
logits = output.cpu().detach().numpy().tolist()
result_per_image = list(zip(labels, scores, logits))
logger.debug(f"Post-processing result batch: {result_per_image}")
return result_per_image


class Resnet50ClassifierLowRes(Resnet50Classifier):
Expand Down Expand Up @@ -249,7 +250,13 @@ def get_dataset(self):
)
return dataset

def save_results(self, object_ids, batch_output, *args, **kwargs):
def save_results(
self,
object_ids,
batch_output: list[tuple[str, float, list]],
*args,
**kwargs,
):
# Here we are saving the moth/non-moth labels
classified_objects_data = [
{
Expand All @@ -258,7 +265,7 @@ def save_results(self, object_ids, batch_output, *args, **kwargs):
"in_queue": True if label == self.positive_binary_label else False,
"model_name": self.name,
}
for label, score in batch_output
for label, score, _logits in batch_output
]
save_classified_objects(self.db_path, object_ids, classified_objects_data)

Expand Down Expand Up @@ -302,16 +309,24 @@ def get_dataset(self):
)
return dataset

def save_results(self, object_ids, batch_output, *args, **kwargs):
def save_results(
self,
object_ids,
batch_output: tuple[list[tuple[str, float]], list],
*args,
**kwargs,
):
# Here we are saving the specific taxon labels
classified_objects_data = [
{
"specific_label": label,
"specific_label_score": score,
"specific_label_score": top_score,
"logits": logits,
"model_name": self.name,
"in_queue": True, # Put back in queue for the feature extractor & tracking
# Put back in queue for the feature extractor & tracking
"in_queue": True,
}
for label, score in batch_output
for label, top_score, logits in batch_output
]
save_classified_objects(self.db_path, object_ids, classified_objects_data)

Expand Down