Skip to content

Commit

Permalink
[str_and_dict.get_trajectories_from_input] refactor from get_sorted_t…
Browse files Browse the repository at this point in the history
…rajectories. Don't sort input lists of files. Allow mixed input, tests
  • Loading branch information
gph82 committed Mar 12, 2024
1 parent 0c7d633 commit df67be7
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 45 deletions.
37 changes: 20 additions & 17 deletions mdciao/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,18 +448,18 @@ def _trajsNtop2xtcsNrefgeom(trajectories,topology):
Parameters
----------
trajectories: check get_sorted_trajectories
trajectories: check get_trajectories_from_input
topology : str, top
Returns
-------
xtcs, refgeom
xtcs : whatever get_sorted_trajectories returns
xtcs : whatever get_trajectories_from_input returns
refgeom : :obj:`mdtraj.Trajectory` object
"""
# Inform about trajectories
xtcs = _mdcu.str_and_dict.get_sorted_trajectories(trajectories)
xtcs = _mdcu.str_and_dict.get_trajectories_from_input(trajectories)
if topology is None:
# TODO in case the xtc[0] is a pdb/grofile, it will be read one more time later
refgeom = _load_any_geom(xtcs[0])[0]
Expand Down Expand Up @@ -633,16 +633,17 @@ def residue_neighborhoods(residues,
* residues = '1,10-12,GLU*,GDP*,E30'
Please refer to :obj:`mdciao.utils.residue_and_atom.rangeexpand_residues2residxs`
for more info
trajectories : str, :obj:`mdtraj.Trajectory`, or None
trajectories : str, :obj:`mdtraj.Trajectory` or lists thereof
The MD-trajectories to calculate the frequencies from.
This input is pretty flexible. For more info check
:obj:`mdciao.utils.str_and_dict.get_sorted_trajectories`.
:obj:`mdciao.utils.str_and_dict.get_trajectories_from_input`.
Accepted values are:
* pattern, e.g. "*.ext"
* one string containing a filename
* list of filenames
* one :obj:`mdtraj.Trajectory` object
* list of :obj:`mdtraj.Trajectory` objects
* list mixing filenames and :obj:`mdtraj.Trajectory` objects
topology : str or :obj:`~mdtraj.Trajectory`, default is None
The topology associated with the :obj:`trajectories`
If None, the topology of the first :obj:`trajectory` will
Expand Down Expand Up @@ -1150,16 +1151,17 @@ def interface(
Parameters
----------
trajectories :
The MD-trajectories to calculate the frequencies
from. This input is pretty flexible. For more info check
:obj:`mdciao.utils.str_and_dict.get_sorted_trajectories`.
trajectories : str, :obj:`mdtraj.Trajectory` or lists thereof
The MD-trajectories to calculate the frequencies from.
This input is pretty flexible. For more info check
:obj:`mdciao.utils.str_and_dict.get_trajectories_from_input`.
Accepted values are:
* pattern, e.g. "*.ext"
* one string containing a filename
* list of filenames
* one :obj:`~mdtraj.Trajectory` object
* list of :obj:`~mdtraj.Trajectory` objects
* one :obj:`mdtraj.Trajectory` object
* list of :obj:`mdtraj.Trajectory` objects
* list mixing filenames and :obj:`mdtraj.Trajectory` objects
topology : str or :obj:`~mdtraj.Trajectory`, default is None
The topology associated with the :obj:`trajectories`
If None, the topology of the first :obj:`trajectory` will
Expand Down Expand Up @@ -1670,16 +1672,17 @@ def sites(site_inputs,
found in the topology will be discarded.
See :obj:`mdciao.sites` for more info on
the site format.
trajectories :
The MD-trajectories to calculate the frequencies
from. This input is pretty flexible. For more info check
:obj:`mdciao.utils.str_and_dict.get_sorted_trajectories`.
trajectories : str, :obj:`mdtraj.Trajectory` or lists thereof
The MD-trajectories to calculate the frequencies from.
This input is pretty flexible. For more info check
:obj:`mdciao.utils.str_and_dict.get_trajectories_from_input`.
Accepted values are:
* pattern, e.g. "*.ext"
* one string containing a filename
* list of filenames
* one :obj:`~mdtraj.Trajectory` object
* list of :obj:`~mdtraj.Trajectory` objects
* one :obj:`mdtraj.Trajectory` object
* list of :obj:`mdtraj.Trajectory` objects
* list mixing filenames and :obj:`mdtraj.Trajectory` objects
topology : str or :obj:`~mdtraj.Trajectory`, default is None
The topology associated with the :obj:`trajectories`
If None, the topology of the first :obj:`trajectory` will
Expand Down
40 changes: 20 additions & 20 deletions mdciao/utils/str_and_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,31 +99,33 @@ def _kwargs_subs(funct_or_method, exclude=None):
return _mpldocstring.Substitution(
substitute_kwargs=_kwargs_docstring(funct_or_method, exclude=exclude))

def get_sorted_trajectories(trajectories):
def get_trajectories_from_input(trajectories):
r"""
Common parser for something that can be interpreted as a trajectory
Parameters
----------
trajectories: can be one of these things:
- pattern, e.g. "*.ext"
- one string containing a filename
- list of filenames
- one :obj:`mdtraj.Trajectory` object
- list of :obj:`mdtraj.Trajectory` objects
* pattern, e.g. "*.ext"
* one single string containing a filename
* one single :obj:`mdtraj.Trajectory` object
* one list containing
* just filenames
* just :obj:`mdtraj.Trajectory` objects
* a mix of filenames and :obj:`mdtraj.Trajectory` objects
Returns
-------
- for an input pattern, sorted trajectory filenames that match that pattern
- for filename, one list containing that filename
- for a list of filenames, a sorted list of filenames
- for one :obj:`mdtraj.Trajectory` object, a list containing that object
- list of :obj:`mdtraj.Trajectory` objects (i.e. does nothing)
outtrajs : list
A list of trajectories. This list can be, depending on the input:
* for an input pattern: sorted trajectory filenames that match that pattern
* for filename or an :obj:`mdtraj.Trajectory`:
one list containing that filename or :obj:`mdtraj.Trajectory` object
* for a list, that same list (i.e. nothing happens)
"""
if isinstance(trajectories,str):
_trajectories = _glob(trajectories)
_trajectories = _natsorted(_glob(trajectories))
if len(_trajectories)==0:
raise FileNotFoundError("Couldn't find (or pattern-match) anything to '%s'.\n"
"ls $CWD[%s]:\n%s:"%(trajectories,
Expand All @@ -132,15 +134,13 @@ def get_sorted_trajectories(trajectories):
else:
trajectories=_trajectories

if isinstance(trajectories[0],str):
xtcs = _natsorted(trajectories)
elif isinstance(trajectories, _md.Trajectory):
xtcs = [trajectories]
if type(trajectories) in [_md.Trajectory, str]:
outtrajs = [trajectories]
else:
assert all([isinstance(itraj, _md.Trajectory) for itraj in trajectories])
xtcs = trajectories
assert all([type(itraj) in [_md.Trajectory, str] for itraj in trajectories])
outtrajs = trajectories

return xtcs
return outtrajs

def inform_about_trajectories(trajectories, only_show_first_and_last=False):
r"""
Expand Down
16 changes: 8 additions & 8 deletions tests/test_str_and_dict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,29 +24,29 @@ def setUp(self):
self.traj_reverse = md.load(test_filenames.traj_xtc_stride_20, top=self.geom.top)[::-1]

