Skip to content
This repository was archived by the owner on Jul 15, 2025. It is now read-only.

Kotlin friendly names #1

Merged
merged 6 commits into from
May 18, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
*.iml
.idea
target
55 changes: 43 additions & 12 deletions ndarray/src/main/java/org/tensorflow/ndarray/Shape.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

package org.tensorflow.ndarray;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
* The shape of a Tensor or {@link NdArray}.
Expand Down Expand Up @@ -74,8 +76,8 @@ public static Shape scalar() {
* Shape scalar = Shape.of()
* }</pre>
*
* @param dimensionSizes number of elements in each dimension of this shape, if any, or
* {@link Shape#UNKNOWN_SIZE} if unknown.
* @param dimensionSizes number of elements in each dimension of this shape, if any, or {@link
* Shape#UNKNOWN_SIZE} if unknown.
* @return a new shape
*/
public static Shape of(long... dimensionSizes) {
Expand Down Expand Up @@ -108,13 +110,14 @@ public long size() {
* an unknown size, {@link Shape#UNKNOWN_SIZE} is returned.
*
* @param i the index of the dimension to get the size for. If this Shape has a known number of
* dimensions, it must be &lt; {@link Shape#numDimensions()}. The index may be negative, in which
* case the position is counted from the end of the shape. E.g.: {@code size(-1)} returns the
* size of the last dimension, {@code size(-2)} the size of the second to last dimension etc.
* dimensions, it must be &lt; {@link Shape#numDimensions()}. The index may be negative, in
* which case the position is counted from the end of the shape. E.g.: {@code size(-1)}
* returns the size of the last dimension, {@code size(-2)} the size of the second to last
* dimension etc.
* @return The size of the dimension with the given index if known, {@link Shape#UNKNOWN_SIZE}
* otherwise.
*/
public long size(int i) {
public long get(int i) {
if (dimensionSizes == null) {
return UNKNOWN_SIZE;
} else if (i >= 0) {
Expand Down Expand Up @@ -177,6 +180,24 @@ public long[] asArray() {
}
}

/**
* Returns a defensive copy of the this Shape's axes. Changes to the returned list do not change
* this Shape's state. Returns null if {@link Shape#isUnknown()} is true.
*/
public List<Long> toListOrNull() {
long[] array = asArray();
if (array == null) {
return null;
}

List<Long> list = new ArrayList<>(array.length);
for (long l : array) {
list.add(l);
}

return list;
}

@Override
public int hashCode() {
return dimensionSizes != null ? Arrays.hashCode(dimensionSizes) : super.hashCode();
Expand All @@ -186,6 +207,7 @@ public int hashCode() {
* Equals implementation for Shapes. Two Shapes are considered equal iff:
*
* <p>
*
* <ul>
* <li>the number of dimensions is defined and equal for both
* <li>the size of each dimension is defined and equal for both
Expand Down Expand Up @@ -236,7 +258,8 @@ public Shape head() {
* Returns an n-dimensional Shape with the dimensions matching the first n dimensions of this
* shape
*
* @param n the number of leading dimensions to get, must be &lt;= than {@link Shape#numDimensions()}
* @param n the number of leading dimensions to get, must be &lt;= than {@link
* Shape#numDimensions()}
* @return an n-dimensional Shape with the first n dimensions matching the first n dimensions of
* this Shape
*/
Expand All @@ -252,7 +275,9 @@ public Shape take(int n) {

/** Returns a new Shape, with this Shape's first dimension removed. */
public Shape tail() {
if (dimensionSizes.length < 2) return Shape.of();
if (dimensionSizes.length < 2) {
return Shape.of();
}
return Shape.of(Arrays.copyOfRange(dimensionSizes, 1, dimensionSizes.length));
}

Expand All @@ -276,15 +301,21 @@ public Shape takeLast(int n) {
}

/**
* Return a {@code end - begin} dimensional shape with dimensions matching this Shape from {@code begin} to {@code end}.
* Return a {@code end - begin} dimensional shape with dimensions matching this Shape from {@code
* begin} to {@code end}.
*
* @param begin Where to start the sub-shape.
* @param end Where to end the sub-shape, exclusive.
* @return the sub-shape bounded by begin and end.
*/
public Shape subShape(int begin, int end){
public Shape subShape(int begin, int end) {
if (end > numDimensions()) {
throw new ArrayIndexOutOfBoundsException(
"End index " + end + " out of bounds: shape only has " + numDimensions() + " dimensions.");
"End index "
+ end
+ " out of bounds: shape only has "
+ numDimensions()
+ " dimensions.");
}
if (begin < 0) {
throw new ArrayIndexOutOfBoundsException(
Expand Down Expand Up @@ -423,7 +454,7 @@ public boolean isCompatibleWith(Shape shape) {
return false;
}
for (int i = 0; i < numDimensions(); i++) {
if (!isCompatible(size(i), shape.size(i))) {
if (!isCompatible(get(i), shape.get(i))) {
return false;
}
}
Expand Down
4 changes: 2 additions & 2 deletions ndarray/src/main/java/org/tensorflow/ndarray/StdArrays.java
Original file line number Diff line number Diff line change
Expand Up @@ -3798,9 +3798,9 @@ private static int[] computeArrayDims(NdArray<?> ndArray, int expectedRank) {
}
int[] arrayShape = new int[expectedRank];
for (int i = 0; i < expectedRank; ++i) {
long dimSize = shape.size(i);
long dimSize = shape.get(i);
if (dimSize > Integer.MAX_VALUE) {
throw new IllegalArgumentException("Dimension " + i + " is too large to fit in a standard array (" + shape.size(i) + ")");
throw new IllegalArgumentException("Dimension " + i + " is too large to fit in a standard array (" + shape.get(i) + ")");
}
arrayShape[i] = (int)dimSize;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public static DimensionalSpace create(Shape shape) {

// Start from the last dimension, where all elements are continuous
for (int i = dimensions.length - 1, elementSize = 1; i >= 0; --i) {
dimensions[i] = new Axis(shape.size(i), elementSize);
dimensions[i] = new Axis(shape.get(i), elementSize);
elementSize *= dimensions[i].numElements();
}
return new DimensionalSpace(dimensions, shape);
Expand Down Expand Up @@ -189,7 +189,9 @@ public long positionOf(long[] coords) {
return position;
}

/** Succinct description of the shape meant for debugging. */
/**
* Succinct description of the shape meant for debugging.
*/
@Override
public String toString() {
return Arrays.toString(dimensions);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
import static org.tensorflow.ndarray.index.Indices.at;
import static org.tensorflow.ndarray.index.Indices.even;
import static org.tensorflow.ndarray.index.Indices.flip;
import static org.tensorflow.ndarray.index.Indices.sliceFrom;
import static org.tensorflow.ndarray.index.Indices.odd;
import static org.tensorflow.ndarray.index.Indices.range;
import static org.tensorflow.ndarray.index.Indices.seq;
import static org.tensorflow.ndarray.index.Indices.sliceFrom;
import static org.tensorflow.ndarray.index.Indices.sliceTo;

import java.nio.BufferOverflowException;
Expand Down Expand Up @@ -132,15 +132,15 @@ public void iterateElements() {
long value = 0L;
for (NdArray<T> matrix : matrix3d.elements(0)) {
assertEquals(2L, matrix.shape().numDimensions());
assertEquals(4L, matrix.shape().size(0));
assertEquals(5L, matrix.shape().size(1));
assertEquals(4L, matrix.shape().get(0));
assertEquals(5L, matrix.shape().get(1));

for (NdArray<T> vector : matrix.elements(0)) {
assertEquals(1L, vector.shape().numDimensions()) ;
assertEquals(5L, vector.shape().size(0));
assertEquals(1L, vector.shape().numDimensions());
assertEquals(5L, vector.shape().get(0));

for (NdArray<T> scalar : vector.scalars()) {
assertEquals(0L, scalar.shape().numDimensions()) ;
assertEquals(0L, scalar.shape().numDimensions());
scalar.setObject(valueOf(value++));
try {
scalar.elements(0);
Expand All @@ -162,7 +162,7 @@ public void iterateElements() {
@Test
public void slices() {
NdArray<T> matrix3d = allocate(Shape.of(5, 4, 5));

T val100 = valueOf(100L);
matrix3d.setObject(val100, 1, 0, 0);
T val101 = valueOf(101L);
Expand Down Expand Up @@ -318,8 +318,8 @@ public void equalsAndHashCode() {
NdArray<T> array4 = allocate(Shape.of(1, 2, 2));

@SuppressWarnings("unchecked")
T[][][] values = (T[][][])(new Object[][][] {
{ { valueOf(0L), valueOf(1L) }, { valueOf(2L), valueOf(0L) } }
T[][][] values = (T[][][]) (new Object[][][]{
{{valueOf(0L), valueOf(1L)}, {valueOf(2L), valueOf(0L)}}
});

StdArrays.copyTo(values[0], array1);
Expand Down
30 changes: 18 additions & 12 deletions ndarray/src/test/java/org/tensorflow/ndarray/ShapeTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,38 @@
*/
package org.tensorflow.ndarray;

import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;

import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.Test;

public class ShapeTest {

@Test
public void allKnownDimensions() {
Shape shape = Shape.of(5, 4, 5);
assertEquals(3, shape.numDimensions());
assertEquals(5, shape.size(0));
assertEquals(4, shape.size(1));
assertEquals(5, shape.size(2));
assertEquals(5, shape.get(0));
assertEquals(4, shape.get(1));
assertEquals(5, shape.get(2));
assertEquals(100, shape.size());
assertArrayEquals(new long[] {5, 4, 5}, shape.asArray());
assertArrayEquals(new long[]{5, 4, 5}, shape.asArray());
try {
shape.size(3);
shape.get(3);
fail();
} catch (IndexOutOfBoundsException e) {
// as expected
}
assertEquals(5, shape.size(-1));
assertEquals(4, shape.size(-2));
assertEquals(5, shape.size(-3));
assertEquals(5, shape.get(-1));
assertEquals(4, shape.get(-2));
assertEquals(5, shape.get(-3));
try {
shape.size(-4);
shape.get(-4);
fail();
} catch (IndexOutOfBoundsException e) {
// as expected
Expand Down Expand Up @@ -133,7 +139,7 @@ public void testShapeModification() {
long[] internalShape = one.asArray();
assertNotNull(internalShape);
internalShape[0] = 42L;
assertEquals(2L, one.size(0));
assertEquals(2L, one.get(0));
}

@Test
Expand Down