Skip to content

Commit 7f9fcee

Browse files
authored
Merge pull request #90 from tnguyen-ornl/fix-u-gate
Fixes for #89
2 parents 347f9d3 + fc1f0e1 commit 7f9fcee

File tree

6 files changed

+61
-13
lines changed

6 files changed

+61
-13
lines changed

tnqvm/base/Gates.hpp

+13
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,19 @@ namespace tnqvm {
221221
};
222222
}
223223

224+
template <>
225+
std::vector<std::vector<std::complex<double>>>
226+
GetGateMatrix<CommonGates::U>(double in_theta, double in_phi,
227+
double in_lambda) {
228+
return {
229+
{std::cos(in_theta / 2.0),
230+
-std::exp(std::complex<double>(0, in_lambda)) *
231+
std::sin(in_theta / 2.0)},
232+
{std::exp(std::complex<double>(0, in_phi)) * std::sin(in_theta / 2.0),
233+
std::exp(std::complex<double>(0, in_phi + in_lambda)) *
234+
std::cos(in_theta / 2.0)}};
235+
}
236+
224237
template <>
225238
std::vector<std::vector<std::complex<double>>> GetGateMatrix<CommonGates::CNOT>() {
226239
return

tnqvm/visitors/exatn-dm/ExaTnDmVisitor.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,11 @@ getGateMatrix(const xacc::Instruction &in_gate, bool in_dagger = false) {
143143
case CommonGates::Rz:
144144
return GetGateMatrix<CommonGates::Rz>(
145145
in_gate.getParameter(0).as<double>());
146+
case CommonGates::U:
147+
return GetGateMatrix<CommonGates::U>(
148+
in_gate.getParameter(0).as<double>(),
149+
in_gate.getParameter(1).as<double>(),
150+
in_gate.getParameter(2).as<double>());
146151
case CommonGates::I:
147152
return GetGateMatrix<CommonGates::I>();
148153
case CommonGates::H:

tnqvm/visitors/exatn-mpo/ExaTnPmpsVisitor.cpp

+9-2
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,15 @@ std::vector<std::complex<double>> getGateMatrix(const xacc::Instruction& in_gate
154154
case CommonGates::X: return GetGateMatrix<CommonGates::X>();
155155
case CommonGates::Y: return GetGateMatrix<CommonGates::Y>();
156156
case CommonGates::Z: return GetGateMatrix<CommonGates::Z>();
157-
case CommonGates::T: return GetGateMatrix<CommonGates::T>();
158-
case CommonGates::Tdg: return GetGateMatrix<CommonGates::Tdg>();
157+
case CommonGates::T:
158+
return GetGateMatrix<CommonGates::T>();
159+
case CommonGates::U:
160+
return GetGateMatrix<CommonGates::U>(
161+
in_gate.getParameter(0).as<double>(),
162+
in_gate.getParameter(1).as<double>(),
163+
in_gate.getParameter(2).as<double>());
164+
case CommonGates::Tdg:
165+
return GetGateMatrix<CommonGates::Tdg>();
159166
case CommonGates::CNOT: return GetGateMatrix<CommonGates::CNOT>();
160167
case CommonGates::Swap: return GetGateMatrix<CommonGates::Swap>();
161168
case CommonGates::iSwap: return GetGateMatrix<CommonGates::iSwap>();

tnqvm/visitors/exatn-mps/ExatnUtils.cpp

+9-2
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,15 @@ GateTensor GateTensorConstructor::getGateTensor(xacc::Instruction& in_gate)
5656
case CommonGates::Y: return GetGateMatrix<CommonGates::Y>();
5757
case CommonGates::Z: return GetGateMatrix<CommonGates::Z>();
5858
case CommonGates::T: return GetGateMatrix<CommonGates::T>();
59-
case CommonGates::Tdg: return GetGateMatrix<CommonGates::Tdg>();
60-
case CommonGates::CNOT: return GetGateMatrix<CommonGates::CNOT>();
59+
case CommonGates::Tdg:
60+
return GetGateMatrix<CommonGates::Tdg>();
61+
case CommonGates::U:
62+
return GetGateMatrix<CommonGates::U>(
63+
in_gate.getParameter(0).as<double>(),
64+
in_gate.getParameter(1).as<double>(),
65+
in_gate.getParameter(2).as<double>());
66+
case CommonGates::CNOT:
67+
return GetGateMatrix<CommonGates::CNOT>();
6168
case CommonGates::Swap: return GetGateMatrix<CommonGates::Swap>();
6269
case CommonGates::iSwap: return GetGateMatrix<CommonGates::iSwap>();
6370
case CommonGates::fSim: return GetGateMatrix<CommonGates::fSim>(in_gate.getParameter(0).as<double>(), in_gate.getParameter(1).as<double>());

tnqvm/visitors/exatn/ExatnVisitor.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -924,7 +924,11 @@ void ExatnVisitor<TNQVM_COMPLEX_TYPE>::visit(CPhase &in_CPhaseGate) {
924924
template<typename TNQVM_COMPLEX_TYPE>
925925
void ExatnVisitor<TNQVM_COMPLEX_TYPE>::visit(U &in_UGate) {
926926
TNQVM_TELEMETRY_ZONE(__FUNCTION__, __FILE__, __LINE__);
927-
appendGateTensor<CommonGates::U>(in_UGate);
927+
assert(in_UGate.nParameters() == 3);
928+
const double theta = in_UGate.getParameter(0).as<double>();
929+
const double phi = in_UGate.getParameter(1).as<double>();
930+
const double lambda = in_UGate.getParameter(2).as<double>();
931+
appendGateTensor<CommonGates::U>(in_UGate, theta, phi, lambda);
928932
}
929933

930934
template<typename TNQVM_COMPLEX_TYPE>

tnqvm/visitors/itensor/mps/ITensorMPSVisitor.cpp

+20-8
Original file line numberDiff line numberDiff line change
@@ -490,15 +490,27 @@ void ITensorMPSVisitor::visit(Rz &gate) {
490490
execTime += singleQubitTime;
491491
}
492492

493-
void ITensorMPSVisitor::visit(U& u) {
494-
Rz z1(u.bits()[0], ipToDouble(u.getParameter(0)));
495-
Ry y(u.bits()[0], ipToDouble(u.getParameter(1)));
496-
Rz z2(u.bits()[0], ipToDouble(u.getParameter(2)));
497-
498-
visit(z1);
499-
visit(y);
500-
visit(z2);
493+
void ITensorMPSVisitor::visit(U &u) {
494+
auto iqbit_in = u.bits()[0];
495+
if (verbose) {
496+
std::cout << "applying " << u.name() << " @ " << iqbit_in << std::endl;
501497
}
498+
const double theta = ipToDouble(u.getParameter(0));
499+
const double phi = ipToDouble(u.getParameter(1));
500+
const double lambda = ipToDouble(u.getParameter(2));
501+
auto ind_in = ind_for_qbit(iqbit_in);
502+
auto ind_out = itensor::Index(u.name(), 2);
503+
auto tGate = itensor::ITensor(ind_in, ind_out);
504+
tGate.set(ind_in(1), ind_out(1), std::cos(theta / 2.0));
505+
tGate.set(ind_in(1), ind_out(2), -std::exp(std::complex<double>(0, lambda)) *
506+
std::sin(theta / 2.0));
507+
tGate.set(ind_in(2), ind_out(1), std::exp(std::complex<double>(0, phi)) * std::sin(theta / 2.0));
508+
tGate.set(ind_in(2), ind_out(2), std::exp(std::complex<double>(0, phi + lambda)) *
509+
std::cos(theta / 2.0));
510+
legMats[iqbit_in] = tGate * legMats[iqbit_in];
511+
printWavefunc();
512+
execTime += singleQubitTime;
513+
}
502514

503515
void ITensorMPSVisitor::visit(CPhase &cp) {
504516
xacc::error("ITensorMPS Visitor CPhase visit unimplemented.");

0 commit comments

Comments
 (0)