Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Numpy-like arrays #397

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ python:
- "2.7"
- "3.6"
env:
- DEPS="pip nose future numpy scipy"
- DEPS="pip nose future numpy"
- DEPS="pip nose future numpy scipy dask[array]"
- DEPS="pip nose future numpy dask[array]"
before_install:
- if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then
wget https://repo.continuum.io/miniconda/Miniconda2-latest-Linux-x86_64.sh -O miniconda.sh;
Expand All @@ -21,6 +21,6 @@ before_install:
install:
- conda install --yes python=$TRAVIS_PYTHON_VERSION $DEPS
- pip install -v .
script:
script:
- cd tests # Run from inside tests directory to make sure Autograd has
- nosetests # fully installed.
4 changes: 2 additions & 2 deletions autograd/core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from itertools import count
from functools import reduce
from .tracer import trace, primitive, toposort, Node, Box, isbox, getval
from .util import func, subval
from .util import func, subval, typeof

# -------------------- reverse mode --------------------

Expand Down Expand Up @@ -230,7 +230,7 @@ def register(cls, value_type, vspace_maker=None):

def vspace(value):
try:
return VSpace.mappings[type(value)](value)
return VSpace.mappings[typeof(value)](value)
except KeyError:
if isbox(value):
return vspace(getval(value))
Expand Down
3 changes: 2 additions & 1 deletion autograd/numpy/numpy_vspaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

class ArrayVSpace(VSpace):
def __init__(self, value):
value = np.array(value, copy=False)
if not hasattr(value, 'shape') or not hasattr(value, 'dtype'):
value = np.array(value, copy=False)
self.shape = value.shape
self.dtype = value.dtype

Expand Down
6 changes: 4 additions & 2 deletions autograd/tracer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import warnings
from contextlib import contextmanager
from collections import defaultdict
from .util import subvals, toposort
from .util import subvals, toposort, typeof
from .wrap_util import wraps

import numpy

def trace(start_node, fun, x):
with trace_stack.new_trace() as t:
start_box = new_box(x, t, start_node)
Expand Down Expand Up @@ -115,7 +117,7 @@ def register(cls, value_type):
box_type_mappings = Box.type_mappings
def new_box(value, trace, node):
try:
return box_type_mappings[type(value)](value, trace, node)
return box_type_mappings[typeof(value)](value, trace, node)
except KeyError:
raise TypeError("Can't differentiate w.r.t. type {}".format(type(value)))

Expand Down
14 changes: 14 additions & 0 deletions autograd/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,20 @@ def toposort(end_node, parents=operator.attrgetter('parents')):
else:
child_counts[parent] -= 1


def typeof(x):
"""
A Modified type function that returns np.ndarray for any array-like

This improves portability of autograd to other projects that might support
the numpy API, despite not being exactly numpy.
"""
if all(hasattr(x, attr) for attr in ['__array_ufunc__', 'shape', 'dtype']):
import numpy
return numpy.ndarray
else:
return type(x)

# -------------------- deprecation warnings -----------------------

import warnings
Expand Down
28 changes: 28 additions & 0 deletions tests/test_numpy_like.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from __future__ import absolute_import
import warnings

import autograd.numpy as np
import autograd.numpy.random as npr
from autograd.test_util import check_grads
from autograd import grad
from numpy_utils import combo_check

from dask.array.utils import assert_eq
import dask.array as da

npr.seed(1)

def test_dask():
x = np.arange(10)
xx = da.arange(10, chunks=(5,))

assert_eq(x, xx)

def f(x):
return np.sin(x).sum()

f_prime = grad(f)

assert isinstance(f_prime(xx), type(xx))

assert_eq(f_prime(x), f_prime(xx))