@@ -819,6 +819,7 @@ def fit(self, X, y):
819
819
# init active set
820
820
ActiveSet = np .ones_like (beta )
821
821
822
+ self ._convergence = list ()
822
823
# Iterative updates
823
824
for t in range (0 , self .max_iter ):
824
825
self .n_iter_ += 1
@@ -855,7 +856,9 @@ def fit(self, X, y):
855
856
856
857
# Convergence by relative parameter change tolerance
857
858
norm_update = np .linalg .norm (beta - beta_old )
858
- if t > 1 and (norm_update / np .linalg .norm (beta )) < tol :
859
+ norm_update /= np .linalg .norm (beta )
860
+ self ._convergence .append (norm_update )
861
+ if t > 1 and self ._convergence [- 1 ] < tol :
859
862
msg = ('\t Parameter update tolerance. ' +
860
863
'Converged in {0:d} iterations' .format (t ))
861
864
logger .info (msg )
@@ -880,6 +883,37 @@ def fit(self, X, y):
880
883
self .is_fitted_ = True
881
884
return self
882
885
886
+ def plot_convergence (self , ax = None , show = True ):
887
+ """Plots the convergence.
888
+
889
+ Parameters
890
+ ----------
891
+ ax : matplotlib.pyplot.axes object
892
+ If not None, plot in this axis.
893
+ show : bool
894
+ If True, call plt.show()
895
+
896
+ Returns
897
+ -------
898
+ fig : matplotlib.Figure
899
+ The matplotlib figure handle
900
+ """
901
+ import matplotlib .pyplot as plt
902
+
903
+ if ax is None :
904
+ fig , ax = plt .subplots (1 , 1 )
905
+
906
+ ax .semilogy (self ._convergence )
907
+ ax .set_xlim ((- 20 , self .max_iter + 20 ))
908
+ ax .axhline (self .tol , linestyle = '--' , color = 'r' , label = 'tol' )
909
+ ax .set_ylabel (r'$\Vert\beta_{t} - \beta_{t-1}\Vert/\Vert\beta_t\Vert$' )
910
+ ax .legend ()
911
+
912
+ if show :
913
+ plt .show ()
914
+
915
+ return ax .get_figure ()
916
+
883
917
def predict (self , X ):
884
918
"""Predict targets.
885
919
0 commit comments