6
6
* the terms of the Apache License 2.0 which accompanies this distribution. *
7
7
******************************************************************************/
8
8
9
+ #include < cmath>
9
10
#include < cudaq/cudm_state.h>
10
11
#include < iostream>
12
+ #include < numeric>
11
13
#include < sstream>
14
+ #include < stdexcept>
12
15
13
16
namespace cudaq {
14
17
15
- cudm_mat_state::cudm_mat_state (std::vector<std::complex<double >> rawData)
18
+ cudm_state::cudm_state (std::vector<std::complex<double >> rawData)
16
19
: rawData_(rawData), state_(nullptr ), handle_(nullptr ),
17
20
hilbertSpaceDims_ () {
18
21
HANDLE_CUDM_ERROR (cudensitymatCreate (&handle_));
19
22
}
20
23
21
- cudm_mat_state ::~cudm_mat_state () {
24
+ cudm_state ::~cudm_state () {
22
25
if (state_) {
23
26
cudensitymatDestroyState (state_);
24
27
}
@@ -27,7 +30,7 @@ cudm_mat_state::~cudm_mat_state() {
27
30
}
28
31
}
29
32
30
- void cudm_mat_state ::init_state (const std::vector<int64_t > &hilbertSpaceDims) {
33
+ void cudm_state ::init_state (const std::vector<int64_t > &hilbertSpaceDims) {
31
34
if (state_) {
32
35
throw std::runtime_error (" State is already initialized." );
33
36
}
@@ -61,17 +64,17 @@ void cudm_mat_state::init_state(const std::vector<int64_t> &hilbertSpaceDims) {
61
64
attach_storage ();
62
65
}
63
66
64
- bool cudm_mat_state ::is_initialized () const { return state_ != nullptr ; }
67
+ bool cudm_state ::is_initialized () const { return state_ != nullptr ; }
65
68
66
- bool cudm_mat_state ::is_density_matrix () const {
69
+ bool cudm_state ::is_density_matrix () const {
67
70
if (!is_initialized ()) {
68
71
return false ;
69
72
}
70
73
71
74
return rawData_.size () == calculate_density_matrix_size (hilbertSpaceDims_);
72
75
}
73
76
74
- std::string cudm_mat_state ::dump () const {
77
+ std::string cudm_state ::dump () const {
75
78
if (!is_initialized ()) {
76
79
throw std::runtime_error (" State is not initialized." );
77
80
}
@@ -88,7 +91,7 @@ std::string cudm_mat_state::dump() const {
88
91
return oss.str ();
89
92
}
90
93
91
- cudm_mat_state cudm_mat_state ::to_density_matrix () const {
94
+ cudm_state cudm_state ::to_density_matrix () const {
92
95
if (!is_initialized ()) {
93
96
throw std::runtime_error (" State is not initialized." );
94
97
}
@@ -108,19 +111,19 @@ cudm_mat_state cudm_mat_state::to_density_matrix() const {
108
111
}
109
112
}
110
113
111
- cudm_mat_state densityMatrixState (densityMatrix);
114
+ cudm_state densityMatrixState (densityMatrix);
112
115
densityMatrixState.init_state (hilbertSpaceDims_);
113
116
return densityMatrixState;
114
117
}
115
118
116
- cudensitymatState_t cudm_mat_state ::get_impl () const {
119
+ cudensitymatState_t cudm_state ::get_impl () const {
117
120
if (!is_initialized ()) {
118
121
throw std::runtime_error (" State is not initialized." );
119
122
}
120
123
return state_;
121
124
}
122
125
123
- void cudm_mat_state ::attach_storage () {
126
+ void cudm_state ::attach_storage () {
124
127
if (!state_) {
125
128
throw std::runtime_error (" State is not initialized." );
126
129
}
@@ -166,7 +169,7 @@ void cudm_mat_state::attach_storage() {
166
169
componentBufferSizes.data ()));
167
170
}
168
171
169
- size_t cudm_mat_state ::calculate_state_vector_size (
172
+ size_t cudm_state ::calculate_state_vector_size (
170
173
const std::vector<int64_t > &hilbertSpaceDims) const {
171
174
size_t size = 1 ;
172
175
for (auto dim : hilbertSpaceDims) {
@@ -175,9 +178,67 @@ size_t cudm_mat_state::calculate_state_vector_size(
175
178
return size;
176
179
}
177
180
178
- size_t cudm_mat_state ::calculate_density_matrix_size (
181
+ size_t cudm_state ::calculate_density_matrix_size (
179
182
const std::vector<int64_t > &hilbertSpaceDims) const {
180
183
size_t vectorSize = calculate_state_vector_size (hilbertSpaceDims);
181
184
return vectorSize * vectorSize;
182
185
}
186
+
187
+ // Initialize state based on InitialStateArgT
188
+ cudm_state
189
+ cudm_state::create_initial_state (const InitialStateArgT &initialStateArg,
190
+ const std::vector<int64_t > &hilbertSpaceDims,
191
+ bool hasCollapseOps) {
192
+ size_t stateVectorSize =
193
+ std::accumulate (hilbertSpaceDims.begin (), hilbertSpaceDims.end (),
194
+ static_cast <size_t >(1 ), std::multiplies<>{});
195
+
196
+ std::vector<std::complex<double >> rawData;
197
+
198
+ if (std::holds_alternative<InitialState>(initialStateArg)) {
199
+ InitialState initialState = std::get<InitialState>(initialStateArg);
200
+
201
+ if (initialState == InitialState::ZERO) {
202
+ rawData.resize (stateVectorSize, {0.0 , 0.0 });
203
+ // |0> state
204
+ rawData[0 ] = {1.0 , 0.0 };
205
+ } else if (initialState == InitialState::UNIFORM) {
206
+ rawData.resize (stateVectorSize, {1.0 / std::sqrt (stateVectorSize), 0.0 });
207
+ } else {
208
+ throw std::invalid_argument (" Unsupported InitialState type." );
209
+ }
210
+ } else if (std::holds_alternative<void *>(initialStateArg)) {
211
+ void *runtimeState = std::get<void *>(initialStateArg);
212
+ if (!runtimeState) {
213
+ throw std::invalid_argument (" Runtime state pointer is null." );
214
+ }
215
+
216
+ try {
217
+ auto *externalData =
218
+ reinterpret_cast <std::vector<std::complex<double >> *>(runtimeState);
219
+
220
+ if (!externalData || externalData->empty ()) {
221
+ throw std::invalid_argument (
222
+ " Runtime state contains invalid or empty data." );
223
+ }
224
+
225
+ rawData = *externalData;
226
+ } catch (const std::exception &e) {
227
+ throw std::runtime_error (" Failed to interpret runtime state: " +
228
+ std::string (e.what ()));
229
+ }
230
+ } else {
231
+ throw std::invalid_argument (" Unsupported InitialStateArgT type." );
232
+ }
233
+
234
+ cudm_state state (rawData);
235
+ state.init_state (hilbertSpaceDims);
236
+
237
+ // Convert to a density matrix if collapse operators are present.
238
+ if (hasCollapseOps && !state.is_density_matrix ()) {
239
+ state = state.to_density_matrix ();
240
+ }
241
+
242
+ return state;
243
+ }
183
244
} // namespace cudaq
0 commit comments