|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 | 3 | import importlib.util
|
4 |
| -from functools import singledispatch |
5 | 4 | from itertools import chain
|
6 | 5 | from typing import (
|
7 | 6 | Any,
|
|
20 | 19 | from numpy.typing import NDArray
|
21 | 20 | from typing_extensions import Unpack, assert_never
|
22 | 21 |
|
23 |
| -from bioimageio.spec._internal.io import HashKwargs, download |
| 22 | +from bioimageio.spec._internal.io_utils import HashKwargs, download |
24 | 23 | from bioimageio.spec.common import FileSource
|
25 | 24 | from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
|
26 | 25 | from bioimageio.spec.model.v0_4 import CallableFromDepencency, CallableFromFile
|
|
44 | 43 | from .tensor import Tensor
|
45 | 44 |
|
46 | 45 |
|
47 |
| -@singledispatch |
48 |
| -def import_callable(node: type, /) -> Callable[..., Any]: |
| 46 | +def import_callable( |
| 47 | + node: Union[CallableFromDepencency, ArchitectureFromLibraryDescr], |
| 48 | + /, |
| 49 | + **kwargs: Unpack[HashKwargs], |
| 50 | +) -> Callable[..., Any]: |
49 | 51 | """import a callable (e.g. a torch.nn.Module) from a spec node describing it"""
|
50 |
| - raise TypeError(type(node)) |
51 |
| - |
52 |
| - |
53 |
| -@import_callable.register |
54 |
| -def _(node: CallableFromDepencency, **kwargs: Unpack[HashKwargs]) -> Callable[..., Any]: |
55 |
| - module = importlib.import_module(node.module_name) |
56 |
| - c = getattr(module, str(node.callable_name)) |
57 |
| - if not callable(c): |
58 |
| - raise ValueError(f"{node} (imported: {c}) is not callable") |
59 |
| - |
60 |
| - return c |
| 52 | + if isinstance(node, CallableFromDepencency): |
| 53 | + module = importlib.import_module(node.module_name) |
| 54 | + c = getattr(module, str(node.callable_name)) |
| 55 | + elif isinstance(node, ArchitectureFromLibraryDescr): |
| 56 | + module = importlib.import_module(node.import_from) |
| 57 | + c = getattr(module, str(node.callable)) |
| 58 | + elif isinstance(node, CallableFromFile): |
| 59 | + c = _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs) |
| 60 | + elif isinstance(node, ArchitectureFromFileDescr): |
| 61 | + c = _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256) |
61 | 62 |
|
| 63 | + else: |
| 64 | + assert_never(node) |
62 | 65 |
|
63 |
| -@import_callable.register |
64 |
| -def _( |
65 |
| - node: ArchitectureFromLibraryDescr, **kwargs: Unpack[HashKwargs] |
66 |
| -) -> Callable[..., Any]: |
67 |
| - module = importlib.import_module(node.import_from) |
68 |
| - c = getattr(module, str(node.callable)) |
69 | 66 | if not callable(c):
|
70 | 67 | raise ValueError(f"{node} (imported: {c}) is not callable")
|
71 | 68 |
|
72 | 69 | return c
|
73 | 70 |
|
74 | 71 |
|
75 |
| -@import_callable.register |
76 |
| -def _(node: CallableFromFile, **kwargs: Unpack[HashKwargs]): |
77 |
| - return _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs) |
78 |
| - |
79 |
| - |
80 |
| -@import_callable.register |
81 |
| -def _(node: ArchitectureFromFileDescr, **kwargs: Unpack[HashKwargs]): |
82 |
| - return _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256) |
83 |
| - |
84 |
| - |
85 | 72 | def _import_from_file_impl(
|
86 | 73 | source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs]
|
87 | 74 | ):
|
|
0 commit comments