Skip to content

Commit 3f89f60

Browse files
authored
Interface should be public for external usage (#522)
* Interface should be public for external usage * Fix #523 * Fix google format
1 parent de1d6f0 commit 3f89f60

File tree

2 files changed

+23
-18
lines changed

2 files changed

+23
-18
lines changed

tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ public class LossesHelper {
5151
* @param tf the TensorFlow Ops
5252
* @param predictions Predicted values, a <code>Operand</code> of arbitrary dimensions.
5353
* @param labels Optional label <code>Operand</code> whose dimensions match <code>prediction
54-
* </code>.
54+
* </code> .
5555
* @param <T> the data type for the labels, predictions and result
5656
* @return LossTuple of <code>prediction</code>, <code>label</code>,<code>sampleWeight</code> will
5757
* be null. Each of them possibly has the last dimension squeezed, <code>sampleWeight</code>
@@ -77,7 +77,7 @@ public static <T extends TNumber> LossTuple<T> squeezeOrExpandDimensions(
7777
* @param tf the TensorFlow Ops
7878
* @param predictions Predicted values, a <code>Operand</code> of arbitrary dimensions.
7979
* @param labels Optional label <code>Operand</code> whose dimensions match <code>prediction
80-
* </code>.
80+
* </code> .
8181
* @param sampleWeights Optional sample weight(s) <code>Operand</code> whose dimensions match
8282
* <code>
8383
* prediction</code>.
@@ -179,7 +179,7 @@ private static <T extends TNumber> Operand<T> maybeExpandWeights(
179179
*
180180
* @param tf the TensorFlowOps
181181
* @param labels Label values, a <code>Tensor</code> whose dimensions match <code>predictions
182-
* </code>.
182+
* </code> .
183183
* @param predictions Predicted values, a <code>Tensor</code> of arbitrary dimensions.
184184
* @param <T> the data type for the labels, predictions and result
185185
* @return <code>labels</code> and <code>predictions</code>, possibly with last dim squeezed.
@@ -194,7 +194,7 @@ public static <T extends TNumber> LossTuple<T> removeSqueezableDimensions(
194194
*
195195
* @param tf the TensorFlowOps
196196
* @param labels Label values, a <code>Operand</code> whose dimensions match <code>predictions
197-
* </code>.
197+
* </code> .
198198
* @param predictions Predicted values, a <code>Tensor</code> of arbitrary dimensions.
199199
* @param expectedRankDiff Expected result of <code>rank(predictions) - rank(labels)</code>.
200200
* @param <T> the data type for the labels, predictions and result
@@ -222,11 +222,13 @@ public static <T extends TNumber> LossTuple<T> removeSqueezableDimensions(
222222
// Use dynamic rank.
223223

224224
// TODO: hold for lazy select feature,
225-
// Operand<TInt32> rankDiff = tf.math.sub(tf.rank(predictions), tf.rank(labels));
225+
// Operand<TInt32> rankDiff = tf.math.sub(tf.rank(predictions),
226+
// tf.rank(labels));
226227
if (predictionsRank == Shape.UNKNOWN_SIZE && Shape.isCompatible(predictionsShape.size(-1), 1)) {
227228
/*
228-
* TODO, if we ever get a select that does lazy evaluation, but for now do the tf.squeeze
229-
* predictions = tf.select( tf.math.equal(tf.constant(expectedRankDiff+1),rankDiff ),
229+
* TODO, if we ever get a select that does lazy evaluation, but for now do the
230+
* tf.squeeze predictions = tf.select(
231+
* tf.math.equal(tf.constant(expectedRankDiff+1),rankDiff ),
230232
* tf.squeeze(predictions, Squeeze.axis(Arrays.asList(-1L))), predictions ); *
231233
*/
232234
predictions = tf.squeeze(predictions, Squeeze.axis(Collections.singletonList(-1L)));
@@ -282,11 +284,12 @@ private static <T extends TNumber> Operand<T> reduceWeightedLoss(
282284
if (reduction == Reduction.NONE) {
283285
loss = weightedLoss;
284286
} else {
285-
loss =
286-
tf.reduceSum(weightedLoss, allAxes(tf, weightedLoss), ReduceSum.keepDims(Boolean.FALSE));
287287
if (reduction == Reduction.AUTO || reduction == Reduction.SUM_OVER_BATCH_SIZE) {
288-
loss = safeMean(tf, loss, weightedLoss.shape().size());
289-
}
288+
loss = safeMean(tf, weightedLoss);
289+
} else
290+
loss =
291+
tf.reduceSum(
292+
weightedLoss, allAxes(tf, weightedLoss), ReduceSum.keepDims(Boolean.FALSE));
290293
}
291294
return loss;
292295
}
@@ -301,10 +304,10 @@ private static <T extends TNumber> Operand<T> reduceWeightedLoss(
301304
* @return A scalar representing the mean of <code>losses</code>. If <code>numElements</code> is
302305
* zero, then zero is returned.
303306
*/
304-
public static <T extends TNumber> Operand<T> safeMean(
305-
Ops tf, Operand<T> losses, long numElements) {
306-
Operand<T> totalLoss = tf.reduceSum(losses, allAxes(tf, losses));
307-
return tf.math.divNoNan(totalLoss, cast(tf, tf.constant(numElements), losses.type()));
307+
public static <T extends TNumber> Operand<T> safeMean(Ops tf, Operand<T> losses) {
308+
Operand<T> totalLoss =
309+
tf.reduceSum(losses, allAxes(tf, losses), ReduceSum.keepDims(Boolean.FALSE));
310+
return tf.math.divNoNan(totalLoss, cast(tf, tf.shape.size(tf.shape(losses)), losses.type()));
308311
}
309312

310313
/**
@@ -348,7 +351,8 @@ public static <T extends TNumber> Operand<T> rangeCheck(
348351
tf.math.logicalAnd(
349352
tf.reduceAll(tf.math.greaterEqual(values, minValue), allDims),
350353
tf.reduceAll(tf.math.lessEqual(values, maxValue), allDims));
351-
// Graph and Eager mode need to be handled differently, control dependencies are not allowed in
354+
// Graph and Eager mode need to be handled differently, control dependencies are
355+
// not allowed in
352356
// Eager mode
353357
if (tf.scope().env().isGraph()) {
354358
AssertThat assertThat =
@@ -398,7 +402,8 @@ public static <T extends TNumber> Operand<T> valueCheck(
398402
} else return values;
399403
} else { // use dynamic shape
400404
Operand<TBool> cond = tf.math.equal(tf.shape.size(tf.shape(diff.out())), tf.constant(0));
401-
// Graph and Eager mode need to be handled differently, control dependencies are not allowed
405+
// Graph and Eager mode need to be handled differently, control dependencies are
406+
// not allowed
402407
// in Eager mode
403408
if (tf.scope().env().isGraph()) {
404409
AssertThat assertThat =

tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import org.tensorflow.types.family.TNumber;
2323

2424
/** Interface for metrics */
25-
interface Metric {
25+
public interface Metric {
2626

2727
/**
2828
* Creates a List of Operations to update the metric state based on input values.

0 commit comments

Comments
 (0)