Skip to content

Refining coGN module #109

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Mar 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@ kgcnn.egg-info
/docs/source/DatasetMol/
/docs/source/DatasetCrystal/
/docs/source/model_energy_force/
__pycache__
94 changes: 92 additions & 2 deletions kgcnn/literature/coGN/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,100 @@
from ._make import make_model
from ._coGN_config import model_default
"""
Using `coGN and coNGN <https://arxiv.org/abs/2302.14102>`__ models in KGCNN.
============================================================================

There are multiple preconfigured configurations for `coGN
model <https://arxiv.org/abs/2302.14102>`__ for different input
representations. The input representations are different for, whether

- the graph represents periodic/symmetric crystal graphs or
non-periodic molecular graphs
- the model must be differentiable with respect to node coordinates to
calculate forces based on predicted energies.

Import all functions and configurations with:

.. code:: python

from kgcnn.literature.coGN import (make_model, make_force_model,
model_default, crystal_asymmetric_unit_graphs, molecular_graphs, crystal_unit_graphs,
crystal_unit_graphs_coord_input, molecular_graphs_coord_input, model_default_nested)

Crystals
--------

- The default coGN model for crystals (takes asymmetric unit graphs as
input)

.. code:: python

model = make_model(**model_default)
# model.inputs

This is equivalent to:

.. code:: python

model = make_model(**crystal_asymmetric_unit_graphs)

Factoring out symmetries via the asymmetric unit graph representation
may accelerate training, since graphs are smaller.

- For unit cells crystal graph representations use:

.. code:: python

model = make_model(**crystal_unit_graphs)
# model.inputs

Precomputing offsets between atoms in a preporcessing step may
accelerate training.

- For unit cell crystal graph representations and force predictions
use:

.. code:: python

model = make_model(**crystal_unit_graphs_coord_input)
force_model = make_force_model(model) # predicts energies and forces
# model.inputs

Molecules
---------

- For simple energy predictions, based on precomputed offsets between
atoms use:

.. code:: python

model = make_model(**molecular_graphs)
# model.inputs

Precomputing offsets between atoms in a preporcessing step may
accelerate training.

- For energy and force predictions based on atom coordinates use:

.. code:: python

model = make_model(**molecular_graphs_coord_input)
force_model = make_force_model(model)
# model.inputs
"""

from ._make import make_model, make_force_model
from ._coGN_config import (model_default, crystal_asymmetric_unit_graphs, molecular_graphs, crystal_unit_graphs,
crystal_unit_graphs_coord_input, molecular_graphs_coord_input)
from ._coNGN_config import model_default_nested


__all__ = [
"make_model",
"make_force_model"
"model_default",
"crystal_asymmetric_unit_graphs",
"crystal_unit_graphs",
"crystal_unit_graphs_coord_input",
"molecular_graphs",
"molecular_graphs_coord_input",
"model_default_nested"
]
92 changes: 79 additions & 13 deletions kgcnn/literature/coGN/_coGN_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,19 +71,85 @@
'update_global_input': [False, True, False],
'multiplicity_readout': True}

model_default = {
"inputs": [
{"shape": (None, 3), "name": "offset", "dtype": "float32", "ragged": True},
{"shape": (None,), "name": "atomic_number", "dtype": "int32", "ragged": True},
{"shape": (None, 2), "name": "edge_indices", "dtype": "int32", "ragged": True},
# {"shape": (None, ), "name": "voronoi_ridge_area", "dtype": "float32", "ragged": True},
{"shape": (None, ), "name": "multiplicity", "dtype": "int32", "ragged": True},
# {"shape": (None, 2), "name": "line_graph_edge_indices", "dtype": "int32", "ragged": True}
],
output_block_cfg_no_multiplicity = deepcopy(output_block_cfg)
output_block_cfg_no_multiplicity['multiplicity_readout'] = False


crystal_asymmetric_unit_graphs = {
"inputs": {
"offset": {"shape": (None, 3), "name": "offset", "dtype": "float32", "ragged": True},
"cell_translation": None,
"affine_matrix": None,
"voronoi_ridge_area": None,
"atomic_number": {"shape": (None,), "name": "atomic_number", "dtype": "int32", "ragged": True},
"frac_coords": None,
"coords": None,
"multiplicity": {"shape": (None, ), "name": "multiplicity", "dtype": "int32", "ragged": True},
"lattice_matrix": None,
"edge_indices": {"shape": (None, 2), "name": "edge_indices", "dtype": "int32", "ragged": True},
"line_graph_edge_indices": None,
},
"input_block_cfg": input_block_cfg,
"processing_blocks_cfg": [deepcopy(processing_block_cfg) for _ in range(depth)],
"output_block_cfg": output_block_cfg,
"multiplicity": True,
"line_graph": False,
"voronoi_ridge_area": False
}
}

