diff --git a/einops/_backends.py b/einops/_backends.py index 40e0502a..2f930bc1 100644 --- a/einops/_backends.py +++ b/einops/_backends.py @@ -660,3 +660,55 @@ def einsum(self, pattern, *x): def shape(self, x): return tuple(x.shape) + +class TinygradBackend(AbstractBackend): + framework_name = "tinygrad" + + def __init__(self): + import tinygrad + + self.tinygrad = tinygrad + + def is_appropriate_type(self, tensor): + return isinstance(tensor, self.tinygrad.Tensor) + + def from_numpy(self, x): + return self.tinygrad.Tensor(x) + + def to_numpy(self, x): + return x.numpy() + + def arange(self, start, stop): + return self.tinygrad.Tensor.arange(start, stop) + + def shape(self, x): + return x.shape + + def reshape(self, x, shape): + return x.reshape(shape) + + def transpose(self, x, axes): + return x.permute(axes) + + def reduce(self, x, operation, axes): + for axis in sorted(axes, reverse=True): + x = getattr(x, operation)(axis=axis) + return x + + def stack_on_zeroth_dimension(self, tensors: list): + return self.tinygrad.Tensor.stack(tensors) + + def add_axis(self, x, new_position): + return x.unsqueeze(new_position) + + def tile(self, x, repeats): + return x.repeat(repeats) + + def concat(self, tensors, axis: int): + return tensors[0].cat(tensors[1:], axis) if len(tensors) > 1 else tensors[0] + + def is_float_type(self, x): + return self.tinygrad.dtypes.is_float(x.dtype) + + def einsum(self, pattern, *x): + return self.tinygrad.Tensor.einsum(pattern, *x) \ No newline at end of file