Skip to content

Commit acc7c8d

Browse files
michaelgsharpashrit-ms
authored andcommitted
Adds the new System.Numerics.Tensors as an input/output type when using dotnet 8.0 and up. (#23261)
### Description Adds the new System.Numerics.Tensors as an input/output type when using dotnet 8.0 and up. It does not change/remove any of the existing API, only adds additional ones. ### Motivation and Context Now that C#/Dotnet has an official tensor type built into the language, we want to expand the places that it can be used.
1 parent 3235e70 commit acc7c8d

File tree

4 files changed

+369
-2
lines changed

4 files changed

+369
-2
lines changed

csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
<Project Sdk="MSBuild.Sdk.Extras/3.0.22">
1+
<Project Sdk="Microsoft.NET.Sdk">
22
<PropertyGroup>
33
<!--- packaging properties -->
44
<OrtPackageId Condition="'$(OrtPackageId)' == ''">Microsoft.ML.OnnxRuntime</OrtPackageId>
@@ -189,6 +189,10 @@
189189
<PackageReference Include="Microsoft.SourceLink.GitHub" Version="8.0.0" PrivateAssets="All" />
190190
</ItemGroup>
191191

192+
<ItemGroup Condition="$([MSBuild]::IsTargetFrameworkCompatible('$(TargetFramework)', 'net8.0'))">
193+
<PackageReference Include="System.Numerics.Tensors" Version="9.0.0" />
194+
</ItemGroup>
195+
192196
<!-- debug output - makes finding/fixing any issues with the the conditions easy. -->
193197
<Target Name="DumpValues" BeforeTargets="PreBuildEvent">
194198
<Message Text="SolutionName='$(SolutionName)'" />

csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs

+152
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,14 @@
99
using System.Runtime.InteropServices;
1010
using System.Text;
1111

12+
#if NET8_0_OR_GREATER
13+
using System.Diagnostics.CodeAnalysis;
14+
using System.Reflection;
15+
using System.Runtime.CompilerServices;
16+
using SystemNumericsTensors = System.Numerics.Tensors;
17+
using TensorPrimitives = System.Numerics.Tensors.TensorPrimitives;
18+
#endif
19+
1220
namespace Microsoft.ML.OnnxRuntime
1321
{
1422
/// <summary>
@@ -205,6 +213,33 @@ public ReadOnlySpan<T> GetTensorDataAsSpan<T>() where T : unmanaged
205213
return MemoryMarshal.Cast<byte, T>(byteSpan);
206214
}
207215

216+
#if NET8_0_OR_GREATER
217+
/// <summary>
218+
/// Returns a ReadOnlyTensorSpan<typeparamref name="T"/> over tensor native buffer that
219+
/// provides a read-only view.
220+
///
221+
/// Note, that the memory may be device allocated and, therefore, not accessible from the CPU.
222+
/// To get memory descriptor use GetTensorMemoryInfo().
223+
///
224+
/// OrtValue must contain a non-string tensor.
225+
/// The span is valid as long as the OrtValue instance is alive (not disposed).
226+
/// </summary>
227+
/// <typeparam name="T"></typeparam>
228+
/// <returns>ReadOnlySpan<typeparamref name="T"/></returns>
229+
/// <exception cref="OnnxRuntimeException"></exception>
230+
[Experimental("SYSLIB5001")]
231+
public SystemNumericsTensors.ReadOnlyTensorSpan<T> GetTensorDataAsTensorSpan<T>() where T : unmanaged
232+
{
233+
var byteSpan = GetTensorBufferRawData(typeof(T));
234+
235+
var typeSpan = MemoryMarshal.Cast<byte, T>(byteSpan);
236+
var shape = GetTypeInfo().TensorTypeAndShapeInfo.Shape;
237+
nint[] nArray = Array.ConvertAll(shape, new Converter<long, nint>(x => (nint)x));
238+
239+
return new SystemNumericsTensors.ReadOnlyTensorSpan<T>(typeSpan, nArray, []);
240+
}
241+
#endif
242+
208243
/// <summary>
209244
/// Returns a Span<typeparamref name="T"/> over tensor native buffer.
210245
/// This enables you to safely and efficiently modify the underlying
@@ -225,6 +260,32 @@ public Span<T> GetTensorMutableDataAsSpan<T>() where T : unmanaged
225260
return MemoryMarshal.Cast<byte, T>(byteSpan);
226261
}
227262

263+
#if NET8_0_OR_GREATER
264+
/// <summary>
265+
/// Returns a TensorSpan<typeparamref name="T"/> over tensor native buffer.
266+
///
267+
/// Note, that the memory may be device allocated and, therefore, not accessible from the CPU.
268+
/// To get memory descriptor use GetTensorMemoryInfo().
269+
///
270+
/// OrtValue must contain a non-string tensor.
271+
/// The span is valid as long as the OrtValue instance is alive (not disposed).
272+
/// </summary>
273+
/// <typeparam name="T"></typeparam>
274+
/// <returns>ReadOnlySpan<typeparamref name="T"/></returns>
275+
/// <exception cref="OnnxRuntimeException"></exception>
276+
[Experimental("SYSLIB5001")]
277+
public SystemNumericsTensors.TensorSpan<T> GetTensorMutableDataAsTensorSpan<T>() where T : unmanaged
278+
{
279+
var byteSpan = GetTensorBufferRawData(typeof(T));
280+
281+
var typeSpan = MemoryMarshal.Cast<byte, T>(byteSpan);
282+
var shape = GetTypeInfo().TensorTypeAndShapeInfo.Shape;
283+
nint[] nArray = Array.ConvertAll(shape, new Converter<long, nint>(x => (nint)x));
284+
285+
return new SystemNumericsTensors.TensorSpan<T>(typeSpan, nArray, []);
286+
}
287+
#endif
288+
228289
/// <summary>
229290
/// Provides mutable raw native buffer access.
230291
/// </summary>
@@ -234,6 +295,23 @@ public Span<byte> GetTensorMutableRawData()
234295
return GetTensorBufferRawData(typeof(byte));
235296
}
236297

