Skip to content

Commit

Permalink
debugged icomformer/ecomformer
Browse files Browse the repository at this point in the history
  • Loading branch information
alexsoleg committed Oct 6, 2024
1 parent be3c95d commit 3a3262b
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 25 deletions.
19 changes: 12 additions & 7 deletions dataset/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import logging
import os
import os.path as osp
from tqdm import tqdm
import numpy as np
import torch
from torch_geometric.data import Data, Batch
from torch_scatter import segment_coo, segment_csr
import roma



Expand Down Expand Up @@ -454,30 +454,35 @@ def optmize_lattice(lattice_vectors):


def compute_knn(max_neigh, radius, path, refcodes):
final_root = os.path.join(path, "data_"+str(max_neigh)+"/")

print(max_neigh)

final_root = os.path.join(path, "data_"+str(max_neigh)+"_"+str(radius)+"/")
print(final_root)

if os.path.exists(final_root) and os.path.isdir(final_root):
logging.info("Already computed PBC for knn "+str(max_neigh))
logging.info("Already computed PBC for knn "+str(max_neigh) + " and radius "+str(radius))
return final_root
else:
os.makedirs(final_root)
os.makedirs(osp.join(final_root,"data/"))


for split in refcodes:
with open(split, 'r') as file:
file_names = [line.strip() for line in file.readlines()]
for file_name in tqdm(file_names, ncols=50, desc="Computing PBC"):
data = torch.load(osp.join(original_root,file_name+".pt"))
for file_name in tqdm(file_names, ncols=100, desc="Computing PBC"):
data = torch.load(osp.join(path,"data/"+file_name+".pt"))

data.pbc = torch.tensor([[True, True, True]])

batch = Batch.from_data_list([data])
edge_index, _, _, cart_vector = radius_graph_pbc(batch, radius, max_neigh)

data.edge_index = edge_index
data.cart_dist = torch.norm(cart_vector, p=2, dim=-1).unsqueeze(-1)
data.cart_dir = torch.nn.functional.normalize(cart_vector, p=2, dim=-1)

torch.save(data, osp.join(final_root,file_name+".pt"))
torch.save(data, osp.join(final_root,"data/"+file_name+".pt"))
return final_root


Expand Down
2 changes: 1 addition & 1 deletion loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def create_loader():
refcodes = [osp.join(cfg.dataset_path,"train_files.csv"), osp.join(cfg.dataset_path,"val_files.csv"), osp.join(cfg.dataset_path,"test_files.csv")]
if cfg.model in ["icomformer", "ecomformer"]:
assert cfg.max_neighbours is not None, "max_neighbours are needed for e/iComformer"
cfg.dataset_path = compute_knn(cfg.max_neighbours, cfg.radius, cfg.path, refcodes)
cfg.dataset_path = compute_knn(cfg.max_neighbours, cfg.radius, cfg.dataset_path, refcodes)

optimize_cell = True if cfg.model == "icomformer" else False
dataset_train, dataset_val, dataset_test = (DatasetADP(root=osp.join(cfg.dataset_path, "data/"), file_names=refcodes[0], hydrogens=cfg.use_H, standarize_temp = cfg.standarize_temp, augment=cfg.augment, optimize_cell=optimize_cell),
Expand Down
1 change: 1 addition & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def montecarlo(model, loader):
cfg.use_H = args.disable_H
cfg.workers = args.workers


torch.set_num_threads(args.threads)

set_printing()
Expand Down
8 changes: 5 additions & 3 deletions models/comformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@

import torch
from torch import nn
from network.transformer import ComformerConv, ComformerConv_edge, ComformerConvEqui
from models.comformer_conv import ComformerConv, ComformerConv_edge, ComformerConvEqui
from models.cartnet import Cholesky_head
from models.utils import RBFExpansion



def bond_cosine(r1, r2):
Expand Down Expand Up @@ -51,7 +53,7 @@ def __init__(self, dim_in):

self.equi_update = ComformerConvEqui(in_channels=self.dim_in, out_channels=self.dim_in, edge_dim=self.dim_in, use_second_order_repr=True)

self.cholesky = Cholesky_head(self.dim_in, 6)
self.cholesky = Cholesky_head(self.dim_in)

def forward(self, data) -> torch.Tensor:
node_features = self.embedding(data.x) + self.temperature_proj_atom(data.temperature.unsqueeze(-1))[data.batch]
Expand Down Expand Up @@ -108,7 +110,7 @@ def __init__(self, dim_in):

self.edge_update_layer = ComformerConv_edge(in_channels=self.dim_in, out_channels=self.dim_in, heads=1, edge_dim=self.dim_in)

self.cholesky = Cholesky_head(self.dim_in, 6)
self.cholesky = Cholesky_head(self.dim_in)

def forward(self, data) -> torch.Tensor:
node_features = self.embedding(data.x) + self.temperature_proj_atom(data.temperature.unsqueeze(-1))[data.batch]
Expand Down
1 change: 1 addition & 0 deletions models/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import math
import numpy as np
from torch import nn, Tensor
import torch.nn.functional as F
from typing import Optional
Expand Down
3 changes: 1 addition & 2 deletions train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ def train(model, loaders, optimizer, loggers):


run = wandb.init(entity=cfg.wandb_entity, project=cfg.wandb_project,
name=cfg.name)

name=cfg.name, config=cfg)

