diff --git a/CHANGELOG.md b/CHANGELOG.md index 50338467..67bd0c81 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ Most recent change on the bottom. ## [Unreleased] - 0.5.6 ### Added +- sklearn dependency removed - `nequip-benchmark` and `nequip-train` report number of weights and number of trainable weights - `nequip-benchmark --no-compile` and `--verbose` and `--memory-summary` - `nequip-benchmark --pdb` for debugging model (builder) errors diff --git a/README.md b/README.md index f70840b8..da741c09 100644 --- a/README.md +++ b/README.md @@ -157,6 +157,6 @@ under the guidance of [Boris Kozinsky at Harvard](https://bkoz.seas.harvard.edu/ If you have questions, please don't hesitate to reach out at batzner[at]g[dot]harvard[dot]edu. If you find a bug or have a proposal for a feature, please post it in the [Issues](https://github.com/mir-group/nequip/issues). -If you have a question, topic, or issue that isn't obviously one of those, try our [GitHub Disucssions](https://github.com/mir-group/nequip/discussions). +If you have a question, topic, or issue that isn't obviously one of those, try our [GitHub Discussions](https://github.com/mir-group/nequip/discussions). If you want to contribute to the code, please read [`CONTRIBUTING.md`](CONTRIBUTING.md). diff --git a/configs/full.yaml b/configs/full.yaml index daefc143..2f98164e 100644 --- a/configs/full.yaml +++ b/configs/full.yaml @@ -317,10 +317,10 @@ per_species_rescale_scales: dataset_forces_rms # If not provided, defaults to dataset_per_species_force_rms or dataset_per_atom_total_energy_std, depending on whether forces are being trained. # per_species_rescale_kwargs: # total_energy: -# alpha: 0.1 +# alpha: 0.001 # max_iteration: 20 # stride: 100 -# keywords for GP decomposition of per specie energy. Optional. Defaults to 0.1 +# keywords for ridge regression decomposition of per specie energy. Optional. Defaults to 0.001. The value should be in the range of 1e-3 to 1e-2 # per_species_rescale_arguments_in_dataset_units: True # if explicit numbers are given for the shifts/scales, this parameter must specify whether the given numbers are unitless shifts/scales or are in the units of the dataset. If ``True``, any global rescalings will correctly be applied to the per-species values. diff --git a/nequip/utils/regressor.py b/nequip/utils/regressor.py index 3d23cf84..30c8f9ab 100644 --- a/nequip/utils/regressor.py +++ b/nequip/utils/regressor.py @@ -1,181 +1,72 @@ import logging import torch -import numpy as np -from typing import Optional -from sklearn.gaussian_process import GaussianProcessRegressor -from sklearn.gaussian_process.kernels import DotProduct, Kernel, Hyperparameter +from torch import matmul +from torch.linalg import solve, inv +from typing import Optional, Sequence +from opt_einsum import contract -def solver(X, y, regressor: Optional[str] = "NormalizedGaussianProcess", **kwargs): - if regressor == "GaussianProcess": - return gp(X, y, **kwargs) - elif regressor == "NormalizedGaussianProcess": - return normalized_gp(X, y, **kwargs) - else: - raise NotImplementedError(f"{regressor} is not implemented") +def solver(X, y, alpha: Optional[float] = 0.001, stride: Optional[int] = 1, **kwargs): + + dtype = torch.get_default_dtype() + X = X[::stride].to(dtype) + y = y[::stride].to(dtype) + + X, y = down_sampling_by_composition(X, y) + + X_norm = torch.sum(X) + + X = X / X_norm + y = y / X_norm -def normalized_gp(X, y, **kwargs): - feature_rms = 1.0 / np.sqrt(np.average(X**2, axis=0)) - feature_rms = np.nan_to_num(feature_rms, 1) y_mean = torch.sum(y) / torch.sum(X) - mean, std = base_gp( - X, - y - (torch.sum(X, axis=1) * y_mean).reshape(y.shape), - NormalizedDotProduct, - {"diagonal_elements": feature_rms}, - **kwargs, - ) - return mean + y_mean, std - - -def gp(X, y, **kwargs): - return base_gp( - X, y, DotProduct, {"sigma_0": 0, "sigma_0_bounds": "fixed"}, **kwargs - ) - - -def base_gp( - X, - y, - kernel, - kernel_kwargs, - alpha: Optional[float] = 0.1, - max_iteration: int = 20, - stride: Optional[int] = None, + + feature_rms = torch.sqrt(torch.mean(X**2, axis=0)) + + alpha_mat = torch.diag(feature_rms) * alpha * alpha + + A = matmul(X.T, X) + alpha_mat + dy = y - (torch.sum(X, axis=1, keepdim=True) * y_mean).reshape(y.shape) + Xy = matmul(X.T, dy) + + mean = solve(A, Xy) + + sigma2 = torch.var(matmul(X, mean) - dy) + Ainv = inv(A) + cov = torch.sqrt(sigma2 * contract("ij,kj,kl,li->i", Ainv, X, X, Ainv)) + + mean = mean + y_mean.reshape([-1]) + + logging.debug(f"Ridge Regression, residue {sigma2}") + + return mean, cov + + +def down_sampling_by_composition( + X: torch.Tensor, y: torch.Tensor, percentage: Sequence = [0.25, 0.5, 0.75] ): - if len(y.shape) == 1: - y = y.reshape([-1, 1]) - - if stride is not None: - X = X[::stride] - y = y[::stride] - - not_fit = True - iteration = 0 - mean = None - std = None - while not_fit: - logging.debug(f"GP fitting iteration {iteration} {alpha}") - try: - _kernel = kernel(**kernel_kwargs) - gpr = GaussianProcessRegressor(kernel=_kernel, random_state=0, alpha=alpha) - gpr = gpr.fit(X, y) - - vec = torch.diag(torch.ones(X.shape[1])) - mean, std = gpr.predict(vec, return_std=True) - - mean = torch.as_tensor(mean, dtype=torch.get_default_dtype()).reshape([-1]) - # ignore all the off-diagonal terms - std = torch.as_tensor(std, dtype=torch.get_default_dtype()).reshape([-1]) - likelihood = gpr.log_marginal_likelihood() - - res = torch.sqrt( - torch.square(torch.matmul(X, mean.reshape([-1, 1])) - y).mean() - ) - - logging.debug( - f"GP fitting: alpha {alpha}:\n" - f" residue {res}\n" - f" mean {mean} std {std}\n" - f" log marginal likelihood {likelihood}" - ) - not_fit = False - - except Exception as e: - logging.info(f"GP fitting failed for alpha={alpha} and {e.args}") - if alpha == 0 or alpha is None: - logging.info("try a non-zero alpha") - not_fit = False - raise ValueError( - f"Please set the {alpha} to non-zero value. \n" - "The dataset energy is rank deficient to be solved with GP" - ) - else: - alpha = alpha * 2 - iteration += 1 - logging.debug(f" increase alpha to {alpha}") - - if iteration >= max_iteration or not_fit is False: - raise ValueError( - "Please set the per species shift and scale to zeros and ones. \n" - "The dataset energy is to diverge to be solved with GP" - ) - - return mean, std - - -class NormalizedDotProduct(Kernel): - r"""Dot-Product kernel. - .. math:: - k(x_i, x_j) = x_i \cdot A \cdot x_j - """ - - def __init__(self, diagonal_elements): - # TO DO: check shape - self.diagonal_elements = diagonal_elements - self.A = np.diag(diagonal_elements) - - def __call__(self, X, Y=None, eval_gradient=False): - """Return the kernel k(X, Y) and optionally its gradient. - Parameters - ---------- - X : ndarray of shape (n_samples_X, n_features) - Left argument of the returned kernel k(X, Y) - Y : ndarray of shape (n_samples_Y, n_features), default=None - Right argument of the returned kernel k(X, Y). If None, k(X, X) - if evaluated instead. - eval_gradient : bool, default=False - Determines whether the gradient with respect to the log of - the kernel hyperparameter is computed. - Only supported when Y is None. - Returns - ------- - K : ndarray of shape (n_samples_X, n_samples_Y) - Kernel k(X, Y) - K_gradient : ndarray of shape (n_samples_X, n_samples_X, n_dims),\ - optional - The gradient of the kernel k(X, X) with respect to the log of the - hyperparameter of the kernel. Only returned when `eval_gradient` - is True. - """ - X = np.atleast_2d(X) - if Y is None: - K = (X.dot(self.A)).dot(X.T) - else: - if eval_gradient: - raise ValueError("Gradient can only be evaluated when Y is None.") - K = (X.dot(self.A)).dot(Y.T) - - if eval_gradient: - return K, np.empty((X.shape[0], X.shape[0], 0)) - else: - return K - - def diag(self, X): - """Returns the diagonal of the kernel k(X, X). - The result of this method is identical to np.diag(self(X)); however, - it can be evaluated more efficiently since only the diagonal is - evaluated. - Parameters - ---------- - X : ndarray of shape (n_samples_X, n_features) - Left argument of the returned kernel k(X, Y). - Returns - ------- - K_diag : ndarray of shape (n_samples_X,) - Diagonal of kernel k(X, X). - """ - return np.einsum("ij,ij,jj->i", X, X, self.A) - - def __repr__(self): - return "" - - def is_stationary(self): - """Returns whether the kernel is stationary.""" - return False - - @property - def hyperparameter_diagonal_elements(self): - return Hyperparameter("diagonal_elements", "numeric", "fixed") + unique_comps, comp_ids = torch.unique(X, dim=0, return_inverse=True) + + n_types = torch.max(comp_ids) + 1 + + sort_by = torch.argsort(comp_ids) + + # find out the block for each composition + d_icomp = comp_ids[sort_by] + d_icomp = d_icomp[:-1] - d_icomp[1:] + node_icomp = torch.where(d_icomp != 0)[0] + id_start = torch.cat((torch.as_tensor([0]), node_icomp + 1)) + id_end = torch.cat((node_icomp + 1, torch.as_tensor([len(sort_by)]))) + + n_points = len(percentage) + new_X = torch.zeros((n_types * n_points, X.shape[1])) + new_y = torch.zeros((n_types * n_points)) + for i in range(n_types): + ids = sort_by[id_start[i] : id_end[i]] + for j, p in enumerate(percentage): + new_y[i * n_points + j] = torch.quantile(y[ids], p, interpolation="linear") + new_X[i * n_points + j] = unique_comps[i] + + return new_X, new_y diff --git a/nequip/utils/unittests/conftest.py b/nequip/utils/unittests/conftest.py index 060e5e7b..77a91930 100644 --- a/nequip/utils/unittests/conftest.py +++ b/nequip/utils/unittests/conftest.py @@ -133,5 +133,22 @@ def atomic_batch(nequip_dataset): return Batch.from_data_list([nequip_dataset[0], nequip_dataset[1]]) +@pytest.fixture(scope="function") +def per_species_set(): + dtype = torch.get_default_dtype() + torch.manual_seed(0) + mean_min = 1 + mean_max = 100 + std = 20 + n_sample = 1000 + n_species = 9 + ref_mean = torch.rand((n_species)) * (mean_max - mean_min) + mean_min + t_mean = torch.ones((n_sample, 1)) * ref_mean.reshape([1, -1]) + ref_std = torch.rand((n_species)) * std + t_std = torch.ones((n_sample, 1)) * ref_std.reshape([1, -1]) + E = torch.normal(t_mean, t_std) + return ref_mean.to(dtype), ref_std.to(dtype), E.to(dtype), n_sample, n_species + + # Use debug mode set_irreps_debug(True) diff --git a/setup.py b/setup.py index 8c977e0a..cba6b51f 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,6 @@ "typing_extensions;python_version<'3.8'", # backport of Final "torch-runstats>=0.2.0", "torch-ema>=0.3.0", - "scikit_learn<=1.0.1", # for GaussianProcess for per-species statistics; 1.0.2 has a bug! ], zip_safe=True, ) diff --git a/tests/unit/data/test_dataset.py b/tests/unit/data/test_dataset.py index f45e0ca8..bad796c3 100644 --- a/tests/unit/data/test_dataset.py +++ b/tests/unit/data/test_dataset.py @@ -31,7 +31,7 @@ def ase_file(molecules): MAX_ATOMIC_NUMBER: int = 5 -NATOMS = 3 +NATOMS = 10 @pytest.fixture(scope="function") @@ -277,16 +277,11 @@ def test_per_node_field(self, npz_dataset, fixed_field, mode, subset): ) print(result) - @pytest.mark.parametrize("alpha", [1e-5, 1e-3, 0.1, 0.5]) + @pytest.mark.parametrize("alpha", [0, 1e-3, 0.01]) @pytest.mark.parametrize("fixed_field", [True, False]) @pytest.mark.parametrize("full_rank", [True, False]) @pytest.mark.parametrize("subset", [True, False]) - @pytest.mark.parametrize( - "regressor", ["NormalizedGaussianProcess", "GaussianProcess"] - ) - def test_per_graph_field( - self, npz_dataset, alpha, fixed_field, full_rank, regressor, subset - ): + def test_per_graph_field(self, npz_dataset, alpha, fixed_field, full_rank, subset): if alpha <= 1e-4 and not full_rank: return @@ -308,10 +303,7 @@ def test_per_graph_field( del n_spec del Ns - if alpha == 1e-5: - ref_mean, ref_std, E = generate_E(N, 100, 1000, 0.0) - else: - ref_mean, ref_std, E = generate_E(N, 100, 1000, 0.5) + ref_mean, ref_std, E = generate_E(N, 100, 1000, 10) if subset: E_orig_order = torch.zeros_like( @@ -333,7 +325,6 @@ def test_per_graph_field( AtomicDataDict.TOTAL_ENERGY_KEY + "per_species_mean_std": { "alpha": alpha, - "regressor": regressor, "stride": 1, } }, @@ -341,21 +332,18 @@ def test_per_graph_field( res = torch.matmul(N, mean.reshape([-1, 1])) - E.reshape([-1, 1]) res2 = torch.sum(torch.square(res)) - print("residue", alpha, res2 - ref_res2) + print("alpha, residue, actual residue", alpha, res2, ref_res2) print("mean", mean, ref_mean) print("diff in mean", mean - ref_mean) print("std", std, ref_std) + tolerance = torch.max(ref_std) * 4 if full_rank: - if alpha == 1e-5: - assert torch.allclose(mean, ref_mean, rtol=1e-1) - else: - assert torch.allclose(mean, ref_mean, rtol=1) - assert torch.allclose(std, torch.zeros_like(ref_mean), atol=alpha * 100) - elif regressor == "NormalizedGaussianProcess": - assert torch.std(mean).numpy() == 0 + assert torch.allclose(mean, ref_mean, atol=tolerance) + # assert torch.allclose(std, torch.zeros_like(ref_mean), atol=alpha * 100) else: - assert mean[0] == mean[1] * 2 + assert torch.allclose(mean, mean[0], atol=tolerance) + # assert torch.std(mean).numpy() == 0 class TestReload: diff --git a/tests/unit/utils/test_gp.py b/tests/unit/utils/test_gp.py deleted file mode 100644 index 4792b9d2..00000000 --- a/tests/unit/utils/test_gp.py +++ /dev/null @@ -1,37 +0,0 @@ -import torch -import pytest - -from nequip.utils.regressor import base_gp -from sklearn.gaussian_process.kernels import DotProduct - - -# @pytest.mark.parametrize("full_rank", [True, False]) -@pytest.mark.parametrize("full_rank", [False]) -@pytest.mark.parametrize("alpha", [0, 1e-3, 0.1, 1]) -def test_random(full_rank, alpha): - - if alpha == 0 and not full_rank: - return - - torch.manual_seed(0) - n_samples = 10 - n_dim = 3 - - if full_rank: - X = torch.randint(low=1, high=10, size=(n_samples, n_dim)) - else: - X = torch.randint(low=1, high=10, size=(n_samples, 1)) * torch.ones( - (n_samples, n_dim) - ) - - ref_mean = torch.rand((n_dim, 1)) - y = torch.matmul(X, ref_mean) - - mean, std = base_gp( - X, y, DotProduct, {"sigma_0": 0, "sigma_0_bounds": "fixed"}, alpha=0.1 - ) - - if full_rank: - assert torch.allclose(ref_mean, mean, rtol=0.5) - else: - assert torch.allclose(mean, mean[0], rtol=1e-3) diff --git a/tests/unit/utils/test_solver.py b/tests/unit/utils/test_solver.py new file mode 100644 index 00000000..049c897d --- /dev/null +++ b/tests/unit/utils/test_solver.py @@ -0,0 +1,38 @@ +import torch +import pytest + +from nequip.utils.regressor import solver + + +@pytest.mark.parametrize("full_rank", [True, False]) +@pytest.mark.parametrize("alpha", [0, 1e-3, 1e-2]) +def test_random(full_rank, alpha, per_species_set): + + if alpha == 0 and not full_rank: + return + + torch.manual_seed(0) + + ref_mean, ref_std, E, n_samples, n_dim = per_species_set + + dtype = torch.get_default_dtype() + + X = torch.randint(low=1, high=10, size=(n_samples, n_dim)).to(dtype) + if not full_rank: + X[:, n_dim - 2] = X[:, n_dim - 1] * 2 + y = (X * E).sum(axis=-1) + + mean, std = solver(X, y, alpha=alpha) + + tolerance = torch.max(ref_std) + + print("tolerance", tolerance) + print("solution", mean, std) + print("diff", mean - ref_mean) + + if full_rank: + assert torch.allclose(ref_mean, mean, atol=tolerance) + else: + assert torch.allclose(mean[n_dim - 1], mean[n_dim - 2], atol=tolerance) + + assert torch.max(std) < tolerance