- 
                Notifications
    
You must be signed in to change notification settings  - Fork 233
 
Add iterator interface #745
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
b10f27e
              cc594a8
              d6a6d75
              2117364
              b7b63e8
              8a64efb
              0826be0
              b0e5c30
              a22346a
              2310160
              15c8b62
              75ece54
              54a8542
              bc42ad8
              afaa7e9
              cc5ebfc
              56e0795
              dc03af8
              3daa277
              35ffc80
              e027fe9
              ea17c1c
              9faa883
              7620e4f
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -1,4 +1,7 @@ | ||
| Base.summary(r::OptimizationResults) = summary(r.method) # might want to do more here than just return summary of the method used | ||
| _method(r::OptimizationResults) = r.method | ||
| 
     | 
||
| Base.summary(r::Union{OptimizationResults, OptimIterator}) = | ||
| summary(_method(r)) # might want to do more here than just return summary of the method used | ||
| minimizer(r::OptimizationResults) = r.minimizer | ||
| minimum(r::OptimizationResults) = r.minimum | ||
| iterations(r::OptimizationResults) = r.iterations | ||
| 
          
            
          
           | 
    @@ -35,24 +38,26 @@ end | |
| x_upper_trace(r::MultivariateOptimizationResults) = | ||
| error("x_upper_trace is not implemented for $(summary(r)).") | ||
| 
     | 
||
| x_trace(ot::OptimIterator, os::IteratorState) = x_trace(OptimizationResults(ot, os)) | ||
| function x_trace(r::MultivariateOptimizationResults) | ||
| tr = trace(r) | ||
| if isa(r.method, NelderMead) | ||
| if isa(_method(r), NelderMead) | ||
| throw( | ||
| ArgumentError( | ||
| "Nelder Mead does not operate with a single x. Please use either centroid_trace(...) or simplex_trace(...) to extract the relevant points from the trace.", | ||
| ), | ||
| ) | ||
| end | ||
| !haskey(tr[1].metadata, "x") && error( | ||
| "Trace does not contain x. To get a trace of x, run optimize() with extended_trace = true", | ||
| "Trace does not contain x. To get a trace of x, run optimize() with extended_trace = true and make sure x is stored in the trace for the method of choice.", | ||
| ) | ||
| [state.metadata["x"] for state in tr] | ||
| end | ||
| 
     | 
||
| centroid_trace(ot::OptimIterator, os::IteratorState) = centroid_trace(OptimizationResults(ot, os)) | ||
| function centroid_trace(r::MultivariateOptimizationResults) | ||
| tr = trace(r) | ||
| 
         
      
  
    Contributor
      
  Author
   
  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suppose  I'm including these changes (adding   | 
