Skip to content

Commit b93aaa5

Browse files
committed
Add support for dynamic rank, see #4
1 parent f0a3ac0 commit b93aaa5

22 files changed

+1195
-837
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ Here are the main features of mdarray:
1414
- Standard Rust mechanisms are used for e.g. indexing and iteration.
1515
- Generic expressions for multidimensional iteration.
1616

17-
The design is inspired from other Rust crates (ndarray, nalgebra, bitvec
18-
and dfdx), the proposed C++ mdarray and mdspan types, and multidimensional
17+
The design is inspired from other Rust crates (ndarray, nalgebra, bitvec, dfdx
18+
and candle), the proposed C++ mdarray and mdspan types, and multidimensional
1919
arrays in other languages.
2020

2121
## License

src/array.rs

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use crate::expression::Expression;
1212
use crate::index::SliceIndex;
1313
use crate::iter::Iter;
1414
use crate::layout::{Dense, Layout};
15-
use crate::shape::{ConstShape, IntoShape, Shape};
15+
use crate::shape::{ConstShape, Shape};
1616
use crate::slice::Slice;
1717
use crate::tensor::Tensor;
1818
use crate::traits::{Apply, FromExpression, IntoExpression};
@@ -25,16 +25,16 @@ pub struct Array<T, S: ConstShape>(pub S::Inner<T>);
2525

