Skip to content

Commit

Permalink
Fixed Ablations
Browse files Browse the repository at this point in the history
  • Loading branch information
alexsoleg committed Oct 21, 2024
1 parent 919cb07 commit fc1a046
Show file tree
Hide file tree
Showing 10 changed files with 201 additions and 80 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ results/
wandb/
dataset/ADP_DATASET/
dataset/megnet/
montecarlo/
75 changes: 35 additions & 40 deletions dataset/datasetADP.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def processed_file_names(self):
return self.file_names

def augment_data(self, data):

R = roma.utils.random_rotmat(size=1, device=data.x.device).squeeze(0)
data.y = R.transpose(-1,-2) @ data.y @ R
data.cart_dir = data.cart_dir @ R
Expand All @@ -40,54 +39,50 @@ def augment_data(self, data):
return data

def get(self, idx):
try:
data = torch.load(osp.join(self.original_root,self.file_names[idx]+".pt"))
data = torch.load(osp.join(self.original_root,self.file_names[idx]+".pt"))
if self.standarize_temp:
data.temperature_og = data.temperature
data.temperature = ((data.temperature - self.mean_temp) / self.std_temp)



data.non_H_mask = data.x != 1

if self.standarize_temp:
data.temperature_og = data.temperature
data.temperature = ((data.temperature - self.mean_temp) / self.std_temp)



data.non_H_mask = data.x != 1

if not self.hydrogens:
#Remove hydrogens
data.x = data.x[data.non_H_mask]
data.pos = data.pos[data.non_H_mask]

atoms = torch.arange(0,data.non_H_mask.shape[0])[data.non_H_mask]
bool_mask_source = torch.isin(data.edge_index[0], atoms )
bool_mask_target = torch.isin(data.edge_index[1], atoms )
bool_mask_combined = bool_mask_source & bool_mask_target
data.edge_index = data.edge_index[:, bool_mask_combined]


if not self.hydrogens:
#Remove hydrogens
data.x = data.x[data.non_H_mask]
data.y = data.y[data.non_H_mask]
data.pos = data.pos[data.non_H_mask]

atoms = torch.arange(0,data.non_H_mask.shape[0])[data.non_H_mask]
bool_mask_source = torch.isin(data.edge_index[0], atoms )
bool_mask_target = torch.isin(data.edge_index[1], atoms )
bool_mask_combined = bool_mask_source & bool_mask_target
data.edge_index = data.edge_index[:, bool_mask_combined]


node_mapping = {old: new for new, old in enumerate(atoms.tolist())}
node_mapping = {old: new for new, old in enumerate(atoms.tolist())}



data.edge_index = torch.tensor([[node_mapping[edge[0].item()], node_mapping[edge[1].item()]] for edge in data.edge_index.t()]).t()
data.edge_index = torch.tensor([[node_mapping[edge[0].item()], node_mapping[edge[1].item()]] for edge in data.edge_index.t()]).t()


data.edge_attr = data.edge_attr[bool_mask_combined, :]
data.non_H_mask = torch.ones(data.x.shape[0], dtype=torch.bool)

data.cart_dir = data.cart_dir[bool_mask_combined, :]
data.cart_dist = data.cart_dist[bool_mask_combined]
data.non_H_mask = torch.ones(data.x.shape[0], dtype=torch.bool)

if self.optimize_cell:
data.cell_og = data.cell
data.cell, rotation_matrix = optmize_lattice(data.cell.squeeze(0))
data.cell = data.cell.unsqueeze(0)
data.cart_dir = data.cart_dir @ rotation_matrix
data.y = rotation_matrix.transpose(-1,-2) @ data.y @ rotation_matrix

if self.optimize_cell:
data.cell_og = data.cell
data.cell, rotation_matrix = optmize_lattice(data.cell.squeeze(0))
data.cell = data.cell.unsqueeze(0)
data.cart_dir = data.cart_dir @ rotation_matrix
data.y = rotation_matrix.transpose(-1,-2) @ data.y @ rotation_matrix


if self.augment:
data = self.augment_data(data)

if self.augment:
data = self.augment_data(data)
except Exception as e:
print(e)
raise Exception(f"Error loading file {self.file_names[idx]}")

return data

Expand Down
Binary file modified fig/pipeline.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch_geometric.loader import DataLoader
import random
import os.path as osp
import numpy as np

