Skip to content

Commit 2b6d83f

Browse files
authored
fix #526 (#527)
* Interface should be public for external usage * Fix #523 * Fix google format * fix #526 * Add test to CategoricalCrossentropyTest.java
1 parent 3f89f60 commit 2b6d83f

File tree

2 files changed

+54
-75
lines changed

2 files changed

+54
-75
lines changed

tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java

+7-10
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,12 @@ public class SoftmaxCrossEntropyWithLogits {
4242
* <p>Usage:
4343
*
4444
* <pre>
45-
* Operand&lt;TFloat32&gt; logits =
46-
* tf.constant(new float[][] {{4.0F, 2.0F, 1.0F}, {0.0F, 5.0F, 1.0F}} );
47-
* Operand&lt;TFloat32&gt; labels =
48-
* tf.constant(new float[][] {{1.0F, 0.0F, 0.0F}, {0.0F, 0.8F, 0.2F}} );
49-
* Operand&lt;TFloat32&gt; output =
50-
* tf.nn.softmaxCrossEntropyWithLogits(labels, logits, -1);
51-
* // output Shape = [2]
52-
* // dataType = FLOAT (1)
53-
* // values { 0.169846, 0.824745 }
45+
* Operand&lt;TFloat32&gt; logits = tf.constant(new float[][] { { 4.0F, 2.0F, 1.0F }, { 0.0F, 5.0F, 1.0F } });
46+
* Operand&lt;TFloat32&gt; labels = tf.constant(new float[][] { { 1.0F, 0.0F, 0.0F }, { 0.0F, 0.8F, 0.2F } });
47+
* Operand&lt;TFloat32&gt; output = tf.nn.softmaxCrossEntropyWithLogits(labels, logits, -1);
48+
* // output Shape = [2]
49+
* // dataType = FLOAT (1)
50+
* // values { 0.169846, 0.824745 }
5451
* </pre>
5552
*
5653
* <p>Backpropagation will happen into both <code>logits</code> and <code>labels</code>. To
@@ -157,7 +154,7 @@ public static <T extends TNumber, U extends TNumber> Operand<T> softmaxCrossEntr
157154
* @return the flattened logits
158155
*/
159156
private static <T extends TNumber> Operand<T> flattenOuterDims(Scope scope, Operand<T> logits) {
160-
Operand<TInt64> one = Constant.scalarOf(scope, 1L);
157+
Operand<TInt64> one = Constant.arrayOf(scope, 1L);
161158

162159
Shape shape = logits.shape();
163160
int ndims = shape.numDimensions();

tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java

+47-65
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,14 @@
1717
import static org.junit.jupiter.api.Assertions.assertThrows;
1818

1919
import org.junit.jupiter.api.Test;
20+
import org.tensorflow.Graph;
2021
import org.tensorflow.Operand;
22+
import org.tensorflow.Session;
2123
import org.tensorflow.framework.utils.TestSession;
2224
import org.tensorflow.ndarray.Shape;
25+
import org.tensorflow.ndarray.buffer.DataBuffers;
2326
import org.tensorflow.op.Ops;
27+
import org.tensorflow.op.core.Placeholder;
2428
import org.tensorflow.types.TFloat32;
2529
import org.tensorflow.types.TInt32;
2630
import org.tensorflow.types.TInt64;
@@ -36,16 +40,8 @@ public void testAllCorrectUnweighted() {
3640
try (TestSession testSession = TestSession.createTestSession(tfMode)) {
3741
Ops tf = testSession.getTF();
3842

39-
long[] trueArray = {
40-
1L, 0L, 0L,
41-
0L, 1L, 0L,
42-
0L, 0L, 1L
43-
};
44-
float[] predArray = {
45-
1.F, 0.F, 0.F,
46-
0.F, 1.F, 0.F,
47-
0.F, 0.F, 1.F
48-
};
43+
long[] trueArray = {1L, 0L, 0L, 0L, 1L, 0L, 0L, 0L, 1L};
44+
float[] predArray = {1.F, 0.F, 0.F, 0.F, 1.F, 0.F, 0.F, 0.F, 1.F};
4945
Operand<TInt64> yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3)));
5046
Operand<TFloat32> yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3)));
5147
CategoricalCrossentropy instance = new CategoricalCrossentropy();
@@ -55,11 +51,7 @@ public void testAllCorrectUnweighted() {
5551
testSession.evaluate(expected, loss);
5652

5753
// Test with logits.
58-
float[] logitsArray = {
59-
10.F, 0.F, 0.F,
60-
0.F, 10.F, 0.F,
61-
0.F, 0.F, 10.F
62-
};
54+
float[] logitsArray = {10.F, 0.F, 0.F, 0.F, 10.F, 0.F, 0.F, 0.F, 10.F};
6355
yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3)));
6456
Operand<TFloat32> logits =
6557
tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3)));
@@ -85,11 +77,7 @@ public void testInvalidPredictionsRange() {
8577
Ops tf = testSession.getTF();
8678
CategoricalCrossentropy instance = new CategoricalCrossentropy();
8779

88-
float[] trueArray = {
89-
1L, 0L, 0L,
90-
0L, 1L, 0L,
91-
0L, 0L, 1L
92-
};
80+
float[] trueArray = {1L, 0L, 0L, 0L, 1L, 0L, 0L, 0L, 1L};
9381
float[] predArray = {-1.F, 0.F, 0.F, 0.F, 1.F, 0.F, 0.F, 0.F, 1.F};
9482
Operand<TFloat32> yTrue =
9583
tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3)));
@@ -111,23 +99,15 @@ public void testUnweighted() {
11199
CategoricalCrossentropy instance = new CategoricalCrossentropy();
112100

113101
int[] trueArray = {1, 0, 0, 0, 1, 0, 0, 0, 1};
114-
float[] predArray = {
115-
.9F, .05F, .05F,
116-
.5F, .89F, .6F,
117-
.05F, .01F, .94F
118-
};
102+
float[] predArray = {.9F, .05F, .05F, .5F, .89F, .6F, .05F, .01F, .94F};
119103
Operand<TInt32> yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3)));
120104
Operand<TFloat32> yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3)));
121105
Operand<TFloat32> loss = instance.call(tf, yTrue, yPred);
122106
float expected = 0.32396814F;
123107
testSession.evaluate(expected, loss);
124108

