@@ -18,6 +18,7 @@ module nf_optimizers
18
18
type, abstract :: optimizer_base_type
19
19
real :: learning_rate = 0.01
20
20
contains
21
+ procedure :: get_name
21
22
procedure (init), deferred :: init
22
23
procedure (minimize), deferred :: minimize
23
24
end type optimizer_base_type
@@ -312,4 +313,52 @@ pure subroutine minimize_adagrad(self, param, gradient)
312
313
313
314
end subroutine minimize_adagrad
314
315
315
- end module nf_optimizers
316
+
317
+ ! Utility Functions
318
+ ! ! Returns the default optimizer corresponding to the provided name
319
+ pure function get_optimizer_by_name (optimizer_name ) result(res)
320
+ character (len=* ), intent (in ) :: optimizer_name
321
+ class(optimizer_base_type), allocatable :: res
322
+
323
+ select case (trim (optimizer_name))
324
+ case (' adagrad' )
325
+ allocate ( res, source = adagrad() )
326
+
327
+ case (' adam' )
328
+ allocate ( res, source = adam() )
329
+
330
+ case (' rmsprop' )
331
+ allocate ( res, source = rmsprop() )
332
+
333
+ case (' sgd' )
334
+ allocate ( res, source = sgd() )
335
+
336
+ case default
337
+ error stop ' optimizer_name must be one of: ' // &
338
+ ' "adagrad", "adam", "rmsprop", "sgd".'
339
+ end select
340
+
341
+ end function get_optimizer_by_name
342
+
343
+
344
+ ! ! Returns the name of the optimizer
345
+ pure function get_name (self ) result(name)
346
+ class(optimizer_base_type), intent (in ) :: self
347
+ character (:), allocatable :: name
348
+
349
+ select type (self)
350
+ class is (adagrad)
351
+ name = ' adagrad'
352
+ class is (adam)
353
+ name = ' adam'
354
+ class is (rmsprop)
355
+ name = ' rmsprop'
356
+ class is (sgd)
357
+ name = ' sgd'
358
+ class default
359
+ error stop ' Unknown optimizer type.'
360
+ end select
361
+
362
+ end function get_name
363
+
364
+ end module nf_optimizers
0 commit comments