Skip to content

Commit aa58bc2

Browse files
committed
Added lazy loading for torch
1 parent 286498a commit aa58bc2

File tree

2 files changed

+25
-4
lines changed

2 files changed

+25
-4
lines changed

evaluation_function/lazy_load.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import importlib
2+
3+
4+
class LazyModule:
5+
"""
6+
This class is response for lazy loading of heavy imports
7+
"""
8+
def __init__(self, name):
9+
self._name = name
10+
self._module = None
11+
12+
def _load(self):
13+
if self._module is None:
14+
self._module = importlib.import_module(self._name)
15+
return self._module
16+
17+
def __getattr__(self, item):
18+
return getattr(self._load(), item)
19+

evaluation_function/models/basic_nn.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,23 @@
1616
1717
"""
1818

19-
import torch
20-
import torch.nn as nn
21-
import torch.optim as optim
22-
2319
from lf_toolkit.evaluation import Result, Params
2420

2521
from pathlib import Path
2622
import os
2723

24+
from evaluation_function.lazy_load import LazyModule
25+
2826
# Setup paths for saving/loading model and data
2927
BASE_DIR = Path(__file__).resolve().parent
3028
MODEL_DIR = Path(os.environ.get("MODEL_DIR", BASE_DIR / "storage"))
3129
MODEL_DIR.mkdir(parents=True, exist_ok=True)
3230
MODEL_PATH = MODEL_DIR / "basic_nn.pt"
3331

32+
torch = LazyModule("torch")
33+
nn = LazyModule("torch.nn")
34+
optim = LazyModule("torch.optim")
35+
3436
def f(x):
3537
"""Target function with noise (sine wave)."""
3638
return torch.sin(x)

0 commit comments

Comments
 (0)