@@ -166,23 +166,18 @@ export class Link {
166
166
accErrorDer = 0
167
167
/** Number of accumulated derivatives since the last update. */
168
168
numAccumulatedDers = 0
169
- regularization : RegularizationFunction
170
169
171
170
/**
172
171
* Constructs a link in the neural network initialized with random weight.
173
172
*
174
173
* @param source The source node.
175
174
* @param dest The destination node.
176
- * @param regularization The regularization function that computes the
177
- * penalty for this weight. If null, there will be no regularization.
178
175
* @param initZero
179
176
*/
180
- constructor ( source : Node , dest : Node ,
181
- regularization : RegularizationFunction , initZero ?: boolean ) {
177
+ constructor ( source : Node , dest : Node , initZero ?: boolean ) {
182
178
this . id = source . id + '-' + dest . id
183
179
this . source = source
184
180
this . dest = dest
185
- this . regularization = regularization
186
181
if ( initZero ) {
187
182
this . weight = 0
188
183
}
@@ -197,16 +192,12 @@ export class Link {
197
192
* 3 nodes in second hidden layer and 1 output node.
198
193
* @param activation The activation function of every hidden node.
199
194
* @param outputActivation The activation function for the output nodes.
200
- * @param regularization The regularization function that computes a penalty
201
- * for a given weight (parameter) in the network. If null, there will be
202
- * no regularization.
203
195
* @param inputIds List of ids for the input nodes.
204
196
* @param initZero
205
197
*/
206
198
export function buildNetwork (
207
199
networkShape : number [ ] , activation : ActivationFunction ,
208
200
outputActivation : ActivationFunction ,
209
- regularization : RegularizationFunction ,
210
201
inputIds : string [ ] , initZero ?: boolean ) : Node [ ] [ ] {
211
202
let numLayers = networkShape . length
212
203
let id = 1
@@ -232,7 +223,7 @@ export function buildNetwork (
232
223
// Add links from nodes in the previous layer to this node.
233
224
for ( let j = 0 ; j < network [ layerIdx - 1 ] . length ; j ++ ) {
234
225
let prevNode = network [ layerIdx - 1 ] [ j ]
235
- let link = new Link ( prevNode , node , regularization , initZero )
226
+ let link = new Link ( prevNode , node , initZero )
236
227
prevNode . outputs . push ( link )
237
228
node . inputLinks . push ( link )
238
229
}
@@ -330,12 +321,25 @@ export function backProp (network: Node[][], target: number,
330
321
}
331
322
}
332
323
324
+ type UpdateWeights = {
325
+ network : Node [ ] [ ] ,
326
+ learningRate : number ,
327
+ regularization : RegularizationFunction ,
328
+ regularizationRate : number ,
329
+ }
330
+
333
331
/**
334
332
* Updates the weights of the network using the previously accumulated error
335
333
* derivatives.
336
334
*/
337
- export function updateWeights ( network : Node [ ] [ ] , learningRate : number ,
338
- regularizationRate : number ) {
335
+ export function updateWeights (
336
+ {
337
+ network,
338
+ learningRate,
339
+ regularization,
340
+ regularizationRate,
341
+ } : UpdateWeights ,
342
+ ) {
339
343
for ( let layerIdx = 1 ; layerIdx < network . length ; layerIdx ++ ) {
340
344
let currentLayer = network [ layerIdx ]
341
345
for ( let i = 0 ; i < currentLayer . length ; i ++ ) {
@@ -352,16 +356,16 @@ export function updateWeights (network: Node[][], learningRate: number,
352
356
if ( link . isDead ) {
353
357
continue
354
358
}
355
- let regulDer = link . regularization ?
356
- link . regularization . der ( link . weight ) : 0
359
+ let regulDer = regularization ?
360
+ regularization . der ( link . weight ) : 0
357
361
if ( link . numAccumulatedDers > 0 ) {
358
362
// Update the weight based on dE/dw.
359
363
link . weight = link . weight -
360
364
( learningRate / link . numAccumulatedDers ) * link . accErrorDer
361
365
// Further update the weight based on regularization.
362
366
let newLinkWeight = link . weight -
363
367
( learningRate * regularizationRate ) * regulDer
364
- if ( link . regularization === RegularizationFunction . L1 &&
368
+ if ( regularization === RegularizationFunction . L1 &&
365
369
link . weight * newLinkWeight < 0 ) {
366
370
// The weight crossed 0 due to the regularization term. Set it to 0.
367
371
link . weight = 0
0 commit comments