@@ -61,24 +61,32 @@ pure module function get_num_params(self) result(num_params)
61
61
end function get_num_params
62
62
63
63
64
- pure module function get_params(self) result(params)
65
- class(dense_layer), intent (in ) :: self
64
+ module function get_params (self ) result(params)
65
+ class(dense_layer), intent (in ), target :: self
66
66
real , allocatable :: params(:)
67
67
68
+ real , pointer :: w_(:) = > null ()
69
+
70
+ w_(1 :size (self % weights)) = > self % weights
71
+
68
72
params = [ &
69
- pack (self % weights, .true. ) , &
73
+ w_ , &
70
74
self % biases &
71
75
]
72
76
73
77
end function get_params
74
78
75
79
76
- pure module function get_gradients(self) result(gradients)
77
- class(dense_layer), intent (in ) :: self
80
+ module function get_gradients (self ) result(gradients)
81
+ class(dense_layer), intent (in ), target :: self
78
82
real , allocatable :: gradients(:)
79
83
84
+ real , pointer :: dw_(:) = > null ()
85
+
86
+ dw_(1 :size (self % dw)) = > self % dw
87
+
80
88
gradients = [ &
81
- pack (self % dw, .true. ) , &
89
+ dw_ , &
82
90
self % db &
83
91
]
84
92
@@ -87,24 +95,23 @@ end function get_gradients
87
95
88
96
module subroutine set_params (self , params )
89
97
class(dense_layer), intent (in out ) :: self
90
- real , intent (in ) :: params(:)
98
+ real , intent (in ), target :: params(:)
99
+
100
+ real , pointer :: p_(:,:) = > null ()
91
101
92
102
! check if the number of parameters is correct
93
103
if (size (params) /= self % get_num_params()) then
94
104
error stop ' Error: number of parameters does not match'
95
105
end if
96
106
97
- ! reshape the weights
98
- self % weights = reshape ( &
99
- params(:self % input_size * self % output_size), &
100
- [self % input_size, self % output_size] &
101
- )
102
-
103
- ! reshape the biases
104
- self % biases = reshape ( &
105
- params(self % input_size * self % output_size + 1 :), &
106
- [self % output_size] &
107
- )
107
+ associate(n = > self % input_size * self % output_size)
108
+ ! reshape the weights
109
+ p_(1 :self % input_size, 1 :self % output_size) = > params(1 : n)
110
+ self % weights = p_
111
+
112
+ ! reshape the biases
113
+ self % biases = params(n + 1 : n + self % output_size)
114
+ end associate
108
115
109
116
end subroutine set_params
110
117
0 commit comments