@@ -51,7 +51,7 @@ public class LossesHelper {
51
51
* @param tf the TensorFlow Ops
52
52
* @param predictions Predicted values, a <code>Operand</code> of arbitrary dimensions.
53
53
* @param labels Optional label <code>Operand</code> whose dimensions match <code>prediction
54
- * </code>.
54
+ * </code> .
55
55
* @param <T> the data type for the labels, predictions and result
56
56
* @return LossTuple of <code>prediction</code>, <code>label</code>,<code>sampleWeight</code> will
57
57
* be null. Each of them possibly has the last dimension squeezed, <code>sampleWeight</code>
@@ -77,7 +77,7 @@ public static <T extends TNumber> LossTuple<T> squeezeOrExpandDimensions(
77
77
* @param tf the TensorFlow Ops
78
78
* @param predictions Predicted values, a <code>Operand</code> of arbitrary dimensions.
79
79
* @param labels Optional label <code>Operand</code> whose dimensions match <code>prediction
80
- * </code>.
80
+ * </code> .
81
81
* @param sampleWeights Optional sample weight(s) <code>Operand</code> whose dimensions match
82
82
* <code>
83
83
* prediction</code>.
@@ -179,7 +179,7 @@ private static <T extends TNumber> Operand<T> maybeExpandWeights(
179
179
*
180
180
* @param tf the TensorFlowOps
181
181
* @param labels Label values, a <code>Tensor</code> whose dimensions match <code>predictions
182
- * </code>.
182
+ * </code> .
183
183
* @param predictions Predicted values, a <code>Tensor</code> of arbitrary dimensions.
184
184
* @param <T> the data type for the labels, predictions and result
185
185
* @return <code>labels</code> and <code>predictions</code>, possibly with last dim squeezed.
@@ -194,7 +194,7 @@ public static <T extends TNumber> LossTuple<T> removeSqueezableDimensions(
194
194
*
195
195
* @param tf the TensorFlowOps
196
196
* @param labels Label values, a <code>Operand</code> whose dimensions match <code>predictions
197
- * </code>.
197
+ * </code> .
198
198
* @param predictions Predicted values, a <code>Tensor</code> of arbitrary dimensions.
199
199
* @param expectedRankDiff Expected result of <code>rank(predictions) - rank(labels)</code>.
200
200
* @param <T> the data type for the labels, predictions and result
@@ -222,11 +222,13 @@ public static <T extends TNumber> LossTuple<T> removeSqueezableDimensions(
222
222
// Use dynamic rank.
223
223
224
224
// TODO: hold for lazy select feature,
225
- // Operand<TInt32> rankDiff = tf.math.sub(tf.rank(predictions), tf.rank(labels));
225
+ // Operand<TInt32> rankDiff = tf.math.sub(tf.rank(predictions),
226
+ // tf.rank(labels));
226
227
if (predictionsRank == Shape .UNKNOWN_SIZE && Shape .isCompatible (predictionsShape .size (-1 ), 1 )) {
227
228
/*
228
- * TODO, if we ever get a select that does lazy evaluation, but for now do the tf.squeeze
229
- * predictions = tf.select( tf.math.equal(tf.constant(expectedRankDiff+1),rankDiff ),
229
+ * TODO, if we ever get a select that does lazy evaluation, but for now do the
230
+ * tf.squeeze predictions = tf.select(
231
+ * tf.math.equal(tf.constant(expectedRankDiff+1),rankDiff ),
230
232
* tf.squeeze(predictions, Squeeze.axis(Arrays.asList(-1L))), predictions ); *
231
233
*/
232
234
predictions = tf .squeeze (predictions , Squeeze .axis (Collections .singletonList (-1L )));
@@ -282,11 +284,12 @@ private static <T extends TNumber> Operand<T> reduceWeightedLoss(
282
284
if (reduction == Reduction .NONE ) {
283
285
loss = weightedLoss ;
284
286
} else {
285
- loss =
286
- tf .reduceSum (weightedLoss , allAxes (tf , weightedLoss ), ReduceSum .keepDims (Boolean .FALSE ));
287
287
if (reduction == Reduction .AUTO || reduction == Reduction .SUM_OVER_BATCH_SIZE ) {
288
- loss = safeMean (tf , loss , weightedLoss .shape ().size ());
289
- }
288
+ loss = safeMean (tf , weightedLoss );
289
+ } else
290
+ loss =
291
+ tf .reduceSum (
292
+ weightedLoss , allAxes (tf , weightedLoss ), ReduceSum .keepDims (Boolean .FALSE ));
290
293
}
291
294
return loss ;
292
295
}
@@ -301,10 +304,10 @@ private static <T extends TNumber> Operand<T> reduceWeightedLoss(
301
304
* @return A scalar representing the mean of <code>losses</code>. If <code>numElements</code> is
302
305
* zero, then zero is returned.
303
306
*/
304
- public static <T extends TNumber > Operand <T > safeMean (
305
- Ops tf , Operand <T > losses , long numElements ) {
306
- Operand < T > totalLoss = tf .reduceSum (losses , allAxes (tf , losses ));
307
- return tf .math .divNoNan (totalLoss , cast (tf , tf .constant ( numElements ), losses .type ()));
307
+ public static <T extends TNumber > Operand <T > safeMean (Ops tf , Operand < T > losses ) {
308
+ Operand <T > totalLoss =
309
+ tf .reduceSum (losses , allAxes (tf , losses ), ReduceSum . keepDims ( Boolean . FALSE ));
310
+ return tf .math .divNoNan (totalLoss , cast (tf , tf .shape . size ( tf . shape ( losses ) ), losses .type ()));
308
311
}
309
312
310
313
/**
@@ -348,7 +351,8 @@ public static <T extends TNumber> Operand<T> rangeCheck(
348
351
tf .math .logicalAnd (
349
352
tf .reduceAll (tf .math .greaterEqual (values , minValue ), allDims ),
350
353
tf .reduceAll (tf .math .lessEqual (values , maxValue ), allDims ));
351
- // Graph and Eager mode need to be handled differently, control dependencies are not allowed in
354
+ // Graph and Eager mode need to be handled differently, control dependencies are
355
+ // not allowed in
352
356
// Eager mode
353
357
if (tf .scope ().env ().isGraph ()) {
354
358
AssertThat assertThat =
@@ -398,7 +402,8 @@ public static <T extends TNumber> Operand<T> valueCheck(
398
402
} else return values ;
399
403
} else { // use dynamic shape
400
404
Operand <TBool > cond = tf .math .equal (tf .shape .size (tf .shape (diff .out ())), tf .constant (0 ));
401
- // Graph and Eager mode need to be handled differently, control dependencies are not allowed
405
+ // Graph and Eager mode need to be handled differently, control dependencies are
406
+ // not allowed
402
407
// in Eager mode
403
408
if (tf .scope ().env ().isGraph ()) {
404
409
AssertThat assertThat =
0 commit comments