forked from AMLab-Amsterdam/CEVAE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluation.py
43 lines (34 loc) · 1.48 KB
/
evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import numpy as np
class Evaluator(object):
def __init__(self, y, t, y_cf=None, mu0=None, mu1=None):
self.y = y
self.t = t
self.y_cf = y_cf
self.mu0 = mu0
self.mu1 = mu1
if mu0 is not None and mu1 is not None:
self.true_ite = mu1 - mu0
def rmse_ite(self, ypred1, ypred0):
pred_ite = np.zeros_like(self.true_ite)
idx1, idx0 = np.where(self.t == 1), np.where(self.t == 0)
ite1, ite0 = self.y[idx1] - ypred0[idx1], ypred1[idx0] - self.y[idx0]
pred_ite[idx1] = ite1
pred_ite[idx0] = ite0
return np.sqrt(np.mean(np.square(self.true_ite - pred_ite)))
def abs_ate(self, ypred1, ypred0):
return np.abs(np.mean(ypred1 - ypred0) - np.mean(self.true_ite))
def pehe(self, ypred1, ypred0):
return np.sqrt(np.mean(np.square((self.mu1 - self.mu0) - (ypred1 - ypred0))))
def y_errors(self, y0, y1):
ypred = (1 - self.t) * y0 + self.t * y1
ypred_cf = self.t * y0 + (1 - self.t) * y1
return self.y_errors_pcf(ypred, ypred_cf)
def y_errors_pcf(self, ypred, ypred_cf):
rmse_factual = np.sqrt(np.mean(np.square(ypred - self.y)))
rmse_cfactual = np.sqrt(np.mean(np.square(ypred_cf - self.y_cf)))
return rmse_factual, rmse_cfactual
def calc_stats(self, ypred1, ypred0):
ite = self.rmse_ite(ypred1, ypred0)
ate = self.abs_ate(ypred1, ypred0)
pehe = self.pehe(ypred1, ypred0)
return ite, ate, pehe