Skip to content

Commit

Permalink
[contacts._md_compute_contacts] updated to mdtraj v1.10.0rc1 (WIP)
Browse files Browse the repository at this point in the history
  • Loading branch information
gph82 committed Jun 5, 2024
1 parent 5faad08 commit ac34112
Showing 1 changed file with 129 additions and 90 deletions.
219 changes: 129 additions & 90 deletions mdciao/contacts/_md_compute_contacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@

# The modifications consist in including the indices
# of the closest atom-pairs in the returned values. The
# modified lines are 141-143, 260, 264, 265, and 276
##############################################################################


# modified lines are in the return value documentation for `atom_pairs`
# and otherwise marked with the comment '#mdciao' in the file.
# The modifications were applied on mdtraj release v1.10.0rc1
# commit 4c9bb6e8bc7d6e86890ff0d57814c3c76f7cf792
##############################################################################
# MDTraj: A Python Library for Loading, Saving, and Manipulating
# Molecular Dynamics Trajectories.
Expand All @@ -54,33 +54,25 @@
# License along with MDTraj. If not, see <http://www.gnu.org/licenses/>.
##############################################################################

# opened an issue for this:
# https://github.com/mdtraj/mdtraj/issues/1569
# Further reading
# https://www.oreilly.com/library/view/understanding-open-source/0596005814/ch03.html
# https://tldrlegal.com/license/gnu-lesser-general-public-license-v3-(lgpl-3)
# http://oss-watch.ac.uk/resources/lgpl
#https://www.gnu.org/licenses/gpl-faq.html#AllCompatibility

##############################################################################
# Imports
##############################################################################

from __future__ import print_function, division
import numpy as _np
import mdtraj as _md
from mdtraj.utils import ensure_type
import itertools
from mdtraj.utils.six import string_types
from mdtraj.utils.six.moves import xrange

import numpy as np

import mdtraj as md
from mdtraj.core import element
from mdtraj.utils import ensure_type

##############################################################################
# Code
##############################################################################

def compute_contacts(traj, contacts='all', scheme='closest-heavy', ignore_nonprotein=True, periodic=True,
soft_min=False, soft_min_beta=20):
def compute_contacts(
traj,
contacts="all",
scheme="closest-heavy",
ignore_nonprotein=True,
periodic=True,
soft_min=False,
soft_min_beta=20,
):
"""Compute the distance between pairs of residues in a trajectory.
Parameters
Expand Down Expand Up @@ -112,9 +104,9 @@ def compute_contacts(traj, contacts='all', scheme='closest-heavy', ignore_nonpro
If periodic is True and the trajectory contains unitcell information,
we will compute distances under the minimum image convention.
soft_min : bool, default=False
If soft_min is true, we will use a differentiable version of
If soft_min is true, we will use a diffrentiable version of
the scheme. The exact expression used
is d = \frac{\beta}{log\sum_i{exp(\frac{\beta}{d_i}})} where
is d = \frac{\beta}{log\\sum_i{exp(\frac{\beta}{d_i}})} where
beta is user parameter which defaults to 20nm. The expression
we use is copied from the plumed mindist calculator.
http://plumed.github.io/doc-v2.0/user-doc/html/mindist.html
Expand All @@ -140,8 +132,7 @@ def compute_contacts(traj, contacts='all', scheme='closest-heavy', ignore_nonpro
the indexing of `residue_pairs`
atom_pairs : np.ndarray, shape=(n_pairs, 2), dtype=int
Each row of this return value gives the indices of the atoms
involved in the contact.
involved in the contact
Examples
--------
>>> # To compute the contact distance between residue 0 and 10 and
Expand All @@ -163,117 +154,165 @@ def compute_contacts(traj, contacts='all', scheme='closest-heavy', ignore_nonpro
Topology.residue : Get residues from the topology by index
"""
if traj.topology is None:
raise ValueError('contact calculation requires a topology')
raise ValueError("contact calculation requires a topology")

if isinstance(contacts, string_types):
if contacts.lower() != 'all':
raise ValueError('(%s) is not a valid contacts specifier' % contacts.lower())
if isinstance(contacts, str):
if contacts.lower() != "all":
raise ValueError(
"(%s) is not a valid contacts specifier" % contacts.lower(),
)

