Skip to content

Commit e9772a0

Browse files
committed
Set dropout's training mode to true in net % train(); add tests
1 parent c984b15 commit e9772a0

File tree

3 files changed

+64
-4
lines changed

3 files changed

+64
-4
lines changed

src/nf/nf_dropout_layer_submodule.f90

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,13 @@ pure module subroutine backward(self, input, gradient)
6666
real, intent(in) :: input(:)
6767
real, intent(in) :: gradient(:)
6868

69-
! Backpropagate gradient through dropout mask
70-
self % gradient = gradient * self % mask * self % scale
69+
if (self % training) then
70+
! Backpropagate gradient through dropout mask
71+
self % gradient = gradient * self % mask * self % scale
72+
else
73+
! In inference mode, pass through the gradient unchanged
74+
self % gradient = gradient
75+
end if
7176
end subroutine backward
7277

7378
end submodule nf_dropout_layer_submodule

src/nf/nf_network_submodule.f90

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,12 +288,21 @@ module function predict_batch_1d(self, input) result(res)
288288
class(network), intent(in out) :: self
289289
real, intent(in) :: input(:,:)
290290
real, allocatable :: res(:,:)
291-
integer :: i, batch_size, num_layers, output_size
291+
integer :: i, n, batch_size, num_layers, output_size
292292

293293
num_layers = size(self % layers)
294294
batch_size = size(input, dim=rank(input))
295295
output_size = product(self % layers(num_layers) % layer_shape)
296296

297+
! predict is run in inference mode only;
298+
! set all dropout layers' training mode to false.
299+
do n = 2, num_layers
300+
select type(this_layer => self % layers(n) % p)
301+
type is(dropout_layer)
302+
this_layer % training = .false.
303+
end select
304+
end do
305+
297306
allocate(res(output_size, batch_size))
298307

299308
batch: do i = 1, size(res, dim=2)
@@ -318,12 +327,21 @@ module function predict_batch_3d(self, input) result(res)
318327
class(network), intent(in out) :: self
319328
real, intent(in) :: input(:,:,:,:)
320329
real, allocatable :: res(:,:)
321-
integer :: i, batch_size, num_layers, output_size
330+
integer :: i, n, batch_size, num_layers, output_size
322331

323332
num_layers = size(self % layers)
324333
batch_size = size(input, dim=rank(input))
325334
output_size = product(self % layers(num_layers) % layer_shape)
326335

336+
! predict is run in inference mode only;
337+
! set all dropout layers' training mode to false.
338+
do n = 2, num_layers
339+
select type(this_layer => self % layers(n) % p)
340+
type is(dropout_layer)
341+
this_layer % training = .false.
342+
end select
343+
end do
344+
327345
allocate(res(output_size, batch_size))
328346

329347
batch: do i = 1, batch_size
@@ -457,6 +475,14 @@ module subroutine train(self, input_data, output_data, batch_size, &
457475
self % loss = quadratic()
458476
end if
459477

478+
! Set all dropout layers' training mode to true.
479+
do n = 2, size(self % layers)
480+
select type(this_layer => self % layers(n) % p)
481+
type is(dropout_layer)
482+
this_layer % training = .true.
483+
end select
484+
end do
485+
460486
dataset_size = size(output_data, dim=2)
461487

462488
epoch_loop: do n = 1, epochs

test/test_dropout_layer.f90

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,16 @@ program test_dropout_layer
1919
select type(layer1_p => layer1 % p)
2020
type is(dropout_layer)
2121

22+
if (layer1_p % dropout_rate /= 0.5) then
23+
ok = .false.
24+
write(stderr, '(a)') 'dropout layer dropout rate should be 0.5.. failed'
25+
end if
26+
27+
if (layer1_p % training) then
28+
ok = .false.
29+
write(stderr, '(a)') 'dropout layer default training mode should be false.. failed'
30+
end if
31+
2232
if (layer1_p % input_size /= 0) then
2333
print *, 'input_size: ', layer1_p % input_size
2434
ok = .false.
@@ -32,6 +42,25 @@ program test_dropout_layer
3242

3343
end select
3444

45+
! Test setting training mode explicitly.
46+
layer1 = dropout(0.5, training=.true.)
47+
select type(layer1_p => layer1 % p)
48+
type is(dropout_layer)
49+
if (.not. layer1_p % training) then
50+
ok = .false.
51+
write(stderr, '(a)') 'dropout layer training mode should be true.. failed'
52+
end if
53+
end select
54+
55+
layer1 = dropout(0.5, training=.false.)
56+
select type(layer1_p => layer1 % p)
57+
type is(dropout_layer)
58+
if (layer1_p % training) then
59+
ok = .false.
60+
write(stderr, '(a)') 'dropout layer training mode should be false.. failed'
61+
end if
62+
end select
63+
3564
! Now we're gonna initialize a minimal network with an input layer and a
3665
! dropout that follows and we'll check that the dropout layer has expected
3766
! state.

0 commit comments

Comments
 (0)