Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions egnn_pytorch/egnn_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,13 @@ def __init__(
valid_radius = float('inf'),
m_pool_method = 'sum',
soft_edges = False,
coor_weights_clamp_value = None
coor_weights_clamp_value = None,
coors_tanh=False # Only to be used alongside norm_coors - highly recommended for stability
):
super().__init__()
assert m_pool_method in {'sum', 'mean'}, 'pool method must be either sum or mean'
assert update_feats or update_coors, 'you must update either features, coordinates, or both'
assert not (coors_tanh and not norm_coors), 'coors_tanh must be used with norm_coors'

self.fourier_features = fourier_features

Expand Down Expand Up @@ -200,11 +202,14 @@ def __init__(
nn.Linear(dim * 2, dim),
) if update_feats else None

# Tanh layer helps with stability but should only be used in conjuction with
# norm_coors
self.coors_mlp = nn.Sequential(
nn.Linear(m_dim, m_dim * 4),
dropout,
SiLU(),
nn.Linear(m_dim * 4, 1)
nn.Linear(m_dim * 4, 1),
nn.Tanh() if coors_tanh else nn.Identity()
) if update_coors else None

self.num_nearest_neighbors = num_nearest_neighbors
Expand Down