Skip to content

Commit

Permalink
Added logic to handle modules and user-defined functions
Browse files Browse the repository at this point in the history
  • Loading branch information
thequackdaddy committed Nov 3, 2018
1 parent 5f662a9 commit 807cc93
Showing 1 changed file with 24 additions and 2 deletions.
26 changes: 24 additions & 2 deletions patsy/design_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,7 @@ def var_names(self, eval_env=0):
else:
return {}

def partial(self, columns, product=False):
def partial(self, columns, product=False, eval_env=0):
"""Returns a partial prediction array where only the variables in the
dict ``columns`` are tranformed per the :class:`DesignInfo`
transformations. The terms that are not influenced by ``columns``
Expand All @@ -703,6 +703,18 @@ def partial(self, columns, product=False):
:returns: A numpy array of the partial design matrix.
"""
from .highlevel import dmatrix
from types import ModuleType

if not eval_env:
from patsy.eval import EvalEnvironment
eval_env = EvalEnvironment.capture(eval_env, reference=1)

# We need to get rid of the non-callable items from the eval_env
namespaces = [{key: value} for ns in eval_env._namespaces
for key, value in six.iteritems(ns)
if callable(value) or isinstance(value, ModuleType)]
eval_env._namespaces = namespaces

if product:
columns = _column_product(columns)
rows = None
Expand All @@ -712,7 +724,7 @@ def partial(self, columns, product=False):
rows = len(columns[col])
parts = []
for term, subterm in six.iteritems(self.term_codings):
term_vars = term.var_names()
term_vars = term.var_names(eval_env)
present = True
for term_var in term_vars:
if term_var not in columns:
Expand Down Expand Up @@ -1312,6 +1324,16 @@ def test_DesignInfo_partial():
assert_raises(ValueError, dm.design_info.partial, {'a': ['a', 'b'],
'b': [1, 2, 3]})

def some_function(x):
return np.where(x > 2, 1, 2)

dm = dmatrix('1 + some_function(c)')
x = np.array([[0, 2],
[0, 2],
[0, 1]])
y = dm.design_info.partial({'c': np.array([1, 2, 3])})
assert_allclose(x, y)


def _column_product(columns):
from itertools import product
Expand Down

0 comments on commit 807cc93

Please sign in to comment.