12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
from functools import partial
15
+ from typing import Optional
15
16
16
17
import numpy as np
17
18
import pymc as pm
@@ -31,40 +32,83 @@ class LatentApprox(pm.gp.Latent):
31
32
class ProjectedProcess (pm .gp .Latent ):
32
33
## AKA: DTC
33
34
def __init__ (
34
- self , n_inducing , * , mean_func = pm .gp .mean .Zero (), cov_func = pm .gp .cov .Constant (0.0 )
35
+ self ,
36
+ n_inducing : Optional [int ] = None ,
37
+ * ,
38
+ mean_func = pm .gp .mean .Zero (),
39
+ cov_func = pm .gp .cov .Constant (0.0 ),
35
40
):
36
41
self .n_inducing = n_inducing
37
42
super ().__init__ (mean_func = mean_func , cov_func = cov_func )
38
43
39
- def _build_prior (self , name , X , Xu , jitter = JITTER_DEFAULT , ** kwargs ):
44
+ def _build_prior (self , name , X , X_inducing , jitter = JITTER_DEFAULT , ** kwargs ):
40
45
mu = self .mean_func (X )
41
- Kuu = self .cov_func (Xu )
46
+ Kuu = self .cov_func (X_inducing )
42
47
L = cholesky (stabilize (Kuu , jitter ))
43
48
44
- n_inducing_points = np .shape (Xu )[0 ]
49
+ n_inducing_points = np .shape (X_inducing )[0 ]
45
50
v = pm .Normal (name + "_u_rotated_" , mu = 0.0 , sigma = 1.0 , size = n_inducing_points , ** kwargs )
46
51
u = pm .Deterministic (name + "_u" , L @ v )
47
52
48
- Kfu = self .cov_func (X , Xu )
53
+ Kfu = self .cov_func (X , X_inducing )
49
54
Kuuiu = solve_upper (pt .transpose (L ), solve_lower (L , u ))
50
55
51
56
return pm .Deterministic (name , mu + Kfu @ Kuuiu ), Kuuiu , L
52
57
53
- def prior (self , name , X , Xu = None , jitter = JITTER_DEFAULT , ** kwargs ):
54
- if Xu is None and self .n_inducing is None :
55
- raise ValueError
56
- elif Xu is None :
57
- if isinstance (X , np .ndarray ):
58
- Xu = pm .gp .util .kmeans_inducing_points (self .n_inducing , X , ** kwargs )
58
+ def prior (
59
+ self ,
60
+ name : str ,
61
+ X : np .ndarray ,
62
+ X_inducing : Optional [np .ndarray ] = None ,
63
+ jitter : float = JITTER_DEFAULT ,
64
+ ** kwargs ,
65
+ ) -> np .ndarray :
66
+ """
67
+ Builds the GP prior with optional inducing points locations.
68
+
69
+ Parameters:
70
+ - name: Name for the GP variable.
71
+ - X: Input data.
72
+ - X_inducing: Optional. Inducing points for the GP.
73
+ - jitter: Jitter to ensure numerical stability.
74
+
75
+ Returns:
76
+ - GP function
77
+ """
78
+ # Check if X is a numpy array
79
+ if not isinstance (X , np .ndarray ):
80
+ raise ValueError ("'X' must be a numpy array." )
81
+
82
+ # Proceed with provided X_inducing or determine X_inducing based on n_inducing
83
+ if X_inducing is not None :
84
+ pass # X_inducing is directly used
85
+
86
+ elif self .n_inducing is not None :
87
+ # Validate n_inducing
88
+ if not isinstance (self .n_inducing , int ) or self .n_inducing <= 0 :
89
+ raise ValueError (
90
+ "The number of inducing points, 'n_inducing', must be a positive integer."
91
+ )
92
+ if self .n_inducing > len (X ):
93
+ raise ValueError (
94
+ "The number of inducing points, 'n_inducing', cannot be greater than the number of data points in 'X'."
95
+ )
96
+ # Use k-means to select X_inducing from X based on n_inducing
97
+ X_inducing = pm .gp .util .kmeans_inducing_points (self .n_inducing , X , ** kwargs )
98
+ else :
99
+ # Neither X_inducing nor n_inducing provided
100
+ raise ValueError (
101
+ "Either 'X_inducing' (inducing points) or 'n_inducing' (number of inducing points) must be specified."
102
+ )
59
103
60
- f , Kuuiu , L = self ._build_prior (name , X , Xu , jitter , ** kwargs )
61
- self .X , self .Xu = X , Xu
104
+ f , Kuuiu , L = self ._build_prior (name , X , X_inducing , jitter , ** kwargs )
105
+ self .X , self .X_inducing = X , X_inducing
62
106
self .L , self .Kuuiu = L , Kuuiu
63
107
self .f = f
64
108
return f
65
109
66
- def _build_conditional (self , name , Xnew , Xu , L , Kuuiu , jitter , ** kwargs ):
67
- Ksu = self .cov_func (Xnew , Xu )
110
+ def _build_conditional (self , name , Xnew , X_inducing , L , Kuuiu , jitter , ** kwargs ):
111
+ Ksu = self .cov_func (Xnew , X_inducing )
68
112
mu = self .mean_func (Xnew ) + Ksu @ Kuuiu
69
113
tmp = solve_lower (L , pt .transpose (Ksu ))
70
114
Qss = pt .transpose (tmp ) @ tmp # Qss = tt.dot(tt.dot(Ksu, tt.nlinalg.pinv(Kuu)), Ksu.T)
@@ -74,7 +118,7 @@ def _build_conditional(self, name, Xnew, Xu, L, Kuuiu, jitter, **kwargs):
74
118
75
119
def conditional (self , name , Xnew , jitter = 1e-6 , ** kwargs ):
76
120
mu , chol = self ._build_conditional (
77
- name , Xnew , self .Xu , self .L , self .Kuuiu , jitter , ** kwargs
121
+ name , Xnew , self .X_inducing , self .L , self .Kuuiu , jitter , ** kwargs
78
122
)
79
123
return pm .MvNormal (name , mu = mu , chol = chol )
80
124
0 commit comments