1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414from functools import partial
15+ from typing import Optional
1516
1617import numpy as np
1718import pymc as pm
@@ -31,40 +32,83 @@ class LatentApprox(pm.gp.Latent):
3132class ProjectedProcess (pm .gp .Latent ):
3233 ## AKA: DTC
3334 def __init__ (
34- self , n_inducing , * , mean_func = pm .gp .mean .Zero (), cov_func = pm .gp .cov .Constant (0.0 )
35+ self ,
36+ n_inducing : Optional [int ] = None ,
37+ * ,
38+ mean_func = pm .gp .mean .Zero (),
39+ cov_func = pm .gp .cov .Constant (0.0 ),
3540 ):
3641 self .n_inducing = n_inducing
3742 super ().__init__ (mean_func = mean_func , cov_func = cov_func )
3843
39- def _build_prior (self , name , X , Xu , jitter = JITTER_DEFAULT , ** kwargs ):
44+ def _build_prior (self , name , X , X_inducing , jitter = JITTER_DEFAULT , ** kwargs ):
4045 mu = self .mean_func (X )
41- Kuu = self .cov_func (Xu )
46+ Kuu = self .cov_func (X_inducing )
4247 L = cholesky (stabilize (Kuu , jitter ))
4348
44- n_inducing_points = np .shape (Xu )[0 ]
49+ n_inducing_points = np .shape (X_inducing )[0 ]
4550 v = pm .Normal (name + "_u_rotated_" , mu = 0.0 , sigma = 1.0 , size = n_inducing_points , ** kwargs )
4651 u = pm .Deterministic (name + "_u" , L @ v )
4752
48- Kfu = self .cov_func (X , Xu )
53+ Kfu = self .cov_func (X , X_inducing )
4954 Kuuiu = solve_upper (pt .transpose (L ), solve_lower (L , u ))
5055
5156 return pm .Deterministic (name , mu + Kfu @ Kuuiu ), Kuuiu , L
5257
53- def prior (self , name , X , Xu = None , jitter = JITTER_DEFAULT , ** kwargs ):
54- if Xu is None and self .n_inducing is None :
55- raise ValueError
56- elif Xu is None :
57- if isinstance (X , np .ndarray ):
58- Xu = pm .gp .util .kmeans_inducing_points (self .n_inducing , X , ** kwargs )
58+ def prior (
59+ self ,
60+ name : str ,
61+ X : np .ndarray ,
62+ X_inducing : Optional [np .ndarray ] = None ,
63+ jitter : float = JITTER_DEFAULT ,
64+ ** kwargs ,
65+ ) -> np .ndarray :
66+ """
67+ Builds the GP prior with optional inducing points locations.
68+
69+ Parameters:
70+ - name: Name for the GP variable.
71+ - X: Input data.
72+ - X_inducing: Optional. Inducing points for the GP.
73+ - jitter: Jitter to ensure numerical stability.
74+
75+ Returns:
76+ - GP function
77+ """
78+ # Check if X is a numpy array
79+ if not isinstance (X , np .ndarray ):
80+ raise ValueError ("'X' must be a numpy array." )
81+
82+ # Proceed with provided X_inducing or determine X_inducing based on n_inducing
83+ if X_inducing is not None :
84+ pass # X_inducing is directly used
85+
86+ elif self .n_inducing is not None :
87+ # Validate n_inducing
88+ if not isinstance (self .n_inducing , int ) or self .n_inducing <= 0 :
89+ raise ValueError (
90+ "The number of inducing points, 'n_inducing', must be a positive integer."
91+ )
92+ if self .n_inducing > len (X ):
93+ raise ValueError (
94+ "The number of inducing points, 'n_inducing', cannot be greater than the number of data points in 'X'."
95+ )
96+ # Use k-means to select X_inducing from X based on n_inducing
97+ X_inducing = pm .gp .util .kmeans_inducing_points (self .n_inducing , X , ** kwargs )
98+ else :
99+ # Neither X_inducing nor n_inducing provided
100+ raise ValueError (
101+ "Either 'X_inducing' (inducing points) or 'n_inducing' (number of inducing points) must be specified."
102+ )
59103
60- f , Kuuiu , L = self ._build_prior (name , X , Xu , jitter , ** kwargs )
61- self .X , self .Xu = X , Xu
104+ f , Kuuiu , L = self ._build_prior (name , X , X_inducing , jitter , ** kwargs )
105+ self .X , self .X_inducing = X , X_inducing
62106 self .L , self .Kuuiu = L , Kuuiu
63107 self .f = f
64108 return f
65109
66- def _build_conditional (self , name , Xnew , Xu , L , Kuuiu , jitter , ** kwargs ):
67- Ksu = self .cov_func (Xnew , Xu )
110+ def _build_conditional (self , name , Xnew , X_inducing , L , Kuuiu , jitter , ** kwargs ):
111+ Ksu = self .cov_func (Xnew , X_inducing )
68112 mu = self .mean_func (Xnew ) + Ksu @ Kuuiu
69113 tmp = solve_lower (L , pt .transpose (Ksu ))
70114 Qss = pt .transpose (tmp ) @ tmp # Qss = tt.dot(tt.dot(Ksu, tt.nlinalg.pinv(Kuu)), Ksu.T)
@@ -74,7 +118,7 @@ def _build_conditional(self, name, Xnew, Xu, L, Kuuiu, jitter, **kwargs):
74118
75119 def conditional (self , name , Xnew , jitter = 1e-6 , ** kwargs ):
76120 mu , chol = self ._build_conditional (
77- name , Xnew , self .Xu , self .L , self .Kuuiu , jitter , ** kwargs
121+ name , Xnew , self .X_inducing , self .L , self .Kuuiu , jitter , ** kwargs
78122 )
79123 return pm .MvNormal (name , mu = mu , chol = chol )
80124
0 commit comments