298+
#if NET8_0_OR_GREATER
299+
/// <summary>
300+
/// Provides mutable raw native buffer access.
301+
/// </summary>
302+
/// <returns>TensorSpan over the native buffer bytes</returns>
303+
[Experimental("SYSLIB5001")]
304+
public SystemNumericsTensors.TensorSpan<byte> GetTensorSpanMutableRawData<T>() where T : unmanaged
305+
{
306+
var byteSpan = GetTensorBufferRawData(typeof(T));
307+
308+
var shape = GetTypeInfo().TensorTypeAndShapeInfo.Shape;
309+
nint[] nArray = Array.ConvertAll(shape, new Converter<long, nint>(x => (nint)x));
310+
311+
return new SystemNumericsTensors.TensorSpan<byte>(byteSpan, nArray, []);
312+
}
313+
#endif
314+
237315
/// <summary>
238316
/// Fetch string tensor element buffer pointer at the specified index,
239317
/// convert/copy to UTF-16 char[] and return a ReadOnlyMemory{char} instance.
@@ -605,6 +683,80 @@ public static OrtValue CreateTensorValueFromMemory<T>(T[] data, long[] shape) wh
605683
return OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, new Memory<T>(data), shape);
606684
}
607685

686+
#if NET8_0_OR_GREATER
687+
/// <summary>
688+
/// This is a factory method creates a native Onnxruntime OrtValue containing a tensor on top of the existing tensor managed memory.
689+
/// The method will attempt to pin managed memory so no copying occurs when data is passed down
690+
/// to native code.
691+
/// </summary>
692+
/// <param name="value">Tensor object</param>
693+
/// <param name="elementType">discovered tensor element type</param>
694+
/// <returns>And instance of OrtValue constructed on top of the object</returns>
695+
[Experimental("SYSLIB5001")]
696+
public static OrtValue CreateTensorValueFromSystemNumericsTensorObject<T>(SystemNumericsTensors.Tensor<T> tensor) where T : unmanaged
697+
{
698+
if (!IsContiguousAndDense(tensor))
699+
{
700+
var newTensor = SystemNumericsTensors.Tensor.Create<T>(tensor.Lengths);
701+
tensor.CopyTo(newTensor);
702+
tensor = newTensor;
703+
}
704+
unsafe
705+
{
706+
var backingData = (T[])tensor.GetType().GetField("_values", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(tensor);
707+
GCHandle handle = GCHandle.Alloc(backingData, GCHandleType.Pinned);
708+
var memHandle = new MemoryHandle(Unsafe.AsPointer(ref tensor.GetPinnableReference()), handle);
709+
710+
try
711+
{
712+
IntPtr dataBufferPointer = IntPtr.Zero;
713+
unsafe
714+
{
715+
dataBufferPointer = (IntPtr)memHandle.Pointer;
716+
}
717+
718+
var bufferLengthInBytes = tensor.FlattenedLength * sizeof(T);
719+
long[] shape = Array.ConvertAll(tensor.Lengths.ToArray(), new Converter<nint, long>(x => (long)x));
720+
721+
var typeInfo = TensorBase.GetTypeInfo(typeof(T)) ??
722+
throw new OnnxRuntimeException(ErrorCode.InvalidArgument, $"Tensor of type: {typeof(T)} is not supported");
723+
724+
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateTensorWithDataAsOrtValue(
725+
OrtMemoryInfo.DefaultInstance.Pointer,
726+
dataBufferPointer,
727+
(UIntPtr)(bufferLengthInBytes),
728+
shape,
729+
(UIntPtr)tensor.Rank,
730+
typeInfo.ElementType,
731+
out IntPtr nativeValue));
732+
733+
return new OrtValue(nativeValue, memHandle);
734+
}
735+
catch (Exception)
736+
{
737+
memHandle.Dispose();
738+
throw;
739+
}
740+
}
741+
}
742+
743+
[Experimental("SYSLIB5001")]
744+
private static bool IsContiguousAndDense<T>(SystemNumericsTensors.Tensor<T> tensor) where T : unmanaged
745+
{
746+
// Right most dimension must be 1 for a dense tensor.
747+
if (tensor.Strides[^1] != 1)
748+
return false;
749+
750+
// For other dimensions, the stride must be equal to the product of the dimensions to the right.
751+
for (int i = tensor.Rank - 2; i >= 0; i--)
752+
{
753+
if (tensor.Strides[i] != TensorPrimitives.Product(tensor.Lengths.Slice(i + 1, tensor.Lengths.Length - i - 1)))
754+
return false;
755+
}
756+
return true;
757+
}
758+
#endif
759+
608760
/// <summary>
609761
/// The factory API creates an OrtValue with memory allocated using the given allocator
610762
/// according to the specified shape and element type. The memory will be released when OrtValue

0 commit comments

Comments
 (0)