From 9fc07772434bb8e94120fa4538b195a62cf3d383 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=80lex=20Sol=C3=A9?= Date: Tue, 22 Oct 2024 12:27:32 +0200 Subject: [PATCH] updated readme and minor bug fixes --- README.md | 49 ++++++++- dataset/figshare_dataset.py | 6 +- loader/loader.py | 22 ++-- main.py | 2 +- main_irene.py | 180 -------------------------------- scripts/debug_irene.sh | 4 - scripts/train_cartnet_adp.sh | 33 +----- scripts/train_cartnet_jarvis.sh | 8 ++ 8 files changed, 70 insertions(+), 234 deletions(-) delete mode 100644 main_irene.py delete mode 100644 scripts/debug_irene.sh create mode 100644 scripts/train_cartnet_jarvis.sh diff --git a/README.md b/README.md index e10e99c..4fe766b 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,20 @@ These dependencies are automatically installed when you create the Conda environ ## Dataset +### ADP Dataset: + +The ADP dataset can be downloaded from the following link. + +The dataset can be extracted using: +tar -xf adp_dataset.tar.gz + +[!NOTE] + +The ADP_DATASET/ folder should be placed inside the dataset/ folder or scpecify the new path via --dataset_path flag in main.py + + + + ## Training To recreate the experiments from the paper: @@ -77,27 +91,54 @@ To recreate the experiments from the paper: To train **ADP Dataset** using **CartNet**: ```sh -bash train_scripts/train_cartnet_adp.sh +cd scripts/ +bash train_cartnet_adp.sh ``` To train **ADP Dataset** using **eComformer**: ```sh -bash train_scripts/train_ecomformer_adp.sh +cd scripts/ +bash train_ecomformer_adp.sh ``` To train **ADP Dataset** using **eComformer**: ```sh -bash train_scripts/train_icomformer_adp.sh +cd scripts/ +bash train_icomformer_adp.sh ``` +To run the ablation experiments in the **ADP Dataset**: + +```sh +cd scripts/ +bash run_ablations.sh +```` + +### Jarvis: + +```sh +cd scripts/ +bash train_cartnet_jarvis.sh +```` + +### The Materials Project + +```sh +cd scripts/ +bash train_cartnet_megnet.sh +``` + + + + ## Evaluation Instructions to evaluate the model: ```sh # Command to evaluate the model -python evaluate.py --config configs/eval_config.yaml --checkpoint path/to/checkpoint.pth +python main.py --inference --checkpoint_path path/to/checkpoint.pth ``` ## Results diff --git a/dataset/figshare_dataset.py b/dataset/figshare_dataset.py index 1c0cecb..5a3c839 100644 --- a/dataset/figshare_dataset.py +++ b/dataset/figshare_dataset.py @@ -9,13 +9,13 @@ class Figshare_Dataset(InMemoryDataset): - def __init__(self, root, data, targets, transform=None, pre_transform=None, name="jarvis", radius=5.0, max_neigh=None, augment=False): + def __init__(self, root, data, targets, transform=None, pre_transform=None, name="jarvis", radius=5.0, max_neigh=-1, augment=False): self.data = data self.targets = targets self.name = name self.radius = radius - self.max_neigh = None + self.max_neigh = max_neigh if max_neigh > 0 else None self.augment = augment super(Figshare_Dataset, self).__init__(root, transform, pre_transform) self.data, self.slices = torch.load(self.processed_paths[0]) @@ -64,7 +64,7 @@ def process(self): batch = Batch.from_data_list([data]) edge_index, _, _, cart_vector = radius_graph_pbc(batch, self.radius, self.max_neigh) - data.cart_dist = torch.norm(cart_vector, p=2, dim=-1).unsqueeze(-1) + data.cart_dist = torch.norm(cart_vector, p=2, dim=-1) data.cart_dir = torch.nn.functional.normalize(cart_vector, p=2, dim=-1) diff --git a/loader/loader.py b/loader/loader.py index 410c31d..6d648b6 100644 --- a/loader/loader.py +++ b/loader/loader.py @@ -37,18 +37,18 @@ def create_loader(): cfg.dataset.name = "dft_3d_2021" seed = 123 #PotNet uses seed=123 for the comparative table - target = cfg.jarvis_target - if cfg.jarvis_target in ["shear modulus", "bulk modulus"] and cfg.dataset.name == "megnet": + target = cfg.figshare_target + if cfg.figshare_target in ["shear modulus", "bulk modulus"] and cfg.dataset.name == "megnet": import pickle as pk - target = cfg.jarvis_target - if cfg.jarvis_target == "bulk modulus": + target = cfg.figshare_target + if cfg.figshare_target == "bulk modulus": try: data_train = pk.load(open("./dataset/megnet/bulk_megnet_train.pkl", "rb")) data_val = pk.load(open("./dataset/megnet/bulk_megnet_val.pkl", "rb")) data_test = pk.load(open("./dataset/megnet/bulk_megnet_test.pkl", "rb")) except: raise Exception("Bulk modulus dataset not found, please download it from https://figshare.com/projects/Bulk_and_shear_datasets/165430") - elif cfg.jarvis_target == "shear modulus": + elif cfg.figshare_target == "shear modulus": try: data_train = pk.load(open("./dataset/megnet/shear_megnet_train.pkl", "rb")) data_val = pk.load(open("./dataset/megnet/shear_megnet_val.pkl", "rb")) @@ -75,7 +75,7 @@ def create_loader(): targets.append(i) else: - data = jdata(cfg.jarvis_name) + data = jdata(cfg.dataset.name) dat = [] all_targets = [] for i in data: @@ -99,11 +99,11 @@ def create_loader(): targets_val = [all_targets[i] for i in ids_val] targets_test = [all_targets[i] for i in ids_test] - radius = cfg.cutoff - prefix = cfg.jarvis_name+"_"+str(radius)+"_"+str(cfg.max_neighbours)+"_"+target+"_"+str(seed) - dataset_train = Figshare_Dataset(root=cfg.dataset_path, data=dat_train, targets=targets_train, radius=radius, name=prefix+"_train") - dataset_val = Figshare_Dataset(root=cfg.dataset_path, data=dat_val, targets=targets_val, radius=radius, name=prefix+"_val") - dataset_test = Figshare_Dataset(root=cfg.dataset_path, data=dat_test, targets=targets_test, radius=radius, name=prefix+"_test") + radius = cfg.radius + prefix = cfg.dataset.name+"_"+str(radius)+"_"+str(cfg.max_neighbours)+"_"+target+"_"+str(seed) + dataset_train = Figshare_Dataset(root=cfg.dataset_path, data=dat_train, targets=targets_train, radius=radius, max_neigh=cfg.max_neighbours, name=prefix+"_train") + dataset_val = Figshare_Dataset(root=cfg.dataset_path, data=dat_val, targets=targets_val, radius=radius, max_neigh=cfg.max_neighbours, name=prefix+"_val") + dataset_test = Figshare_Dataset(root=cfg.dataset_path, data=dat_test, targets=targets_test, radius=radius, max_neigh=cfg.max_neighbours, name=prefix+"_test") else: raise Exception("Dataset not implemented") diff --git a/main.py b/main.py index fef2480..d3271eb 100644 --- a/main.py +++ b/main.py @@ -96,7 +96,7 @@ def montecarlo(model, loader): parser.add_argument('--name', type=str, default="CartNet", help="name of the Wandb experiment" ) parser.add_argument("--batch", type=int, default=4, help="Batch size") parser.add_argument("--batch_accumulation", type=int, default=16, help="Batch Accumulation") - parser.add_argument("--dataset", type=str, default="ADP", help="Dataset name. Available: ADP, Jarvis, MaterialsProject") + parser.add_argument("--dataset", type=str, default="ADP", help="Dataset name. Available: ADP, jarvis, megnet") parser.add_argument("--dataset_path", type=str, default="./dataset/ADP_DATASET/") parser.add_argument("--inference", action="store_true", help="Inference") parser.add_argument("--montecarlo", action="store_true", help="Montecarlo") diff --git a/main_irene.py b/main_irene.py deleted file mode 100644 index f4ce95a..0000000 --- a/main_irene.py +++ /dev/null @@ -1,180 +0,0 @@ -import torch -import logging -import argparse -import pickle -from tqdm import tqdm -from logger.logger import create_logger -from loader.loader import create_loader -from models.master import create_model -from train.train_irene import train -from torch_geometric.graphgym.utils.comp_budget import params_count -from torch_geometric import seed_everything -from torch_geometric.graphgym.config import cfg, set_cfg -from torch_geometric.graphgym.logger import set_printing - - -def inference(model, loader): - from train.metrics import compute_loss, compute_3D_IoU - model.eval() - - with torch.no_grad(): - inference_output = {"pred": [], "true": [], "refcode": [], "pos": [], "atoms": [], "iou": [], "mae": []} - for iter, batch in tqdm(enumerate(loader), total=len(loader), ncols=50): - batch.to("cuda:0") - inference_output["atoms"].append(batch.x[batch.non_H_mask].detach().to("cpu")) - inference_output["pos"].append(batch.pos[batch.non_H_mask].detach().to("cpu")) - inference_output["refcode"].append(batch.refcode) - _pred, _true = model(batch) - inference_output["pred"].append(_pred.detach().to("cpu")) - inference_output["true"].append(_true.detach().to("cpu")) - inference_output["iou"].append(compute_3D_IoU(_pred, _true).detach().to("cpu")) - inference_output["mae"].append(compute_loss(_pred, _true)[0].detach().to("cpu")) - - pickle.dump(inference_output, open(cfg.inference_output, "wb")) - -def montecarlo(model, loader): - from train.metrics import compute_loss, compute_3D_IoU - import roma - - model.eval() - iou_montecarlo = [] - mae_montecarlo = [] - with torch.no_grad(): - for i in range(100): - inference_output = {"pred": [], "true": [], "refcode": [], "pos": [], "atoms": [], "iou": [], "mae": []} - for iter, batch in tqdm(enumerate(loader), total=len(loader), ncols=50): - batch_copy = batch.clone() - batch.to("cuda:0") - inference_output["atoms"].append(batch.x[batch.non_H_mask].detach().to("cpu")) - inference_output["pos"].append(batch.pos[batch.non_H_mask].detach().to("cpu")) - inference_output["refcode"].append(batch.refcode) - pseudo_true, _ = model(batch) - R = roma.utils.random_rotmat(size=1, device=batch.x.device).squeeze(0) - pseudo_true = R.transpose(-1,-2) @ pseudo_true @ R - batch_copy.to("cuda:0") - batch_copy.cart_dir = batch_copy.cart_dir @ R - pred, _ = model(batch_copy) - inference_output["pred"].append(pred.detach().to("cpu")) - inference_output["true"].append(pseudo_true.detach().to("cpu")) - inference_output["iou"].append(compute_3D_IoU(_pred, _true).detach().to("cpu")) - inference_output["mae"].append(compute_loss(_pred, _true)[0].detach().to("cpu")) - pickle.dump(inference_output, open(cfg.inference_output.replace(".pkl", "_montecarlo_"+str(i)+".pkl"), "wb")) - iou_montecarlo+=inference_output["iou"] - mae_montecarlo+=inference_output["mae"] - - iou_montecarlo = torch.cat(iou_montecarlo, dim=0) - mae_montecarlo = torch.cat(mae_montecarlo, dim=0) - - logging.info(f"Montecarlo IoU: {iou_montecarlo.mean().item()}+/-{iou_montecarlo.std().item()}") - logging.info(f"Montecarlo MAE: {mae_montecarlo.mean().item()}+/-{mae_montecarlo.std().item()}") - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--seed', type=int, default=0, help='Seed for the experiment') - parser.add_argument('--name', type=str, default="CartNet", help="name of the Wandb experiment" ) - parser.add_argument("--batch", type=int, default=4, help="Batch size") - parser.add_argument("--batch_accumulation", type=int, default=16, help="Batch Accumulation") - parser.add_argument("--dataset", type=str, default="ADP", help="Dataset name. Available: ADP, Jarvis, MaterialsProject") - parser.add_argument("--dataset_path", type=str, default="./dataset/ADP_DATASET/") - parser.add_argument("--inference", action="store_true", help="Inference") - parser.add_argument("--montecarlo", action="store_true", help="Montecarlo") - parser.add_argument("--weighs_path", type=str, default=None, help="Path to the weights of the model") - parser.add_argument("--inference_output", type=str, default="./inference.pkl", help="Path to the inference output") - parser.add_argument("--figshare_target", type=str, default="formation_energy_peratom", help="Figshare dataset target") - parser.add_argument("--wandb_project", type=str, default="ADP", help="Wandb project name") - parser.add_argument("--wandb_entity", type=str, default="aiquaneuro", help="Name of the wandb entity") - parser.add_argument("--loss", type=str, default="MAE", help="Loss function") - parser.add_argument("--epochs", type=int, default=50, help="Number of epochs") - parser.add_argument("--learning_rate", type=float, default=1e-3, help="Learning rate") - parser.add_argument("--warmup", type=float, default=0.01, help="Warmup") - parser.add_argument('--model', type=str, default="CartNet", help="Model Name") - parser.add_argument("--max_neighbours", type=int, default=25, help="Max neighbours (only for iComformer/eComformer)") - parser.add_argument("--radius", type=float, default=5.0, help="Radius for the Radius Graph Neighbourhood") - parser.add_argument("--num_layers", type=int, default=4, help="Number of layers") - parser.add_argument("--dim_in", type=int, default=256, help="Input dimension") - parser.add_argument("--dim_rbf", type=int, default=64, help="Number of RBF") - parser.add_argument('--augment', action='store_true', help='Hydrogens') - parser.add_argument("--invariant", action="store_true", help="Rotation Invariant model") - parser.add_argument("--disable_temp", action="store_false", help="Disable Temperature") - parser.add_argument("--no_standarize_temp", action="store_false", help="Standarize temperature") - parser.add_argument("--disable_envelope", action="store_false", help="Disable envelope") - parser.add_argument('--disable_H', action='store_false', help='Hydrogens') - parser.add_argument("--threads", type=int, default= 8, help="Number of threads") - parser.add_argument("--workers", type=int, default=5, help="Number of workers") - - set_cfg(cfg) - - args, _ = parser.parse_known_args() - cfg.seed = args.seed - cfg.name = args.name - cfg.run_dir = "results/"+cfg.name+"/"+str(cfg.seed) - cfg.dataset.task_type = "regression" - cfg.batch = args.batch - cfg.batch_accumulation = args.batch_accumulation - cfg.dataset.name = args.dataset - cfg.dataset_path = args.dataset_path - cfg.figshare_target = args.figshare_target - cfg.wandb_project = args.wandb_project - cfg.wandb_entity = args.wandb_entity - cfg.loss = args.loss - cfg.optim.max_epoch = args.epochs - cfg.learning_rate = args.learning_rate - cfg.warmup = args.warmup - cfg.model = args.model - cfg.max_neighbours = False if cfg.model== "CartNet" else args.max_neighbours - cfg.radius = args.radius - cfg.num_layers = args.num_layers - cfg.dim_in = args.dim_in - cfg.dim_rbf = args.dim_rbf - cfg.augment = False if cfg.model in ["icomformer", "ecomformer"] else args.augment - cfg.invariant = args.invariant - cfg.use_temp = False if cfg.dataset.name != "ADP" else args.disable_temp - cfg.standarize_temp = args.no_standarize_temp - cfg.envelope = args.disable_envelope - cfg.use_H = args.disable_H - cfg.workers = args.workers - - torch.set_num_threads(args.threads) - - set_printing() - - #Seed - seed_everything(cfg.seed) - - logging.info(f"Experiment will be saved at: {cfg.run_dir}") - - loaders = create_loader() - - model = create_model() - - logging.info(model) - cfg.params_count = params_count(model) - logging.info(f"Number of parameters: {cfg.params_count}") - - optimizer = torch.optim.Adam(model.parameters(), lr=cfg.learning_rate) - - loggers = create_logger() - - if args.inference: - assert args.weighs_path is not None, "Weights path not provided" - assert cfg.dataset.name == "ADP", "Inference only for ADP dataset" - ckpt = torch.load(args.weighs_path) - model.load_state_dict(ckpt["model_state"]) - cfg.inference_output = args.inference_output - inference(model, loaders[-1]) - elif args.montecarlo: - assert args.weighs_path is not None, "Weights path not provided" - assert cfg.dataset.name == "ADP", "Montecarlo only for ADP dataset" - montecarlo(model, loaders[-1]) - else: - train(model, loaders, optimizer, loggers) - - - - - - - - - - diff --git a/scripts/debug_irene.sh b/scripts/debug_irene.sh deleted file mode 100644 index 4b081c7..0000000 --- a/scripts/debug_irene.sh +++ /dev/null @@ -1,4 +0,0 @@ - -CUDA_VISIBLE_DEVICES=0 python main_irene.py --seed 0 --name "debug_irene" --model "CartNet" --dataset "ADP" \ - --wandb_project "CartNet Paper" --batch_size 64 --lr 0.001 --epochs 1 \ - --augment --workers 6 \ No newline at end of file diff --git a/scripts/train_cartnet_adp.sh b/scripts/train_cartnet_adp.sh index d3ffb19..2a6dc78 100644 --- a/scripts/train_cartnet_adp.sh +++ b/scripts/train_cartnet_adp.sh @@ -1,36 +1,7 @@ -CUDA_VISIBLE_DEVICES=0 python main.py --seed 0 --name "CartNet" --model "CartNet" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \ +CUDA_VISIBLE_DEVICES=0 python ../main.py --seed 0 --name "CartNet" --model "CartNet" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \ --wandb_project "CartNet Paper" --batch 4 --batch_accumulation 16 --lr 0.001 --epochs 50 \ - --augment & -CUDA_VISIBLE_DEVICES=1 python main.py --seed 1 --name "CartNet" --model "CartNet" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \ - --wandb_project "CartNet Paper" --batch 4 --batch_accumulation 16 --lr 0.001 --epochs 50 \ - --augment & - -CUDA_VISIBLE_DEVICES=2 python main.py --seed 2 --name "CartNet" --model "CartNet" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \ - --wandb_project "CartNet Paper" --batch 4 --batch_accumulation 16 --lr 0.001 --epochs 50 \ - --augment & - -CUDA_VISIBLE_DEVICES=3 python main.py --seed 3 --name "CartNet" --model "CartNet" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \ - --wandb_project "CartNet Paper" --batch 4 --batch_accumulation 16 --lr 0.001 --epochs 50 \ - --augment & - -CUDA_VISIBLE_DEVICES=4 python main.py --seed 4 --name "CartNet" --model "CartNet" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \ - --wandb_project "CartNet Paper" --batch 4 --batch_accumulation 16 --lr 0.001 --epochs 50 \ - --augment & - -CUDA_VISIBLE_DEVICES=5 python main.py --seed 5 --name "CartNet" --model "CartNet" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \ - --wandb_project "CartNet Paper" --batch 4 --batch_accumulation 16 --lr 0.001 --epochs 50 \ - --augment & - -CUDA_VISIBLE_DEVICES=6 python main.py --seed 6 --name "CartNet" --model "CartNet" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \ - --wandb_project "CartNet Paper" --batch 4 --batch_accumulation 16 --lr 0.001 --epochs 50 \ - --augment & - -CUDA_VISIBLE_DEVICES=7 python main.py --seed 7 --name "CartNet" --model "CartNet" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \ - --wandb_project "CartNet Paper" --batch 4 --batch_accumulation 16 --lr 0.001 --epochs 50 \ - --augment & - -wait + --augment \ No newline at end of file diff --git a/scripts/train_cartnet_jarvis.sh b/scripts/train_cartnet_jarvis.sh new file mode 100644 index 0000000..f82a12d --- /dev/null +++ b/scripts/train_cartnet_jarvis.sh @@ -0,0 +1,8 @@ + +cd .. +CUDA_VISIBLE_DEVICES=0 python main.py --seed 0 --name "jarvis_dft_3D_optb88vdw_bandgap_variantnetv4_variant_cell" --model "CartNet" --dataset "jarvis" --dataset_path "./dataset/jarvis/" \ + --wandb_project "CartNet Paper" --batch 64 --batch_accumulation 1 --lr 0.001 --epochs 50 --figshare_target "optb88vdw_bandgap" \ + --augment + + + \ No newline at end of file