125109
// Test with logits.
126-
float[] logitsArray = {
127-
8.F, 1.F, 1.F,
128-
0.F, 9.F, 1.F,
129-
2.F, 3.F, 5.F
130-
};
110+
float[] logitsArray = {8.F, 1.F, 1.F, 0.F, 9.F, 1.F, 2.F, 3.F, 5.F};
131111
Operand<TFloat32> logits =
132112
tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3)));
133113
instance = new CategoricalCrossentropy(true);
@@ -145,16 +125,8 @@ public void testScalarWeighted() {
145125
try (TestSession testSession = TestSession.createTestSession(tfMode)) {
146126
Ops tf = testSession.getTF();
147127

148-
int[] trueArray = {
149-
1, 0, 0,
150-
0, 1, 0,
151-
0, 0, 1
152-
};
153-
float[] predArray = {
154-
.9F, .05F, .05F,
155-
.5F, .89F, .6F,
156-
.05F, .01F, .94F
157-
};
128+
int[] trueArray = {1, 0, 0, 0, 1, 0, 0, 0, 1};
129+
float[] predArray = {.9F, .05F, .05F, .5F, .89F, .6F, .05F, .01F, .94F};
158130
Operand<TInt32> yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3)));
159131
Operand<TFloat32> yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3)));
160132
Operand<TFloat32> sampleWeight = tf.constant(2.3F);
@@ -166,11 +138,7 @@ public void testScalarWeighted() {
166138
testSession.evaluate(expected, loss);
167139

168140
// Test with logits.
169-
float[] logitsArray = {
170-
8.F, 1.F, 1.F,
171-
0.F, 9.F, 1.F,
172-
2.F, 3.F, 5.F
173-
};
141+
float[] logitsArray = {8.F, 1.F, 1.F, 0.F, 9.F, 1.F, 2.F, 3.F, 5.F};
174142
Operand<TFloat32> logits =
175143
tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3)));
176144
instance = new CategoricalCrossentropy(true);
@@ -189,16 +157,8 @@ public void testSsampleWeighted() {
189157
CategoricalCrossentropy instance = new CategoricalCrossentropy();
190158

191159
float[] sampeWeightArray = {1.2F, 3.4F, 5.6F};
192-
int[] trueArray = {
193-
1, 0, 0,
194-
0, 1, 0,
195-
0, 0, 1
196-
};
197-
float[] predArray = {
198-
.9F, .05F, .05F,
199-
.5F, .89F, .6F,
200-
.05F, .01F, .94F
201-
};
160+
int[] trueArray = {1, 0, 0, 0, 1, 0, 0, 0, 1};
161+
float[] predArray = {.9F, .05F, .05F, .5F, .89F, .6F, .05F, .01F, .94F};
202162
Operand<TInt32> yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3)));
203163
Operand<TFloat32> yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3)));
204164
Operand<TFloat32> sampleWeight =
@@ -208,11 +168,7 @@ public void testSsampleWeighted() {
208168
testSession.evaluate(expected, loss);
209169

210170
// Test with logits.
211-
float[] logitsArray = {
212-
8.F, 1.F, 1.F,
213-
0.F, 9.F, 1.F,
214-
2.F, 3.F, 5.F
215-
};
171+
float[] logitsArray = {8.F, 1.F, 1.F, 0.F, 9.F, 1.F, 2.F, 3.F, 5.F};
216172
Operand<TFloat32> logits =
217173
tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3)));
218174
instance = new CategoricalCrossentropy(true);
@@ -231,11 +187,7 @@ public void testNoReduction() {
231187

232188
// Test with logits.
233189
int[] trueArray = {1, 0, 0, 0, 1, 0, 0, 0, 1};
234-
float[] logitsArray = {
235-
8.F, 1.F, 1.F,
236-
0.F, 9.F, 1.F,
237-
2.F, 3.F, 5.F
238-
};
190+
float[] logitsArray = {8.F, 1.F, 1.F, 0.F, 9.F, 1.F, 2.F, 3.F, 5.F};
239191
Operand<TInt32> yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3)));
240192
Operand<TFloat32> logits =
241193
tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3)));
@@ -266,4 +218,34 @@ public void testLabelSmoothing() {
266218
testSession.evaluate(expected, loss);
267219
}
268220
}
221+
222+
@Test
223+
public void testCategoricalCrossEntopyWithDynamicBatchSize() {
224+
try (Graph graph = new Graph()) {
225+
Ops tf = Ops.create(graph);
226+
Operand yPred = tf.placeholder(TFloat32.class, Placeholder.shape(Shape.of(-1, 3)));
227+
Operand yTrue =
228+
tf.reshape(tf.constant(new float[] {1f, 0f, 0f, 0f, 1f, 0f, 0f, 0f, 1f}), tf.array(3, 3));
229+
CategoricalCrossentropy instance = new CategoricalCrossentropy(true);
230+
Operand loss =
231+
instance.call(tf, yTrue, yPred); // Throw TFInvalidArgument Exception without fix
232+
try (Session session = new Session(graph);
233+
TFloat32 result =
234+
(TFloat32)
235+
session
236+
.runner()
237+
.feed(
238+
yPred,
239+
TFloat32.tensorOf(
240+
Shape.of(3, 3),
241+
DataBuffers.of(
242+
new float[] {1.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 1.f})))
243+
.fetch(loss)
244+
.run()
245+
.get(0)) {
246+
if (Math.abs(0.5514477f - result.getFloat()) > 0.01)
247+
throw new IllegalStateException("Invalid result :" + result.getFloat());
248+
}
249+
}
250+
}
269251
}

0 commit comments

Comments
 (0)