17
17
import static org .junit .jupiter .api .Assertions .assertThrows ;
18
18
19
19
import org .junit .jupiter .api .Test ;
20
+ import org .tensorflow .Graph ;
20
21
import org .tensorflow .Operand ;
22
+ import org .tensorflow .Session ;
21
23
import org .tensorflow .framework .utils .TestSession ;
22
24
import org .tensorflow .ndarray .Shape ;
25
+ import org .tensorflow .ndarray .buffer .DataBuffers ;
23
26
import org .tensorflow .op .Ops ;
27
+ import org .tensorflow .op .core .Placeholder ;
24
28
import org .tensorflow .types .TFloat32 ;
25
29
import org .tensorflow .types .TInt32 ;
26
30
import org .tensorflow .types .TInt64 ;
@@ -36,16 +40,8 @@ public void testAllCorrectUnweighted() {
36
40
try (TestSession testSession = TestSession .createTestSession (tfMode )) {
37
41
Ops tf = testSession .getTF ();
38
42
39
- long [] trueArray = {
40
- 1L , 0L , 0L ,
41
- 0L , 1L , 0L ,
42
- 0L , 0L , 1L
43
- };
44
- float [] predArray = {
45
- 1.F , 0.F , 0.F ,
46
- 0.F , 1.F , 0.F ,
47
- 0.F , 0.F , 1.F
48
- };
43
+ long [] trueArray = {1L , 0L , 0L , 0L , 1L , 0L , 0L , 0L , 1L };
44
+ float [] predArray = {1.F , 0.F , 0.F , 0.F , 1.F , 0.F , 0.F , 0.F , 1.F };
49
45
Operand <TInt64 > yTrue = tf .reshape (tf .constant (trueArray ), tf .constant (Shape .of (3 , 3 )));
50
46
Operand <TFloat32 > yPred = tf .reshape (tf .constant (predArray ), tf .constant (Shape .of (3 , 3 )));
51
47
CategoricalCrossentropy instance = new CategoricalCrossentropy ();
@@ -55,11 +51,7 @@ public void testAllCorrectUnweighted() {
55
51
testSession .evaluate (expected , loss );
56
52
57
53
// Test with logits.
58
- float [] logitsArray = {
59
- 10.F , 0.F , 0.F ,
60
- 0.F , 10.F , 0.F ,
61
- 0.F , 0.F , 10.F
62
- };
54
+ float [] logitsArray = {10.F , 0.F , 0.F , 0.F , 10.F , 0.F , 0.F , 0.F , 10.F };
63
55
yTrue = tf .reshape (tf .constant (trueArray ), tf .constant (Shape .of (3 , 3 )));
64
56
Operand <TFloat32 > logits =
65
57
tf .reshape (tf .constant (logitsArray ), tf .constant (Shape .of (3 , 3 )));
@@ -85,11 +77,7 @@ public void testInvalidPredictionsRange() {
85
77
Ops tf = testSession .getTF ();
86
78
CategoricalCrossentropy instance = new CategoricalCrossentropy ();
87
79
88
- float [] trueArray = {
89
- 1L , 0L , 0L ,
90
- 0L , 1L , 0L ,
91
- 0L , 0L , 1L
92
- };
80
+ float [] trueArray = {1L , 0L , 0L , 0L , 1L , 0L , 0L , 0L , 1L };
93
81
float [] predArray = {-1.F , 0.F , 0.F , 0.F , 1.F , 0.F , 0.F , 0.F , 1.F };
94
82
Operand <TFloat32 > yTrue =
95
83
tf .reshape (tf .constant (trueArray ), tf .constant (Shape .of (3 , 3 )));
@@ -111,23 +99,15 @@ public void testUnweighted() {
111
99
CategoricalCrossentropy instance = new CategoricalCrossentropy ();
112
100
113
101
int [] trueArray = {1 , 0 , 0 , 0 , 1 , 0 , 0 , 0 , 1 };
114
- float [] predArray = {
115
- .9F , .05F , .05F ,
116
- .5F , .89F , .6F ,
117
- .05F , .01F , .94F
118
- };
102
+ float [] predArray = {.9F , .05F , .05F , .5F , .89F , .6F , .05F , .01F , .94F };
119
103
Operand <TInt32 > yTrue = tf .reshape (tf .constant (trueArray ), tf .constant (Shape .of (3 , 3 )));
120
104
Operand <TFloat32 > yPred = tf .reshape (tf .constant (predArray ), tf .constant (Shape .of (3 , 3 )));
121
105
Operand <TFloat32 > loss = instance .call (tf , yTrue , yPred );
122
106
float expected = 0.32396814F ;
123
107
testSession .evaluate (expected , loss );
124
108
125
109
// Test with logits.
126
- float [] logitsArray = {
127
- 8.F , 1.F , 1.F ,
128
- 0.F , 9.F , 1.F ,
129
- 2.F , 3.F , 5.F
130
- };
110
+ float [] logitsArray = {8.F , 1.F , 1.F , 0.F , 9.F , 1.F , 2.F , 3.F , 5.F };
131
111
Operand <TFloat32 > logits =
132
112
tf .reshape (tf .constant (logitsArray ), tf .constant (Shape .of (3 , 3 )));
133
113
instance = new CategoricalCrossentropy (true );
@@ -145,16 +125,8 @@ public void testScalarWeighted() {
145
125
try (TestSession testSession = TestSession .createTestSession (tfMode )) {
146
126
Ops tf = testSession .getTF ();
147
127
148
- int [] trueArray = {
149
- 1 , 0 , 0 ,
150
- 0 , 1 , 0 ,
151
- 0 , 0 , 1
152
- };
153
- float [] predArray = {
154
- .9F , .05F , .05F ,
155
- .5F , .89F , .6F ,
156
- .05F , .01F , .94F
157
- };
128
+ int [] trueArray = {1 , 0 , 0 , 0 , 1 , 0 , 0 , 0 , 1 };
129
+ float [] predArray = {.9F , .05F , .05F , .5F , .89F , .6F , .05F , .01F , .94F };
158
130
Operand <TInt32 > yTrue = tf .reshape (tf .constant (trueArray ), tf .constant (Shape .of (3 , 3 )));
159
131
Operand <TFloat32 > yPred = tf .reshape (tf .constant (predArray ), tf .constant (Shape .of (3 , 3 )));
160
132
Operand <TFloat32 > sampleWeight = tf .constant (2.3F );
@@ -166,11 +138,7 @@ public void testScalarWeighted() {
166
138
testSession .evaluate (expected , loss );
167
139
168
140
// Test with logits.
169
- float [] logitsArray = {
170
- 8.F , 1.F , 1.F ,
171
- 0.F , 9.F , 1.F ,
172
- 2.F , 3.F , 5.F
173
- };
141
+ float [] logitsArray = {8.F , 1.F , 1.F , 0.F , 9.F , 1.F , 2.F , 3.F , 5.F };
174
142
Operand <TFloat32 > logits =
175
143
tf .reshape (tf .constant (logitsArray ), tf .constant (Shape .of (3 , 3 )));
176
144
instance = new CategoricalCrossentropy (true );
@@ -189,16 +157,8 @@ public void testSsampleWeighted() {
189
157
CategoricalCrossentropy instance = new CategoricalCrossentropy ();
190
158
191
159
float [] sampeWeightArray = {1.2F , 3.4F , 5.6F };
192
- int [] trueArray = {
193
- 1 , 0 , 0 ,
194
- 0 , 1 , 0 ,
195
- 0 , 0 , 1
196
- };
197
- float [] predArray = {
198
- .9F , .05F , .05F ,
199
- .5F , .89F , .6F ,
200
- .05F , .01F , .94F
201
- };
160
+ int [] trueArray = {1 , 0 , 0 , 0 , 1 , 0 , 0 , 0 , 1 };
161
+ float [] predArray = {.9F , .05F , .05F , .5F , .89F , .6F , .05F , .01F , .94F };
202
162
Operand <TInt32 > yTrue = tf .reshape (tf .constant (trueArray ), tf .constant (Shape .of (3 , 3 )));
203
163
Operand <TFloat32 > yPred = tf .reshape (tf .constant (predArray ), tf .constant (Shape .of (3 , 3 )));
204
164
Operand <TFloat32 > sampleWeight =
@@ -208,11 +168,7 @@ public void testSsampleWeighted() {
208
168
testSession .evaluate (expected , loss );
209
169
210
170
// Test with logits.
211
- float [] logitsArray = {
212
- 8.F , 1.F , 1.F ,
213
- 0.F , 9.F , 1.F ,
214
- 2.F , 3.F , 5.F
215
- };
171
+ float [] logitsArray = {8.F , 1.F , 1.F , 0.F , 9.F , 1.F , 2.F , 3.F , 5.F };
216
172
Operand <TFloat32 > logits =
217
173
tf .reshape (tf .constant (logitsArray ), tf .constant (Shape .of (3 , 3 )));
218
174
instance = new CategoricalCrossentropy (true );
@@ -231,11 +187,7 @@ public void testNoReduction() {
231
187
232
188
// Test with logits.
233
189
int [] trueArray = {1 , 0 , 0 , 0 , 1 , 0 , 0 , 0 , 1 };
234
- float [] logitsArray = {
235
- 8.F , 1.F , 1.F ,
236
- 0.F , 9.F , 1.F ,
237
- 2.F , 3.F , 5.F
238
- };
190
+ float [] logitsArray = {8.F , 1.F , 1.F , 0.F , 9.F , 1.F , 2.F , 3.F , 5.F };
239
191
Operand <TInt32 > yTrue = tf .reshape (tf .constant (trueArray ), tf .constant (Shape .of (3 , 3 )));
240
192
Operand <TFloat32 > logits =
241
193
tf .reshape (tf .constant (logitsArray ), tf .constant (Shape .of (3 , 3 )));
@@ -266,4 +218,34 @@ public void testLabelSmoothing() {
266
218
testSession .evaluate (expected , loss );
267
219
}
268
220
}
221
+
222
+ @ Test
223
+ public void testCategoricalCrossEntopyWithDynamicBatchSize () {
224
+ try (Graph graph = new Graph ()) {
225
+ Ops tf = Ops .create (graph );
226
+ Operand yPred = tf .placeholder (TFloat32 .class , Placeholder .shape (Shape .of (-1 , 3 )));
227
+ Operand yTrue =
228
+ tf .reshape (tf .constant (new float [] {1f , 0f , 0f , 0f , 1f , 0f , 0f , 0f , 1f }), tf .array (3 , 3 ));
229
+ CategoricalCrossentropy instance = new CategoricalCrossentropy (true );
230
+ Operand loss =
231
+ instance .call (tf , yTrue , yPred ); // Throw TFInvalidArgument Exception without fix
232
+ try (Session session = new Session (graph );
233
+ TFloat32 result =
234
+ (TFloat32 )
235
+ session
236
+ .runner ()
237
+ .feed (
238
+ yPred ,
239
+ TFloat32 .tensorOf (
240
+ Shape .of (3 , 3 ),
241
+ DataBuffers .of (
242
+ new float [] {1.f , 0.f , 0.f , 0.f , 1.f , 0.f , 0.f , 0.f , 1.f })))
243
+ .fetch (loss )
244
+ .run ()
245
+ .get (0 )) {
246
+ if (Math .abs (0.5514477f - result .getFloat ()) > 0.01 )
247
+ throw new IllegalStateException ("Invalid result :" + result .getFloat ());
248
+ }
249
+ }
250
+ }
269
251
}
0 commit comments