def test_glob_with_pattern(self):
str_and_dict.get_sorted_trajectories(path.join(test_filenames.example_path,"*.xtc"))
str_and_dict.get_trajectories_from_input(path.join(test_filenames.example_path, "*.xtc"))

def test_glob_with_filename(self):
str_and_dict.get_sorted_trajectories(test_filenames.traj_xtc_stride_20)
str_and_dict.get_trajectories_from_input(test_filenames.traj_xtc_stride_20)

def test_with_one_trajectory_object(self):
list_out = str_and_dict.get_sorted_trajectories(self.traj)
list_out = str_and_dict.get_trajectories_from_input(self.traj)
assert len(list_out)==1
assert isinstance(list_out[0], md.Trajectory)

def test_with_trajectory_objects(self):
str_and_dict.get_sorted_trajectories([self.traj,
self.traj_reverse])
str_and_dict.get_trajectories_from_input([self.traj,
self.traj_reverse])


def test_fails_if_not_traj_at_all(self):
with pytest.raises(FileNotFoundError):
str_and_dict.get_sorted_trajectories("bogus.xtc")
str_and_dict.get_trajectories_from_input("bogus.xtc")

def test_fails_if_not_trajs(self):
with pytest.raises(AssertionError):
str_and_dict.get_sorted_trajectories([self.traj,
1])
str_and_dict.get_trajectories_from_input([self.traj,
1])


class Test_inform_about_trajectories(unittest.TestCase):
Expand Down

0 comments on commit df67be7

Please sign in to comment.