Skip to content

Commit 20ace9e

Browse files
update kliep
1 parent 9aad8c9 commit 20ace9e

File tree

2 files changed

+34
-14
lines changed

2 files changed

+34
-14
lines changed

adapt/instance_based/_kliep.py

+31-11
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ class KLIEP:
8080
If get_estimator is ``None``, a ``LinearRegression`` object will be
8181
used by default as estimator.
8282
83-
sigmas : float or list of float, optional (default=0.1)
83+
sigmas : float or list of float, optional (default=1/nb_features)
8484
Kernel bandwidths.
8585
If ``sigmas`` is a list of multiple values, the
8686
kernel bandwidth is selected with the LCV procedure.
@@ -89,10 +89,20 @@ class KLIEP:
8989
Cross-validation split parameter.
9090
Used only if sigmas has more than one value.
9191
92-
max_points : int, optional (default=100)
92+
max_centers : int, optional (default=100)
9393
Maximal number of target instances use to
9494
compute kernels.
9595
96+
lr: float, optional (default=1e-4)
97+
Learning rate of the gradient ascent.
98+
99+
tol: float, optional (default=1e-6)
100+
Optimization threshold.
101+
102+
max_iter: int, optional (default=5000)
103+
Maximal iteration of the gradient ascent
104+
optimization.
105+
96106
kwargs : key, value arguments, optional
97107
Additional arguments for the constructor.
98108
@@ -128,11 +138,11 @@ class KLIEP:
128138
"Direct importance estimation with model selection and its application \
129139
to covariateshift adaptation". In NIPS 2007
130140
"""
131-
def __init__(self, estimator=None,
141+
def __init__(self, get_estimator=None,
132142
sigmas=None, max_centers=100,
133143
cv=5, lr=1e-4, tol=1e-6, max_iter=5000,
134144
verbose=1, **kwargs):
135-
self.estimator = estimator
145+
self.get_estimator = get_estimator
136146
self.sigmas = sigmas
137147
self.cv = cv
138148
self.max_centers = max_centers
@@ -142,8 +152,8 @@ def __init__(self, estimator=None,
142152
self.verbose = verbose
143153
self.kwargs = kwargs
144154

145-
if self.estimator is None:
146-
self.estimator = LinearRegression()
155+
if self.get_estimator is None:
156+
self.get_estimator = LinearRegression
147157

148158

149159
def fit_weights(self, Xs, Xt=None):
@@ -187,24 +197,26 @@ def fit_weights(self, Xs, Xt=None):
187197

188198

189199
def fit_estimator(self, X, y, **fit_params):
200+
self.estimator_ = self.get_estimator(**self.kwargs)
190201
if hasattr(self, "weights_"):
191-
if "sample_weight" in inspect.signature(self.estimator.fit).parameters:
192-
self.estimator.fit(X, y,
202+
if "sample_weight" in inspect.signature(self.estimator_.fit).parameters:
203+
self.estimator_.fit(X, y,
193204
sample_weight=self.weights_,
194205
**fit_params)
195206
else:
196207
bootstrap_index = np.random.choice(
197208
len(X), size=len(X), replace=True,
198209
p=self.weights_ / self.weights_.sum())
199-
self.estimator.fit(X[bootstrap_index], y[bootstrap_index],
210+
self.estimator_.fit(X[bootstrap_index], y[bootstrap_index],
200211
**fit_params)
201212
else:
202213
raise NotFittedError("Weights are not fitted yet, please "
203214
"call 'fit_weights' first.")
204215
return self
205216

206217

207-
def fit(self, Xs, ys, Xt=None, **fit_params):
218+
def fit(self, X, y, src_index, tgt_index,
219+
tgt_index_labeled=None, **fit_params):
208220
"""
209221
Fit KLIEP.
210222
@@ -233,6 +245,14 @@ def fit(self, Xs, ys, Xt=None, **fit_params):
233245
-------
234246
self : returns an instance of self
235247
"""
248+
if tgt_index_labeled is None:
249+
Xs = X[src_index]
250+
ys = y[src_index]
251+
else:
252+
Xs = X[np.concatenate((src_index, tgt_index_labeled))]
253+
ys = y[np.concatenate((src_index, tgt_index_labeled))]
254+
Xt = X[tgt_index]
255+
236256
if self.verbose:
237257
print("Fitting weights...")
238258
self.fit_weights(Xs, Xt)
@@ -318,7 +338,7 @@ def predict(self, X):
318338
y_pred: array
319339
prediction of estimator.
320340
"""
321-
return self.estimator.predict(X)
341+
return self.estimator_.predict(X)
322342

323343

324344
def predict_weights(self, X=None):

tests/test_kliep.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@ def test_setup():
2525

2626
def test_fit():
2727
np.random.seed(0)
28-
model = KLIEP(LinearRegression(), sigmas=[10, 100],
28+
model = KLIEP(LinearRegression, sigmas=[10, 100],
2929
fit_intercept=False)
30-
model.fit(Xs, y[:100], Xt)#range(100), range(100, 200))
31-
assert np.abs(model.estimator.coef_[0] - 0.2) < 10
30+
model.fit(X, y, range(100), range(100, 200))
31+
assert np.abs(model.estimator_.coef_[0] - 0.2) < 10
3232
assert model.weights_[:50].sum() > 90
3333
assert model.weights_[50:].sum() < 0.5
3434
assert np.abs(model.predict(Xt) - y[100:]).sum() < 20

0 commit comments

Comments
 (0)