Skip to content

Commit

Permalink
simplify ArrayVspace._product implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Aug 27, 2017
1 parent 3c14983 commit 6a216a1
Showing 1 changed file with 11 additions and 12 deletions.
23 changes: 11 additions & 12 deletions autograd/numpy/numpy_vspaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,17 @@ def _inner_prod(self, x, y):
return np.dot(np.ravel(x), np.ravel(y))

def _product(self, other_vspace):
product_vspace = self.__new__(self.__class__)
product_vspace.shape = self.shape + other_vspace.shape
product_vspace.dtype = np.promote_types(self.dtype, other_vspace.dtype)
return product_vspace

def _contract(self, other_vspace):
ndim = other_vspace.ndim
if not self.shape[-ndim:] == other_vspace.shape[:ndim]: raise ValueError
contraction_vspace = self.__new__(self.__class__)
contraction_vspace.shape = self.shape[:-ndim] + other_vspace.shape[ndim:]
contraction_vspace.dtype = np.promote_types(self.dtype, other_vspace.dtype)
return contraction_vspace
return self._contract(other_vspace, ndim=0)

def _contract(self, other_vspace, ndim=None):
ndim = other_vspace.ndim if ndim is None else ndim
if not self.shape[-ndim % self.ndim:] == other_vspace.shape[:ndim]:
raise ValueError

result = self.__new__(self.__class__)
result.shape = self.shape[:-ndim % self.ndim] + other_vspace.shape[ndim:]
result.dtype = np.promote_types(self.dtype, other_vspace.dtype)
return result

def _kronecker_tensor(self):
return np.reshape(np.eye(self.size), self.shape + self.shape)
Expand Down

0 comments on commit 6a216a1

Please sign in to comment.