Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quick attempt at model interpretation #7

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions cyto_ml/data/intake.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Utilities for expressing our dataset as an intake catalog"""

from intake_xarray import ImageSource


def intake_yaml(
test_url: str,
Expand Down Expand Up @@ -29,3 +31,8 @@ def intake_yaml(
"""
# coerce_shape: [256, 256]
return template


def image_source(image_url):
"""Utility to read and return an image in the same way as intake_xarray does"""
return ImageSource(image_url).to_dask()
15 changes: 11 additions & 4 deletions cyto_ml/models/scivision.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,20 @@ def load_model(url: str):
return model


def raw_model(model: PretrainedModel):
"""Utility to retrieve the pytorch model from its scivision wrapper.
We do this because the wrapper's `predict` interface assumes EXIF metadata
"""
return model._plumbing.model.pretrained_model


def truncate_model(model: PretrainedModel):
"""
Accepts a scivision model wrapper and returns the pytorch model,
with its last fully-connected layer removed so that it returns
2048 features rather than a handle of label predictions
"""
network = torch.nn.Sequential(
*(list(model._plumbing.model.pretrained_model.children())[:-1])
)
network = torch.nn.Sequential(*(list(raw_model(model).children())[:-1]))
return network


Expand All @@ -37,7 +42,9 @@ def prepare_image(image: DataArray):
c) Uses a CUDA device if available
"""
# Convert the image data to a PyTorch tensor
tensor_image = torchvision.transforms.ToTensor()(image.to_numpy())
if hasattr(image, "to_numpy"):
image = image.to_numpy()
tensor_image = torchvision.transforms.ToTensor()(image)

# Check if the input is a single image or a batch
if len(tensor_image.shape) == 3:
Expand Down
8 changes: 7 additions & 1 deletion cyto_ml/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from cyto_ml.models.scivision import (
load_model,
truncate_model,
raw_model,
SCIVISION_URL,
)

Expand All @@ -29,5 +30,10 @@ def image_batch(image_dir):


@pytest.fixture
def scivision_model():
def truncated_model():
return truncate_model(load_model(SCIVISION_URL))


@pytest.fixture
def original_model():
return raw_model(load_model(SCIVISION_URL))
Binary file added cyto_ml/tests/fixtures/cefas_images/copepod_1.tif
Binary file not shown.
Binary file added cyto_ml/tests/fixtures/wasp.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
11 changes: 9 additions & 2 deletions cyto_ml/tests/test_image_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,18 @@
from cyto_ml.models.scivision import prepare_image, flat_embeddings


def test_embeddings(scivision_model, single_image):
features = scivision_model(prepare_image(ImageSource(single_image).to_dask()))
def test_embeddings(truncated_model, single_image):
features = truncated_model(prepare_image(ImageSource(single_image).to_dask()))

assert isinstance(features, Tensor)

embeddings = flat_embeddings(features)

assert len(embeddings) > 0
assert len(embeddings) == features.size()[1]


def test_predictions(original_model, single_image):
predictions = original_model(prepare_image(ImageSource(single_image).to_dask()))
# A probably not very illuminating three output classes
assert len(predictions.detach().cpu().numpy()[0]) == 3
7 changes: 7 additions & 0 deletions cyto_ml/tests/test_intake_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from cyto_ml.data.intake import image_source
from xarray import DataArray


def test_image_source(single_image):
img = image_source(single_image)
assert isinstance(img, DataArray)
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ dependencies:
- dask
- pip:
- pytest
- captum
- imagecodecs
- intake # for reading scivision
- torch==1.10.0 # install before cefas_scivision; it needs this version
Expand Down
1,426 changes: 1,426 additions & 0 deletions notebooks/ModelExplainability.ipynb

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion scripts/image_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def store_embeddings(row):
return

embeddings = flat_embeddings(model(prepare_image(image_data)))

collection.add(
documents=[row.Filename],
embeddings=[embeddings],
Expand Down
Loading