forked from drivendataorg/boem-belugas-runtime
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
46 lines (40 loc) · 1.25 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from __future__ import annotations
from ranzen.hydra import Option
import torch.multiprocessing
from whaledo.algorithms import Moco, SimClr
from whaledo.data.datamodule import WhaledoDataModule
from whaledo.models.artifact import ArtifactLoader
from whaledo.models.backbones import Beit, ConvNeXt, ResNet, Swin, SwinV2, ViT
from whaledo.models.meta.ema import EmaModel
from whaledo.models.meta.ft import BitFit
from whaledo.relay import WhaledoRelay
torch.multiprocessing.set_sharing_strategy("file_system")
if __name__ == "__main__":
dm_ops: list[Option] = [
Option(WhaledoDataModule, name="whaledo"),
]
alg_ops: list[Option] = [
Option(Moco, "moco"),
Option(SimClr, "simclr"),
]
bb_ops: list[Option] = [
Option(Beit, "beit"),
Option(ConvNeXt, "convnext"),
Option(ResNet, "resnet"),
Option(Swin, "swin"),
Option(SwinV2, "swinv2"),
Option(ViT, "vit"),
Option(ArtifactLoader, "artifact"),
]
mm_ops: list[Option] = [
Option(BitFit, "bitfit"),
Option(EmaModel, "ema"),
]
WhaledoRelay.with_hydra(
root="conf",
dm=dm_ops,
alg=alg_ops,
backbone=bb_ops,
meta_model=mm_ops,
clear_cache=True,
)