Skip to content

Commit bbec3ed

Browse files
committed
* Renaming cudm_mat_state -> cudm_state
* Adding create_initial_state method to create initial state based on the passed InitialStateArgT Signed-off-by: Sachin Pisal <[email protected]>
1 parent fbfa91a commit bbec3ed

File tree

2 files changed

+93
-17
lines changed

2 files changed

+93
-17
lines changed

runtime/cudaq/cudm_state.h

+20-5
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,28 @@
1616
#include <vector>
1717

1818
namespace cudaq {
19-
class cudm_mat_state {
19+
// Enum to specify the initial quantum state.
20+
enum class InitialState { ZERO, UNIFORM };
21+
22+
using InitialStateArgT = std::variant<void *, InitialState>;
23+
24+
class cudm_state {
2025
public:
2126
/// @brief To initialize state with raw data.
22-
explicit cudm_mat_state(std::vector<std::complex<double>> rawData);
27+
explicit cudm_state(std::vector<std::complex<double>> rawData);
2328

2429
/// @brief Destructor to clean up resources
25-
~cudm_mat_state();
30+
~cudm_state();
31+
32+
/// @brief Factory method to create an initial state.
33+
/// @param InitialStateArgT The type or representation of the initial state.
34+
/// @param Dimensions of the Hilbert space.
35+
/// @param hasCollapseOps Whether collapse operators are present.
36+
/// @return A new 'cudm_state' initialized to the specified state.
37+
static cudm_state
38+
create_initial_state(const InitialStateArgT &initialStateArg,
39+
const std::vector<int64_t> &hilbertSpaceDims,
40+
bool hasCollapseOps);
2641

2742
/// @brief Initialize the state as a density matrix or state vector based on
2843
/// dimensions.
@@ -42,8 +57,8 @@ class cudm_mat_state {
4257
std::string dump() const;
4358

4459
/// @brief Convert the state vector to a density matrix.
45-
/// @return A new cudm_mat_state representing the density matrix.
46-
cudm_mat_state to_density_matrix() const;
60+
/// @return A new cudm_state representing the density matrix.
61+
cudm_state to_density_matrix() const;
4762

4863
/// @brief Get the underlying implementation (if any).
4964
/// @return The underlying state implementation.

runtime/cudaq/dynamics/cudm_state.cpp

+73-12
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,22 @@
66
* the terms of the Apache License 2.0 which accompanies this distribution. *
77
******************************************************************************/
88

9+
#include <cmath>
910
#include <cudaq/cudm_state.h>
1011
#include <iostream>
12+
#include <numeric>
1113
#include <sstream>
14+
#include <stdexcept>
1215

1316
namespace cudaq {
1417

15-
cudm_mat_state::cudm_mat_state(std::vector<std::complex<double>> rawData)
18+
cudm_state::cudm_state(std::vector<std::complex<double>> rawData)
1619
: rawData_(rawData), state_(nullptr), handle_(nullptr),
1720
hilbertSpaceDims_() {
1821
HANDLE_CUDM_ERROR(cudensitymatCreate(&handle_));
1922
}
2023

21-
cudm_mat_state::~cudm_mat_state() {
24+
cudm_state::~cudm_state() {
2225
if (state_) {
2326
cudensitymatDestroyState(state_);
2427
}
@@ -27,7 +30,7 @@ cudm_mat_state::~cudm_mat_state() {
2730
}
2831
}
2932

30-
void cudm_mat_state::init_state(const std::vector<int64_t> &hilbertSpaceDims) {
33+
void cudm_state::init_state(const std::vector<int64_t> &hilbertSpaceDims) {
3134
if (state_) {
3235
throw std::runtime_error("State is already initialized.");
3336
}
@@ -61,17 +64,17 @@ void cudm_mat_state::init_state(const std::vector<int64_t> &hilbertSpaceDims) {
6164
attach_storage();
6265
}
6366

64-
bool cudm_mat_state::is_initialized() const { return state_ != nullptr; }
67+
bool cudm_state::is_initialized() const { return state_ != nullptr; }
6568

66-
bool cudm_mat_state::is_density_matrix() const {
69+
bool cudm_state::is_density_matrix() const {
6770
if (!is_initialized()) {
6871
return false;
6972
}
7073

7174
return rawData_.size() == calculate_density_matrix_size(hilbertSpaceDims_);
7275
}
7376

74-
std::string cudm_mat_state::dump() const {
77+
std::string cudm_state::dump() const {
7578
if (!is_initialized()) {
7679
throw std::runtime_error("State is not initialized.");
7780
}
@@ -88,7 +91,7 @@ std::string cudm_mat_state::dump() const {
8891
return oss.str();
8992
}
9093

91-
cudm_mat_state cudm_mat_state::to_density_matrix() const {
94+
cudm_state cudm_state::to_density_matrix() const {
9295
if (!is_initialized()) {
9396
throw std::runtime_error("State is not initialized.");
9497
}
@@ -108,19 +111,19 @@ cudm_mat_state cudm_mat_state::to_density_matrix() const {
108111
}
109112
}
110113

111-
cudm_mat_state densityMatrixState(densityMatrix);
114+
cudm_state densityMatrixState(densityMatrix);
112115
densityMatrixState.init_state(hilbertSpaceDims_);
113116
return densityMatrixState;
114117
}
115118

116-
cudensitymatState_t cudm_mat_state::get_impl() const {
119+
cudensitymatState_t cudm_state::get_impl() const {
117120
if (!is_initialized()) {
118121
throw std::runtime_error("State is not initialized.");
119122
}
120123
return state_;
121124
}
122125

123-
void cudm_mat_state::attach_storage() {
126+
void cudm_state::attach_storage() {
124127
if (!state_) {
125128
throw std::runtime_error("State is not initialized.");
126129
}
@@ -166,7 +169,7 @@ void cudm_mat_state::attach_storage() {
166169
componentBufferSizes.data()));
167170
}
168171

169-
size_t cudm_mat_state::calculate_state_vector_size(
172+
size_t cudm_state::calculate_state_vector_size(
170173
const std::vector<int64_t> &hilbertSpaceDims) const {
171174
size_t size = 1;
172175
for (auto dim : hilbertSpaceDims) {
@@ -175,9 +178,67 @@ size_t cudm_mat_state::calculate_state_vector_size(
175178
return size;
176179
}
177180

178-
size_t cudm_mat_state::calculate_density_matrix_size(
181+
size_t cudm_state::calculate_density_matrix_size(
179182
const std::vector<int64_t> &hilbertSpaceDims) const {
180183
size_t vectorSize = calculate_state_vector_size(hilbertSpaceDims);
181184
return vectorSize * vectorSize;
182185
}
186+
187+
// Initialize state based on InitialStateArgT
188+
cudm_state
189+
cudm_state::create_initial_state(const InitialStateArgT &initialStateArg,
190+
const std::vector<int64_t> &hilbertSpaceDims,
191+
bool hasCollapseOps) {
192+
size_t stateVectorSize =
193+
std::accumulate(hilbertSpaceDims.begin(), hilbertSpaceDims.end(),
194+
static_cast<size_t>(1), std::multiplies<>{});
195+
196+
std::vector<std::complex<double>> rawData;
197+
198+
if (std::holds_alternative<InitialState>(initialStateArg)) {
199+
InitialState initialState = std::get<InitialState>(initialStateArg);
200+
201+
if (initialState == InitialState::ZERO) {
202+
rawData.resize(stateVectorSize, {0.0, 0.0});
203+
// |0> state
204+
rawData[0] = {1.0, 0.0};
205+
} else if (initialState == InitialState::UNIFORM) {
206+
rawData.resize(stateVectorSize, {1.0 / std::sqrt(stateVectorSize), 0.0});
207+
} else {
208+
throw std::invalid_argument("Unsupported InitialState type.");
209+
}
210+
} else if (std::holds_alternative<void *>(initialStateArg)) {
211+
void *runtimeState = std::get<void *>(initialStateArg);
212+
if (!runtimeState) {
213+
throw std::invalid_argument("Runtime state pointer is null.");
214+
}
215+
216+
try {
217+
auto *externalData =
218+
reinterpret_cast<std::vector<std::complex<double>> *>(runtimeState);
219+
220+
if (!externalData || externalData->empty()) {
221+
throw std::invalid_argument(
222+
"Runtime state contains invalid or empty data.");
223+
}
224+
225+
rawData = *externalData;
226+
} catch (const std::exception &e) {
227+
throw std::runtime_error("Failed to interpret runtime state: " +
228+
std::string(e.what()));
229+
}
230+
} else {
231+
throw std::invalid_argument("Unsupported InitialStateArgT type.");
232+
}
233+
234+
cudm_state state(rawData);
235+
state.init_state(hilbertSpaceDims);
236+
237+
// Convert to a density matrix if collapse operators are present.
238+
if (hasCollapseOps && !state.is_density_matrix()) {
239+
state = state.to_density_matrix();
240+
}
241+
242+
return state;
243+
}
183244
} // namespace cudaq

0 commit comments

Comments
 (0)