Skip to content

Commit

Permalink
updated readme and minor bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
alexsoleg committed Oct 22, 2024
1 parent 1fb2b8b commit 9fc0777
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 234 deletions.
49 changes: 45 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,20 @@ These dependencies are automatically installed when you create the Conda environ

## Dataset

### ADP Dataset:

The ADP dataset can be downloaded from the following link.

The dataset can be extracted using:
tar -xf adp_dataset.tar.gz

[!NOTE]

The ADP_DATASET/ folder should be placed inside the dataset/ folder or scpecify the new path via --dataset_path flag in main.py




## Training

To recreate the experiments from the paper:
Expand All @@ -77,27 +91,54 @@ To recreate the experiments from the paper:
To train **ADP Dataset** using **CartNet**:

```sh
bash train_scripts/train_cartnet_adp.sh
cd scripts/
bash train_cartnet_adp.sh
```

To train **ADP Dataset** using **eComformer**:

```sh
bash train_scripts/train_ecomformer_adp.sh
cd scripts/
bash train_ecomformer_adp.sh
```
To train **ADP Dataset** using **eComformer**:

```sh
bash train_scripts/train_icomformer_adp.sh
cd scripts/
bash train_icomformer_adp.sh
```

To run the ablation experiments in the **ADP Dataset**:

```sh
cd scripts/
bash run_ablations.sh
````

### Jarvis:

```sh
cd scripts/
bash train_cartnet_jarvis.sh
````
### The Materials Project
```sh
cd scripts/
bash train_cartnet_megnet.sh
```




## Evaluation

Instructions to evaluate the model:

```sh
# Command to evaluate the model
python evaluate.py --config configs/eval_config.yaml --checkpoint path/to/checkpoint.pth
python main.py --inference --checkpoint_path path/to/checkpoint.pth
```

## Results
Expand Down
6 changes: 3 additions & 3 deletions dataset/figshare_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@


class Figshare_Dataset(InMemoryDataset):
def __init__(self, root, data, targets, transform=None, pre_transform=None, name="jarvis", radius=5.0, max_neigh=None, augment=False):
def __init__(self, root, data, targets, transform=None, pre_transform=None, name="jarvis", radius=5.0, max_neigh=-1, augment=False):

self.data = data
self.targets = targets
self.name = name
self.radius = radius
self.max_neigh = None
self.max_neigh = max_neigh if max_neigh > 0 else None
self.augment = augment
super(Figshare_Dataset, self).__init__(root, transform, pre_transform)
self.data, self.slices = torch.load(self.processed_paths[0])
Expand Down Expand Up @@ -64,7 +64,7 @@ def process(self):
batch = Batch.from_data_list([data])
edge_index, _, _, cart_vector = radius_graph_pbc(batch, self.radius, self.max_neigh)

data.cart_dist = torch.norm(cart_vector, p=2, dim=-1).unsqueeze(-1)
data.cart_dist = torch.norm(cart_vector, p=2, dim=-1)
data.cart_dir = torch.nn.functional.normalize(cart_vector, p=2, dim=-1)


Expand Down
22 changes: 11 additions & 11 deletions loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,18 @@ def create_loader():
cfg.dataset.name = "dft_3d_2021"

seed = 123 #PotNet uses seed=123 for the comparative table
target = cfg.jarvis_target
if cfg.jarvis_target in ["shear modulus", "bulk modulus"] and cfg.dataset.name == "megnet":
target = cfg.figshare_target
if cfg.figshare_target in ["shear modulus", "bulk modulus"] and cfg.dataset.name == "megnet":
import pickle as pk
target = cfg.jarvis_target
if cfg.jarvis_target == "bulk modulus":
target = cfg.figshare_target
if cfg.figshare_target == "bulk modulus":
try:
data_train = pk.load(open("./dataset/megnet/bulk_megnet_train.pkl", "rb"))
data_val = pk.load(open("./dataset/megnet/bulk_megnet_val.pkl", "rb"))
data_test = pk.load(open("./dataset/megnet/bulk_megnet_test.pkl", "rb"))
except:
raise Exception("Bulk modulus dataset not found, please download it from https://figshare.com/projects/Bulk_and_shear_datasets/165430")
elif cfg.jarvis_target == "shear modulus":
elif cfg.figshare_target == "shear modulus":
try:
data_train = pk.load(open("./dataset/megnet/shear_megnet_train.pkl", "rb"))
data_val = pk.load(open("./dataset/megnet/shear_megnet_val.pkl", "rb"))
Expand All @@ -75,7 +75,7 @@ def create_loader():
targets.append(i)

else:
data = jdata(cfg.jarvis_name)
data = jdata(cfg.dataset.name)
dat = []
all_targets = []
for i in data:
Expand All @@ -99,11 +99,11 @@ def create_loader():
targets_val = [all_targets[i] for i in ids_val]
targets_test = [all_targets[i] for i in ids_test]

