Skip to content

Commit 38887ec

Browse files
committed
u
1 parent 68b6764 commit 38887ec

File tree

5 files changed

+18
-8
lines changed

5 files changed

+18
-8
lines changed

main_gan.py

+2
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ def main(config):
9393
if config.mode == 'train':
9494
trainer.fit(model, datamodule=dm,
9595
ckpt_path=config.resume_ckpt_path)
96+
trainer.test(model, datamodule=dm,
97+
ckpt_path=config.resume_ckpt_path)
9698
elif config.mode == 'test':
9799
trainer.test(model, datamodule=dm,
98100
ckpt_path=config.resume_ckpt_path)

mol_utils.py

-4
Original file line numberDiff line numberDiff line change
@@ -355,9 +355,5 @@ def save_mol_img(mols, f_name='tmp.png', is_test=False):
355355
a_smi = Chem.MolToSmiles(a_mol)
356356
mol_graph = read_smiles(a_smi)
357357

358-
break
359-
360-
# if not is_test:
361-
# break
362358
except:
363359
continue

parse_config.py

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def get_GAN_config():
3232

3333
# Training or Testing Config
3434
parser.add_argument('--lambda_wgan', type=float, help='weight of the wgan loss')
35+
parser.add_argument('--resume_ckpt_path', type=str, help='resume training from the checkpoint')
3536

3637
# Result Config
3738
parser.add_argument('--result_dir', type=str, help='Directory to save results')

plmodule_gan.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,6 @@ def gp_norm(y, x):
130130
return ((dydx.norm(2, dim=1) - 1) ** 2).mean()
131131

132132
# Random weight term for interpolation between real and fake samples
133-
134133
edge_alpha = torch.rand(real_edges.size(0), 1, 1, 1).type_as(real_edges).requires_grad_(False)
135134
node_alpha = edge_alpha.reshape(-1, 1, 1).requires_grad_(False)
136135
# Get random interpolation between real and fake samples
@@ -330,10 +329,16 @@ def configure_optimizers(self):
330329
self.opt_d = torch.optim.Adam(self.D.parameters(), lr=self.hparams.lr_d)
331330
self.opt_v = torch.optim.Adam(self.V.parameters(), lr=self.hparams.lr_v)
332331
return self.opt_g, self.opt_d, self.opt_v
333-
334-
def on_val_epoch_end(self):
332+
333+
def _shared_on_eval_epoch_end(self):
335334
edges_logits, nodes_logits = self.G(self.sampled_img_z.to(self.device))
336335
mols = self.get_gen_mols(nodes_logits, edges_logits, self.hparams.post_method)
337336
# Saving molecule images.
338337
mol_f_name = os.path.join(self.hparams.img_dir, 'mol-{}.png'.format(self.current_epoch))
339-
save_mol_img(mols, mol_f_name, is_test=self.hparams.mode == 'test')
338+
save_mol_img(mols, mol_f_name, is_test=True)
339+
340+
def on_val_epoch_end(self):
341+
self._shared_on_eval_epoch_end()
342+
343+
def on_test_epoch_end(self):
344+
self._shared_on_eval_epoch_end()

requirements.txt

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
numpy==1.21.5
2+
pysmiles==1.0.1
3+
pytorch_lightning==1.6.3
4+
rdkit==2009.Q1-1
5+
scikit_learn==1.1.1
6+
torch==1.11.0

0 commit comments

Comments
 (0)