Skip to content

Commit

Permalink
update features
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Mar 30, 2023
1 parent 7f8b11f commit aeb5b2d
Show file tree
Hide file tree
Showing 17 changed files with 28 additions and 23 deletions.
2 changes: 1 addition & 1 deletion examples/ae.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __init__(self, hparams: HParams) -> None:
"loss": MeanMetric(),
}
#
super().__init__([optimizer], metrics, hparams.__dict__)
super().__init__([optimizer], metrics, hparams)
self.encoder = encoder
self.decoder = decoder
self.lr_s = lr_s
Expand Down
2 changes: 1 addition & 1 deletion examples/cl.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(self, hparams: HParams) -> None:
}
self.temperature = hparams.temperature
#
super().__init__([optimizer], metrics, hparams.__dict__)
super().__init__([optimizer], metrics, hparams)
self.resnet = resnet
self.lr_s = lr_s

Expand Down
2 changes: 1 addition & 1 deletion examples/cl_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(self, hparams: HParams) -> None:
}
self.temperature = hparams.temperature
#
super().__init__([optimizer], metrics, hparams.__dict__)
super().__init__([optimizer], metrics, hparams)
self.resnet = resnet
self.lr_s = lr_s

Expand Down
2 changes: 1 addition & 1 deletion examples/cl_ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(self, hparams: HParams) -> None:
}
self.temperature = hparams.temperature
#
super().__init__([optimizer], metrics, hparams.__dict__)
super().__init__([optimizer], metrics, hparams)
self.resnet = resnet
self.lr_s = lr_s

Expand Down
2 changes: 1 addition & 1 deletion examples/cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self, hparams: HParams) -> None:
"acc": Accuracy("multiclass", num_classes=num_classes),
}
#
super().__init__([optimizer], metrics, hparams.__dict__)
super().__init__([optimizer], metrics, hparams)
self.model = model
self.lr_s = lr_s
self.loss_fn = nn.CrossEntropyLoss()
Expand Down
2 changes: 1 addition & 1 deletion examples/cv_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(self, hparams: HParams) -> None:
"acc": Accuracy("multiclass", num_classes=num_classes),
}
#
super().__init__([optimizer], metrics, hparams.__dict__)
super().__init__([optimizer], metrics, hparams)
self.model = model
self.lr_s = lr_s
self.loss_fn = nn.CrossEntropyLoss()
Expand Down
2 changes: 1 addition & 1 deletion examples/cv_ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(self, hparams: HParams) -> None:
"acc": Accuracy("multiclass", num_classes=num_classes),
}
#
super().__init__([optimizer], metrics, hparams.__dict__)
super().__init__([optimizer], metrics, hparams)
self.model = model
self.lr_s = lr_s
self.loss_fn = nn.CrossEntropyLoss()
Expand Down
10 changes: 5 additions & 5 deletions examples/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def get_rand_p(global_step: int, T_max: int, eta_min: float, eta_max: float) ->

class MyLModule(ml.LModule):
def __init__(self, memo_pool: MemoryPool, hparams: HParams) -> None:

self.hparams: HParams
env = gym.make(hparams.env_name, render_mode=RENDER_MODE)
#
in_channels: int = env.observation_space.shape[0]
Expand All @@ -158,7 +158,7 @@ def __init__(self, memo_pool: MemoryPool, hparams: HParams) -> None:
agent = Agent(env, memo_pool, model, ml.select_device(device_ids))

optimizer = getattr(optim, hparams.optim_name)(model.parameters(), **hparams.optim_hparams)
super().__init__([optimizer], {}, hparams.__dict__)
super().__init__([optimizer], {}, hparams)
self.model = model
self.old_model = deepcopy(self.model).requires_grad_(False)
# New_model and old_model are used for model training.
Expand All @@ -169,10 +169,10 @@ def __init__(self, memo_pool: MemoryPool, hparams: HParams) -> None:
self.agent = agent
self.get_rand_p = partial(get_rand_p, **hparams.rand_p)
#
self.warmup_memory_steps = self.hparams["warmup_memory_steps"]
self.warmup_memory_steps = self.hparams.warmup_memory_steps
# synchronize the model every sync_steps
self.sync_steps = self.hparams["sync_steps"]
self.gamma = self.hparams["gamma"] # reward decay
self.sync_steps = self.hparams.sync_steps
self.gamma = self.hparams.gamma # reward decay

