Skip to content

Commit a69bcb9

Browse files
committed
fix: display method
1 parent 82d8073 commit a69bcb9

File tree

2 files changed

+23
-6
lines changed

2 files changed

+23
-6
lines changed

src/storage.rs

+13
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ impl<T> DynamicStorage<T> {
3535
Ok(index)
3636
}
3737

38+
pub fn size(&self) -> usize {
39+
self.data.len()
40+
}
41+
3842
pub fn iter(&self) -> std::slice::Iter<'_, T> {
3943
self.data.iter()
4044
}
@@ -50,6 +54,15 @@ impl<T> Index<usize> for DynamicStorage<T> {
5054
fn index(&self, index: usize) -> &Self::Output {
5155
&self.data[index]
5256
}
57+
58+
}
59+
60+
impl<T> Index<std::ops::Range<usize>> for DynamicStorage<T> {
61+
type Output = [T];
62+
63+
fn index(&self, range: std::ops::Range<usize>) -> &Self::Output {
64+
&self.data[range]
65+
}
5366
}
5467

5568
impl<T> IndexMut<usize> for DynamicStorage<T> {

src/tensor.rs

+10-6
Original file line numberDiff line numberDiff line change
@@ -429,12 +429,12 @@ impl<T: Num + PartialOrd + Copy> Div<DynamicMatrix<T>> for Tensor<T> {
429429

430430
impl<T: Num + PartialOrd + Copy + std::fmt::Display> Tensor<T> {
431431
pub fn display(&self) -> String {
432-
fn format_tensor<T: Num + PartialOrd + Copy + std::fmt::Display>(data: &[T], shape: &[usize], level: usize) -> String {
433-
if shape.len() == 1 {
432+
fn format_tensor<T: Num + PartialOrd + Copy + std::fmt::Display>(data: &DynamicStorage<T>, shape: &Shape, level: usize) -> String {
433+
if shape.order() == 1 {
434434
let mut result = String::from("[");
435435
for (i, item) in data.iter().enumerate() {
436436
result.push_str(&format!("{}", item));
437-
if i < data.len() - 1 {
437+
if i < data.size() - 1 {
438438
result.push_str(", ");
439439
}
440440
}
@@ -443,21 +443,25 @@ impl<T: Num + PartialOrd + Copy + std::fmt::Display> Tensor<T> {
443443
}
444444

445445
let mut result = String::from("[");
446-
let sub_size = shape[1..].iter().product();
446+
let sub_size = Shape::new(shape[1..].to_vec()).unwrap().size();
447447
for i in 0..shape[0] {
448448
if i > 0 {
449449
result.push_str(",\n");
450+
for _ in 0..shape.order() - 2 {
451+
result.push('\n');
452+
}
450453
for _ in 0..level {
451454
result.push(' ');
452455
}
453456
}
454-
result.push_str(&format_tensor(&data[i * sub_size..(i + 1) * sub_size], &shape[1..], level + 1));
457+
let sub_data = DynamicStorage::new(data[i * sub_size..(i + 1) * sub_size].to_vec());
458+
result.push_str(&format_tensor(&sub_data, &Shape::new(shape[1..].to_vec()).unwrap(), level + 1));
455459
}
456460
result.push(']');
457461
result
458462
}
459463

460-
format_tensor(&self.data, &self.shape.dims(), 1)
464+
format_tensor(&self.data, &self.shape, 1)
461465
}
462466
}
463467

0 commit comments

Comments
 (0)