Skip to content

Commit 82d8073

Browse files
committed
feat: display method for tensor
1 parent 70e31bc commit 82d8073

File tree

1 file changed

+74
-0
lines changed

1 file changed

+74
-0
lines changed

src/tensor.rs

+74
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,41 @@ impl<T: Num + PartialOrd + Copy> Div<DynamicMatrix<T>> for Tensor<T> {
427427
}
428428
}
429429

430+
impl<T: Num + PartialOrd + Copy + std::fmt::Display> Tensor<T> {
431+
pub fn display(&self) -> String {
432+
fn format_tensor<T: Num + PartialOrd + Copy + std::fmt::Display>(data: &[T], shape: &[usize], level: usize) -> String {
433+
if shape.len() == 1 {
434+
let mut result = String::from("[");
435+
for (i, item) in data.iter().enumerate() {
436+
result.push_str(&format!("{}", item));
437+
if i < data.len() - 1 {
438+
result.push_str(", ");
439+
}
440+
}
441+
result.push(']');
442+
return result;
443+
}
444+
445+
let mut result = String::from("[");
446+
let sub_size = shape[1..].iter().product();
447+
for i in 0..shape[0] {
448+
if i > 0 {
449+
result.push_str(",\n");
450+
for _ in 0..level {
451+
result.push(' ');
452+
}
453+
}
454+
result.push_str(&format_tensor(&data[i * sub_size..(i + 1) * sub_size], &shape[1..], level + 1));
455+
}
456+
result.push(']');
457+
result
458+
}
459+
460+
format_tensor(&self.data, &self.shape.dims(), 1)
461+
}
462+
}
463+
464+
430465
#[cfg(test)]
431466
mod tests {
432467
use super::*;
@@ -1127,5 +1162,44 @@ mod tests {
11271162
assert_eq!(result[coord![1, 1]], 2.0);
11281163
assert_eq!(result.shape(), &shape);
11291164
}
1165+
1166+
#[test]
1167+
fn test_display_1d_tensor() {
1168+
let shape = shape![3].unwrap();
1169+
let data = vec![1.0, 2.0, 3.0];
1170+
let tensor = Tensor::new(&shape, &data).unwrap();
1171+
let display = tensor.display();
1172+
assert_eq!(display, "[1, 2, 3]");
1173+
}
1174+
1175+
#[test]
1176+
fn test_display_2d_tensor() {
1177+
let shape = shape![2, 2].unwrap();
1178+
let data = vec![1.0, 2.0, 3.0, 4.0];
1179+
let tensor = Tensor::new(&shape, &data).unwrap();
1180+
let display = tensor.display();
1181+
assert_eq!(display, "[[1, 2],\n [3, 4]]");
1182+
}
1183+
1184+
#[test]
1185+
fn test_display_3d_tensor() {
1186+
let shape = shape![2, 2, 2].unwrap();
1187+
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1188+
let tensor = Tensor::new(&shape, &data).unwrap();
1189+
let display = tensor.display();
1190+
assert_eq!(display, "[[[1, 2],\n [3, 4]],\n\n [[5, 6],\n [7, 8]]]");
1191+
}
1192+
1193+
#[test]
1194+
fn test_display_4d_tensor() {
1195+
let shape = shape![2, 2, 2, 2].unwrap();
1196+
let data = vec![
1197+
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
1198+
9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0
1199+
];
1200+
let tensor = Tensor::new(&shape, &data).unwrap();
1201+
let display = tensor.display();
1202+
assert_eq!(display, "[[[[1, 2],\n [3, 4]],\n\n [[5, 6],\n [7, 8]]],\n\n\n [[[9, 10],\n [11, 12]],\n\n [[13, 14],\n [15, 16]]]]");
1203+
}
11301204
}
11311205

0 commit comments

Comments
 (0)