@@ -427,6 +427,41 @@ impl<T: Num + PartialOrd + Copy> Div<DynamicMatrix<T>> for Tensor<T> {
427
427
}
428
428
}
429
429
430
+ impl < T : Num + PartialOrd + Copy + std:: fmt:: Display > Tensor < T > {
431
+ pub fn display ( & self ) -> String {
432
+ fn format_tensor < T : Num + PartialOrd + Copy + std:: fmt:: Display > ( data : & [ T ] , shape : & [ usize ] , level : usize ) -> String {
433
+ if shape. len ( ) == 1 {
434
+ let mut result = String :: from ( "[" ) ;
435
+ for ( i, item) in data. iter ( ) . enumerate ( ) {
436
+ result. push_str ( & format ! ( "{}" , item) ) ;
437
+ if i < data. len ( ) - 1 {
438
+ result. push_str ( ", " ) ;
439
+ }
440
+ }
441
+ result. push ( ']' ) ;
442
+ return result;
443
+ }
444
+
445
+ let mut result = String :: from ( "[" ) ;
446
+ let sub_size = shape[ 1 ..] . iter ( ) . product ( ) ;
447
+ for i in 0 ..shape[ 0 ] {
448
+ if i > 0 {
449
+ result. push_str ( ",\n " ) ;
450
+ for _ in 0 ..level {
451
+ result. push ( ' ' ) ;
452
+ }
453
+ }
454
+ result. push_str ( & format_tensor ( & data[ i * sub_size..( i + 1 ) * sub_size] , & shape[ 1 ..] , level + 1 ) ) ;
455
+ }
456
+ result. push ( ']' ) ;
457
+ result
458
+ }
459
+
460
+ format_tensor ( & self . data , & self . shape . dims ( ) , 1 )
461
+ }
462
+ }
463
+
464
+
430
465
#[ cfg( test) ]
431
466
mod tests {
432
467
use super :: * ;
@@ -1127,5 +1162,44 @@ mod tests {
1127
1162
assert_eq ! ( result[ coord![ 1 , 1 ] ] , 2.0 ) ;
1128
1163
assert_eq ! ( result. shape( ) , & shape) ;
1129
1164
}
1165
+
1166
+ #[ test]
1167
+ fn test_display_1d_tensor ( ) {
1168
+ let shape = shape ! [ 3 ] . unwrap ( ) ;
1169
+ let data = vec ! [ 1.0 , 2.0 , 3.0 ] ;
1170
+ let tensor = Tensor :: new ( & shape, & data) . unwrap ( ) ;
1171
+ let display = tensor. display ( ) ;
1172
+ assert_eq ! ( display, "[1, 2, 3]" ) ;
1173
+ }
1174
+
1175
+ #[ test]
1176
+ fn test_display_2d_tensor ( ) {
1177
+ let shape = shape ! [ 2 , 2 ] . unwrap ( ) ;
1178
+ let data = vec ! [ 1.0 , 2.0 , 3.0 , 4.0 ] ;
1179
+ let tensor = Tensor :: new ( & shape, & data) . unwrap ( ) ;
1180
+ let display = tensor. display ( ) ;
1181
+ assert_eq ! ( display, "[[1, 2],\n [3, 4]]" ) ;
1182
+ }
1183
+
1184
+ #[ test]
1185
+ fn test_display_3d_tensor ( ) {
1186
+ let shape = shape ! [ 2 , 2 , 2 ] . unwrap ( ) ;
1187
+ let data = vec ! [ 1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 , 7.0 , 8.0 ] ;
1188
+ let tensor = Tensor :: new ( & shape, & data) . unwrap ( ) ;
1189
+ let display = tensor. display ( ) ;
1190
+ assert_eq ! ( display, "[[[1, 2],\n [3, 4]],\n \n [[5, 6],\n [7, 8]]]" ) ;
1191
+ }
1192
+
1193
+ #[ test]
1194
+ fn test_display_4d_tensor ( ) {
1195
+ let shape = shape ! [ 2 , 2 , 2 , 2 ] . unwrap ( ) ;
1196
+ let data = vec ! [
1197
+ 1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 , 7.0 , 8.0 ,
1198
+ 9.0 , 10.0 , 11.0 , 12.0 , 13.0 , 14.0 , 15.0 , 16.0
1199
+ ] ;
1200
+ let tensor = Tensor :: new ( & shape, & data) . unwrap ( ) ;
1201
+ let display = tensor. display ( ) ;
1202
+ assert_eq ! ( display, "[[[[1, 2],\n [3, 4]],\n \n [[5, 6],\n [7, 8]]],\n \n \n [[[9, 10],\n [11, 12]],\n \n [[13, 14],\n [15, 16]]]]" ) ;
1203
+ }
1130
1204
}
1131
1205
0 commit comments