From 13f2867b56c5ceddab4b777fe50f6f00495177d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=80lex=20Sol=C3=A9?= Date: Fri, 24 Jan 2025 15:53:57 +0100 Subject: [PATCH] updated enviroment --- README.md | 11 +++++++++-- dataset/extract_csd_data.py | 20 +++++++++++--------- environment_2.yml | 2 +- 3 files changed, 21 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 062ab80..9c03b03 100644 --- a/README.md +++ b/README.md @@ -49,16 +49,19 @@ Instructions to set up the environment: git clone https://github.com/imatge-upc/CartNet.git cd CartNet -# Create a Conda environment +# Create a Conda environment (original env) conda env create -f environment.yml +# or alternatively, if you want to use torch 2.4.0 +conda env create -f environment_2.yml + # Activate the environment conda activate CartNet ``` ## Dependencies -The environment relies on these dependencies: +The environment used for the results reported in the paper relies on these dependencies: ```sh pytorch==1.13.1 @@ -79,6 +82,10 @@ csd-python-api==3.3.1 These dependencies are automatically installed when you create the Conda environment using the `environment.yml` file. +### Update + +We have updated our dependencies to torch 2.4.0 to facilitate further research. This can be installed via the “environment_2.yml” file. + ## Dataset diff --git a/dataset/extract_csd_data.py b/dataset/extract_csd_data.py index cfd88a7..01b7c67 100644 --- a/dataset/extract_csd_data.py +++ b/dataset/extract_csd_data.py @@ -47,13 +47,13 @@ def refcsd2graph(refcode, output_folder): entry = csd_reader.entry(refcode) if entry.pressure is not None: - return None + return refcode if entry.remarks is not None: - return None + return refcode if entry.crystal.has_disorder: - return None + return refcode if entry.temperature is None: doc = cif.read_string(entry.to_string(format='cif')) @@ -67,7 +67,7 @@ def refcsd2graph(refcode, output_folder): assert(temperature[0] is not None) temperature = float(temperature[0]) except Exception as e: - return None + return refcode else: temperature = entry.temperature @@ -75,7 +75,7 @@ def refcsd2graph(refcode, output_folder): try: assert(len(temp)==1) except: - return None + return refcode temperature = float(temp[0]) @@ -94,7 +94,7 @@ def refcsd2graph(refcode, output_folder): continue elif atom.atomic_number != 1: print("istrotropic") - return + return refcode if atom.displacement_parameters.type == "Isotropic" and atom.atomic_number == 1: adp.append(torch.eye(3).unsqueeze(0)*atom.displacement_parameters.isotropic_equivalent) @@ -192,6 +192,8 @@ def target(error_flag): data_df = pd.read_csv('./csv/all_dataset.csv', header=None) res = [refcsd2graph(refcode, output_folder) for refcode in tqdm(data_df[0].tolist())] - res = [r for r in res if r is None] - with open("errors.txt", "w") as f: - f.write("\n".join(res)) \ No newline at end of file + res = [r for r in res if r is not None] + + if len(res) > 0: + with open("errors.txt", "w") as f: + f.write("\n".join(res)) \ No newline at end of file diff --git a/environment_2.yml b/environment_2.yml index 3fa961b..38d7f2e 100644 --- a/environment_2.yml +++ b/environment_2.yml @@ -1,4 +1,4 @@ -name: CartNet_paper +name: CartNet_paper_2.4.0 channels: - pytorch - pyg