residue_pairs = []
for i in xrange(traj.n_residues):
for i in range(traj.n_residues):
residue_i = traj.topology.residue(i)
if ignore_nonprotein and not any(a for a in residue_i.atoms if a.name.lower() == 'ca'):
if ignore_nonprotein and not any(a for a in residue_i.atoms if a.name.lower() == "ca"):
continue
for j in xrange(i+3, traj.n_residues):
for j in range(i + 3, traj.n_residues):
residue_j = traj.topology.residue(j)
if ignore_nonprotein and not any(a for a in residue_j.atoms if a.name.lower() == 'ca'):
if ignore_nonprotein and not any(a for a in residue_j.atoms if a.name.lower() == "ca"):
continue
if residue_i.chain == residue_j.chain:
residue_pairs.append((i, j))

residue_pairs = _np.array(residue_pairs)
residue_pairs = np.array(residue_pairs)
if len(residue_pairs) == 0:
raise ValueError('No acceptable residue pairs found')
raise ValueError("No acceptable residue pairs found")

else:
residue_pairs = ensure_type(_np.asarray(contacts), dtype=int, ndim=2, name='contacts',
shape=(None, 2), warn_on_cast=False)
if not _np.all((residue_pairs >= 0) * (residue_pairs < traj.n_residues)):
raise ValueError('contacts requests a residue that is not in the permitted range')
residue_pairs = ensure_type(
np.asarray(contacts),
dtype=int,
ndim=2,
name="contacts",
shape=(None, 2),
warn_on_cast=False,
)
if not np.all((residue_pairs >= 0) * (residue_pairs < traj.n_residues)):
raise ValueError(
"contacts requests a residue that is not in the permitted range",
)

# now the bulk of the function. This will calculate atom distances and then
# re-work them in the required scheme to get residue distances
scheme = scheme.lower()
if scheme not in ['ca', 'closest', 'closest-heavy', 'sidechain', 'sidechain-heavy']:
raise ValueError('scheme must be one of [ca, closest, closest-heavy, sidechain, sidechain-heavy]')
if scheme not in ["ca", "closest", "closest-heavy", "sidechain", "sidechain-heavy"]:
raise ValueError(
"scheme must be one of [ca, closest, closest-heavy, sidechain, sidechain-heavy]",
)

if scheme == 'ca':
if scheme == "ca":
if soft_min:
import warnings
warnings.warn("The soft_min=True option with scheme=ca gives"
"the same results as soft_min=False")

warnings.warn(
"The soft_min=True option with scheme=ca gives" "the same results as soft_min=False",
)
filtered_residue_pairs = []
atom_pairs = []

for r0, r1 in residue_pairs:
ca_atoms_0 = [a.index for a in traj.top.residue(r0).atoms if a.name.lower() == 'ca']
ca_atoms_1 = [a.index for a in traj.top.residue(r1).atoms if a.name.lower() == 'ca']
ca_atoms_0 = [a.index for a in traj.top.residue(r0).atoms if a.name.lower() == "ca"]
ca_atoms_1 = [a.index for a in traj.top.residue(r1).atoms if a.name.lower() == "ca"]
if len(ca_atoms_0) == 1 and len(ca_atoms_1) == 1:
atom_pairs.append((ca_atoms_0[0], ca_atoms_1[0]))
filtered_residue_pairs.append((r0, r1))
elif len(ca_atoms_0) == 0 or len(ca_atoms_1) == 0:
# residue does not contain a CA atom, skip it
if contacts != 'all':
if contacts != "all":
# if the user manually asked for this residue, and didn't use "all"
import warnings
warnings.warn('Ignoring contacts pair %d-%d. No alpha carbon.' % (r0, r1))

warnings.warn(
"Ignoring contacts pair %d-%d. No alpha carbon." % (r0, r1),
)
else:
raise ValueError('More than 1 alpha carbon detected in residue %d or %d' % (r0, r1))
raise ValueError(
"More than 1 alpha carbon detected in residue %d or %d" % (r0, r1),
)

residue_pairs = _np.array(filtered_residue_pairs)
distances = _md.compute_distances(traj, atom_pairs, periodic=periodic)
aa_pairs = [[pair]*traj.n_frames for pair in atom_pairs]
residue_pairs = np.array(filtered_residue_pairs)
distances = md.compute_distances(traj, atom_pairs, periodic=periodic)

