Skip to content

Commit d518678

Browse files
authored
Add sparse tensor mappings (#405)
1 parent 1f69192 commit d518678

26 files changed

+1535
-136
lines changed

tensorflow-core/tensorflow-core-api/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
<javacpp.parser.skip>${native.build.skip}</javacpp.parser.skip>
2121
<javacpp.compiler.skip>${native.build.skip}</javacpp.compiler.skip>
2222
<java.module.name>org.tensorflow.core.api</java.module.name>
23-
<ndarray.version>0.3.3</ndarray.version>
23+
<ndarray.version>0.4.0-SNAPSHOT</ndarray.version>
2424
<truth.version>1.0.1</truth.version>
2525
</properties>
2626

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,12 @@
3434
* A tensor which memory has not been mapped to a data space directly accessible from the JVM.
3535
*
3636
* <p>A raw tensor is a minimalist representation of a tensor allocated in native memory by the
37-
* TensorFlow runtime library and it controls its lifetime within the current process. The data
38-
* is represented by a flat {@link ByteDataBuffer buffer of bytes}, until it is mapped in a
39-
* n-dimensional typed space by a {@link TType typed tensor}.</p>
37+
* TensorFlow runtime library and it controls its lifetime within the current process. The data is
38+
* represented by a flat {@link ByteDataBuffer buffer of bytes}, until it is mapped in a
39+
* n-dimensional typed space by a {@link TType typed tensor}.
4040
*
41-
* <p>Instances of a RawTensor are <b>not</b> thread-safe and their resource must be released
42-
* by calling {@link #close()} explicitly or implicitly via try-with-resources.</p>
41+
* <p>Instances of a RawTensor are <b>not</b> thread-safe and their resource must be released by
42+
* calling {@link #close()} explicitly or implicitly via try-with-resources.
4343
*/
4444
public final class RawTensor implements Tensor {
4545

@@ -81,9 +81,7 @@ public ByteDataBuffer data() {
8181
return buffer;
8282
}
8383

84-
/**
85-
* Returns a string describing the type and shape of the tensor.
86-
*/
84+
/** Returns a string describing the type and shape of the tensor. */
8785
@Override
8886
public String toString() {
8987
return String.format("%s tensor with shape %s", typeInfo.dataType(), shape);
@@ -92,20 +90,20 @@ public String toString() {
9290
/**
9391
* Allocates a new tensor in native memory of the given type, shape and size.
9492
*
95-
* <p>The size of the tensor must be at least large enough to contain all scalars for the
96-
* given type and shape. More memory can also be allocated to store also metadata within the
97-
* tensor itself, e.g. a lookup table in a string tensor.
93+
* <p>The size of the tensor must be at least large enough to contain all scalars for the given
94+
* type and shape. More memory can also be allocated to store also metadata within the tensor
95+
* itself, e.g. a lookup table in a string tensor.
9896
*
9997
* @param type tensor type class
10098
* @param shape shape of the tensor
10199
* @param size size in bytes of the tensor, or -1 to compute the size from the shape
102100
* @return allocated tensor
103101
* @throws IllegalArgumentException if {@code size} is smaller than the minimum space required to
104-
* store the tensor data
105-
* @throws IllegalArgumentException if {@code size} is set to -1 but elements of the given
106-
* {@code type} are of variable length (e.g. strings)
107-
* @throws IllegalArgumentException if {@code shape} is totally or partially
108-
* {@link Shape#hasUnknownDimension() unknown}
102+
* store the tensor data
103+
* @throws IllegalArgumentException if {@code size} is set to -1 but elements of the given {@code
104+
* type} are of variable length (e.g. strings)
105+
* @throws IllegalArgumentException if {@code shape} is totally or partially {@link
106+
* Shape#hasUnknownDimension() unknown}
109107
* @throws IllegalStateException if tensor failed to be allocated
110108
*/
111109
static RawTensor allocate(Class<? extends TType> type, Shape shape, long size) {
@@ -123,12 +121,14 @@ static RawTensor allocate(Class<? extends TType> type, Shape shape, long size) {
123121
allocatedSize = shape.size() * typeInfo.byteSize();
124122

125123
} else if (!typeInfo.isVariableLength() && shape.size() * typeInfo.byteSize() > allocatedSize) {
126-
// Minimum requirements for datatypes of variable length cannot be verified in a relevant way so
124+
// Minimum requirements for datatypes of variable length cannot be verified in a relevant way
125+
// so
127126
// we only validate them for fixed length datatypes
128127
throw new IllegalArgumentException(
129128
"Tensor size is not large enough to contain all scalar values");
130129
}
131-
TF_Tensor nativeHandle = allocate(typeInfo.dataType().getNumber(), shape.asArray(), allocatedSize);
130+
TF_Tensor nativeHandle =
131+
allocate(typeInfo.dataType().getNumber(), shape.asArray(), allocatedSize);
132132
try (PointerScope scope = new PointerScope()) {
133133
scope.attach(nativeHandle);
134134
RawTensor t = new RawTensor(typeInfo, shape);
@@ -147,9 +147,9 @@ static RawTensor fromHandle(TF_Tensor handle) {
147147
TensorTypeInfo<?> typeInfo = TensorTypeRegistry.find(DataType.forNumber(dtype(handle)));
148148
RawTensor t = new RawTensor(typeInfo, Shape.of(shape(handle)));
149149
try (PointerScope scope = new PointerScope()) {
150-
scope.attach(handle);
151-
t.tensorHandle = handle;
152-
t.tensorScope = scope.extend();
150+
scope.attach(handle);
151+
t.tensorHandle = handle;
152+
t.tensorScope = scope.extend();
153153
}
154154
return t;
155155
}
@@ -168,6 +168,7 @@ static RawTensor fromHandle(TF_Tensor handle, EagerSession session) {
168168

169169
/**
170170
* Returns the native handle to this tensor
171+
*
171172
* @throws IllegalStateException if tensor has been closed
172173
*/
173174
TF_Tensor nativeHandle() {
@@ -178,14 +179,20 @@ TF_Tensor nativeHandle() {
178179
* Returns a typed reference to this tensor
179180
*
180181
* <p>In some cases, it is more useful to keep a typed reference to a tensor rather than its raw
181-
* nature to prevent mapping its memory on every access (e.g. when calling {@link Operand#asTensor()}).
182+
* nature to prevent mapping its memory on every access (e.g. when calling {@link
183+
* Operand#asTensor()}).
182184
*
183185
* @return typed reference to this tensor
184186
*/
185187
TType asTypedTensor() {
186188
return typeInfo.mapper().mapDense(this);
187189
}
188190

191+
/** @return metadata about the type of this tensor. */
192+
TensorTypeInfo<? extends TType> typeInfo() {
193+
return typeInfo;
194+
}
195+
189196
private static TF_Tensor requireHandle(TF_Tensor handle) {
190197
if (handle == null || handle.isNull()) {
191198
throw new IllegalStateException("close() was called on the Tensor");
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
package org.tensorflow;
16+
17+
import org.bytedeco.javacpp.PointerScope;
18+
import org.tensorflow.types.TInt64;
19+
import org.tensorflow.types.family.TType;
20+
21+
/**
22+
* A virtual type of {@link Tensor} composed of three dense tensors (indices, values and dimensions)
23+
* used to represent the sparse data into a multi-dimensional dense space.
24+
*
25+
* <p>Any tensor returned by a sparse tensor factory (e.g. {@link TInt64#sparseTensorOf(TInt64,
26+
* TInt64, TInt64)}) can be casted back to this interface to access directly the dense tensors it is
27+
* composed of.
28+
*
29+
* <p>A sparse tensor will keep strong references to its dense tensors to prevent them to be
30+
* released before it is closed itself. Likewise, closing a sparse tensor won't release the memory
31+
* of its dense tensors until they in turn are closed. It is then important to protect not only the
32+
* dense tensors within a <i>try-with-resource</i> block but the sparse tensor itself.
33+
*
34+
* <p>For example, this code is perfectly safe:
35+
*
36+
* <pre>{@code
37+
* TFloat64 createSparseTensor() {
38+
* try (TInt64 indices = TInt64.tensorOf(...);
39+
* TFloat64 values = TFloat64.vectorOf(...);
40+
* TInt64 denseShape = TInt64.vectorOf(...)) {
41+
* return TFloat64.sparseTensorOf(indices, values, denseShape);
42+
* }
43+
* }
44+
* try (TFloat64 sparseTensor = createSparseTensor()) {
45+
* ...
46+
* }
47+
* }</pre>
48+
*
49+
* @param <T> type of data stored in the tensor
50+
*/
51+
public interface SparseTensor<T extends TType> extends Tensor {
52+
53+
/**
54+
* Creates a sparse tensor from {@code indices}, {@code values} and {@code denseShape} dense
55+
* tensors.
56+
*
57+
* @param indices A 2-D tensor of shape {@code [N, ndims]}, that specifies the indices of the
58+
* elements in the sparse tensor that contain non-default values (elements are zero-indexed).
59+
* For example, {@code indices=[[1,3,1], [2,4,0]]} specifies that the elements with indexes of
60+
* {@code [1,3,1]} and {@code [2,4,0]} have non-default values.
61+
* @param values A 1-D tensor of shape {@code [N]}, which supplies the values for each element in
62+
* indices. For example, given {@code indices=[[1,3,1], [2,4,0]]}, the parameter {@code
63+
* values=[18, 3.8]} specifies that element {@code [1,3,1]} of the sparse tensor has a value
64+
* of {@code 18}, and element {@code [2,4,0]} of the tensor has a value of {@code 3.8}.
65+
* @param denseShape A 1-D tensor of shape {@code [ndims]} where each the value at index {@code i}
66+
* represents the size of dimension {@code i} in a dense version of that tensor.
67+
* @return the new sparse tensor
68+
* @throws IllegalArgumentException if shapes of the dense tensors are not compatible
69+
*/
70+
static <T extends TType> SparseTensor<T> of(TInt64 indices, T values, TInt64 denseShape) {
71+
if (indices.rank() != 2) {
72+
throw new IllegalArgumentException("Sparse indices must be a rank-2 tensor");
73+
}
74+
if (values.rank() != 1) {
75+
throw new IllegalArgumentException("Sparse values must be a rank-1 tensor");
76+
}
77+
if (denseShape.rank() != 1) {
78+
throw new IllegalArgumentException("Sparse shape must be a rank-1 tensor");
79+
}
80+
if (indices.shape().get(0) != values.shape().get(0)) {
81+
throw new IllegalArgumentException(
82+
"Number of indices must be equal to the number of values ["
83+
+ indices.shape().get(0)
84+
+ " != "
85+
+ values.shape().get(0)
86+
+ "]");
87+
}
88+
if (indices.shape().get(1) != denseShape.shape().get(0)) {
89+
throw new IllegalArgumentException(
90+
"Indices must have a coordinate for each dimensions of the tensor ["
91+
+ indices.shape().get(1)
92+
+ " != "
93+
+ denseShape.shape().get(0)
94+
+ "]");
95+
}
96+
// Use mapper of the values tensor as this is the one giving the type of the sparse tensor as
97+
// well
98+
TensorMapper<T> mapper = (TensorMapper<T>) values.asRawTensor().typeInfo().mapper();
99+
100+
// Attach all tensors to a new pointer scope (this will increment their reference count) and
101+
// preserve a strong reference to that scope inside the sparse tensor. This is done by
102+
// extending this scope in the sparse tensor constructors, via mapSparse()
103+
try (PointerScope scope = new PointerScope()) {
104+
scope.attach(indices.asRawTensor().nativeHandle());
105+
scope.attach(values.asRawTensor().nativeHandle());
106+
scope.attach(denseShape.asRawTensor().nativeHandle());
107+
return mapper.mapSparse(indices, values, denseShape, scope);
108+
}
109+
}
110+
111+
@Override
112+
default RawTensor asRawTensor() {
113+
throw new UnsupportedOperationException(
114+
"Sparse tensors cannot be converted to a single raw tensor");
115+
}
116+
117+
/**
118+
* Returns this instance as a typed tensor.
119+
*
120+
* <p>This method is equivalent to cast directly the {@code SparseTensor<T>} instance to {@code
121+
* T}.
122+
*
123+
* @return the typed tensor
124+
*/
125+
default T asTypedTensor() {
126+
return (T) this;
127+
}
128+
129+
/**
130+
* Gets the indices of the sparsed values.
131+
*
132+
* <p>Indices are a 2-D long array of shape {@code [N, ndims]}, that specifies the indices of the
133+
* elements in the sparse tensor that contain nonzero values (elements are zero-indexed).
134+
*
135+
* <p>For example, {@code indices=[[1,3], [2,4]]} specifies that the elements with indexes of
136+
* coordinates {@code [1,3]} and {@code [2,4]} have nonzero values.
137+
*
138+
* @return the indices
139+
*/
140+
TInt64 indices();
141+
142+
/**
143+
* Gets the sparse values.
144+
*
145+
* <p>Values are a 1-D array of type {@code T} and shape {@code [N]}, that supplies the values for
146+
* each element in indices.
147+
*
148+
* <p>For example, given {@code indices=[[1,3], [2,4]]}, and {@code values=[18, 3.6]} specifies
149+
* that element {@code [1,3]} of the sparse tensor has a value of {@code 18}, and element {@code
150+
* [2,4]} of the sparse tensor has a value of {@code 3.6}.
151+
*
152+
* @return the values
153+
*/
154+
T values();
155+
156+
/**
157+
* Gets the sparse tensor dimensions defining the shape in that tensor in a dense space.
158+
*
159+
* <p>Dimensions A 1-D tensor of shape {@code [ndims]} where each the value at index {@code i}
160+
* represents to total number of element in dimension {@code i} in a dense version of that tensor.
161+
*
162+
* @return the dense shape
163+
*/
164+
TInt64 denseShape();
165+
}

0 commit comments

Comments
 (0)