num_splits = len(loggers)
full_epoch_times = []
Expand Down
11 changes: 11 additions & 0 deletions train_scripts/train_ecomformer_adp.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
CUDA_VISIBLE_DEVICES=2 python main.py --seed 0 --name "ecomformer" --model "ecomformer" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \
--wandb_project "CartNet Paper" --batch 4 --batch_accumulation 16 --lr 0.001 --epochs 50 &

CUDA_VISIBLE_DEVICES=4 python main.py --seed 1 --name "ecomformer" --model "ecomformer" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \
--wandb_project "CartNet Paper" --batch 4 --batch_accumulation 16 --lr 0.001 --epochs 50 &

CUDA_VISIBLE_DEVICES=2 python main.py --seed 2 --name "ecomformer" --model "ecomformer" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \
--wandb_project "CartNet Paper" --batch 4 --batch_accumulation 16 --lr 0.001 --epochs 50 &

CUDA_VISIBLE_DEVICES=3 python main.py --seed 3 --name "ecomformer" --model "ecomformer" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \
--wandb_project "CartNet Paper" --batch 4 --batch_accumulation 16 --lr 0.001 --epochs 50 &
20 changes: 8 additions & 12 deletions train_scripts/train_icomformer_adp.sh
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
CUDA_VISIBLE_DEVICES=0 python main.py --seed 0 --name "CartNet" --model "CartNet" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \
--wandb_project "CartNet Paper" --batch_size 64 --lr 0.001 --epochs 50 \
--augment &
CUDA_VISIBLE_DEVICES=4 python main.py --seed 1 --name "CartNet" --model "CartNet" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \
--wandb_project "CartNet Paper" --batch_size 64 --lr 0.001 --epochs 50 \
--augment &
CUDA_VISIBLE_DEVICES=0 python main.py --seed 0 --name "icomformer" --model "icomformer" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \
--wandb_project "CartNet Paper" --batch_size 64 --lr 0.001 --epochs 50 &
CUDA_VISIBLE_DEVICES=4 python main.py --seed 1 --name "icomformer" --model "icomformer" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \
--wandb_project "CartNet Paper" --batch_size 64 --lr 0.001 --epochs 50 &

CUDA_VISIBLE_DEVICES=2 python main.py --seed 2 --name "CartNet" --model "CartNet" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \
--wandb_project "CartNet Paper" --batch_size 64 --lr 0.001 --epochs 50 \
--augment &
CUDA_VISIBLE_DEVICES=2 python main.py --seed 2 --name "icomformer" --model "icomformer" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \
--wandb_project "CartNet Paper" --batch_size 64 --lr 0.001 --epochs 50 &

CUDA_VISIBLE_DEVICES=3 python main.py --seed 3 --name "CartNet" --model "CartNet" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \
--wandb_project "CartNet Paper" --batch_size 64 --lr 0.001 --epochs 50 \
--augment &
CUDA_VISIBLE_DEVICES=3 python main.py --seed 3 --name "icomformer" --model "icomformer" --dataset "ADP" --dataset_path "/scratch/g1alexs/ADP_DATASET" \
--wandb_project "CartNet Paper" --batch_size 64 --lr 0.001 --epochs 50 &

0 comments on commit 3a3262b

Please sign in to comment.