Skip to content

Commit 28c86a4

Browse files
authored
Refactor noise model for unitary mixtures (#2652)
Signed-off-by: Ben Howe <[email protected]>
1 parent 1be2bf8 commit 28c86a4

File tree

4 files changed

+182
-169
lines changed

4 files changed

+182
-169
lines changed

runtime/common/NoiseModel.cpp

+111-1
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,73 @@
1010
#include "Logger.h"
1111
#include "common/CustomOp.h"
1212
#include "common/EigenDense.h"
13+
#include <numeric>
14+
#include <optional>
15+
1316
namespace cudaq {
1417

18+
// Helper to check whether a matrix is a scaled unitary matrix, i.e., `k * U`
19+
// where U is a unitary matrix. If so, it also returns the `k` factor.
20+
// Otherwise, return a nullopt.
21+
static std::optional<double>
22+
isScaledUnitary(const std::vector<std::complex<double>> &mat, double eps) {
23+
typedef Eigen::Matrix<std::complex<double>, Eigen::Dynamic, Eigen::Dynamic,
24+
Eigen::RowMajor>
25+
RowMajorMatTy;
26+
const int dim = std::log2(mat.size());
27+
Eigen::Map<const RowMajorMatTy> kMat(mat.data(), dim, dim);
28+
if (kMat.isZero(eps))
29+
return 0.0;
30+
// Check that (K_dag * K) is a scaled identity matrix
31+
// i.e., the K matrix is a scaled unitary.
32+
auto kdK = kMat.adjoint() * kMat;
33+
if (!kdK.isDiagonal(eps))
34+
return std::nullopt;
35+
// First element
36+
std::complex<double> val = kdK(0, 0);
37+
if (val.real() > eps && std::abs(val.imag()) < eps) {
38+
auto scaledKdK = (std::complex<double>{1.0} / val) * kdK;
39+
if (scaledKdK.isIdentity())
40+
return std::sqrt(val.real());
41+
}
42+
return std::nullopt;
43+
}
44+
45+
// Helper to determine if a vector of Kraus ops are actually a unitary mixture.
46+
// If so, it returns all the unitaries and the probabilities associated with
47+
// each one of those unitaries.
48+
static std::optional<std::pair<std::vector<double>,
49+
std::vector<std::vector<std::complex<double>>>>>
50+
computeUnitaryMixture(
51+
const std::vector<std::vector<std::complex<double>>> &krausOps,
52+
double tol = 1e-6) {
53+
std::vector<double> probs;
54+
std::vector<std::vector<std::complex<double>>> mats;
55+
const auto scaleMat = [](const std::vector<std::complex<double>> &mat,
56+
double scaleFactor) {
57+
std::vector<std::complex<double>> scaledMat = mat;
58+
// If scaleFactor is 0, then it means the original matrix was likely all
59+
// zeros. In that case, the probability will be 0, so the matrix doesn't
60+
// matter, but we don't want NaNs to trickle in anywhere.
61+
if (scaleFactor != 0.0)
62+
for (auto &x : scaledMat)
63+
x /= scaleFactor;
64+
return scaledMat;
65+
};
66+
for (const auto &op : krausOps) {
67+
const auto scaledFactor = isScaledUnitary(op, tol);
68+
if (!scaledFactor.has_value())
69+
return std::nullopt;
70+
probs.emplace_back(scaledFactor.value() * scaledFactor.value());
71+
mats.emplace_back(scaleMat(op, scaledFactor.value()));
72+
}
73+
74+
if (std::abs(1.0 - std::reduce(probs.begin(), probs.end())) > tol)
75+
return std::nullopt;
76+
77+
return std::make_pair(probs, mats);
78+
}
79+
1580
template <typename EigenMatTy>
1681
bool isIdentity(const EigenMatTy &mat, double threshold = 1e-9) {
1782
EigenMatTy idMat = EigenMatTy::Identity(mat.rows(), mat.cols());
@@ -78,9 +143,52 @@ void validateCompletenessRelation_fp64(const std::vector<kraus_op> &ops) {
78143
"Provided kraus_ops are not completely positive and trace preserving.");
79144
}
80145

146+
void generateUnitaryParameters_fp32(
147+
const std::vector<kraus_op> &ops,
148+
std::vector<std::vector<std::complex<double>>> &unitary_ops,
149+
std::vector<double> &probabilities) {
150+
std::vector<std::vector<std::complex<double>>> double_kraus_ops;
151+
double_kraus_ops.reserve(ops.size());
152+
for (auto &op : ops) {
153+
// WARNING: danger here. We are intentially treating the incoming op as fp32
154+
// type instead of what the compiler thinks it is (fp64). We have to do this
155+
// because this file is compiled with cudaq::real = fp64, but the incoming
156+
// data for this specific routine is actually fp32.
157+
const std::complex<float> *ptr =
158+
reinterpret_cast<const std::complex<float> *>(op.data.data());
159+
// Use 2 * size because pointer arithmetic is on fp32 instead of fp64
160+
double_kraus_ops.emplace_back(
161+
std::vector<std::complex<double>>(ptr, ptr + 2 * op.data.size()));
162+
}
163+
164+
auto asUnitaryMixture = computeUnitaryMixture(double_kraus_ops);
165+
if (asUnitaryMixture.has_value()) {
166+
probabilities = std::move(asUnitaryMixture.value().first);
167+
unitary_ops = std::move(asUnitaryMixture.value().second);
168+
}
169+
}
170+
171+
void generateUnitaryParameters_fp64(
172+
const std::vector<kraus_op> &ops,
173+
std::vector<std::vector<std::complex<double>>> &unitary_ops,
174+
std::vector<double> &probabilities) {
175+
std::vector<std::vector<std::complex<double>>> double_kraus_ops;
176+
double_kraus_ops.reserve(ops.size());
177+
for (auto &op : ops)
178+
double_kraus_ops.emplace_back(
179+
std::vector<std::complex<double>>(op.data.begin(), op.data.end()));
180+
181+
auto asUnitaryMixture = computeUnitaryMixture(double_kraus_ops);
182+
if (asUnitaryMixture.has_value()) {
183+
probabilities = std::move(asUnitaryMixture.value().first);
184+
unitary_ops = std::move(asUnitaryMixture.value().second);
185+
}
186+
}
187+
81188
kraus_channel::kraus_channel(const kraus_channel &other)
82189
: ops(other.ops), noise_type(other.noise_type),
83-
parameters(other.parameters) {}
190+
parameters(other.parameters), unitary_ops(other.unitary_ops),
191+
probabilities(other.probabilities) {}
84192

85193
std::size_t kraus_channel::size() const { return ops.size(); }
86194

@@ -94,6 +202,8 @@ kraus_channel &kraus_channel::operator=(const kraus_channel &other) {
94202
ops = other.ops;
95203
noise_type = other.noise_type;
96204
parameters = other.parameters;
205+
unitary_ops = other.unitary_ops;
206+
probabilities = other.probabilities;
97207
return *this;
98208
}
99209

runtime/common/NoiseModel.h

+41-1
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,12 @@ struct kraus_op {
102102

103103
void validateCompletenessRelation_fp32(const std::vector<kraus_op> &ops);
104104
void validateCompletenessRelation_fp64(const std::vector<kraus_op> &ops);
105+
void generateUnitaryParameters_fp32(
106+
const std::vector<kraus_op> &ops,
107+
std::vector<std::vector<std::complex<double>>> &, std::vector<double> &);
108+
void generateUnitaryParameters_fp64(
109+
const std::vector<kraus_op> &ops,
110+
std::vector<std::vector<std::complex<double>>> &, std::vector<double> &);
105111

106112
/// @brief A kraus_channel represents a quantum noise channel
107113
/// on specific qubits. The action of the noise channel is
@@ -143,7 +149,17 @@ class kraus_channel {
143149
// corruption.
144150
std::vector<double> parameters;
145151

146-
~kraus_channel() = default;
152+
/// @brief If all Kraus ops are - when scaled - unitary, this holds the
153+
/// unitary versions of those ops. These values are always "double" regardless
154+
/// of whether cudaq::real is float or double.
155+
std::vector<std::vector<std::complex<double>>> unitary_ops;
156+
157+
/// @brief If all Kraus ops are - when scaled - unitary, this holds the
158+
/// probabilities of those ops. These values are always "double" regardless
159+
/// of whether cudaq::real is float or double.
160+
std::vector<double> probabilities;
161+
162+
virtual ~kraus_channel() = default;
147163

148164
/// @brief The nullary constructor
149165
kraus_channel() = default;
@@ -158,12 +174,14 @@ class kraus_channel {
158174
kraus_channel(std::initializer_list<T> &&...inputLists) {
159175
(ops.emplace_back(std::move(inputLists)), ...);
160176
validateCompleteness();
177+
generateUnitaryParameters();
161178
}
162179

163180
/// @brief The constructor, take qubits and channel kraus_ops as lvalue
164181
/// reference
165182
kraus_channel(const std::vector<kraus_op> &inOps) : ops(inOps) {
166183
validateCompleteness();
184+
generateUnitaryParameters();
167185
}
168186

169187
/// @brief The constructor, take qubits and channel kraus_ops as rvalue
@@ -189,6 +207,23 @@ class kraus_channel {
189207

190208
/// @brief Add a kraus_op to this channel.
191209
void push_back(kraus_op op);
210+
211+
/// @brief Returns whether or not this is a unitary mixture.
212+
bool is_unitary_mixture() const { return !unitary_ops.empty(); }
213+
214+
/// @brief Checks if Kraus ops have unitary representations and saves them if
215+
/// they do. Users should only need to call this if they have modified the
216+
/// Kraus ops and want to recompute these values.
217+
void generateUnitaryParameters() {
218+
unitary_ops.clear();
219+
probabilities.clear();
220+
if constexpr (std::is_same_v<cudaq::complex::value_type, float>) {
221+
generateUnitaryParameters_fp32(ops, this->unitary_ops,
222+
this->probabilities);
223+
return;
224+
}
225+
generateUnitaryParameters_fp64(ops, this->unitary_ops, this->probabilities);
226+
}
192227
};
193228

194229
/// @brief The noise_model type keeps track of a set of
@@ -381,6 +416,7 @@ class depolarization_channel : public kraus_channel {
381416
this->parameters.push_back(probability);
382417
noise_type = noise_model_type::depolarization_channel;
383418
validateCompleteness();
419+
generateUnitaryParameters();
384420
}
385421
};
386422

@@ -396,6 +432,8 @@ class amplitude_damping_channel : public kraus_channel {
396432
this->parameters.push_back(probability);
397433
noise_type = noise_model_type::amplitude_damping_channel;
398434
validateCompleteness();
435+
// Note: amplitude damping is non-unitary, so there is no value in calling
436+
// generateUnitaryParameters().
399437
}
400438
};
401439

@@ -412,6 +450,7 @@ class bit_flip_channel : public kraus_channel {
412450
this->parameters.push_back(probability);
413451
noise_type = noise_model_type::bit_flip_channel;
414452
validateCompleteness();
453+
generateUnitaryParameters();
415454
}
416455
};
417456

@@ -429,6 +468,7 @@ class phase_flip_channel : public kraus_channel {
429468
this->parameters.push_back(probability);
430469
noise_type = noise_model_type::phase_flip_channel;
431470
validateCompleteness();
471+
generateUnitaryParameters();
432472
}
433473
};
434474
} // namespace cudaq

0 commit comments

Comments
 (0)