Skip to content

Commit 0ae9d2c

Browse files
committed
Relax tensorflow version limit
1 parent f73adc8 commit 0ae9d2c

File tree

3 files changed

+33
-9
lines changed

3 files changed

+33
-9
lines changed

econml/iv/nnet/_deepiv.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,13 @@ def mog_loss_model(n_components, d_t):
8989
# LL = C - log(sum(pi_i/sig^d * exp(-d2/(2*sig^2))))
9090
# Use logsumexp for numeric stability:
9191
# LL = C - log(sum(exp(-d2/(2*sig^2) + log(pi_i/sig^d))))
92-
# TODO: does the numeric stability actually make any difference?
9392
def make_logloss(d2, sig, pi):
94-
return -K.logsumexp(-d2 / (2 * K.square(sig)) + K.log(pi / K.pow(sig, d_t)), axis=-1)
93+
# logsumexp doesn't exist in keras 2.4; simulate it
94+
values = - d2 / (2 * K.square(sig)) + K.log(pi / K.pow(sig, d_t))
95+
# logsumexp(a,b,c) = log(exp(a)+exp(b)+exp(c)) = log((exp(a-k)+exp(b-k)+exp(c-k))*exp(k))
96+
# = log((exp(a-k)+exp(b-k)+exp(c-k))) + k
97+
mx = K.max(values, axis=-1)
98+
return -K.log(K.sum(K.exp(values - L.Reshape((-1, 1))(mx)), axis=-1)) - mx
9599

96100
ll = L.Lambda(lambda dsp: make_logloss(*dsp), output_shape=(1,))([d2, sig, pi])
97101

@@ -346,7 +350,7 @@ def fit(self, Y, T, X, Z, *, inference=None):
346350
model.add_loss(L.Lambda(K.mean)(ll))
347351
model.compile(self._optimizer)
348352
# TODO: do we need to give the user more control over other arguments to fit?
349-
model.fit([Z, X, T], [], **self._first_stage_options)
353+
model.fit([Z, X, T], **self._first_stage_options)
350354

351355
lm = response_loss_model(lambda t, x: self._h(t, x),
352356
lambda z, x: Model([z_in, x_in],
@@ -361,7 +365,7 @@ def fit(self, Y, T, X, Z, *, inference=None):
361365
response_model.add_loss(L.Lambda(K.mean)(rl))
362366
response_model.compile(self._optimizer)
363367
# TODO: do we need to give the user more control over other arguments to fit?
364-
response_model.fit([Z, X, Y], [], **self._second_stage_options)
368+
response_model.fit([Z, X, Y], **self._second_stage_options)
365369

366370
self._effect_model = Model([t_in, x_in], [self._h(t_in, x_in)])
367371

econml/tests/test_deepiv.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,27 @@ def test_stop_grad(self):
3232
model = keras.Model([x_input, y_input, z_input], [loss])
3333
model.add_loss(K.mean(loss))
3434
model.compile('nadam')
35-
model.fit([np.array([[1]]), np.array([[2]]), np.array([[0]])], [])
35+
model.fit([np.array([[1]]), np.array([[2]]), np.array([[0]])])
36+
37+
def test_mog_loss(self):
38+
inputs = [keras.layers.Input(shape=s) for s in [(3,), (3, 2), (3,), (2,)]]
39+
ll_model = keras.engine.Model(inputs, mog_loss_model(3, 2)(inputs))
40+
41+
for n in range(10):
42+
ps = -np.log(np.random.uniform(size=(3,)))
43+
pi = ps / np.sum(ps)
44+
mu = np.random.normal(size=(3, 2))
45+
sig = np.exp(np.random.normal(size=3,))
46+
t = np.random.normal(size=(2,))
47+
48+
pred = ll_model.predict([pi.reshape(1, 3), mu.reshape(1, 3, 2), sig.reshape(1, 3), t.reshape(1, 2)])
49+
50+
# LL = C - log(sum(pi_i/sig^d * exp(-d2/(2*sig^2))))
51+
d = mu - t.reshape(-1, 2)
52+
d2 = np.sum(d * d, axis=-1)
53+
ll = -np.log(np.sum(pi / (sig * sig) * np.exp(-d2 / (2 * sig * sig)), axis=0))
54+
55+
assert np.allclose(ll, pred[0])
3656

3757
@pytest.mark.slow
3858
def test_deepiv_shape(self):
@@ -500,7 +520,7 @@ def norm(lr):
500520
model = keras.engine.Model([x_input, t_input], [ll])
501521
model.add_loss(K.mean(ll))
502522
model.compile('nadam')
503-
model.fit([x, t], [], epochs=5)
523+
model.fit([x, t], epochs=5)
504524

505525
# For some reason this doesn't work at all when run against the CNTK backend...
506526
# model.compile('nadam', loss=lambda _,l:l)
@@ -559,7 +579,7 @@ def sample(n):
559579
model = keras.engine.Model([x_input, t_input], [ll])
560580
model.add_loss(K.mean(ll))
561581
model.compile('nadam')
562-
model.fit([x, t], [], epochs=100)
582+
model.fit([x, t], epochs=100)
563583

564584
model2 = keras.engine.Model([x_input], [pi, mu, sig])
565585
import matplotlib

setup.cfg

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ install_requires =
4343
numpy
4444
scipy > 1.4.0
4545
scikit-learn >= 0.24
46-
keras < 2.4
46+
keras
4747
sparse
48-
tensorflow > 1.10, < 2.3
48+
tensorflow > 1.10
4949
joblib >= 0.13.0
5050
numba != 0.42.1
5151
statsmodels >= 0.9

0 commit comments

Comments
 (0)