Skip to content

Minimal inference script with bioimage.io.core.prediction() fails #441

Open
@qin-yu

Description

@qin-yu

To allow the following script to run without errors:

import numpy as np
from bioimageio.core.prediction import predict
from bioimageio.core.sample import Sample
from bioimageio.core.tensor import Tensor
from bioimageio.spec.model.v0_5 import TensorId

array = np.random.randint(0, 255, (2, 128, 128, 128), dtype=np.uint8)
dims = ('c', 'z', 'y', 'x')
sample = Sample(members={TensorId('a'): Tensor(array=array, dims=dims)}, stat={}, id='try')

temp = predict(
    model='philosophical-panda',
    inputs=sample,  # `predict()` accepts this input but fails
)

The following needs to be fixed:

  • create_sample_for_model() should accept an iterable of tensor sources
    • Currently only accept key-tensor dict and can't reduce user effort
  • model adapters need to rearrange axis order for samples with axis specified
    • Currently both model and sample has axis information but nothing is done to match them
  • wrong tensor ids in sample should raise meaningful exception earlier
    • Currently the error occurs at "None is passed to normalisation layers in models"

Temporary solution is to be fully explicit:

from typing import assert_never

import numpy as np
from bioimageio.core.axis import AxisId
from bioimageio.core.prediction import predict
from bioimageio.core.sample import Sample
from bioimageio.core.tensor import Tensor
from bioimageio.spec import load_model_description
from bioimageio.spec.model import v0_4, v0_5
from bioimageio.spec.model.v0_5 import TensorId

model = load_model_description("philosophical-panda")
if isinstance(model, v0_4.ModelDescr):
    input_ids = [ipt.name for ipt in model.inputs]
elif isinstance(model, v0_5.ModelDescr):
    input_ids = [ipt.id for ipt in model.inputs]
else:
    assert_never(model)

assert len(input_ids) == 1
tensor_id = input_ids[0]

print("model expects these inputs:", input_ids)

array = np.random.randint(0, 255, (2, 128, 128, 128), dtype=np.uint8)
dims = ("channel", "z", "y", "x")  # FIXME <-- `AxisId` has to be "channel" not "c"
sample = Sample(
    members={
        TensorId(tensor_id): Tensor(array=array, dims=dims).transpose(  # FIXME  <-- `TensorId` has to be specified by user
            [
                AxisId(a) if isinstance(a, str) else a.id for a in model.inputs[0].axes
            ]  # FIXME <-- `AxisId` has to be re-ordered by user
        )
    },
    stat={},
    id="try",
)

temp = predict(model=model, inputs=sample)

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingenhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions