-
Notifications
You must be signed in to change notification settings - Fork 143
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Remove sklearn (recreated PR) (#277)
* change solver * remove dependency on dataset * add ridge tests * swap to ridge * add down sampling * change to coef * change to torch.solve * black * fix correlated columns * fix sqrt error * black * black * black * add normalization * black * flake8 * change explaination * add debug log * Update tests/conftest.py Co-authored-by: Alby M. <[email protected]> * Update tests/conftest.py Co-authored-by: Alby M. <[email protected]> * Fix typo in README (#270) * update change log * Update tests/conftest.py Co-authored-by: Alby M. <[email protected]> * Update tests/conftest.py Co-authored-by: Alby M. <[email protected]> Co-authored-by: Lixin Sun <[email protected]> Co-authored-by: Lixin Sun <[email protected]> Co-authored-by: Simon Batzner <[email protected]>
- Loading branch information
1 parent
4a7fb10
commit 332947f
Showing
9 changed files
with
132 additions
and
235 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,181 +1,72 @@ | ||
import logging | ||
import torch | ||
import numpy as np | ||
from typing import Optional | ||
from sklearn.gaussian_process import GaussianProcessRegressor | ||
from sklearn.gaussian_process.kernels import DotProduct, Kernel, Hyperparameter | ||
|
||
from torch import matmul | ||
from torch.linalg import solve, inv | ||
from typing import Optional, Sequence | ||
from opt_einsum import contract | ||
|
||
def solver(X, y, regressor: Optional[str] = "NormalizedGaussianProcess", **kwargs): | ||
if regressor == "GaussianProcess": | ||
return gp(X, y, **kwargs) | ||
elif regressor == "NormalizedGaussianProcess": | ||
return normalized_gp(X, y, **kwargs) | ||
else: | ||
raise NotImplementedError(f"{regressor} is not implemented") | ||
|
||
def solver(X, y, alpha: Optional[float] = 0.001, stride: Optional[int] = 1, **kwargs): | ||
|
||
dtype = torch.get_default_dtype() | ||
X = X[::stride].to(dtype) | ||
y = y[::stride].to(dtype) | ||
|
||
X, y = down_sampling_by_composition(X, y) | ||
|
||
X_norm = torch.sum(X) | ||
|
||
X = X / X_norm | ||
y = y / X_norm | ||
|
||
def normalized_gp(X, y, **kwargs): | ||
feature_rms = 1.0 / np.sqrt(np.average(X**2, axis=0)) | ||
feature_rms = np.nan_to_num(feature_rms, 1) | ||
y_mean = torch.sum(y) / torch.sum(X) | ||
mean, std = base_gp( | ||
X, | ||
y - (torch.sum(X, axis=1) * y_mean).reshape(y.shape), | ||
NormalizedDotProduct, | ||
{"diagonal_elements": feature_rms}, | ||
**kwargs, | ||
) | ||
return mean + y_mean, std | ||
|
||
|
||
def gp(X, y, **kwargs): | ||
return base_gp( | ||
X, y, DotProduct, {"sigma_0": 0, "sigma_0_bounds": "fixed"}, **kwargs | ||
) | ||
|
||
|
||
def base_gp( | ||
X, | ||
y, | ||
kernel, | ||
kernel_kwargs, | ||
alpha: Optional[float] = 0.1, | ||
max_iteration: int = 20, | ||
stride: Optional[int] = None, | ||
|
||
feature_rms = torch.sqrt(torch.mean(X**2, axis=0)) | ||
|
||
alpha_mat = torch.diag(feature_rms) * alpha * alpha | ||
|
||
A = matmul(X.T, X) + alpha_mat | ||
dy = y - (torch.sum(X, axis=1, keepdim=True) * y_mean).reshape(y.shape) | ||
Xy = matmul(X.T, dy) | ||
|
||
mean = solve(A, Xy) | ||
|
||
sigma2 = torch.var(matmul(X, mean) - dy) | ||
Ainv = inv(A) | ||
cov = torch.sqrt(sigma2 * contract("ij,kj,kl,li->i", Ainv, X, X, Ainv)) | ||
|
||
mean = mean + y_mean.reshape([-1]) | ||
|
||
logging.debug(f"Ridge Regression, residue {sigma2}") | ||
|
||
return mean, cov | ||
|
||
|
||
def down_sampling_by_composition( | ||
X: torch.Tensor, y: torch.Tensor, percentage: Sequence = [0.25, 0.5, 0.75] | ||
): | ||
|
||
if len(y.shape) == 1: | ||
y = y.reshape([-1, 1]) | ||
|
||
if stride is not None: | ||
X = X[::stride] | ||
y = y[::stride] | ||
|
||
not_fit = True | ||
iteration = 0 | ||
mean = None | ||
std = None | ||
while not_fit: | ||
logging.debug(f"GP fitting iteration {iteration} {alpha}") | ||
try: | ||
_kernel = kernel(**kernel_kwargs) | ||
gpr = GaussianProcessRegressor(kernel=_kernel, random_state=0, alpha=alpha) | ||
gpr = gpr.fit(X, y) | ||
|
||
vec = torch.diag(torch.ones(X.shape[1])) | ||
mean, std = gpr.predict(vec, return_std=True) | ||
|
||
mean = torch.as_tensor(mean, dtype=torch.get_default_dtype()).reshape([-1]) | ||
# ignore all the off-diagonal terms | ||
std = torch.as_tensor(std, dtype=torch.get_default_dtype()).reshape([-1]) | ||
likelihood = gpr.log_marginal_likelihood() | ||
|
||
res = torch.sqrt( | ||
torch.square(torch.matmul(X, mean.reshape([-1, 1])) - y).mean() | ||
) | ||
|
||
logging.debug( | ||
f"GP fitting: alpha {alpha}:\n" | ||
f" residue {res}\n" | ||
f" mean {mean} std {std}\n" | ||
f" log marginal likelihood {likelihood}" | ||
) | ||
not_fit = False | ||
|
||
except Exception as e: | ||
logging.info(f"GP fitting failed for alpha={alpha} and {e.args}") | ||
if alpha == 0 or alpha is None: | ||
logging.info("try a non-zero alpha") | ||
not_fit = False | ||
raise ValueError( | ||
f"Please set the {alpha} to non-zero value. \n" | ||
"The dataset energy is rank deficient to be solved with GP" | ||
) | ||
else: | ||
alpha = alpha * 2 | ||
iteration += 1 | ||
logging.debug(f" increase alpha to {alpha}") | ||
|
||
if iteration >= max_iteration or not_fit is False: | ||
raise ValueError( | ||
"Please set the per species shift and scale to zeros and ones. \n" | ||
"The dataset energy is to diverge to be solved with GP" | ||
) | ||
|
||
return mean, std | ||
|
||
|
||
class NormalizedDotProduct(Kernel): | ||
r"""Dot-Product kernel. | ||
.. math:: | ||
k(x_i, x_j) = x_i \cdot A \cdot x_j | ||
""" | ||
|
||
def __init__(self, diagonal_elements): | ||
# TO DO: check shape | ||
self.diagonal_elements = diagonal_elements | ||
self.A = np.diag(diagonal_elements) | ||
|
||
def __call__(self, X, Y=None, eval_gradient=False): | ||
"""Return the kernel k(X, Y) and optionally its gradient. | ||
Parameters | ||
---------- | ||
X : ndarray of shape (n_samples_X, n_features) | ||
Left argument of the returned kernel k(X, Y) | ||
Y : ndarray of shape (n_samples_Y, n_features), default=None | ||
Right argument of the returned kernel k(X, Y). If None, k(X, X) | ||
if evaluated instead. | ||
eval_gradient : bool, default=False | ||
Determines whether the gradient with respect to the log of | ||
the kernel hyperparameter is computed. | ||
Only supported when Y is None. | ||
Returns | ||
------- | ||
K : ndarray of shape (n_samples_X, n_samples_Y) | ||
Kernel k(X, Y) | ||
K_gradient : ndarray of shape (n_samples_X, n_samples_X, n_dims),\ | ||
optional | ||
The gradient of the kernel k(X, X) with respect to the log of the | ||
hyperparameter of the kernel. Only returned when `eval_gradient` | ||
is True. | ||
""" | ||
X = np.atleast_2d(X) | ||
if Y is None: | ||
K = (X.dot(self.A)).dot(X.T) | ||
else: | ||
if eval_gradient: | ||
raise ValueError("Gradient can only be evaluated when Y is None.") | ||
K = (X.dot(self.A)).dot(Y.T) | ||
|
||
if eval_gradient: | ||
return K, np.empty((X.shape[0], X.shape[0], 0)) | ||
else: | ||
return K | ||
|
||
def diag(self, X): | ||
"""Returns the diagonal of the kernel k(X, X). | ||
The result of this method is identical to np.diag(self(X)); however, | ||
it can be evaluated more efficiently since only the diagonal is | ||
evaluated. | ||
Parameters | ||
---------- | ||
X : ndarray of shape (n_samples_X, n_features) | ||
Left argument of the returned kernel k(X, Y). | ||
Returns | ||
------- | ||
K_diag : ndarray of shape (n_samples_X,) | ||
Diagonal of kernel k(X, X). | ||
""" | ||
return np.einsum("ij,ij,jj->i", X, X, self.A) | ||
|
||
def __repr__(self): | ||
return "" | ||
|
||
def is_stationary(self): | ||
"""Returns whether the kernel is stationary.""" | ||
return False | ||
|
||
@property | ||
def hyperparameter_diagonal_elements(self): | ||
return Hyperparameter("diagonal_elements", "numeric", "fixed") | ||
unique_comps, comp_ids = torch.unique(X, dim=0, return_inverse=True) | ||
|
||
n_types = torch.max(comp_ids) + 1 | ||
|
||
sort_by = torch.argsort(comp_ids) | ||
|
||
# find out the block for each composition | ||
d_icomp = comp_ids[sort_by] | ||
d_icomp = d_icomp[:-1] - d_icomp[1:] | ||
node_icomp = torch.where(d_icomp != 0)[0] | ||
id_start = torch.cat((torch.as_tensor([0]), node_icomp + 1)) | ||
id_end = torch.cat((node_icomp + 1, torch.as_tensor([len(sort_by)]))) | ||
|
||
n_points = len(percentage) | ||
new_X = torch.zeros((n_types * n_points, X.shape[1])) | ||
new_y = torch.zeros((n_types * n_points)) | ||
for i in range(n_types): | ||
ids = sort_by[id_start[i] : id_end[i]] | ||
for j, p in enumerate(percentage): | ||
new_y[i * n_points + j] = torch.quantile(y[ids], p, interpolation="linear") | ||
new_X[i * n_points + j] = unique_comps[i] | ||
|
||
return new_X, new_y |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.