Skip to content

Commit 2724747

Browse files
committed
ENH: add plot for convergence
1 parent 8fd4c83 commit 2724747

File tree

6 files changed

+53
-8
lines changed

6 files changed

+53
-8
lines changed

doc/contributing.rst

+2
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,11 @@ The following should be installed in order to build the documentation.
161161

162162
* `sphinx <https://github.com/sphinx-doc/sphinx/>`_
163163
* `sphinx-gallery <https://github.com/sphinx-gallery/sphinx-gallery/>`_
164+
* `sphinx_bootstrap_theme <https://github.com/ryan-roemer/sphinx-bootstrap-theme>`_
164165
* `pillow <https://github.com/python-pillow/Pillow/>`_
165166
* `numpydoc <https://github.com/numpy/numpydoc/>`_
166167
* `matplotlib <https://github.com/matplotlib/matplotlib/>`_
168+
* `spykes <http://kordinglab.com/spykes/>`_
167169

168170
Shortcut:
169171

doc/install.rst

+4-7
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,22 @@ Check dependencies
66
------------------
77
We currently support ``Python 3.5+``.
88

9-
For the package: ``numpy>=1.11``, ``scipy>=0.17``, ``scikit-learn>=0.18``
9+
For the package: ``numpy>=1.11`` and ``scipy>=0.17``. For the plotting code,
10+
optionally install ``matplotlib``.
1011

11-
Additionally, for examples: ``pandas>=0.20``
12+
Additionally, for examples: ``pandas>=0.20`` and ``scikit-learn>=0.18``.
13+
Our library is ``scikit-learn`` compatible.
1214

1315
Both `Canopy <https://www.enthought.com/products/canopy/>`__
1416
and `Anaconda <https://www.continuum.io/downloads>`__
1517
ship with a recent version of all these packages.
1618

17-
Additionally, for development, tests and coverage: ``pytest``, ``pytest-cov``, ``coverage``, ``flake8``
18-
19-
Additionally, for building documentation: ``sphinx``, ``sphinx-gallery``, ``sphinx_bootstrap_theme``, ``pillow``, ``numpydoc``, ``matplotlib``, ``spykes``
20-
2119
In case you have other distributions of Python, you can install
2220
the dependencies using ``pip``.
2321

2422
.. code-block:: bash
2523
2624
pip install numpy scipy
27-
pip install -U scikit-learn
2825
2926
Get pyglmnet
3027
------------

doc/whats_new.rst

+3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ Release notes
1111
Current
1212
-------
1313

14+
- Add method :method:`pyglmnet.GLM.plot_convergence` to inspect convergence by `Mainak Jas`_.
15+
1416
Changelog
1517
~~~~~~~~~
1618

@@ -20,6 +22,7 @@ BUG
2022
- Graceful handling of small Hessian term in coordinate descent solver that led to exploding update term by `Pavan Ramkumar`_.
2123
- Ensure full compatibility of `GLM` class with `scikit-learn` by `Titipat Achakulvisut`_.
2224

25+
2326
API
2427
~~~
2528

examples/plot_community_crime.py

+4
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,7 @@
8888
ax.spines['top'].set_visible(False)
8989
ax.spines['right'].set_visible(False)
9090
plt.show()
91+
92+
########################################################
93+
# We can also check if the algorithm converged properly
94+
glmcv.best_estimator_.plot_convergence()

pyglmnet/pyglmnet.py

+35-1
Original file line numberDiff line numberDiff line change
@@ -819,6 +819,7 @@ def fit(self, X, y):
819819
# init active set
820820
ActiveSet = np.ones_like(beta)
821821

822+
self._convergence = list()
822823
# Iterative updates
823824
for t in range(0, self.max_iter):
824825
self.n_iter_ += 1
@@ -855,7 +856,9 @@ def fit(self, X, y):
855856

856857
# Convergence by relative parameter change tolerance
857858
norm_update = np.linalg.norm(beta - beta_old)
858-
if t > 1 and (norm_update / np.linalg.norm(beta)) < tol:
859+
norm_update /= np.linalg.norm(beta)
860+
self._convergence.append(norm_update)
861+
if t > 1 and self._convergence[-1] < tol:
859862
msg = ('\tParameter update tolerance. ' +
860863
'Converged in {0:d} iterations'.format(t))
861864
logger.info(msg)
@@ -880,6 +883,37 @@ def fit(self, X, y):
880883
self.is_fitted_ = True
881884
return self
882885

886+
def plot_convergence(self, ax=None, show=True):
887+
"""Plots the convergence.
888+
889+
Parameters
890+
----------
891+
ax : matplotlib.pyplot.axes object
892+
If not None, plot in this axis.
893+
show : bool
894+
If True, call plt.show()
895+
896+
Returns
897+
-------
898+
fig : matplotlib.Figure
899+
The matplotlib figure handle
900+
"""
901+
import matplotlib.pyplot as plt
902+
903+
if ax is None:
904+
fig, ax = plt.subplots(1, 1)
905+
906+
ax.semilogy(self._convergence)
907+
ax.set_xlim((-20, self.max_iter + 20))
908+
ax.axhline(self.tol, linestyle='--', color='r', label='tol')
909+
ax.set_ylabel(r'$\Vert\beta_{t} - \beta_{t-1}\Vert/\Vert\beta_t\Vert$')
910+
ax.legend()
911+
912+
if show:
913+
plt.show()
914+
915+
return ax.get_figure()
916+
883917
def predict(self, X):
884918
"""Predict targets.
885919

tests/test_pyglmnet.py

+5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import matplotlib
2+
13
import subprocess
24
import os.path as op
35

@@ -19,6 +21,8 @@
1921
from pyglmnet import (GLM, GLMCV, _grad_L2loss, _L2loss, simulate_glm,
2022
_gradhess_logloss_1d, _loss, datasets, ALLOWED_DISTRS)
2123

24+
matplotlib.use('agg')
25+
2226

2327
def test_glm_estimator():
2428
"""Test GLM class using scikit-learn's check_estimator."""
@@ -482,6 +486,7 @@ def test_api_input():
482486
glm.fit(X, y)
483487
glm.predict(X)
484488
glm.score(X, y)
489+
glm.plot_convergence()
485490
glm = GLM(distr='gaussian', solver='test')
486491

487492
with pytest.raises(ValueError, match="solver must be one of"):

0 commit comments

Comments
 (0)