2626
impl<T, S: ConstShape> Array<T, S> {
2727
/// Creates an array from the given element.
28-
pub fn from_elem<I: IntoShape<IntoShape = S>>(shape: I, elem: T) -> Self
28+
pub fn from_elem(elem: T) -> Self
2929
where
3030
T: Clone,
3131
{
32-
Self::from_expr(expr::from_elem(shape, elem))
32+
Self::from_expr(expr::from_elem(S::default(), elem))
3333
}
3434

3535
/// Creates an array with the results from the given function.
36-
pub fn from_fn<I: IntoShape<IntoShape = S>, F: FnMut(S::Dims) -> T>(shape: I, f: F) -> Self {
37-
Self::from_expr(expr::from_fn(shape, f))
36+
pub fn from_fn<F: FnMut(&[usize]) -> T>(f: F) -> Self {
37+
Self::from_expr(expr::from_fn(S::default(), f))
3838
}
3939

4040
/// Converts an array with a single element into the contained value.
@@ -45,19 +45,16 @@ impl<T, S: ConstShape> Array<T, S> {
4545
pub fn into_scalar(self) -> T {
4646
assert!(self.len() == 1, "invalid length");
4747

48-
self.into_shape(()).0
48+
self.into_shape::<()>().0
4949
}
5050

5151
/// Converts the array into a reshaped array, which must have the same length.
5252
///
5353
/// # Panics
5454
///
5555
/// Panics if the array length is changed.
56-
pub fn into_shape<I>(self, shape: I) -> Array<T, I::IntoShape>
57-
where
58-
I: IntoShape<IntoShape: ConstShape>,
59-
{
60-
assert!(shape.into_shape().len() == self.len(), "length must not change");
56+
pub fn into_shape<I: ConstShape>(self) -> Array<T, I> {
57+
assert!(I::default().len() == self.len(), "length must not change");
6158

6259
let me = ManuallyDrop::new(self);
6360

@@ -75,7 +72,7 @@ impl<T, S: ConstShape> Array<T, S> {
7572
index: usize,
7673
}
7774

78-
impl<'a, T, S: ConstShape> Drop for DropGuard<'a, T, S> {
75+
impl<T, S: ConstShape> Drop for DropGuard<'_, T, S> {
7976
fn drop(&mut self) {
8077
let ptr = self.array.as_mut_ptr() as *mut T;
8178

@@ -85,7 +82,8 @@ impl<T, S: ConstShape> Array<T, S> {
8582
}
8683
}
8784

88-
assert!(expr.dims()[..] == S::default().dims()[..], "invalid shape");
85+
// Ensure that the shape is valid.
86+
_ = expr.shape().with_dims(|dims| S::from_dims(dims));
8987

9088
let mut array = MaybeUninit::uninit();
9189
let mut guard = DropGuard { array: &mut array, index: 0 };

src/dim.rs

Lines changed: 28 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
use std::fmt::{Debug, Formatter, Result};
2-
3-
use std::ops::{
4-
Bound, IndexMut, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive,
5-
};
2+
use std::hash::Hash;
63

74
/// Array dimension trait.
8-
pub trait Dim: Copy + Debug + Default + Send + Sync {
5+
pub trait Dim: Copy + Debug + Default + Eq + Hash + Send + Sync {
96
/// Merge dimensions, where constant size is preferred over dynamic.
107
type Merge<D: Dim>: Dim;
118

@@ -23,50 +20,28 @@ pub trait Dim: Copy + Debug + Default + Send + Sync {
2320
fn size(self) -> usize;
2421
}
2522

26-
/// Array dimensions trait.
27-
pub trait Dims:
28-
Copy
29-
+ Debug
30-
+ Default
31-
+ IndexMut<(Bound<usize>, Bound<usize>), Output = [usize]>
32-
+ IndexMut<usize, Output = usize>
33-
+ IndexMut<Range<usize>, Output = [usize]>
34-
+ IndexMut<RangeFrom<usize>, Output = [usize]>
35-
+ IndexMut<RangeFull, Output = [usize]>
36-
+ IndexMut<RangeInclusive<usize>, Output = [usize]>
37-
+ IndexMut<RangeTo<usize>, Output = [usize]>
38-
+ IndexMut<RangeToInclusive<usize>, Output = [usize]>
39-
+ Send
40-
+ Sync
41-
+ for<'a> TryFrom<&'a [usize], Error: Debug>
42-
{
43-
}
44-
45-
/// Array strides trait.
46-
pub trait Strides:
47-
Copy
23+
#[allow(unreachable_pub)]
24+
pub trait Dims<T: Copy + Debug + Default + Eq + Hash + Send + Sync>:
25+
AsMut<[T]>
26+
+ AsRef<[T]>
27+
+ Clone
4828
+ Debug
4929
+ Default
50-
+ IndexMut<(Bound<usize>, Bound<usize>), Output = [isize]>
51-
+ IndexMut<usize, Output = isize>
52-
+ IndexMut<Range<usize>, Output = [isize]>
53-
+ IndexMut<RangeFrom<usize>, Output = [isize]>
54-
+ IndexMut<RangeFull, Output = [isize]>
55-
+ IndexMut<RangeInclusive<usize>, Output = [isize]>
56-
+ IndexMut<RangeTo<usize>, Output = [isize]>
57-
+ IndexMut<RangeToInclusive<usize>, Output = [isize]>
30+
+ Eq
31+
+ Hash
5832
+ Send
5933
+ Sync
60-
+ for<'a> TryFrom<&'a [isize], Error: Debug>
34+
+ for<'a> TryFrom<&'a [T], Error: Debug>
6135
{
36+
fn new(len: usize) -> Self;
6237
}
6338

6439
/// Type-level constant.
65-
#[derive(Clone, Copy, Default)]
40+
#[derive(Clone, Copy, Default, Eq, Hash, PartialEq)]
6641
pub struct Const<const N: usize>;
6742

6843
/// Dynamically-sized dimension type.
69-
#[derive(Clone, Copy, Debug, Default)]
44+
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
7045
pub struct Dyn(pub usize);
7146

7247
impl<const N: usize> Debug for Const<N> {
@@ -105,13 +80,24 @@ impl Dim for Dyn {
10580
}
10681
}
10782

108-
macro_rules! impl_dims_strides {
83+
macro_rules! impl_dims {
10984
($($n:tt),+) => {
11085
$(
111-
impl Dims for [usize; $n] {}
112-
impl Strides for [isize; $n] {}
86+
impl<T: Copy + Debug + Default + Eq + Hash + Send + Sync> Dims<T> for [T; $n] {
87+
fn new(len: usize) -> Self {
88+
assert!(len == $n, "invalid length");
89+
90+
Self::default()
91+
}
92+
}
11393
)+
11494
};
11595
}
11696

117-
impl_dims_strides!(0, 1, 2, 3, 4, 5, 6);
97+
impl_dims!(0, 1, 2, 3, 4, 5, 6);
98+
99+
impl<T: Copy + Debug + Default + Eq + Hash + Send + Sync> Dims<T> for Box<[T]> {
100+
fn new(len: usize) -> Self {
101+
vec![T::default(); len].into()
102+
}
103+
}

0 commit comments

Comments
 (0)