def create_loader():
"""
Expand Down
62 changes: 44 additions & 18 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,59 +14,81 @@


def inference(model, loader):
from train.metrics import compute_loss, compute_3D_IoU
from train.metrics import compute_loss, compute_3D_IoU, get_similarity_index
model.eval()

with torch.no_grad():
inference_output = {"pred": [], "true": [], "refcode": [], "pos": [], "atoms": [], "iou": [], "mae": []}
inference_output = {"pred": [], "true": [], "temp": [], "cell": [], "refcode": [], "pos": [], "atoms": [], "iou": [], "mae": [], "similarity_index": []}
for iter, batch in tqdm(enumerate(loader), total=len(loader), ncols=50):
batch.to("cuda:0")
inference_output["cell"].append(batch.cell.detach().to("cpu"))
inference_output["atoms"].append(batch.x[batch.non_H_mask].detach().to("cpu"))
inference_output["pos"].append(batch.pos[batch.non_H_mask].detach().to("cpu"))
inference_output["refcode"].append(batch.refcode)
inference_output["refcode"].append(batch.refcode[0])
inference_output["temp"].append(batch.temperature_og.detach().to("cpu")[0])
_pred, _true = model(batch)
inference_output["pred"].append(_pred.detach().to("cpu"))
inference_output["true"].append(_true.detach().to("cpu"))
inference_output["iou"].append(compute_3D_IoU(_pred, _true).detach().to("cpu"))
inference_output["mae"].append(compute_loss(_pred, _true)[0].detach().to("cpu"))

inference_output["similarity_index"].append(get_similarity_index(_pred, _true).detach().to("cpu"))


iou = torch.cat(inference_output["iou"], dim=0)
mae = torch.cat(inference_output["mae"], dim=0)
similarity_index = torch.cat(inference_output["similarity_index"], dim=0)

logging.info(f"Mean IoU: {iou.mean().item()} +/- {iou.std().item()}")
logging.info(f"Mean MAE: {mae.mean().item()} +/- {mae.std().item()}")
logging.info(f"Mean Similarity Index: {similarity_index.mean().item()} +/- {similarity_index.std().item()}")

pickle.dump(inference_output, open(cfg.inference_output, "wb"))

def montecarlo(model, loader):
from train.metrics import compute_loss, compute_3D_IoU
from train.metrics import compute_loss, compute_3D_IoU, get_similarity_index
import roma

model.eval()
iou_montecarlo = []
similarity_index_montecarlo = []
mae_montecarlo = []
with torch.no_grad():
for i in range(100):
inference_output = {"pred": [], "true": [], "refcode": [], "pos": [], "atoms": [], "iou": [], "mae": []}
for i in tqdm(range(100), ncols=50, desc="Montecarlo"):
inference_output = {"pred": [], "true": [], "cell": [], "refcode": [], "pos": [], "atoms": [], "mae": [], "iou": [], "similarity_index": []}
for iter, batch in tqdm(enumerate(loader), total=len(loader), ncols=50):
batch_copy = batch.clone()
batch.to("cuda:0")
inference_output["cell"].append(batch.cell.detach().to("cpu"))
inference_output["atoms"].append(batch.x[batch.non_H_mask].detach().to("cpu"))
inference_output["pos"].append(batch.pos[batch.non_H_mask].detach().to("cpu"))
inference_output["refcode"].append(batch.refcode)
inference_output["refcode"].append(batch.refcode[0])
pseudo_true, _ = model(batch)
R = roma.utils.random_rotmat(size=1, device=batch.x.device).squeeze(0)
pseudo_true = R.transpose(-1,-2) @ pseudo_true @ R
R = roma.utils.random_rotmat(size=1, device=pseudo_true.device).squeeze(0)
batch_copy.to("cuda:0")
batch_copy.cart_dir = batch_copy.cart_dir @ R
pseudo_true = R.transpose(-1,-2) @ pseudo_true @ R
pred, _ = model(batch_copy)
inference_output["pred"].append(pred.detach().to("cpu"))
inference_output["true"].append(pseudo_true.detach().to("cpu"))
inference_output["iou"].append(compute_3D_IoU(pred, pseudo_true).detach().to("cpu"))
inference_output["similarity_index"].append(get_similarity_index(pred, pseudo_true).detach().to("cpu"))
inference_output["mae"].append(compute_loss(pred, pseudo_true)[0].detach().to("cpu"))
pickle.dump(inference_output, open(cfg.inference_output.replace(".pkl", "_montecarlo_"+str(i)+".pkl"), "wb"))
logging.info(f"Montecarlo {i}")
logging.info(f"IoU: {torch.cat(inference_output['iou'], dim=0).mean().item()}")
logging.info(f"MAE: {torch.cat(inference_output['mae'], dim=0).mean().item()}")
logging.info(f"Similarity Index: {torch.cat(inference_output['similarity_index'], dim=0).mean().item()}")
iou_montecarlo+=inference_output["iou"]
mae_montecarlo+=inference_output["mae"]
similarity_index_montecarlo+=inference_output["similarity_index"]

iou_montecarlo = torch.cat(iou_montecarlo, dim=0)
mae_montecarlo = torch.cat(mae_montecarlo, dim=0)
similarity_index_montecarlo = torch.cat(similarity_index_montecarlo, dim=0)

logging.info(f"Montecarlo IoU: {iou_montecarlo.mean().item()}+/-{iou_montecarlo.std().item()}")
logging.info(f"Montecarlo MAE: {mae_montecarlo.mean().item()}+/-{mae_montecarlo.std().item()}")
logging.info(f"Montecarlo IoU: {iou_montecarlo.mean().item()} +/- {iou_montecarlo.std().item()}")
logging.info(f"Montecarlo MAE: {mae_montecarlo.mean().item()} +/- {mae_montecarlo.std().item()}")
logging.info(f"Montecarlo Similarity Index: {similarity_index_montecarlo.mean().item()} +/- {similarity_index_montecarlo.std().item()}")

if __name__ == "__main__":
parser = argparse.ArgumentParser()
Expand All @@ -78,7 +100,7 @@ def montecarlo(model, loader):
parser.add_argument("--dataset_path", type=str, default="./dataset/ADP_DATASET/")
parser.add_argument("--inference", action="store_true", help="Inference")
parser.add_argument("--montecarlo", action="store_true", help="Montecarlo")
parser.add_argument("--weighs_path", type=str, default=None, help="Path to the weights of the model")
parser.add_argument("--checkpoint_path", type=str, default=None, help="Path of the checkpoints of the model")
parser.add_argument("--inference_output", type=str, default="./inference.pkl", help="Path to the inference output")
parser.add_argument("--figshare_target", type=str, default="formation_energy_peratom", help="Figshare dataset target")
parser.add_argument("--wandb_project", type=str, default="ADP", help="Wandb project name")
Expand All @@ -93,7 +115,7 @@ def montecarlo(model, loader):
parser.add_argument("--num_layers", type=int, default=4, help="Number of layers")
parser.add_argument("--dim_in", type=int, default=256, help="Input dimension")
parser.add_argument("--dim_rbf", type=int, default=64, help="Number of RBF")
parser.add_argument('--augment', action='store_true', help='Hydrogens')
parser.add_argument('--augment', action='store_true', help='augment')
parser.add_argument("--invariant", action="store_true", help="Rotation Invariant model")
parser.add_argument("--disable_temp", action="store_false", help="Disable Temperature")
parser.add_argument("--no_standarize_temp", action="store_false", help="Standarize temperature")
Expand All @@ -108,6 +130,7 @@ def montecarlo(model, loader):
cfg.seed = args.seed
cfg.name = args.name
cfg.run_dir = "results/"+cfg.name+"/"+str(cfg.seed)
cfg.inference_output = args.inference_output
cfg.dataset.task_type = "regression"
cfg.batch = args.batch
cfg.batch_accumulation = args.batch_accumulation
Expand All @@ -121,7 +144,7 @@ def montecarlo(model, loader):
cfg.lr = args.lr
cfg.warmup = args.warmup
cfg.model = args.model
cfg.max_neighbours = None if cfg.model== "CartNet" else args.max_neighbours
cfg.max_neighbours = -1 if cfg.model== "CartNet" else args.max_neighbours
cfg.radius = args.radius
cfg.num_layers = args.num_layers
cfg.dim_in = args.dim_in
Expand Down Expand Up @@ -157,15 +180,18 @@ def montecarlo(model, loader):
loggers = create_logger()

if args.inference:
assert args.weighs_path is not None, "Weights path not provided"
assert args.checkpoint_path is not None, "Weights path not provided"
assert cfg.dataset.name == "ADP", "Inference only for ADP dataset"
ckpt = torch.load(args.weighs_path)
ckpt = torch.load(args.checkpoint_path)
model.load_state_dict(ckpt["model_state"])
cfg.inference_output = args.inference_output
inference(model, loaders[-1])
elif args.montecarlo:
assert args.weighs_path is not None, "Weights path not provided"
assert args.checkpoint_path is not None, "Weights path not provided"
assert cfg.dataset.name == "ADP", "Montecarlo only for ADP dataset"
ckpt = torch.load(args.checkpoint_path)
model.load_state_dict(ckpt["model_state"])
cfg.inference_output = args.inference_output
montecarlo(model, loaders[-1])
else:
train(model, loaders, optimizer, loggers)
Expand Down
1 change: 0 additions & 1 deletion models/comformer_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ def __init__(

self.bn = nn.BatchNorm1d(out_channels)
self.sigmoid = nn.Sigmoid()
print('I am using the invariant version of EPCNet')

def forward(self, edge: Union[Tensor, PairTensor], edge_nei_len: OptTensor = None, edge_nei_angle: OptTensor = None):
# preprocess for edge of shape [num_edges, hidden_dim]
Expand Down
35 changes: 24 additions & 11 deletions train/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@ def get_error_volume(pred, true):
volume_pred = get_volume(true)
return (torch.abs(volume_true - volume_pred)/(volume_true+SMOOTH))

def get_KL(pred, true):
p = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(true.shape[0],3, device=true.device), true)
q = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(pred.shape[0],3, device=true.device), pred)

return torch.distributions.kl.kl_divergence(p, q)

def get_similarity_index(pred, true):
r12_num = 2 ** (3 / 2) * torch.linalg.det(torch.linalg.inv(true) @ torch.linalg.inv(pred)) ** (1 / 4)
r12_den = torch.linalg.det(torch.linalg.inv(true) + torch.linalg.inv(pred)) ** (1 / 2)
return 100*(1-r12_num/r12_den)

def iou_pytorch3D(outputs: torch.Tensor, labels: torch.Tensor):

intersection = (outputs & labels).float().sum((1, 2, 3)) # Will be zero if Truth=0 or Prediction=0
Expand Down Expand Up @@ -81,37 +92,39 @@ def compute_3D_IoU(pred,true):

return iou

def compute_metrics_and_logging(pred, true, mae, mse, loss, volume_percentage_error, lr, time_used, logger, iou=None):
def compute_metrics_and_logging(pred, true, mae, mse, loss, lr, time_used, logger, test_metrics=False):

if cfg.dataset.name == "ADP":
if iou is not None:
logger.update_stats(true = true,
pred = pred,
if test_metrics:
logger.update_stats(true = true.to("cpu"),
pred = pred.to("cpu"),
loss = loss.mean().item(),
MAE = mae.mean().item(),
MSE = mse.mean().item(),
lr = lr,
time_used = time_used,
params = cfg.params_count,
dataset_name = cfg.dataset.name,
volume_percentage_error = volume_percentage_error.mean().item(),
iou = iou.mean().item()
volume_percentage_error = get_error_volume(pred,true).mean().item(),
iou = compute_3D_IoU(pred,true).mean().item(),
similarity_index = get_similarity_index(pred, true).mean().item(),
)
else:
logger.update_stats(true = true,
pred = pred,
logger.update_stats(true = true.to("cpu"),
pred = pred.to("cpu"),
loss = loss.mean().item(),
MAE = mae.mean().item(),
MSE = mse.mean().item(),
lr = lr,
volume_percentage_error = volume_percentage_error.mean().item(),
volume_percentage_error = get_error_volume(pred,true).mean().item(),
similarity_index = get_similarity_index(pred, true).mean().item(),
time_used = time_used,
params = cfg.params_count,
dataset_name = cfg.dataset.name,
)
else:
logger.update_stats(true = true,
pred = pred,
logger.update_stats(true = true.to("cpu"),
pred = pred.to("cpu"),
loss = loss.mean().item(),
MAE = mae.mean().item(),
MSE = mse.mean().item(),
Expand Down
17 changes: 7 additions & 10 deletions train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from tqdm import tqdm
from torch_geometric.graphgym.config import cfg
from torch.optim.lr_scheduler import OneCycleLR
from train.metrics import compute_metrics_and_logging, compute_loss, compute_3D_IoU, get_error_volume
from train.metrics import compute_metrics_and_logging, compute_loss


def flatten_dict(metrics):
Expand Down Expand Up @@ -170,13 +170,11 @@ def train_epoch(logger, loader, model, optimizer, batch_accumulation, scheduler)
optimizer.zero_grad()


compute_metrics_and_logging(pred = pred.detach().to('cpu'),
true = true.detach().to('cpu'),
compute_metrics_and_logging(pred = pred.detach(),
true = true.detach(),
mae = MAE,
mse = MSE,
loss = loss,
volume_percentage_error = get_error_volume(pred, true),
iou = None,
lr = optimizer.param_groups[0]['lr'],
time_used = time.time()-time_start,
logger = logger)
Expand All @@ -203,15 +201,14 @@ def eval_epoch(logger, loader, model, test_metrics=False):
raise Exception("Loss not implemented")


compute_metrics_and_logging(pred = pred.detach().to('cpu'),
true = true.detach().to('cpu'),
compute_metrics_and_logging(pred = pred.detach(),
true = true.detach(),
mae = MAE,
mse = MSE,
loss = loss,
volume_percentage_error = get_error_volume(pred, true),
iou = compute_3D_IoU(pred, true).detach().to('cpu') if test_metrics else None,
lr = 0,
time_used = time.time()-time_start,
logger = logger)
logger = logger,
test_metrics=test_metrics)


Loading

0 comments on commit fc1a046

Please sign in to comment.