Skip to content

Commit 05522e7

Browse files
committed
[plots.compare_violins] return the representative frames, tests
1 parent a1d36c6 commit 05522e7

File tree

2 files changed

+130
-88
lines changed

2 files changed

+130
-88
lines changed

mdciao/plots/plots.py

+101-63
Original file line numberDiff line numberDiff line change
@@ -1366,7 +1366,7 @@ def compare_violins(groups,
13661366
matching their contact labels, since the residue indices
13671367
might differ across :obj:`groups`. To achieve this:
13681368
* "K30-D40" is considered equivalent to "D40-D30",
1369-
use :obj:`key_separator` to change this.
1369+
use `key_separator` to change this.
13701370
* "K30-D40" is considered equivalent to "K30-E40"
13711371
if a :obj:`mutations_dict={"E40":"D40"}` is passed
13721372
* "[email protected]" is considered equivalent to "K30-D40"
@@ -1448,46 +1448,51 @@ def compare_violins(groups,
14481448
to least "formed" on the right of the plot.
14491449
However, for each residue pair, this mean is
14501450
an average over the distance in all
1451-
the different :obj:`groups`, so some
1451+
the different `groups`, so some
14521452
heterogeneity is expected. Alternatively,
14531453
you can sort using the contact labels,
14541454
regardless of the distance values. Note
14551455
that for this, string comparisons between
14561456
contact-labels will take place. and that
1457-
contact-labels are altered by :obj:`key_separator`
1458-
to unify across different :obj:`groups`
1459-
Try setting :obj:`key_separator` to None
1457+
contact-labels are altered by `key_separator`
1458+
to unify across different `groups`
1459+
Try setting `key_separator` to None
14601460
if you see unexpected behavior, although
14611461
though this might have other side effects,
1462-
(see obj:~`mdciao.utils.str_and_dict.unify_freq_dicts`)
1463-
:obj:`sort_by` can be a:
1464-
* str : 'residue'
1465-
Sort by ascending residue sequence index (resSeq),
1466-
which will be inferred from each contact label,
1467-
e.g. 30 for "[email protected]". See :obj:`~mdciao.contacts.ContactGroup.gen_ctc_labels`
1468-
for more info on how they are generated.
1469-
Internally, the order is generated via
1470-
:obj:`~mdciao.utils.str_and_dict.lexsort_ctc_labels`.
1471-
If you want to reverse or alter this
1472-
ascending default order, we recommend using
1473-
:obj:`~mdciao.utils.str_and_dict.lexsort_ctc_labels`
1474-
**before** calling :obj:`compare_violins` and use
1475-
its output (sorted_ctc_labels) as a list
1476-
argument for :obj:`sort_by`. Also note that
1477-
residue indices as contained in
1478-
:obj:`~mdciao.contacts.ContactGroup.res_idx_pairs`
1462+
(see :obj:`~mdciao.utils.str_and_dict.unify_freq_dicts`)
1463+
`sort_by` can be a:
1464+
* str : 'residue' or 'numeric'
1465+
Sort by ascending residue sequence index (resSeq),
1466+
which will be inferred from each contact label,
1467+
e.g. 30 for "[email protected]". See :obj:`~mdciao.contacts.ContactGroup.gen_ctc_labels`
1468+
for more info on how they are generated.
1469+
Internally, the order is generated via
1470+
:obj:`~mdciao.utils.str_and_dict.lexsort_ctc_labels`.
1471+
If you want to reverse or alter this
1472+
ascending default order, we recommend using
1473+
:obj:`~mdciao.utils.str_and_dict.lexsort_ctc_labels`
1474+
**before** calling :obj:`compare_violins` and use
1475+
its output (`labels`) as a list
1476+
argument for `sort_by`. Also note that
1477+
residue indices as contained in
1478+
:obj:`~mdciao.contacts.ContactGroup.res_idx_pairs`
1479+
* str : 'keep'
1480+
Sort using the same order of the labels as in
1481+
the first contact group
1482+
* str : 'consensus'
1483+
Sort following consensus nomenclature (GPCR, CGN or KLIFS)
14791484
* list : a list of contact labels,
1480-
eg. ["GLU30-ALA30", "[email protected]"].
1481-
Only these residue pairs (in this order)
1482-
will be shown, regardless of what other
1483-
pairs are contained in the :obj:`groups`. It
1484-
assumes the user knows what contacts
1485-
are present and can come up with a meaningful
1486-
list. Not all labels need to be in all
1487-
:obj:`groups` nor do all :obj:`groups`
1488-
have to contain all labels, but at least
1489-
one label needs to match, otherwise the
1490-
method will fail
1485+
eg. ["GLU30-ALA30", "[email protected]"].
1486+
Only these residue pairs (in this order)
1487+
will be shown, regardless of what other
1488+
pairs are contained in the `groups`. It
1489+
assumes the user knows what contacts
1490+
are present and can come up with a meaningful
1491+
list. Not all labels need to be in all
1492+
`groups` nor do all `groups`
1493+
have to contain all labels, but at least
1494+
one label needs to match, otherwise the
1495+
method will fail
14911496
zero_freq : float, default is 1e-2
14921497
Frequencies below this number will
14931498
be considered zero and not shown it they are
@@ -1506,37 +1511,41 @@ def compare_violins(groups,
15061511
can also be removed.
15071512
Only has an effect if `ctc_cutoff_Ang` is not None.
15081513
representatives : anything (bool, int, dict, list) default is None
1509-
Plot, with a small dot on top of the violins,
1510-
the values of the residue-residue distances of representative
1511-
geometries. The representative geometries can be parsed
1512-
directly as a dict of :obj:`~mdtraj.Trajectory` objects,
1513-
or extracted on-the-fly by calling the :obj:`mdciao.contacts.ContactGroup.repframes`
1514-
method of each of the `groups`. Check the docs of
1515-
:obj:`mdciao.contacts.ContactGroup.repframes` to find out what is meant
1516-
with "representative".
1514+
Include information about representative values in the
1515+
plot. This can be done in several ways. Easiest
1516+
is to let this method call :obj:`mdciao.contacts.ContactGroup.repframes`
1517+
internally. This will locate representative frames, extract
1518+
their residue-residue distance values and plot them as small dots
1519+
on top of the violins. When possible, also the geometries corresponding
1520+
to these frames will be returned. Alternatively, the user
1521+
can directly input a dictionary of :obj:`~mdtraj.Trajectory` objects
1522+
(representative or not) for which the residue-residue distance values
1523+
will be computed and plotted. Check the docs of
1524+
:obj:`mdciao.contacts.ContactGroup.repframes` to find out
1525+
what is meant with "representative".
15171526
This is what each type of input does:
15181527
15191528
* boolean True:
1520-
Calls :obj:`mdciao.ContactGroup.repframes` with the
1521-
method's default parameters and plots the result
1529+
Calls :obj:`mdciao.ContactGroup.repframes` with the
1530+
method's default parameters.
15221531
* int > 0:
1523-
Calls :obj:`mdciao.ContactGroup.repframes` with the
1524-
parameter `n_frames` set to this integer. This parameter
1525-
controls how many representatives are extracted and
1526-
subsequently plotted.
1532+
Calls :obj:`mdciao.ContactGroup.repframes` with the
1533+
parameter `n_frames` set to this integer. This parameter
1534+
controls how many representatives are extracted and
1535+
subsequently plotted.
15271536
* dict of parameters:
1528-
A dictionary with explict values for the optional
1529-
parameters of :obj:`mdciao.contacts.ContactGroup.repframes`,
1530-
usually `n_frames` (an int) and `scheme`, ("mean" or "mode"),
1531-
depending what you mean with "representative". Check the method's
1532-
documentation for more info.
1537+
A dictionary with explict values for the optional
1538+
parameters of :obj:`mdciao.contacts.ContactGroup.repframes`,
1539+
usually `n_frames` (an int) and `scheme`, ("mean" or "mode"),
1540+
depending what you mean with "representative". Check the method's
1541+
documentation for more info.
15331542
* dict of :obj:`~mdtraj.Trajectory` objects:
1534-
Has to have the same keys as `groups`. No checks are done
1535-
whether these objects match the actual molecular topologies
1536-
of `groups`, so beware of potential mismatches here.
1537-
Typically, these frames come from having used
1538-
:obj:`mdciao.contacts.ContactGroup.repframes` with
1539-
`return_traj`=True.
1543+
Has to have the same keys as `groups`. No checks are done
1544+
whether these objects match the actual molecular topologies
1545+
of `groups`, so beware of potential mismatches here.
1546+
Typically, these frames come from having used
1547+
:obj:`mdciao.contacts.ContactGroup.repframes` with
1548+
`return_traj`=True.
15401549
* dict of dicts containing values
15411550
#TODO not implemented yet
15421551
@@ -1547,6 +1556,12 @@ def compare_violins(groups,
15471556
labels : list
15481557
The list of plotted labels,
15491558
in the order they are plotted
1559+
repframes : dict
1560+
Will only be returned if
1561+
`representatives` was not None.
1562+
The representative frames for
1563+
each `group` according to the
1564+
parameters of `representatives`
15501565
"""
15511566
_fontsize=_rcParams["font.size"]
15521567
_rcParams["font.size"] = fontsize
@@ -1559,6 +1574,7 @@ def compare_violins(groups,
15591574
else:
15601575
_groups = groups
15611576
repframes_per_sys_per_ctc = {}
1577+
reptraj_per_sys_per_ctc = {}
15621578
for syskey, group in _groups.items():
15631579
labels = group.gen_ctc_labels(AA_format=AA_format,
15641580
fragments=[True if defrag is None else False][0],
@@ -1570,22 +1586,41 @@ def compare_violins(groups,
15701586
freqs_per_sys_per_ctc[syskey] = {key:freq for key, freq in zip(labels, group.frequency_per_contact(ctc_cutoff_Ang))}
15711587

15721588
if bool(representatives):
1589+
#Tune the kwargs on a per-case basis then call repframes only once,
1590+
# wrapped in the try block for when there's no files
1591+
repframes_kwargs = {"ctc_cutoff_Ang": ctc_cutoff_Ang,
1592+
"return_traj": True}
15731593
# Do we have representatives?
15741594
if isinstance(representatives, bool):
1575-
d = group.repframes(ctc_cutoff_Ang=ctc_cutoff_Ang)[2]
1595+
pass
15761596
if isinstance(representatives, int) and representatives>0:
1577-
d = group.repframes(ctc_cutoff_Ang=ctc_cutoff_Ang, n_frames=representatives)[2].T
1597+
repframes_kwargs.update({"n_frames" : representatives,
1598+
"verbose" : False})
15781599
if isinstance(representatives, dict) and len(representatives)>0:
15791600
if syskey not in representatives.keys() :
1601+
representatives.update(repframes_kwargs)
15801602
representatives.pop("ctc_cutoff_ang", None)
15811603
representatives.pop("show_violins", None)
1582-
d = group.repframes(**representatives)[2].T
1604+
representatives["return_traj"] = True
15831605
else:
15841606
assert isinstance(representatives[syskey], _md.Trajectory)
15851607
d = _md.compute_contacts(representatives[syskey], contacts=group.res_idxs_pairs)[0].T
1608+
traj = representatives[syskey]
1609+
repframes_kwargs = None
1610+
1611+
if repframes_kwargs is not None:
1612+
try:
1613+
__, __, d, traj = group.repframes(**repframes_kwargs)
1614+
except FileNotFoundError as e:
1615+
print(e)
1616+
repframes_kwargs["return_traj"] = False
1617+
__, __, d = group.repframes(**repframes_kwargs)
1618+
traj = None
1619+
d = d.T.squeeze()
1620+
15861621
repframes_per_sys_per_ctc[syskey] = {key: val * 10 for key, val in
15871622
zip(labels, d)}
1588-
1623+
reptraj_per_sys_per_ctc[syskey]=traj
15891624
representatives = bool(representatives)
15901625
# Unify data
15911626
data4violins_per_sys_per_ctc = _mdcu.str_and_dict.unify_freq_dicts(data4violins_per_sys_per_ctc,
@@ -1701,7 +1736,10 @@ def compare_violins(groups,
17011736
myfig.tight_layout()
17021737

17031738
_rcParams["font.size"] = _fontsize
1704-
return myfig, iax, list(key2ii.keys())
1739+
if repframes_per_sys_per_ctc != {}:
1740+
return myfig, iax, list(key2ii.keys()), reptraj_per_sys_per_ctc
1741+
else:
1742+
return myfig, iax, list(key2ii.keys())
17051743

17061744

17071745
def _sorter_by_key_or_val(sort_by, indict):

tests/test_plots.py

+29-25
Original file line numberDiff line numberDiff line change
@@ -741,11 +741,10 @@ def setUpClass(cls):
741741
cls.CGL394_larger = ContactGroupL394(ctc_control=.99, ctc_cutoff_Ang=5)
742742

743743
def test_works(self):
744-
fig, ax, sorted_keys = plots.compare_violins({"small":self.CGL394, "big":self.CGL394_larger},
745-
anchor="L394",
746-
ymax=10, ctc_cutoff_Ang=4)
747-
748-
#fig.savefig("test.pdf")
744+
fig, ax, sorted_keys = plots.compare_violins({"small": self.CGL394, "big": self.CGL394_larger},
745+
anchor="L394",
746+
ymax=10, ctc_cutoff_Ang=4)
747+
# fig.savefig("test.pdf")
749748
_plt.close("all")
750749

751750
def test_works_no_defrag_and_list_zero_freq_remove_identities(self):
@@ -772,46 +771,51 @@ def test_works_no_defrag_and_list_zero_freq_remove_identities(self):
772771

773772
def test_works_no_defrag_and_list(self):
774773
fig, ax, sorted_keys = plots.compare_violins([self.CGL394, self.CGL394_larger],
775-
anchor="L394",
776-
ymax=10, ctc_cutoff_Ang=4,
777-
defrag=None)
774+
anchor="L394",
775+
ymax=10, ctc_cutoff_Ang=4,
776+
defrag=None)
778777

779778
#fig.savefig("test.pdf")
780779
_plt.close("all")
781780

782781
def test_repframes_True(self):
783-
fig, ax, sorted_keys = plots.compare_violins({"small": self.CGL394, "big": self.CGL394_larger},
784-
anchor="L394",
785-
ymax=10, ctc_cutoff_Ang=4,
786-
representatives=True)
787-
782+
fig, ax, sorted_keys, repframes = plots.compare_violins({"small": self.CGL394, "big": self.CGL394_larger},
783+
anchor="L394",
784+
ymax=10, ctc_cutoff_Ang=4,
785+
representatives=True)
788786

787+
assert repframes["small"] is None
788+
assert repframes["big"] is None
789789
#fig.savefig("test.pdf")
790790
_plt.close("all")
791791

792792
def test_repframes_int(self):
793-
fig, ax, sorted_keys = plots.compare_violins({"small": self.CGL394, "big": self.CGL394_larger},
794-
anchor="L394",
795-
ymax=10, ctc_cutoff_Ang=4,
796-
representatives=2)
793+
fig, ax, sorted_keys, repframes = plots.compare_violins({"small": self.CGL394, "big": self.CGL394_larger},
794+
anchor="L394",
795+
ymax=10, ctc_cutoff_Ang=4,
796+
representatives=2)
797+
assert repframes["small"] is None
798+
assert repframes["big"] is None
797799
# fig.savefig("test.pdf")
798800
_plt.close("all")
799801

800802
def test_repframes_dict_kwargs(self):
801-
fig, ax, sorted_keys = plots.compare_violins({"small": self.CGL394, "big": self.CGL394_larger},
802-
anchor="L394",
803-
ymax=10, ctc_cutoff_Ang=4,
804-
representatives={"n_frames":3, "scheme":"mean"})
803+
fig, ax, sorted_keys, repframes_out = plots.compare_violins({"small": self.CGL394, "big": self.CGL394_larger},
804+
anchor="L394",
805+
ymax=10, ctc_cutoff_Ang=4,
806+
representatives={"n_frames": 3, "scheme": "mean"})
805807
#fig.savefig("test.pdf")
806808
_plt.close("all")
807809

808810
def test_repframes_dict_geoms(self):
809811
traj = _md.load(test_filenames.traj_xtc_stride_20, top=test_filenames.top_pdb)
810812
repframes = {"small" : traj[:3],
811813
"big" : traj[:5]}
812-
fig, ax, sorted_keys = plots.compare_violins({"small": self.CGL394, "big": self.CGL394_larger},
813-
anchor="L394",
814-
ymax=10, ctc_cutoff_Ang=4,
815-
representatives=repframes)
814+
fig, ax, sorted_keys, repframes_out = plots.compare_violins({"small": self.CGL394, "big": self.CGL394_larger},
815+
anchor="L394",
816+
ymax=10, ctc_cutoff_Ang=4,
817+
representatives=repframes)
818+
for key, val in repframes_out.items():
819+
assert val is repframes[key]
816820
#fig.savefig("test.pdf")
817821
_plt.close("all")

0 commit comments

Comments
 (0)