Skip to content

Commit ee07caa

Browse files
AlexAndorraricardoV94
authored andcommitted
Improve API for ProjectedProcess GP
1 parent 47c4b48 commit ee07caa

File tree

1 file changed

+60
-16
lines changed

1 file changed

+60
-16
lines changed

pymc_experimental/gp/latent_approx.py

+60-16
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from functools import partial
15+
from typing import Optional
1516

1617
import numpy as np
1718
import pymc as pm
@@ -31,40 +32,83 @@ class LatentApprox(pm.gp.Latent):
3132
class ProjectedProcess(pm.gp.Latent):
3233
## AKA: DTC
3334
def __init__(
34-
self, n_inducing, *, mean_func=pm.gp.mean.Zero(), cov_func=pm.gp.cov.Constant(0.0)
35+
self,
36+
n_inducing: Optional[int] = None,
37+
*,
38+
mean_func=pm.gp.mean.Zero(),
39+
cov_func=pm.gp.cov.Constant(0.0),
3540
):
3641
self.n_inducing = n_inducing
3742
super().__init__(mean_func=mean_func, cov_func=cov_func)
3843

39-
def _build_prior(self, name, X, Xu, jitter=JITTER_DEFAULT, **kwargs):
44+
def _build_prior(self, name, X, X_inducing, jitter=JITTER_DEFAULT, **kwargs):
4045
mu = self.mean_func(X)
41-
Kuu = self.cov_func(Xu)
46+
Kuu = self.cov_func(X_inducing)
4247
L = cholesky(stabilize(Kuu, jitter))
4348

44-
n_inducing_points = np.shape(Xu)[0]
49+
n_inducing_points = np.shape(X_inducing)[0]
4550
v = pm.Normal(name + "_u_rotated_", mu=0.0, sigma=1.0, size=n_inducing_points, **kwargs)
4651
u = pm.Deterministic(name + "_u", L @ v)
4752

48-
Kfu = self.cov_func(X, Xu)
53+
Kfu = self.cov_func(X, X_inducing)
4954
Kuuiu = solve_upper(pt.transpose(L), solve_lower(L, u))
5055

5156
return pm.Deterministic(name, mu + Kfu @ Kuuiu), Kuuiu, L
5257

53-
def prior(self, name, X, Xu=None, jitter=JITTER_DEFAULT, **kwargs):
54-
if Xu is None and self.n_inducing is None:
55-
raise ValueError
56-
elif Xu is None:
57-
if isinstance(X, np.ndarray):
58-
Xu = pm.gp.util.kmeans_inducing_points(self.n_inducing, X, **kwargs)
58+
def prior(
59+
self,
60+
name: str,
61+
X: np.ndarray,
62+
X_inducing: Optional[np.ndarray] = None,
63+
jitter: float = JITTER_DEFAULT,
64+
**kwargs,
65+
) -> np.ndarray:
66+
"""
67+
Builds the GP prior with optional inducing points locations.
68+
69+
Parameters:
70+
- name: Name for the GP variable.
71+
- X: Input data.
72+
- X_inducing: Optional. Inducing points for the GP.
73+
- jitter: Jitter to ensure numerical stability.
74+
75+
Returns:
76+
- GP function
77+
"""
78+
# Check if X is a numpy array
79+
if not isinstance(X, np.ndarray):
80+
raise ValueError("'X' must be a numpy array.")
81+
82+
# Proceed with provided X_inducing or determine X_inducing based on n_inducing
83+
if X_inducing is not None:
84+
pass # X_inducing is directly used
85+
86+
elif self.n_inducing is not None:
87+
# Validate n_inducing
88+
if not isinstance(self.n_inducing, int) or self.n_inducing <= 0:
89+
raise ValueError(
90+
"The number of inducing points, 'n_inducing', must be a positive integer."
91+
)
92+
if self.n_inducing > len(X):
93+
raise ValueError(
94+
"The number of inducing points, 'n_inducing', cannot be greater than the number of data points in 'X'."
95+
)
96+
# Use k-means to select X_inducing from X based on n_inducing
97+
X_inducing = pm.gp.util.kmeans_inducing_points(self.n_inducing, X, **kwargs)
98+
else:
99+
# Neither X_inducing nor n_inducing provided
100+
raise ValueError(
101+
"Either 'X_inducing' (inducing points) or 'n_inducing' (number of inducing points) must be specified."
102+
)
59103

60-
f, Kuuiu, L = self._build_prior(name, X, Xu, jitter, **kwargs)
61-
self.X, self.Xu = X, Xu
104+
f, Kuuiu, L = self._build_prior(name, X, X_inducing, jitter, **kwargs)
105+
self.X, self.X_inducing = X, X_inducing
62106
self.L, self.Kuuiu = L, Kuuiu
63107
self.f = f
64108
return f
65109

66-
def _build_conditional(self, name, Xnew, Xu, L, Kuuiu, jitter, **kwargs):
67-
Ksu = self.cov_func(Xnew, Xu)
110+
def _build_conditional(self, name, Xnew, X_inducing, L, Kuuiu, jitter, **kwargs):
111+
Ksu = self.cov_func(Xnew, X_inducing)
68112
mu = self.mean_func(Xnew) + Ksu @ Kuuiu
69113
tmp = solve_lower(L, pt.transpose(Ksu))
70114
Qss = pt.transpose(tmp) @ tmp # Qss = tt.dot(tt.dot(Ksu, tt.nlinalg.pinv(Kuu)), Ksu.T)
@@ -74,7 +118,7 @@ def _build_conditional(self, name, Xnew, Xu, L, Kuuiu, jitter, **kwargs):
74118

75119
def conditional(self, name, Xnew, jitter=1e-6, **kwargs):
76120
mu, chol = self._build_conditional(
77-
name, Xnew, self.Xu, self.L, self.Kuuiu, jitter, **kwargs
121+
name, Xnew, self.X_inducing, self.L, self.Kuuiu, jitter, **kwargs
78122
)
79123
return pm.MvNormal(name, mu=mu, chol=chol)
80124

0 commit comments

Comments
 (0)