||
| if !isa(r.method, NelderMead) | ||
| if !isa(_method(r), NelderMead) | ||
| throw( | ||
| ArgumentError( | ||
| "There is no centroid involved in optimization using $(r.method). Please use x_trace(...) to grab the points from the trace.", | ||
| 
        
          
        
         | 
    @@ -64,9 +69,10 @@ function centroid_trace(r::MultivariateOptimizationResults) | |
| ) | ||
| [state.metadata["centroid"] for state in tr] | ||
| end | ||
| simplex_trace(ot::OptimIterator, os::IteratorState) = simplex_trace(OptimizationResults(ot, os)) | ||
| function simplex_trace(r::MultivariateOptimizationResults) | ||
| tr = trace(r) | ||
| if !isa(r.method, NelderMead) | ||
| if !isa(_method(r), NelderMead) | ||
| throw( | ||
| ArgumentError( | ||
| "There is no simplex involved in optimization using $(r.method). Please use x_trace(...) to grab the points from the trace.", | ||
| 
        
          
        
         | 
    @@ -78,9 +84,10 @@ function simplex_trace(r::MultivariateOptimizationResults) | |
| ) | ||
| [state.metadata["simplex"] for state in tr] | ||
| end | ||
| simplex_value_trace(ot::OptimIterator, os::IteratorState) = simplex_value_trace(OptimizationResults(ot, os)) | ||
| function simplex_value_trace(r::MultivariateOptimizationResults) | ||
| tr = trace(r) | ||
| if !isa(r.method, NelderMead) | ||
| if !isa(_method(r), NelderMead) | ||
| throw( | ||
| ArgumentError( | ||
| "There are no simplex values involved in optimization using $(r.method). Please use f_trace(...) to grab the objective values from the trace.", | ||
| 
        
          
        
         | 
    @@ -94,9 +101,11 @@ function simplex_value_trace(r::MultivariateOptimizationResults) | |
| end | ||
| 
     | 
||
| 
     | 
||
| f_trace(ot::OptimIterator, os::IteratorState) = f_trace(OptimizationResults(ot, os)) | ||
| f_trace(r::OptimizationResults) = [state.value for state in trace(r)] | ||
| g_norm_trace(r::OptimizationResults) = | ||
| error("g_norm_trace is not implemented for $(summary(r)).") | ||
| g_norm_trace(ot::OptimIterator, os::IteratorState) = g_norm_trace(OptimizationResults(ot, os)) | ||
| g_norm_trace(r::MultivariateOptimizationResults) = [state.g_norm for state in trace(r)] | ||
| 
     | 
||
| f_calls(r::OptimizationResults) = r.f_calls | ||
| 
        
          
        
         | 
    @@ -114,7 +123,8 @@ h_calls(d) = first(d.h_calls) | |
| h_calls(d::TwiceDifferentiableHV) = first(d.hv_calls) | ||
| 
     | 
||
| converged(r::UnivariateOptimizationResults) = r.stopped_by.converged | ||
| function converged(r::MultivariateOptimizationResults) | ||
| converged(ot::OptimIterator, os::IteratorState) = converged(OptimizationResults(ot, os)) | ||
| function converged(r::Union{MultivariateOptimizationResults, IteratorState}) | ||
| conv_flags = r.stopped_by.x_converged || r.stopped_by.f_converged || r.stopped_by.g_converged | ||
| x_isfinite = isfinite(x_abschange(r)) || isnan(x_relchange(r)) | ||
| f_isfinite = if r.stopped_by.iterations > 0 | ||
| 
          
            
          
           | 
    ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -1,105 +0,0 @@ | ||
| Base.@deprecate method(x) summary(x) | ||
| 
     | 
||
| const has_deprecated_fminbox = Ref(false) | ||
| function optimize( | ||
| df::OnceDifferentiable, | ||
| initial_x::Array{T}, | ||
| l::Array{T}, | ||
| u::Array{T}, | ||
| ::Type{Fminbox}; | ||
| x_tol::T = eps(T), | ||
| f_tol::T = sqrt(eps(T)), | ||
| g_tol::T = sqrt(eps(T)), | ||
| allow_f_increases::Bool = true, | ||
| iterations::Integer = 1_000, | ||
| store_trace::Bool = false, | ||
| show_trace::Bool = false, | ||
| extended_trace::Bool = false, | ||
| show_warnings::Bool = true, | ||
| callback = nothing, | ||
| show_every::Integer = 1, | ||
| linesearch = LineSearches.HagerZhang{T}(), | ||
| eta::Real = convert(T, 0.4), | ||
| mu0::T = convert(T, NaN), | ||
| mufactor::T = convert(T, 0.001), | ||
| precondprep = (P, x, l, u, mu) -> precondprepbox!(P, x, l, u, mu), | ||
| optimizer = ConjugateGradient, | ||
| optimizer_o = Options( | ||
| store_trace = store_trace, | ||
| show_trace = show_trace, | ||
| extended_trace = extended_trace, | ||
| show_warnings = show_warnings, | ||
| ), | ||
| nargs..., | ||
| ) where {T<:AbstractFloat} | ||
| if !has_deprecated_fminbox[] | ||
| @warn( | ||
| "Fminbox with the optimizer keyword is deprecated, construct Fminbox{optimizer}() and pass it to optimize(...) instead." | ||
| ) | ||
| has_deprecated_fminbox[] = true | ||
| end | ||
| optimize( | ||
| df, | ||
| initial_x, | ||
| l, | ||
| u, | ||
| Fminbox{optimizer}(); | ||
| allow_f_increases = allow_f_increases, | ||
| iterations = iterations, | ||
| store_trace = store_trace, | ||
| show_trace = show_trace, | ||
| extended_trace = extended_trace, | ||
| show_warnings = show_warnings, | ||
| show_every = show_every, | ||
| callback = callback, | ||
| linesearch = linesearch, | ||
| eta = eta, | ||
| mu0 = mu0, | ||
| mufactor = mufactor, | ||
| precondprep = precondprep, | ||
| optimizer_o = optimizer_o, | ||
| ) | ||
| end | ||
| 
     | 
||
| function optimize(::AbstractObjective) | ||
| throw( | ||
| ErrorException( | ||
| "Optimizing an objective `obj` without providing an initial `x` has been deprecated without backwards compatability. Please explicitly provide an `x`: `optimize(obj, x)``", | ||
| ), | ||
| ) | ||
| end | ||
| function optimize(::AbstractObjective, ::Method) | ||
| throw( | ||
| ErrorException( | ||
| "Optimizing an objective `obj` without providing an initial `x` has been deprecated without backwards compatability. Please explicitly provide an `x`: `optimize(obj, x, method)``", | ||
| ), | ||
| ) | ||
| end | ||
| function optimize(::AbstractObjective, ::Method, ::Options) | ||
| throw( | ||
| ErrorException( | ||
| "Optimizing an objective `obj` without providing an initial `x` has been deprecated without backwards compatability. Please explicitly provide an `x`: `optimize(obj, x, method, options)``", | ||
| ), | ||
| ) | ||
| end | ||
| function optimize(::AbstractObjective, ::Options) | ||
| throw( | ||
| ErrorException( | ||
| "Optimizing an objective `obj` without providing an initial `x` has been deprecated without backwards compatability. Please explicitly provide an `x`: `optimize(obj, x, options)``", | ||
| ), | ||
| ) | ||
| end | ||
| 
     | 
||
| function optimize( | ||
| df::OnceDifferentiable, | ||
| l::Array{T}, | ||
| u::Array{T}, | ||
| F::Fminbox{O}; | ||
| kwargs..., | ||
| ) where {T<:AbstractFloat,O<:AbstractOptimizer} | ||
| throw( | ||
| ErrorException( | ||
| "Optimizing an objective `obj` without providing an initial `x` has been deprecated without backwards compatability. Please explicitly provide an `x`: `optimize(obj, x, l, u, method, options)``", | ||
| ), | ||
| ) | ||
| end | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| 
          
            
          
           | 
    @@ -159,7 +159,7 @@ promote_objtype( | |
| ) = td | ||
| 
     | 
||
| # if no method or options are present | ||
| function optimize( | ||
| function optimizing( | ||
| f, | ||
| initial_x::AbstractArray; | ||
| inplace = true, | ||
| 
        
          
        
         | 
    @@ -173,9 +173,9 @@ function optimize( | |
| add_default_opts!(checked_kwargs, method) | ||
| 
     | 
||
| options = Options(; checked_kwargs...) | ||
| optimize(d, initial_x, method, options) | ||
| optimizing(d, initial_x, method, options) | ||
| end | ||
| function optimize( | ||
| function optimizing( | ||
| f, | ||
| g, | ||
| initial_x::AbstractArray; | ||
| 
        
          
        
         | 
    @@ -190,9 +190,9 @@ function optimize( | |
| add_default_opts!(checked_kwargs, method) | ||
| 
     | 
||
| options = Options(; checked_kwargs...) | ||
| optimize(d, initial_x, method, options) | ||
| optimizing(d, initial_x, method, options) | ||
| end | ||
| function optimize( | ||
| function optimizing( | ||
| f, | ||
| g, | ||
| h, | ||
| 
        
          
        
         | 
    @@ -208,19 +208,15 @@ function optimize( | |
| add_default_opts!(checked_kwargs, method) | ||
| 
     | 
||
| options = Options(; checked_kwargs...) | ||
| optimize(d, initial_x, method, options) | ||
| optimizing(d, initial_x, method, options) | ||
| end | ||
| 
     | 
||
| # no method supplied with objective | ||
| function optimize( | ||
| d::T, | ||
| initial_x::AbstractArray, | ||
| options::Options, | ||
| ) where {T<:AbstractObjective} | ||
| optimize(d, initial_x, fallback_method(d), options) | ||
| function optimizing(d::AbstractObjective, initial_x::AbstractArray, options::Options) | ||
| optimizing(d, initial_x, fallback_method(d), options) | ||
| end | ||
| # no method supplied with inplace and autodiff keywords becauase objective is not supplied | ||
| function optimize( | ||
| function optimizing( | ||
| f, | ||
| initial_x::AbstractArray, | ||
| options::Options; | ||
| 
        
          
        
         | 
    @@ -229,9 +225,9 @@ function optimize( | |
| ) | ||
| method = fallback_method(f) | ||
| d = promote_objtype(method, initial_x, autodiff, inplace, f) | ||
| optimize(d, initial_x, method, options) | ||
| optimizing(d, initial_x, method, options) | ||
| end | ||
| function optimize( | ||
| function optimizing( | ||
| f, | ||
| g, | ||
| initial_x::AbstractArray, | ||
| 
        
          
        
         | 
    @@ -242,9 +238,9 @@ function optimize( | |
| 
     | 
||
| method = fallback_method(f, g) | ||
| d = promote_objtype(method, initial_x, autodiff, inplace, f, g) | ||
| optimize(d, initial_x, method, options) | ||
| optimizing(d, initial_x, method, options) | ||
| end | ||
| function optimize( | ||
| function optimizing( | ||
| f, | ||
| g, | ||
| h, | ||
| 
        
          
        
         | 
    @@ -257,11 +253,11 @@ function optimize( | |
| method = fallback_method(f, g, h) | ||
| d = promote_objtype(method, initial_x, autodiff, inplace, f, g, h) | ||
| 
     | 
||
| optimize(d, initial_x, method, options) | ||
| optimizing(d, initial_x, method, options) | ||
| end | ||
| 
     | 
||
| # potentially everything is supplied (besides caches) | ||
| function optimize( | ||
| function optimizing( | ||
| f, | ||
| initial_x::AbstractArray, | ||
| method::AbstractOptimizer, | ||
| 
        
          
        
         | 
    @@ -271,7 +267,7 @@ function optimize( | |
| ) | ||
| 
     | 
||
| d = promote_objtype(method, initial_x, autodiff, inplace, f) | ||
| optimize(d, initial_x, method, options) | ||
| optimizing(d, initial_x, method, options) | ||
| end | ||
| function optimize( | ||
| f, | ||
| 
        
          
        
         | 
    @@ -286,7 +282,7 @@ function optimize( | |
| d = promote_objtype(method, initial_x, autodiff, inplace, f) | ||
| optimize(d, c, initial_x, method, options) | ||
| end | ||
| function optimize( | ||
| function optimizing( | ||
| f, | ||
| g, | ||
| initial_x::AbstractArray, | ||
| 
        
          
        
         | 
    @@ -298,9 +294,9 @@ function optimize( | |
| 
     | 
||
| d = promote_objtype(method, initial_x, autodiff, inplace, f, g) | ||
| 
     | 
||
| optimize(d, initial_x, method, options) | ||
| optimizing(d, initial_x, method, options) | ||
| end | ||
| function optimize( | ||
| function optimizing( | ||
| f, | ||
| g, | ||
| h, | ||
| 
        
          
        
         | 
    @@ -313,17 +309,29 @@ function optimize( | |
| 
     | 
||
| d = promote_objtype(method, initial_x, autodiff, inplace, f, g, h) | ||
| 
     | 
||
| optimize(d, initial_x, method, options) | ||
| optimizing(d, initial_x, method, options) | ||
| end | ||
| 
     | 
||
| function optimize( | ||
| function optimizing( | ||
| d::D, | ||
| initial_x::AbstractArray, | ||
| method::SecondOrderOptimizer, | ||
| options::Options = Options(; default_options(method)...); | ||
| autodiff = :finite, | ||
| inplace = true, | ||
| ) where {D<:Union{NonDifferentiable,OnceDifferentiable}} | ||
| 
     | 
||
| d = promote_objtype(method, initial_x, autodiff, inplace, d) | ||
| optimize(d, initial_x, method, options) | ||
| optimizing(d, initial_x, method, options) | ||
| end | ||
| 
     | 
||
| function optimize(args...; kwargs...) | ||
| local istate | ||
| iter = optimizing(args...; kwargs...) | ||
| for istate′ in iter | ||
| istate = istate′ | ||
| end | ||
| # We can safely assume that `istate` is defined at this point. That is to say, | ||
| # `OptimIterator` guarantees that `iterate(::OptimIterator) !== nothing`. | ||
| 
         
      Comment on lines
    
      +334
     to 
      +335
    
   
  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think JET won't agree with this comment 😄 Generally, the code above seems a bit unfortunate... Maybe  I also wonder, is there no utility in Julia for directly obtaining the last state of an iterator?  | 
||
| return OptimizationResults(iter, istate) | ||
| end | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Base.summaryshould be overloaded asBase.summary(::IO, x)