crystal_unit_graphs = {
"inputs": {
"offset": {"shape": (None, 3), "name": "offset", "dtype": "float32", "ragged": True},
"cell_translation": None,
"affine_matrix": None,
"voronoi_ridge_area": None,
"atomic_number": {"shape": (None,), "name": "atomic_number", "dtype": "int32", "ragged": True},
"frac_coords": None,
"coords": None,
"multiplicity": None,
"lattice_matrix": None,
"edge_indices": {"shape": (None, 2), "name": "edge_indices", "dtype": "int32", "ragged": True},
"line_graph_edge_indices": None,
},
"input_block_cfg": input_block_cfg,
"processing_blocks_cfg": [deepcopy(processing_block_cfg) for _ in range(depth)],
"output_block_cfg": output_block_cfg_no_multiplicity,
}
molecular_graphs = crystal_unit_graphs

crystal_unit_graphs_coord_input = {
"inputs": {
"offset": None,
"cell_translation": {"shape": (None,3), "dtype": "float32", "name": "cell_translation", "ragged": True},
"affine_matrix": None,
"voronoi_ridge_area": None,
"atomic_number": {"shape": (None,), "name": "atomic_number", "dtype": "int32", "ragged": True},
"frac_coords": {"shape": (None,3), "dtype": "float32", "name": "frac_coords", "ragged": True},
"coords": None,
"multiplicity": None,
"lattice_matrix": {"shape": (3,3), "dtype": "float32", "name": "lattice_matrix"},
"edge_indices": {"shape": (None, 2), "name": "edge_indices", "dtype": "int32", "ragged": True},
"line_graph_edge_indices": None,
},
"input_block_cfg": input_block_cfg,
"processing_blocks_cfg": [deepcopy(processing_block_cfg) for _ in range(depth)],
"output_block_cfg": output_block_cfg_no_multiplicity,
}

molecular_graphs_coord_input = {
"inputs": {
"offset": None,
"cell_translation": None,
"affine_matrix": None,
"voronoi_ridge_area": None,
"atomic_number": {"shape": (None,), "name": "atomic_number", "dtype": "int32", "ragged": True},
"frac_coords": None,
"coords": {"shape": (None,3), "dtype": "float32", "name": "coords", "ragged": True},
"multiplicity": None,
"lattice_matrix": None,
"edge_indices": {"shape": (None, 2), "name": "edge_indices", "dtype": "int32", "ragged": True},
"line_graph_edge_indices": None,
},
"input_block_cfg": input_block_cfg,
"processing_blocks_cfg": [deepcopy(processing_block_cfg) for _ in range(depth)],
"output_block_cfg": output_block_cfg_no_multiplicity,
}

model_default = crystal_asymmetric_unit_graphs
26 changes: 14 additions & 12 deletions kgcnn/literature/coGN/_coNGN_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,18 +102,20 @@


model_default_nested = {
"inputs": [
{"shape": (None, 3), "name": "offset", "dtype": "float32", "ragged": True},
{"shape": (None,), "name": "atomic_number", "dtype": "int32", "ragged": True},
{"shape": (None, 2), "name": "edge_indices", "dtype": "int32", "ragged": True},
# {"shape": (None,), "name": "voronoi_ridge_area", "dtype": "float32", "ragged": True},
{"shape": (None,), "name": "multiplicity", "dtype": "int32", "ragged": True},
{"shape": (None, 2), "name": "line_graph_edge_indices", "dtype": "int32", "ragged": True}
],
"inputs": {
"offset": {"shape": (None, 3), "name": "offset", "dtype": "float32", "ragged": True},
"cell_translation": None,
"affine_matrix": None,
"voronoi_ridge_area": {"shape": (None,), "name": "offset", "dtype": "float32", "ragged": True},
"atomic_number": {"shape": (None,), "name": "atomic_number", "dtype": "int32", "ragged": True},
"frac_coords": None,
"coords": None,
"multiplicity": {"shape": (None, ), "name": "multiplicity", "dtype": "int32", "ragged": True},
"lattice_matrix": None,
"line_graph_edge_indices": {"shape": (None, 2), "name": "line_graph_edge_indices", "dtype": "int32", "ragged": True},
"edge_indices": {"shape": (None, 2), "name": "edge_indices", "dtype": "int32", "ragged": True},
},
"input_block_cfg": input_block_cfg,
"processing_blocks_cfg": [deepcopy(processing_block_cfg) for _ in range(depth)],
"output_block_cfg": output_block_cfg,
"multiplicity": True,
"line_graph": True,
"voronoi_ridge_area": False
}
}
Loading