Skip to content
This repository was archived by the owner on Nov 11, 2021. It is now read-only.

Commit 8627dec

Browse files
committed
feat(playground): remove regularization from link internal state + apply regularization from state, on the fly
inspired by this PR: tensorflow#139
1 parent 894f04c commit 8627dec

File tree

2 files changed

+27
-19
lines changed

2 files changed

+27
-19
lines changed

src/nn.ts

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -166,23 +166,18 @@ export class Link {
166166
accErrorDer = 0
167167
/** Number of accumulated derivatives since the last update. */
168168
numAccumulatedDers = 0
169-
regularization: RegularizationFunction
170169

171170
/**
172171
* Constructs a link in the neural network initialized with random weight.
173172
*
174173
* @param source The source node.
175174
* @param dest The destination node.
176-
* @param regularization The regularization function that computes the
177-
* penalty for this weight. If null, there will be no regularization.
178175
* @param initZero
179176
*/
180-
constructor (source: Node, dest: Node,
181-
regularization: RegularizationFunction, initZero?: boolean) {
177+
constructor (source: Node, dest: Node, initZero?: boolean) {
182178
this.id = source.id + '-' + dest.id
183179
this.source = source
184180
this.dest = dest
185-
this.regularization = regularization
186181
if (initZero) {
187182
this.weight = 0
188183
}
@@ -197,16 +192,12 @@ export class Link {
197192
* 3 nodes in second hidden layer and 1 output node.
198193
* @param activation The activation function of every hidden node.
199194
* @param outputActivation The activation function for the output nodes.
200-
* @param regularization The regularization function that computes a penalty
201-
* for a given weight (parameter) in the network. If null, there will be
202-
* no regularization.
203195
* @param inputIds List of ids for the input nodes.
204196
* @param initZero
205197
*/
206198
export function buildNetwork (
207199
networkShape: number[], activation: ActivationFunction,
208200
outputActivation: ActivationFunction,
209-
regularization: RegularizationFunction,
210201
inputIds: string[], initZero?: boolean): Node[][] {
211202
let numLayers = networkShape.length
212203
let id = 1
@@ -232,7 +223,7 @@ export function buildNetwork (
232223
// Add links from nodes in the previous layer to this node.
233224
for (let j = 0; j < network[layerIdx - 1].length; j++) {
234225
let prevNode = network[layerIdx - 1][j]
235-
let link = new Link (prevNode, node, regularization, initZero)
226+
let link = new Link (prevNode, node, initZero)
236227
prevNode.outputs.push (link)
237228
node.inputLinks.push (link)
238229
}
@@ -330,12 +321,25 @@ export function backProp (network: Node[][], target: number,
330321
}
331322
}
332323

324+
type UpdateWeights = {
325+
network: Node[][],
326+
learningRate: number,
327+
regularization: RegularizationFunction,
328+
regularizationRate: number,
329+
}
330+
333331
/**
334332
* Updates the weights of the network using the previously accumulated error
335333
* derivatives.
336334
*/
337-
export function updateWeights (network: Node[][], learningRate: number,
338-
regularizationRate: number) {
335+
export function updateWeights (
336+
{
337+
network,
338+
learningRate,
339+
regularization,
340+
regularizationRate,
341+
}: UpdateWeights,
342+
) {
339343
for (let layerIdx = 1; layerIdx < network.length; layerIdx++) {
340344
let currentLayer = network[layerIdx]
341345
for (let i = 0; i < currentLayer.length; i++) {
@@ -352,16 +356,16 @@ export function updateWeights (network: Node[][], learningRate: number,
352356
if (link.isDead) {
353357
continue
354358
}
355-
let regulDer = link.regularization ?
356-
link.regularization.der (link.weight) : 0
359+
let regulDer = regularization ?
360+
regularization.der (link.weight) : 0
357361
if (link.numAccumulatedDers > 0) {
358362
// Update the weight based on dE/dw.
359363
link.weight = link.weight -
360364
(learningRate / link.numAccumulatedDers) * link.accErrorDer
361365
// Further update the weight based on regularization.
362366
let newLinkWeight = link.weight -
363367
(learningRate * regularizationRate) * regulDer
364-
if (link.regularization === RegularizationFunction.L1 &&
368+
if (regularization === RegularizationFunction.L1 &&
365369
link.weight * newLinkWeight < 0) {
366370
// The weight crossed 0 due to the regularization term. Set it to 0.
367371
link.weight = 0

src/playground.ts

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -950,7 +950,12 @@ function oneStep (): void {
950950
nn.forwardProp (network, input)
951951
nn.backProp (network, point.label, nn.Errors.SQUARE)
952952
if ((i + 1) % state.batchSize === 0) {
953-
nn.updateWeights (network, state.learningRate, state.regularizationRate)
953+
nn.updateWeights ({
954+
network,
955+
learningRate: state.learningRate,
956+
regularization: state.regularization,
957+
regularizationRate: state.regularizationRate,
958+
})
954959
}
955960
})
956961
// Compute the loss.
@@ -992,8 +997,7 @@ function reset (onStartup = false) {
992997
let shape = [numInputs].concat (state.networkShape).concat ([1])
993998
let outputActivation = (state.problem === Problem.REGRESSION) ?
994999
nn.Activations.LINEAR : nn.Activations.TANH
995-
network = nn.buildNetwork (shape, state.activation, outputActivation,
996-
state.regularization, constructInputIds (), state.initZero)
1000+
network = nn.buildNetwork (shape, state.activation, outputActivation, constructInputIds (), state.initZero)
9971001
lossTrain = getLoss (network, trainData)
9981002
lossTest = getLoss (network, testData)
9991003
drawNetwork (network)

0 commit comments

Comments
 (0)