Skip to content

Commit

Permalink
add tinygrad, or at least the straightforward part
Browse files Browse the repository at this point in the history
  • Loading branch information
blueridanus committed Dec 7, 2023
1 parent a6e9353 commit 789d1fd
Showing 1 changed file with 52 additions and 0 deletions.
52 changes: 52 additions & 0 deletions einops/_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,3 +660,55 @@ def einsum(self, pattern, *x):

def shape(self, x):
return tuple(x.shape)

class TinygradBackend(AbstractBackend):
framework_name = "tinygrad"

def __init__(self):
import tinygrad

self.tinygrad = tinygrad

def is_appropriate_type(self, tensor):
return isinstance(tensor, self.tinygrad.Tensor)

def from_numpy(self, x):
return self.tinygrad.Tensor(x)

def to_numpy(self, x):
return x.numpy()

def arange(self, start, stop):
return self.tinygrad.Tensor.arange(start, stop)

def shape(self, x):
return x.shape

def reshape(self, x, shape):
return x.reshape(shape)

def transpose(self, x, axes):
return x.permute(axes)

def reduce(self, x, operation, axes):
for axis in sorted(axes, reverse=True):
x = getattr(x, operation)(axis=axis)
return x

def stack_on_zeroth_dimension(self, tensors: list):
return self.tinygrad.stack(tensors)

def add_axis(self, x, new_position):
return x.unsqueeze(new_position)

def tile(self, x, repeats):
return x.repeat(repeats)

def concat(self, tensors, axis: int):
return tensors[0].cat(tensors[1:], axis) if len(tensors) > 1 else tensors[0]

def is_float_type(self, x):
return self.dtypes.is_float(x.dtype)

def __repr__(self):
return "<einops backend for {}>".format(self.framework_name)

0 comments on commit 789d1fd

Please sign in to comment.