Skip to content

Commit

Permalink
commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Oct 19, 2022
1 parent 821977c commit 02c5cf2
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 8 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
2. Download the latest version (>=1.12) of Torch(corresponding CUDA version) from the [official website](https://pytorch.org/get-started/locally/) of Torch. It is not recommended to automatically install Torch (CUDA 10.2 default) using the Mini-Lightning dependency, which will cause CUDA version mismatch.
3. Install mini-lightning
```bash
# from pypi (v0.1.4.1)
# from pypi
pip install mini-lightning==0.1.4.1

# Or download the files from the repository to local,
Expand Down
11 changes: 5 additions & 6 deletions mini_lightning/_mini_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ def trainer_init(self, trainer: "Trainer") -> None:
for metric in self.metrics.values():
metric.to(trainer.device)

def _batch_to_device(self, batch: Any, device: Device) -> Any:
@classmethod
def batch_to_device(cls, batch: Any, device: Device) -> Any:
if isinstance(batch, Tensor):
# Ref: https://pytorch-lightning.readthedocs.io/en/stable/_modules/pytorch_lightning/utilities/apply_func.html?highlight=non_blocking#
# same as pytorch-lightning
Expand All @@ -142,20 +143,17 @@ def _batch_to_device(self, batch: Any, device: Device) -> Any:
if isinstance(batch, Sequence):
res = []
for b in batch:
res.append(self._batch_to_device(b, device))
res.append(cls.batch_to_device(b, device))
if isinstance(batch, tuple):
res = tuple(res)
elif isinstance(batch, Mapping):
res = {}
for k, v in batch.items():
res[k] = self._batch_to_device(v, device)
res[k] = cls.batch_to_device(v, device)
else:
raise TypeError(f"batch: {batch}, {type(batch)}")
return res

def batch_to_device(self, batch: Any, device: Device) -> Any:
return self._batch_to_device(batch, device)

def optimizer_step(self) -> None:
# note: skipping the update behavior at the first step may result in a warning in lr_scheduler.
# Don't worry about that ~.
Expand Down Expand Up @@ -424,6 +422,7 @@ def __init__(
self.save_hparams(hparams)
#
if resume_from_ckpt is not None:
logger.info(f"Using ckpt: {resume_from_ckpt}")
self._load_ckpt(resume_from_ckpt, self.device)
#
self.lmodel.trainer_init(self)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def read_file(path: str) -> str:
]
setup(
name="mini-lightning",
version="0.1.4.1",
version="0.1.5",
description=description,
long_description=long_description,
long_description_content_type='text/markdown',
Expand Down

0 comments on commit 02c5cf2

Please sign in to comment.