forked from GPflow/GPflow
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix covariance overparameterisation (GPflow#150)
GPflow often optimizes positive-definite matrices. To maintain positive-definiteness without constrained optimization, a lower-triangular matrix is optimized. Sigma + L L ^T The previous approach to optimizing L was to ignore the upper half. The mean that there were some extra variables in the optimization vector, which did nothing. This PR implements a tensorflow op which transforms back-and-forth between triangular matrix L and a 'packed' vector representation. The result is that there are no redundant parameters in the optimization vector.
- Loading branch information
1 parent
9c2faf6
commit 5726dd4
Showing
11 changed files
with
437 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
// Copyright 2016 Mark van der Wilk | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include <cmath> | ||
#include "tensorflow/core/framework/op.h" | ||
#include "tensorflow/core/framework/op_kernel.h" | ||
#include "tensorflow/core/framework/register_types.h" | ||
|
||
|
||
REGISTER_OP("TriToVec") | ||
.Attr("T: realnumbertype") | ||
.Input("trimat: T") | ||
.Output("vec: T") | ||
.Doc(R"doc( | ||
Converts a series of triangular matrices to a series of vectors (i.e. a matrix). | ||
If the input is D x N x N, then the output is D x M, where the lower | ||
triangle of each N x N matrix has been packed into an M-vector. | ||
)doc"); | ||
|
||
using namespace tensorflow; | ||
|
||
template <typename T> | ||
class TriToVecOp : public OpKernel { | ||
public: | ||
explicit TriToVecOp(OpKernelConstruction* context) : OpKernel(context) {} | ||
|
||
void Compute(OpKernelContext* context) override { | ||
// Grab the input tensor | ||
const Tensor& input_tensor = context->input(0); | ||
|
||
const TensorShape& input_shape = input_tensor.shape(); | ||
const int rank = input_shape.dims(); | ||
|
||
// For now, keep it as just a matrix | ||
OP_REQUIRES(context, rank == 3, | ||
errors::InvalidArgument("TriToVec expects a rank-3 tensor, received shape: ", | ||
input_shape.DebugString())); | ||
|
||
const int k = input_shape.dim_size(rank - 1); // Matrix size | ||
OP_REQUIRES(context, k == input_shape.dim_size(rank - 2), | ||
errors::InvalidArgument("input's last two dimensions must be equal, received shape: ", | ||
input_shape.DebugString())); | ||
|
||
auto f = input_tensor.flat_inner_dims<T, 3>(); | ||
|
||
// Create an output tensor | ||
TensorShape out_shape({input_shape.dim_size(rank - 3), k * (k+1) / 2}); | ||
Tensor* output_tensor = NULL; | ||
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, | ||
&output_tensor)); | ||
|
||
|
||
auto output = output_tensor->template flat<T>(); | ||
int i = 0; | ||
for (int z = 0; z != f.dimension(0); z++) { | ||
for (int y = 0; y != f.dimension(1); y++) { | ||
for (int x = 0; x != f.dimension(2); x++) { | ||
if (y >= x) { | ||
output(i) = f(z, y, x); | ||
i++; | ||
} | ||
} | ||
} | ||
} | ||
} | ||
}; | ||
|
||
#define REGISTER_KERNEL(type) \ | ||
REGISTER_KERNEL_BUILDER( \ | ||
Name("TriToVec") \ | ||
.Device(DEVICE_CPU) \ | ||
.TypeConstraint<type>("T"), \ | ||
TriToVecOp<type> \ | ||
); | ||
|
||
|
||
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL); | ||
|
||
#undef REGISTER_KERNEL |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
// Copyright 2016 Mark van der Wilk | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include <cmath> | ||
#include "tensorflow/core/framework/op.h" | ||
#include "tensorflow/core/framework/op_kernel.h" | ||
#include "tensorflow/core/framework/register_types.h" | ||
|
||
|
||
REGISTER_OP("VecToTri") | ||
.Attr("T: realnumbertype") | ||
.Input("vec: T") | ||
.Output("matrix: T") | ||
.Doc(R"doc( | ||
Converts a matrix into a series of triangular matrices. | ||
If the input is D x M, then the output is D x N x N, where the lower | ||
triangle of each N x N matrix is constructed by unpacking each M-vector. | ||
See also: TriToVec. | ||
)doc"); | ||
|
||
|
||
|
||
using namespace tensorflow; | ||
|
||
template <typename T> | ||
class VecToTriOp : public OpKernel { | ||
public: | ||
explicit VecToTriOp(OpKernelConstruction* context) : OpKernel(context) {} | ||
|
||
void Compute(OpKernelContext* context) override { | ||
// Grab the input tensor | ||
const Tensor& input_tensor = context->input(0); | ||
auto input = input_tensor.flat<T>(); | ||
|
||
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_tensor.shape()), | ||
errors::InvalidArgument("VecToTri expects a 2-D matrix.")); | ||
|
||
auto ds = input_tensor.shape().dim_sizes(); | ||
|
||
int matsize = (int)std::floor(std::sqrt(ds[1] * 8 + 1) / 2.0 - 0.5); // Deduce square matrix size | ||
int recvecsize = (int)std::round(0.5*matsize*(matsize+1)); // Reconstruct number of required vector elements | ||
|
||
OP_REQUIRES(context, recvecsize == ds[1], | ||
errors::InvalidArgument("Must have triangle number of elements in the input vector.") | ||
); | ||
|
||
// Create an output tensor | ||
TensorShape out_shape({ds[0], matsize, matsize}); | ||
Tensor* output_tensor = NULL; | ||
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, | ||
&output_tensor)); | ||
|
||
auto output = output_tensor->template flat<T>(); | ||
const int N = output.size(); | ||
for (int i = 0; i != N; i++) { | ||
int output_mat = i / (matsize * matsize); | ||
int x = i % matsize; | ||
int y = (i / matsize) % matsize; | ||
if (x > y) { | ||
output(i) = (T)0; | ||
} else { | ||
int idx = (i % (matsize*matsize)) - (int)std::round(matsize*y-0.5*y*y-0.5*y); | ||
output(i) = input(idx + ds[1] * output_mat); | ||
} | ||
} | ||
} | ||
}; | ||
|
||
#define REGISTER_KERNEL(type) \ | ||
REGISTER_KERNEL_BUILDER( \ | ||
Name("VecToTri") \ | ||
.Device(DEVICE_CPU) \ | ||
.TypeConstraint<type>("T"), \ | ||
VecToTriOp<type> \ | ||
); | ||
|
||
|
||
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL); | ||
|
||
#undef REGISTER_KERNEL |
Oops, something went wrong.