radius = cfg.cutoff
prefix = cfg.jarvis_name+"_"+str(radius)+"_"+str(cfg.max_neighbours)+"_"+target+"_"+str(seed)
dataset_train = Figshare_Dataset(root=cfg.dataset_path, data=dat_train, targets=targets_train, radius=radius, name=prefix+"_train")
dataset_val = Figshare_Dataset(root=cfg.dataset_path, data=dat_val, targets=targets_val, radius=radius, name=prefix+"_val")
dataset_test = Figshare_Dataset(root=cfg.dataset_path, data=dat_test, targets=targets_test, radius=radius, name=prefix+"_test")
radius = cfg.radius
prefix = cfg.dataset.name+"_"+str(radius)+"_"+str(cfg.max_neighbours)+"_"+target+"_"+str(seed)
dataset_train = Figshare_Dataset(root=cfg.dataset_path, data=dat_train, targets=targets_train, radius=radius, max_neigh=cfg.max_neighbours, name=prefix+"_train")
dataset_val = Figshare_Dataset(root=cfg.dataset_path, data=dat_val, targets=targets_val, radius=radius, max_neigh=cfg.max_neighbours, name=prefix+"_val")
dataset_test = Figshare_Dataset(root=cfg.dataset_path, data=dat_test, targets=targets_test, radius=radius, max_neigh=cfg.max_neighbours, name=prefix+"_test")
else:
raise Exception("Dataset not implemented")

Expand Down
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def montecarlo(model, loader):
parser.add_argument('--name', type=str, default="CartNet", help="name of the Wandb experiment" )
parser.add_argument("--batch", type=int, default=4, help="Batch size")
parser.add_argument("--batch_accumulation", type=int, default=16, help="Batch Accumulation")
parser.add_argument("--dataset", type=str, default="ADP", help="Dataset name. Available: ADP, Jarvis, MaterialsProject")
parser.add_argument("--dataset", type=str, default="ADP", help="Dataset name. Available: ADP, jarvis, megnet")
parser.add_argument("--dataset_path", type=str, default="./dataset/ADP_DATASET/")
parser.add_argument("--inference", action="store_true", help="Inference")
parser.add_argument("--montecarlo", action="store_true", help="Montecarlo")
Expand Down
180 changes: 0 additions & 180 deletions main_irene.py

This file was deleted.

4 changes: 0 additions & 4 deletions scripts/debug_irene.sh

This file was deleted.

33 changes: 2 additions & 31 deletions scripts/train_cartnet_adp.sh
Original file line number Diff line number Diff line change
@@ -1,36 +1,7 @@

CUDA_VISIBLE_DEVICES=0 python main.py --seed 0 --name "CartNet" --model "CartNet" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \
CUDA_VISIBLE_DEVICES=0 python ../main.py --seed 0 --name "CartNet" --model "CartNet" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \
--wandb_project "CartNet Paper" --batch 4 --batch_accumulation 16 --lr 0.001 --epochs 50 \
--augment &
CUDA_VISIBLE_DEVICES=1 python main.py --seed 1 --name "CartNet" --model "CartNet" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \
--wandb_project "CartNet Paper" --batch 4 --batch_accumulation 16 --lr 0.001 --epochs 50 \
--augment &

CUDA_VISIBLE_DEVICES=2 python main.py --seed 2 --name "CartNet" --model "CartNet" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \
--wandb_project "CartNet Paper" --batch 4 --batch_accumulation 16 --lr 0.001 --epochs 50 \
--augment &

CUDA_VISIBLE_DEVICES=3 python main.py --seed 3 --name "CartNet" --model "CartNet" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \
--wandb_project "CartNet Paper" --batch 4 --batch_accumulation 16 --lr 0.001 --epochs 50 \
--augment &

CUDA_VISIBLE_DEVICES=4 python main.py --seed 4 --name "CartNet" --model "CartNet" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \
--wandb_project "CartNet Paper" --batch 4 --batch_accumulation 16 --lr 0.001 --epochs 50 \
--augment &

CUDA_VISIBLE_DEVICES=5 python main.py --seed 5 --name "CartNet" --model "CartNet" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \
--wandb_project "CartNet Paper" --batch 4 --batch_accumulation 16 --lr 0.001 --epochs 50 \
--augment &

CUDA_VISIBLE_DEVICES=6 python main.py --seed 6 --name "CartNet" --model "CartNet" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \
--wandb_project "CartNet Paper" --batch 4 --batch_accumulation 16 --lr 0.001 --epochs 50 \
--augment &

CUDA_VISIBLE_DEVICES=7 python main.py --seed 7 --name "CartNet" --model "CartNet" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \
--wandb_project "CartNet Paper" --batch 4 --batch_accumulation 16 --lr 0.001 --epochs 50 \
--augment &

wait
--augment



Loading

0 comments on commit 9fc0777

Please sign in to comment.