elif scheme in ['closest', 'closest-heavy', 'sidechain', 'sidechain-heavy']:
if scheme == 'closest':
residue_membership = [[atom.index for atom in residue.atoms]
for residue in traj.topology.residues]
elif scheme == 'closest-heavy':
elif scheme in ["closest", "closest-heavy", "sidechain", "sidechain-heavy"]:
if scheme == "closest":
residue_membership = [[atom.index for atom in residue.atoms] for residue in traj.topology.residues]
elif scheme == "closest-heavy":
# then remove the hydrogens from the above list
residue_membership = [[atom.index for atom in residue.atoms if not (atom.element == element.hydrogen)]
for residue in traj.topology.residues]
elif scheme == 'sidechain':
residue_membership = [[atom.index for atom in residue.atoms if atom.is_sidechain]
for residue in traj.topology.residues]
elif scheme == 'sidechain-heavy':
residue_membership = [
[atom.index for atom in residue.atoms if not (atom.element == element.hydrogen)]
for residue in traj.topology.residues
]
elif scheme == "sidechain":
residue_membership = [
[atom.index for atom in residue.atoms if atom.is_sidechain] for residue in traj.topology.residues
]
elif scheme == "sidechain-heavy":
# then remove the hydrogens from the above list
if 'GLY' in [residue.name for residue in traj.topology.residues]:
import warnings
warnings.warn('selected topology includes at least one glycine residue, which has no heavy atoms in its sidechain. The distances involving glycine residues '
'will be computed using the sidechain hydrogen instead.')
residue_membership = [[atom.index for atom in residue.atoms if atom.is_sidechain and not (atom.element == element.hydrogen)] if not residue.name == 'GLY'
else [atom.index for atom in residue.atoms if atom.is_sidechain]
for residue in traj.topology.residues]
if "GLY" in [residue.name for residue in traj.topology.residues]:
import warnings

warnings.warn(
"selected topology includes at least one glycine residue, which has no heavy "
"atoms in its sidechain. The distances involving glycine residues will be "
"computed using the sidechain hydrogen instead.",
)

residue_membership = [
(
[
atom.index
for atom in residue.atoms
if atom.is_sidechain and not (atom.element == element.hydrogen)
]
if not residue.name == "GLY"
else [atom.index for atom in residue.atoms if atom.is_sidechain]
)
for residue in traj.topology.residues
]

residue_lens = [len(ainds) for ainds in residue_membership]

atom_pairs = []
n_atom_pairs_per_residue_pair = []
for pair in residue_pairs:
atom_pairs.extend(list(itertools.product(residue_membership[pair[0]], residue_membership[pair[1]])))
n_atom_pairs_per_residue_pair.append(residue_lens[pair[0]] * residue_lens[pair[1]])
atom_pairs.extend(
list(
itertools.product(
residue_membership[pair[0]],
residue_membership[pair[1]],
),
),
)
n_atom_pairs_per_residue_pair.append(
residue_lens[pair[0]] * residue_lens[pair[1]],
)

atom_distances = _md.compute_distances(traj, atom_pairs, periodic=periodic)
atom_distances = md.compute_distances(traj, atom_pairs, periodic=periodic)

# now squash the results based on residue membership
n_residue_pairs = len(residue_pairs)
distances = _np.zeros((len(traj), n_residue_pairs), dtype=_np.float32)
n_atom_pairs_per_residue_pair = _np.asarray(n_atom_pairs_per_residue_pair)
distances = np.zeros((len(traj), n_residue_pairs), dtype=np.float32)
n_atom_pairs_per_residue_pair = np.asarray(n_atom_pairs_per_residue_pair)

aa_pairs = []
for i in xrange(n_residue_pairs):
index = int(_np.sum(n_atom_pairs_per_residue_pair[:i]))
aa_pairs = [] #mdciao
for i in range(n_residue_pairs):
index = int(np.sum(n_atom_pairs_per_residue_pair[:i]))
n = n_atom_pairs_per_residue_pair[i]
idx_min = atom_distances[:, index: index + n].argmin(axis=1)
aa_pairs.append(_np.array(atom_pairs[index: index + n])[idx_min])
idx_min = atom_distances[:, index: index + n].argmin(axis=1) #mdciao
aa_pairs.append(np.array(atom_pairs[index: index + n])[idx_min]) #mdciao
if not soft_min:
distances[:, i] = atom_distances[:, index : index + n].min(axis=1)
else:
distances[:, i] = soft_min_beta / \
_np.log(_np.sum(_np.exp(soft_min_beta /
atom_distances[:, index : index + n]), axis=1))
distances[:, i] = soft_min_beta / np.log(
np.sum(
np.exp(
soft_min_beta / atom_distances[:, index : index + n],
),
axis=1,
),
)

else:
raise ValueError('This is not supposed to happen!')
raise ValueError("This is not supposed to happen!")

return distances, residue_pairs, _np.hstack(aa_pairs)
return distances, residue_pairs, np.hstack(aa_pairs) #mdciao

0 comments on commit ac34112

Please sign in to comment.