9
9
using System . Runtime . InteropServices ;
10
10
using System . Text ;
11
11
12
+ #if NET8_0_OR_GREATER
13
+ using System . Diagnostics . CodeAnalysis ;
14
+ using System . Reflection ;
15
+ using System . Runtime . CompilerServices ;
16
+ using SystemNumericsTensors = System . Numerics . Tensors ;
17
+ using TensorPrimitives = System . Numerics . Tensors . TensorPrimitives ;
18
+ #endif
19
+
12
20
namespace Microsoft . ML . OnnxRuntime
13
21
{
14
22
/// <summary>
@@ -205,6 +213,33 @@ public ReadOnlySpan<T> GetTensorDataAsSpan<T>() where T : unmanaged
205
213
return MemoryMarshal . Cast < byte , T > ( byteSpan ) ;
206
214
}
207
215
216
+ #if NET8_0_OR_GREATER
217
+ /// <summary>
218
+ /// Returns a ReadOnlyTensorSpan<typeparamref name="T"/> over tensor native buffer that
219
+ /// provides a read-only view.
220
+ ///
221
+ /// Note, that the memory may be device allocated and, therefore, not accessible from the CPU.
222
+ /// To get memory descriptor use GetTensorMemoryInfo().
223
+ ///
224
+ /// OrtValue must contain a non-string tensor.
225
+ /// The span is valid as long as the OrtValue instance is alive (not disposed).
226
+ /// </summary>
227
+ /// <typeparam name="T"></typeparam>
228
+ /// <returns>ReadOnlySpan<typeparamref name="T"/></returns>
229
+ /// <exception cref="OnnxRuntimeException"></exception>
230
+ [ Experimental ( "SYSLIB5001" ) ]
231
+ public SystemNumericsTensors . ReadOnlyTensorSpan < T > GetTensorDataAsTensorSpan < T > ( ) where T : unmanaged
232
+ {
233
+ var byteSpan = GetTensorBufferRawData ( typeof ( T ) ) ;
234
+
235
+ var typeSpan = MemoryMarshal . Cast < byte , T > ( byteSpan ) ;
236
+ var shape = GetTypeInfo ( ) . TensorTypeAndShapeInfo . Shape ;
237
+ nint [ ] nArray = Array . ConvertAll ( shape , new Converter < long , nint > ( x => ( nint ) x ) ) ;
238
+
239
+ return new SystemNumericsTensors . ReadOnlyTensorSpan < T > ( typeSpan , nArray , [ ] ) ;
240
+ }
241
+ #endif
242
+
208
243
/// <summary>
209
244
/// Returns a Span<typeparamref name="T"/> over tensor native buffer.
210
245
/// This enables you to safely and efficiently modify the underlying
@@ -225,6 +260,32 @@ public Span<T> GetTensorMutableDataAsSpan<T>() where T : unmanaged
225
260
return MemoryMarshal . Cast < byte , T > ( byteSpan ) ;
226
261
}
227
262
263
+ #if NET8_0_OR_GREATER
264
+ /// <summary>
265
+ /// Returns a TensorSpan<typeparamref name="T"/> over tensor native buffer.
266
+ ///
267
+ /// Note, that the memory may be device allocated and, therefore, not accessible from the CPU.
268
+ /// To get memory descriptor use GetTensorMemoryInfo().
269
+ ///
270
+ /// OrtValue must contain a non-string tensor.
271
+ /// The span is valid as long as the OrtValue instance is alive (not disposed).
272
+ /// </summary>
273
+ /// <typeparam name="T"></typeparam>
274
+ /// <returns>ReadOnlySpan<typeparamref name="T"/></returns>
275
+ /// <exception cref="OnnxRuntimeException"></exception>
276
+ [ Experimental ( "SYSLIB5001" ) ]
277
+ public SystemNumericsTensors . TensorSpan < T > GetTensorMutableDataAsTensorSpan < T > ( ) where T : unmanaged
278
+ {
279
+ var byteSpan = GetTensorBufferRawData ( typeof ( T ) ) ;
280
+
281
+ var typeSpan = MemoryMarshal . Cast < byte , T > ( byteSpan ) ;
282
+ var shape = GetTypeInfo ( ) . TensorTypeAndShapeInfo . Shape ;
283
+ nint [ ] nArray = Array . ConvertAll ( shape , new Converter < long , nint > ( x => ( nint ) x ) ) ;
284
+
285
+ return new SystemNumericsTensors . TensorSpan < T > ( typeSpan , nArray , [ ] ) ;
286
+ }
287
+ #endif
288
+
228
289
/// <summary>
229
290
/// Provides mutable raw native buffer access.
230
291
/// </summary>
@@ -234,6 +295,23 @@ public Span<byte> GetTensorMutableRawData()
234
295
return GetTensorBufferRawData ( typeof ( byte ) ) ;
235
296
}
236
297
298
+ #if NET8_0_OR_GREATER
299
+ /// <summary>
300
+ /// Provides mutable raw native buffer access.
301
+ /// </summary>
302
+ /// <returns>TensorSpan over the native buffer bytes</returns>
303
+ [ Experimental ( "SYSLIB5001" ) ]
304
+ public SystemNumericsTensors . TensorSpan < byte > GetTensorSpanMutableRawData < T > ( ) where T : unmanaged
305
+ {
306
+ var byteSpan = GetTensorBufferRawData ( typeof ( T ) ) ;
307
+
308
+ var shape = GetTypeInfo ( ) . TensorTypeAndShapeInfo . Shape ;
309
+ nint [ ] nArray = Array . ConvertAll ( shape , new Converter < long , nint > ( x => ( nint ) x ) ) ;
310
+
311
+ return new SystemNumericsTensors . TensorSpan < byte > ( byteSpan , nArray , [ ] ) ;
312
+ }
313
+ #endif
314
+
237
315
/// <summary>
238
316
/// Fetch string tensor element buffer pointer at the specified index,
239
317
/// convert/copy to UTF-16 char[] and return a ReadOnlyMemory{char} instance.
@@ -605,6 +683,80 @@ public static OrtValue CreateTensorValueFromMemory<T>(T[] data, long[] shape) wh
605
683
return OrtValue . CreateTensorValueFromMemory ( OrtMemoryInfo . DefaultInstance , new Memory < T > ( data ) , shape ) ;
606
684
}
607
685
686
+ #if NET8_0_OR_GREATER
687
+ /// <summary>
688
+ /// This is a factory method creates a native Onnxruntime OrtValue containing a tensor on top of the existing tensor managed memory.
689
+ /// The method will attempt to pin managed memory so no copying occurs when data is passed down
690
+ /// to native code.
691
+ /// </summary>
692
+ /// <param name="value">Tensor object</param>
693
+ /// <param name="elementType">discovered tensor element type</param>
694
+ /// <returns>And instance of OrtValue constructed on top of the object</returns>
695
+ [ Experimental ( "SYSLIB5001" ) ]
696
+ public static OrtValue CreateTensorValueFromSystemNumericsTensorObject < T > ( SystemNumericsTensors . Tensor < T > tensor ) where T : unmanaged
697
+ {
698
+ if ( ! IsContiguousAndDense ( tensor ) )
699
+ {
700
+ var newTensor = SystemNumericsTensors . Tensor . Create < T > ( tensor . Lengths ) ;
701
+ tensor . CopyTo ( newTensor ) ;
702
+ tensor = newTensor ;
703
+ }
704
+ unsafe
705
+ {
706
+ var backingData = ( T [ ] ) tensor . GetType ( ) . GetField ( "_values" , BindingFlags . Instance | BindingFlags . NonPublic ) . GetValue ( tensor ) ;
707
+ GCHandle handle = GCHandle . Alloc ( backingData , GCHandleType . Pinned ) ;
708
+ var memHandle = new MemoryHandle ( Unsafe . AsPointer ( ref tensor . GetPinnableReference ( ) ) , handle ) ;
709
+
710
+ try
711
+ {
712
+ IntPtr dataBufferPointer = IntPtr . Zero ;
713
+ unsafe
714
+ {
715
+ dataBufferPointer = ( IntPtr ) memHandle . Pointer ;
716
+ }
717
+
718
+ var bufferLengthInBytes = tensor . FlattenedLength * sizeof ( T ) ;
719
+ long [ ] shape = Array . ConvertAll ( tensor . Lengths . ToArray ( ) , new Converter < nint , long > ( x => ( long ) x ) ) ;
720
+
721
+ var typeInfo = TensorBase . GetTypeInfo ( typeof ( T ) ) ??
722
+ throw new OnnxRuntimeException ( ErrorCode . InvalidArgument , $ "Tensor of type: { typeof ( T ) } is not supported") ;
723
+
724
+ NativeApiStatus . VerifySuccess ( NativeMethods . OrtCreateTensorWithDataAsOrtValue (
725
+ OrtMemoryInfo . DefaultInstance . Pointer ,
726
+ dataBufferPointer ,
727
+ ( UIntPtr ) ( bufferLengthInBytes ) ,
728
+ shape ,
729
+ ( UIntPtr ) tensor . Rank ,
730
+ typeInfo . ElementType ,
731
+ out IntPtr nativeValue ) ) ;
732
+
733
+ return new OrtValue ( nativeValue , memHandle ) ;
734
+ }
735
+ catch ( Exception )
736
+ {
737
+ memHandle . Dispose ( ) ;
738
+ throw ;
739
+ }
740
+ }
741
+ }
742
+
743
+ [ Experimental ( "SYSLIB5001" ) ]
744
+ private static bool IsContiguousAndDense < T > ( SystemNumericsTensors . Tensor < T > tensor ) where T : unmanaged
745
+ {
746
+ // Right most dimension must be 1 for a dense tensor.
747
+ if ( tensor . Strides [ ^ 1 ] != 1 )
748
+ return false ;
749
+
750
+ // For other dimensions, the stride must be equal to the product of the dimensions to the right.
751
+ for ( int i = tensor . Rank - 2 ; i >= 0 ; i -- )
752
+ {
753
+ if ( tensor . Strides [ i ] != TensorPrimitives . Product ( tensor . Lengths . Slice ( i + 1 , tensor . Lengths . Length - i - 1 ) ) )
754
+ return false ;
755
+ }
756
+ return true ;
757
+ }
758
+ #endif
759
+
608
760
/// <summary>
609
761
/// The factory API creates an OrtValue with memory allocated using the given allocator
610
762
/// according to the specified shape and element type. The memory will be released when OrtValue
0 commit comments