@@ -819,6 +819,7 @@ def fit(self, X, y):
819819 # init active set
820820 ActiveSet = np .ones_like (beta )
821821
822+ self ._convergence = list ()
822823 # Iterative updates
823824 for t in range (0 , self .max_iter ):
824825 self .n_iter_ += 1
@@ -855,7 +856,9 @@ def fit(self, X, y):
855856
856857 # Convergence by relative parameter change tolerance
857858 norm_update = np .linalg .norm (beta - beta_old )
858- if t > 1 and (norm_update / np .linalg .norm (beta )) < tol :
859+ norm_update /= np .linalg .norm (beta )
860+ self ._convergence .append (norm_update )
861+ if t > 1 and self ._convergence [- 1 ] < tol :
859862 msg = ('\t Parameter update tolerance. ' +
860863 'Converged in {0:d} iterations' .format (t ))
861864 logger .info (msg )
@@ -880,6 +883,37 @@ def fit(self, X, y):
880883 self .is_fitted_ = True
881884 return self
882885
886+ def plot_convergence (self , ax = None , show = True ):
887+ """Plots the convergence.
888+
889+ Parameters
890+ ----------
891+ ax : matplotlib.pyplot.axes object
892+ If not None, plot in this axis.
893+ show : bool
894+ If True, call plt.show()
895+
896+ Returns
897+ -------
898+ fig : matplotlib.Figure
899+ The matplotlib figure handle
900+ """
901+ import matplotlib .pyplot as plt
902+
903+ if ax is None :
904+ fig , ax = plt .subplots (1 , 1 )
905+
906+ ax .semilogy (self ._convergence )
907+ ax .set_xlim ((- 20 , self .max_iter + 20 ))
908+ ax .axhline (self .tol , linestyle = '--' , color = 'r' , label = 'tol' )
909+ ax .set_ylabel (r'$\Vert\beta_{t} - \beta_{t-1}\Vert/\Vert\beta_t\Vert$' )
910+ ax .legend ()
911+
912+ if show :
913+ plt .show ()
914+
915+ return ax .get_figure ()
916+
883917 def predict (self , X ):
884918 """Predict targets.
885919
0 commit comments