Skip to content

Commit 8cf5cb5

Browse files
committed
linear2d_layer: temporarily remove api
1 parent 3d793f3 commit 8cf5cb5

File tree

4 files changed

+32
-32
lines changed

4 files changed

+32
-32
lines changed

src/nf/nf_layer_constructors.f90

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,8 @@ module function reshape(output_shape) result(res)
166166
!! Resulting layer instance
167167
end function reshape
168168

169-
module function linear2d(sequence_length, in_features, out_features, batch_size) result(res)
170-
integer, intent(in) :: batch_size, sequence_length, in_features, out_features
169+
module function linear2d(sequence_length, in_features, out_features) result(res)
170+
integer, intent(in) :: sequence_length, in_features, out_features
171171
type(layer) :: res
172172
end function linear2d
173173

src/nf/nf_layer_constructors_submodule.f90

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,13 +135,13 @@ module function reshape(output_shape) result(res)
135135

136136
end function reshape
137137

138-
module function linear2d(sequence_length, in_features, out_features, batch_size) result(res)
139-
integer, intent(in) :: batch_size, sequence_length, in_features, out_features
138+
module function linear2d(sequence_length, in_features, out_features) result(res)
139+
integer, intent(in) :: sequence_length, in_features, out_features
140140
type(layer) :: res
141141

142142
res % name = 'linear2d'
143-
res % layer_shape = [sequence_length, out_features, batch_size]
144-
allocate(res % p, source=linear2d_layer(sequence_length, in_features, out_features, batch_size))
143+
res % layer_shape = [sequence_length, out_features]
144+
allocate(res % p, source=linear2d_layer(sequence_length, in_features, out_features))
145145
end function linear2d
146146

147147
end submodule nf_layer_constructors_submodule

src/nf/nf_layer_submodule.f90

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ pure module subroutine backward_1d(self, previous, gradient)
4545
call this_layer % backward(prev_layer % output, gradient)
4646
type is(maxpool2d_layer)
4747
call this_layer % backward(prev_layer % output, gradient)
48-
type is(linear2d_layer)
49-
call this_layer % backward(prev_layer % output, gradient)
48+
! type is(linear2d_layer)
49+
! call this_layer % backward(prev_layer % output, gradient)
5050
end select
5151

5252
end select
@@ -104,11 +104,11 @@ pure module subroutine backward_3d(self, previous, gradient)
104104
call this_layer % backward(prev_layer % output, gradient)
105105
end select
106106

107-
type is(linear2d_layer)
108-
select type(prev_layer => previous % p)
109-
type is(input3d_layer)
110-
call this_layer % backward(prev_layer % output, gradient)
111-
end select
107+
! type is(linear2d_layer)
108+
! select type(prev_layer => previous % p)
109+
! type is(input3d_layer)
110+
! call this_layer % backward(prev_layer % output, gradient)
111+
! end select
112112

113113
end select
114114

@@ -174,8 +174,8 @@ pure module subroutine forward(self, input)
174174
call this_layer % forward(prev_layer % output)
175175
type is(reshape3d_layer)
176176
call this_layer % forward(prev_layer % output)
177-
type is(linear2d_layer)
178-
call this_layer % forward(prev_layer % output)
177+
! type is(linear2d_layer)
178+
! call this_layer % forward(prev_layer % output)
179179
end select
180180

181181
type is(reshape3d_layer)
@@ -190,13 +190,13 @@ pure module subroutine forward(self, input)
190190
call this_layer % forward(prev_layer % output)
191191
end select
192192

193-
type is(linear2d_layer)
194-
select type(prev_layer => input % p)
195-
type is(input3d_layer)
196-
call this_layer % forward(prev_layer % output)
197-
type is(linear2d_layer)
198-
call this_layer % forward(prev_layer % output)
199-
end select
193+
! type is(linear2d_layer)
194+
! select type(prev_layer => input % p)
195+
! type is(input3d_layer)
196+
! call this_layer % forward(prev_layer % output)
197+
! type is(linear2d_layer)
198+
! call this_layer % forward(prev_layer % output)
199+
! end select
200200

201201
end select
202202

@@ -311,8 +311,8 @@ elemental module function get_num_params(self) result(num_params)
311311
num_params = 0
312312
type is (reshape3d_layer)
313313
num_params = 0
314-
type is (linear2d_layer)
315-
num_params = this_layer % get_num_params()
314+
! type is (linear2d_layer)
315+
! num_params = this_layer % get_num_params()
316316
class default
317317
error stop 'Unknown layer type.'
318318
end select
@@ -338,8 +338,8 @@ module function get_params(self) result(params)
338338
! No parameters to get.
339339
type is (reshape3d_layer)
340340
! No parameters to get.
341-
type is (linear2d_layer)
342-
params = this_layer % get_params()
341+
! type is (linear2d_layer)
342+
! params = this_layer % get_params()
343343
class default
344344
error stop 'Unknown layer type.'
345345
end select
@@ -365,8 +365,8 @@ module function get_gradients(self) result(gradients)
365365
! No gradients to get.
366366
type is (reshape3d_layer)
367367
! No gradients to get.
368-
type is (linear2d_layer)
369-
gradients = this_layer % get_gradients()
368+
! type is (linear2d_layer)
369+
! gradients = this_layer % get_gradients()
370370
class default
371371
error stop 'Unknown layer type.'
372372
end select
@@ -427,8 +427,8 @@ module subroutine set_params(self, params)
427427
class default
428428
error stop 'Unknown layer type.'
429429

430-
type is (linear2d_layer)
431-
call this_layer % set_params(params)
430+
! type is (linear2d_layer)
431+
! call this_layer % set_params(params)
432432
end select
433433

434434
end subroutine set_params

src/nf/nf_network_submodule.f90

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,8 @@ module subroutine backward(self, output, loss)
148148
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
149149
type is(reshape3d_layer)
150150
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
151-
type is(linear2d_layer)
152-
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
151+
! type is(linear2d_layer)
152+
! call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
153153
end select
154154
end if
155155

0 commit comments

Comments
 (0)