Skip to content

Commit

Permalink
added dataset creation code
Browse files Browse the repository at this point in the history
  • Loading branch information
alexsoleg committed Jan 22, 2025
1 parent 24c69a0 commit 8094e57
Show file tree
Hide file tree
Showing 3 changed files with 278 additions and 16 deletions.
12 changes: 9 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ jarvis-tools==2024.8.30
lightning==2.2.5
roma==1.5.0
e3nn==0.5.1
csd-python-api==3.3.1
```

These dependencies are automatically installed when you create the Conda environment using the `environment.yml` file.
Expand All @@ -85,11 +86,16 @@ These dependencies are automatically installed when you create the Conda environ

The ADP (Anisotropic Displacement Parameters) dataset is curated from over 200,000 experimental crystal structures from the Cambridge Structural Database (CSD). This dataset is used to study atomic thermal vibrations represented through thermal ellipsoids. The dataset was curated to ensure high-quality and reliable ADPs. The dataset spans a wide temperature range (0K to 600K) and features a variety of atomic environments, with an average of 194.2 atoms per crystal structure. The dataset is split into 162,270 structures for training, 22,219 for validation, and 23,553 for testing.

Code to create the dataset comming soon
The dataset can be generated using the following code:

```sh
cd dataset/
python extract_csd_data.py --output "/path/to/data/"
```

> [!NOTE]
>
> The ADP_DATASET/ folder should be placed inside the dataset/ folder or specify the new path via --dataset_path flag in main.py
>
> Dataset generation requires a valid license for the [Cambridge Structural Database (CSD) Python API](https://downloads.ccdc.cam.ac.uk/documentation/API/index.html#).
### Jarvis
For tasks derived from Jarvis dataset, we followed the methodology of [Choudhary et al.](https://www.nature.com/articles/s41524-021-00650-1) in ALIGNN, utilizing the same training, validation, and test datasets. The dataset is automatically downloaded and processed by the code.
Expand Down
196 changes: 196 additions & 0 deletions dataset/extract_csd_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
import re
import queue
import argparse
import torch
import torch.multiprocessing as mp
from torch_geometric.data import Data, Batch
from tqdm import tqdm
import os.path as osp
from ccdc import io
from gemmi import cif
import pandas as pd
from utils import radius_graph_pbc


def frac_to_cart_matrix(abc, angles):
a, b, c = abc
alpha, beta, gamma = torch.tensor(angles) * (torch.pi / 180.0) # Convert to radians
volume = 1 - torch.cos(alpha)**2 - torch.cos(beta)**2 - torch.cos(gamma)**2 + 2 * torch.cos(alpha) * torch.cos(beta) * torch.cos(gamma)
volume = c * torch.sqrt(volume) * b * a
M = torch.tensor([
[a, b * torch.cos(gamma), c * torch.cos(beta)],
[0, b * torch.sin(gamma), c * (torch.cos(alpha) - torch.cos(beta)*torch.cos(gamma)) / torch.sin(gamma)],
[0, 0, volume / (a * b * torch.sin(gamma))]
])
return M.t()


def delete_repeated(coord, threshold=1e-4):
coord = torch.where(coord < 0 , coord+1, coord)
coord = torch.where(coord > 1 , coord-1, coord)
coord = torch.where(torch.isclose(coord, torch.ones(coord.shape, dtype=torch.float32), atol=1e-4), torch.zeros_like(coord, dtype=torch.float32), coord)
distance = coord.unsqueeze(0) - coord.unsqueeze(1)
distance = torch.linalg.norm(distance, dim=-1)

duplicates = distance < threshold
first_true_indices = torch.argmax(duplicates.float(), dim=1)

mask_to_keep = torch.arange(coord.shape[0]) <= first_true_indices

return mask_to_keep

possible_atomic_num_list = list(range(1, 119))

def refcsd2graph(refcode, output_folder):
try:
csd_reader = io.EntryReader("CSD")
entry = csd_reader.entry(refcode)

if entry.pressure is not None:
return None

if entry.remarks is not None:
return None

if entry.crystal.has_disorder:
return None

if entry.temperature is None:
doc = cif.read_string(entry.to_string(format='cif'))


try: # copy all the data from mmCIF file
block = doc.sole_block() # mmCIF has exactly one block
temperature = block.find_pair("_diffrn_ambient_temperature")[1]
temperature = re.findall(r'\d+\.?\d*',string=str(temperature))
assert(len(temperature)==1)
assert(temperature[0] is not None)
temperature = float(temperature[0])
except Exception as e:
return None
else:
temperature = entry.temperature

temp = re.findall(r'\d+\.?\d*',string=str(entry.temperature))
try:
assert(len(temp)==1)
except:
return None

temperature = float(temp[0])

data = Data()
entry = csd_reader.entry(refcode)
packing = entry.crystal.packing(inclusion="OnlyAtomsIncluded")
keep_mask = delete_repeated(torch.tensor([[atom.fractional_coordinates.x, atom.fractional_coordinates.y, atom.fractional_coordinates.z] for atom in packing.atoms]))
adp = []
data.x = torch.tensor([possible_atomic_num_list.index(atom.atomic_number)+1 for atom in packing.atoms])[keep_mask]
data.pos = torch.tensor([[atom.coordinates.x, atom.coordinates.y, atom.coordinates.z] for atom in packing.atoms])[keep_mask]

for atom in packing.atoms:
if atom.displacement_parameters is None:
if atom.atomic_number == 1:
adp.append(torch.eye(3).unsqueeze(0)*0.01)
continue
elif atom.atomic_number != 1:
print("istrotropic")
return

if atom.displacement_parameters.type == "Isotropic" and atom.atomic_number == 1:
adp.append(torch.eye(3).unsqueeze(0)*atom.displacement_parameters.isotropic_equivalent)

elif atom.displacement_parameters.type == "Anisotropic":
adp.append(torch.tensor(atom.displacement_parameters.values).unsqueeze(0))

else:
raise NotImplementedError

data.y = torch.cat(adp, dim=0)[keep_mask]

data.y = data.y[data.x != 1]

abc = entry.crystal.cell_lengths.a, entry.crystal.cell_lengths.b, entry.crystal.cell_lengths.c
angles = entry.crystal.cell_angles.alpha, entry.crystal.cell_angles.beta, entry.crystal.cell_angles.gamma

M = frac_to_cart_matrix(abc, angles)

N = torch.diag(torch.linalg.norm(torch.linalg.inv(M.transpose(-1,-2)).squeeze(0), dim=-1))

data.cell = M.unsqueeze(0)


data.y = N.transpose(-1,-2)@data.y@N
data.y = data.cell.transpose(-1,-2)@data.y@data.cell

assert torch.allclose(torch.tensor(entry.crystal.orthogonal_to_fractional.translation), torch.zeros(3)), refcode


data.pbc = torch.tensor([[True, True, True]])
data.natoms = torch.tensor([data.x.shape[0]])
data.temperature = torch.tensor([temperature])
data.refcode = refcode

batch = Batch.from_data_list([data])

edge_index, _, _, edge_attr = radius_graph_pbc(batch, 5.0, None)

data.edge_index = edge_index
direct_norm = torch.norm(edge_attr, dim=-1).unsqueeze(-1)
data.edge_attr = torch.cat([edge_attr/direct_norm, direct_norm], dim=-1)



torch.save(data, osp.join(output_folder,str(refcode)+".pt"))
except Exception as e:

raise Exception(f"Error occurred for refcode: {refcode} Error message: {str(e)}")


def worker_process(task_queue, results_queue, counter, error_event):
while True:
if error_event.is_set(): # Check if the error event is set at the start of each iteration
break
try:
refcode, output_folder = task_queue.get_nowait()

error_occurred = mp.Value('i', 0) # 0 means no error, 1 means error occurred

def target(error_flag):
try:
refcsd2graph(refcode, output_folder)
except Exception as e:
error_flag.value = 1
results_queue.put(str(e))

process = mp.Process(target=target, args=(error_occurred,))
process.start()
process.join(timeout=600) # 10 minutes timeout

if error_occurred.value == 1: # If an error occurred in the target function
error_event.set()
break

if process.is_alive():
process.terminate()
process.join()
results_queue.put(f"Timeout occurred for refcode: {refcode}!")

# Increment the shared counter
with counter.get_lock():
counter.value += 1

except queue.Empty:
break


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Process CSD data into graphs')
parser.add_argument('--output', type=str, default="ADP_DATASET/data/", help='Output folder path')
args = parser.parse_args()
output_folder = args.output
data_df = pd.read_csv('./csv/all_dataset.csv', header=None)

res = [refcsd2graph(refcode, output_folder) for refcode in tqdm(data_df[0].tolist())]
res = [r for r in res if r is None]
with open("errors.txt", "w") as f:
f.write("\n".join(res))
Loading

0 comments on commit 8094e57

Please sign in to comment.