#
self._warmup_memo(self.warmup_memory_steps)
Expand Down
2 changes: 1 addition & 1 deletion examples/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(self, hparams: HParams) -> None:
G, D = Generator(**hparams.G_hparams), Discriminator(**hparams.D_hparams)
opt_G = getattr(optim, hparams.opt_G_name)(G.parameters(), **hparams.opt_G_hparams)
opt_D = getattr(optim, hparams.opt_D_name)(D.parameters(), **hparams.opt_D_hparams)
super().__init__([opt_G, opt_D], {}, hparams.__dict__)
super().__init__([opt_G, opt_D], {}, hparams)
self.G = G
self.D = D
self.loss_fn = nn.BCEWithLogitsLoss()
Expand Down
2 changes: 1 addition & 1 deletion examples/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __init__(self, hparams: HParams) -> None:
"acc": Accuracy("multiclass", num_classes=num_classes),
}
#
super().__init__([optimizer], metrics, hparams.__dict__)
super().__init__([optimizer], metrics, hparams)
self.model = model
self.lr_s = lr_s
self.loss_fn = nn.CrossEntropyLoss()
Expand Down
2 changes: 1 addition & 1 deletion examples/gnn2.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(self, hparams: HParams) -> None:
"auc": AUROC("binary")
}
#
super().__init__([optimizer], metrics, hparams.__dict__)
super().__init__([optimizer], metrics, hparams)
self.model = model
self.lr_s = lr_s
self.loss_fn = nn.BCEWithLogitsLoss()
Expand Down
2 changes: 1 addition & 1 deletion examples/gnn3.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(self, hparams: HParams) -> None:
"acc": Accuracy("binary"),
}
#
super().__init__([optimizer], metrics, hparams.__dict__)
super().__init__([optimizer], metrics, hparams)
self.model = model
self.lr_s = lr_s
self.loss_fn = nn.BCEWithLogitsLoss()
Expand Down
2 changes: 1 addition & 1 deletion examples/meta_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __init__(self, hparams: HParams) -> None:
"acc": Accuracy("multiclass", num_classes=proto_dim),
}
#
super().__init__([optimizer], metrics, hparams.__dict__)
super().__init__([optimizer], metrics, hparams)
self.model = model
self.lr_s = lr_s
self.loss_fn = nn.CrossEntropyLoss()
Expand Down
2 changes: 1 addition & 1 deletion examples/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(self, hparams: HParams) -> None:
}
lr_s: LRScheduler = lrs.CosineAnnealingLR(optimizer, **hparams.lrs_hparams)
lr_s = ml.warmup_decorator(lr_s, hparams.warmup)
super().__init__([optimizer], metrics, hparams.__dict__)
super().__init__([optimizer], metrics, hparams)
self.model = model
self.lr_s = lr_s
self.loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)
Expand Down
2 changes: 1 addition & 1 deletion examples/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def __init__(self, hparams: HParams) -> None:
"loss": MeanMetric(),
}
#
super().__init__([optimizer], metrics, hparams.__dict__)
super().__init__([optimizer], metrics, hparams)
self.mse = nn.MSELoss()
self.encoder = encoder
self.decoder = decoder
Expand Down
11 changes: 8 additions & 3 deletions mini_lightning/_mini_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,18 @@ def __init__(
self,
optimizers: List[Optimizer],
metrics: Dict[str, Metric],
hparams: Optional[Dict[str, Any]] = None
hparams: Any = None
) -> None:
"""
optimizers: Use List, for supporting GAN
hparams: Hyperparameters to be saved
object or Dict[str, Any] or None
"""
# _models: for trainer_init(device, ddp), _epoch_start(train, eval); print_model_info; save_ckpt
self._models: List[str] = []
self.optimizers = optimizers
self.metrics = metrics
self.hparams: Dict[str, Any] = hparams if hparams is not None else {}
self.hparams = hparams
self.trainer: Optional["Trainer"] = None

@property
Expand Down Expand Up @@ -516,9 +517,13 @@ def _check_hparams(self, hparams: Any) -> Any:
res = repr(hparams) # e.g. function
return res

def save_hparams(self, hparams: Dict[str, Any]) -> None:
def save_hparams(self, hparams: Any) -> None:
if self.rank not in {-1, 0}:
return
if hparams is None:
hparams = {}
elif not isinstance(hparams, dict):
hparams = hparams.__dict__
saved_hparams = self._check_hparams(hparams)
logger.info(f"Saving hparams: {saved_hparams}")
write_to_yaml(saved_hparams, self.hparams_path)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def read_file(path: str) -> str:
]
setup(
name="mini-lightning",
version="0.1.10",
version="0.2.0",
description=description,
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down

0 comments on commit aeb5b2d

Please sign in to comment.