Skip to content

Commit

Permalink
updated enviroment
Browse files Browse the repository at this point in the history
  • Loading branch information
alexsoleg committed Jan 24, 2025
1 parent 04a6af2 commit 13f2867
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 12 deletions.
11 changes: 9 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,19 @@ Instructions to set up the environment:
git clone https://github.com/imatge-upc/CartNet.git
cd CartNet

# Create a Conda environment
# Create a Conda environment (original env)
conda env create -f environment.yml

# or alternatively, if you want to use torch 2.4.0
conda env create -f environment_2.yml

# Activate the environment
conda activate CartNet
```

## Dependencies

The environment relies on these dependencies:
The environment used for the results reported in the paper relies on these dependencies:

```sh
pytorch==1.13.1
Expand All @@ -79,6 +82,10 @@ csd-python-api==3.3.1

These dependencies are automatically installed when you create the Conda environment using the `environment.yml` file.

### Update

We have updated our dependencies to torch 2.4.0 to facilitate further research. This can be installed via the “environment_2.yml” file.


## Dataset

Expand Down
20 changes: 11 additions & 9 deletions dataset/extract_csd_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ def refcsd2graph(refcode, output_folder):
entry = csd_reader.entry(refcode)

if entry.pressure is not None:
return None
return refcode

if entry.remarks is not None:
return None
return refcode

if entry.crystal.has_disorder:
return None
return refcode

if entry.temperature is None:
doc = cif.read_string(entry.to_string(format='cif'))
Expand All @@ -67,15 +67,15 @@ def refcsd2graph(refcode, output_folder):
assert(temperature[0] is not None)
temperature = float(temperature[0])
except Exception as e:
return None
return refcode
else:
temperature = entry.temperature

temp = re.findall(r'\d+\.?\d*',string=str(entry.temperature))
try:
assert(len(temp)==1)
except:
return None
return refcode

temperature = float(temp[0])

Expand All @@ -94,7 +94,7 @@ def refcsd2graph(refcode, output_folder):
continue
elif atom.atomic_number != 1:
print("istrotropic")
return
return refcode

if atom.displacement_parameters.type == "Isotropic" and atom.atomic_number == 1:
adp.append(torch.eye(3).unsqueeze(0)*atom.displacement_parameters.isotropic_equivalent)
Expand Down Expand Up @@ -192,6 +192,8 @@ def target(error_flag):
data_df = pd.read_csv('./csv/all_dataset.csv', header=None)

res = [refcsd2graph(refcode, output_folder) for refcode in tqdm(data_df[0].tolist())]
res = [r for r in res if r is None]
with open("errors.txt", "w") as f:
f.write("\n".join(res))
res = [r for r in res if r is not None]

if len(res) > 0:
with open("errors.txt", "w") as f:
f.write("\n".join(res))
2 changes: 1 addition & 1 deletion environment_2.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: CartNet_paper
name: CartNet_paper_2.4.0
channels:
- pytorch
- pyg
Expand Down

0 comments on commit 13f2867

Please sign in to comment.