Skip to content

Commit 81b4269

Browse files
committed
Added gmm.score python function
1 parent 0f157b1 commit 81b4269

File tree

2 files changed

+81
-0
lines changed

2 files changed

+81
-0
lines changed

python/gmmmodule.c

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,85 @@ GMM_fit(GMMObject *self, PyObject *args, PyObject *keywds)
134134
return Py_BuildValue("");
135135
}
136136

137+
static PyObject *
138+
GMM_score(GMMObject *self, PyObject *args, PyObject *keywds)
139+
{
140+
// Parse the arguments
141+
static char *kwlist[] = {"X",
142+
NULL};
143+
PyObject *X_obj;
144+
if (!PyArg_ParseTupleAndKeywords(args, keywds, "O", kwlist, &X_obj))
145+
{
146+
return NULL;
147+
}
148+
149+
// Get data matrix from numpy array
150+
PyArrayObject *X_array = (PyArrayObject *) PyArray_ContiguousFromObject(X_obj, PyArray_DOUBLE, 2, 2);
151+
if (X_array == NULL)
152+
{
153+
printf("Data matrix (X) in bad format.\n");
154+
return NULL;
155+
}
156+
if (PyArray_NDIM(X_array) != 2)
157+
{
158+
printf("Data matrix (X) must be a 2D matrix.\n");
159+
return NULL;
160+
}
161+
int N = (int) PyArray_DIM(X_array, 0);
162+
int D = (int) PyArray_DIM(X_array, 1);
163+
if (D != PyArray_DIM(self->means, 1))
164+
{
165+
printf("Invalid dimensions of data matrix X.\n");
166+
return NULL;
167+
}
168+
double **X = malloc(N*sizeof(double *));
169+
for (int t=0; t<N; t++)
170+
X[t] = (double *) X_array->data + D*t;
171+
172+
// Initialize GMM from parameters
173+
GMM *gmm = malloc(sizeof(GMM));
174+
gmm->M = self->k;
175+
gmm->D = D;
176+
int covar_len = 0;
177+
const char *cov_type = PyString_AsString(self->cov_type);
178+
if (strcmp(cov_type, "spherical") == 0)
179+
{
180+
covar_len = 1;
181+
gmm->cov_type = SPHERICAL;
182+
}
183+
else if (strcmp(cov_type, "diagonal") == 0)
184+
{
185+
covar_len = D;
186+
gmm->cov_type = DIAGONAL;
187+
}
188+
double **means = malloc(self->k*sizeof(double *));
189+
double **covars = malloc(self->k*sizeof(double *));
190+
for (int k=0; k<self->k; k++)
191+
{
192+
means[k] = (double *) self->means->data + k*D;
193+
covars[k] = (double *) self->covars->data + k*covar_len;
194+
}
195+
gmm->weights = (double *) self->weights->data;
196+
gmm->means = means;
197+
gmm->covars = covars;
198+
199+
// Score the data points
200+
double llh = gmm_score(gmm, X, N);
201+
202+
// Free the GMM object
203+
free(gmm);
204+
205+
// Free the parameter arrays
206+
free(means);
207+
free(covars);
208+
209+
// Free the data matrix and pyobjects
210+
free(X);
211+
Py_DECREF(X_array);
212+
213+
return Py_BuildValue("");
214+
}
215+
137216
static PyMemberDef GMM_members[] = {
138217
{"weights", T_OBJECT, offsetof(GMMObject, weights), 0, "Component weights"},
139218
{"means", T_OBJECT, offsetof(GMMObject, means), 0, "Component means"},
@@ -149,6 +228,7 @@ static PyMemberDef GMM_members[] = {
149228

150229
static PyMethodDef GMM_methods[] = {
151230
{"fit", (PyCFunction) GMM_fit, METH_VARARGS | METH_KEYWORDS, "Fit the GMM on the data."},
231+
{"score", (PyCFunction) GMM_score, METH_VARARGS | METH_KEYWORDS, "Scores the data using the GMM."},
152232
{NULL}
153233
};
154234

python/test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
st = time.time()
1111
g1 = gmm.GMM(k=3, CovType='diagonal', InitMethod='kmeans')
1212
g1.fit(X)
13+
g1.score(X)
1314
en = time.time()
1415
print 'time1 = ' + str(en-st) + ' s'
1516

0 commit comments

Comments
 (0)