1
+ /* ******************************************************************************
2
+ * Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. *
3
+ * All rights reserved. *
4
+ * *
5
+ * This source code and the accompanying materials are made available under *
6
+ * the terms of the Apache License 2.0 which accompanies this distribution. *
7
+ ******************************************************************************/
8
+
9
+ #include " BraketExecutor.h"
10
+ #include " BraketServerHelper.h"
11
+
12
+ namespace cudaq {
13
+
14
+ details::future
15
+ BraketExecutor::execute (std::vector<KernelExecution> &codesToExecute) {
16
+ auto braketServerHelper = dynamic_cast <BraketServerHelper *>(serverHelper);
17
+ assert (braketServerHelper);
18
+ braketServerHelper->setShots (shots);
19
+
20
+ auto [dummy1, dummy2, messages] =
21
+ braketServerHelper->createJob (codesToExecute);
22
+
23
+ std::string const defaultBucket = defaultBucketFuture.get ();
24
+ std::string const defaultPrefix = " tasks" ;
25
+
26
+ auto config = braketServerHelper->getConfig ();
27
+ cudaq::info (" Backend config: {}, shots {}" , config, shots);
28
+ config.insert ({" shots" , std::to_string (shots)});
29
+
30
+ std::vector<Aws::Braket::Model::CreateQuantumTaskOutcomeCallable>
31
+ createOutcomes;
32
+
33
+ for (const auto &message : messages) {
34
+ Aws::Braket::Model::CreateQuantumTaskRequest req;
35
+ req.SetAction (message[" action" ]);
36
+ req.SetDeviceArn (message[" deviceArn" ]);
37
+ req.SetShots (message[" shots" ]);
38
+ if (jobToken)
39
+ req.SetJobToken (jobToken);
40
+ req.SetOutputS3Bucket (defaultBucket);
41
+ req.SetOutputS3KeyPrefix (defaultPrefix);
42
+
43
+ createOutcomes.push_back (braketClient.CreateQuantumTaskCallable (req));
44
+ }
45
+
46
+ return std::async (
47
+ std::launch::async,
48
+ [this ](std::vector<Aws::Braket::Model::CreateQuantumTaskOutcomeCallable>
49
+ createOutcomes) {
50
+ std::vector<ExecutionResult> results;
51
+ for (auto &outcome : createOutcomes) {
52
+ auto createResponse = outcome.get ();
53
+ if (!createResponse.IsSuccess ()) {
54
+ throw std::runtime_error (createResponse.GetError ().GetMessage ());
55
+ }
56
+ std::string taskArn = createResponse.GetResult ().GetQuantumTaskArn ();
57
+ cudaq::info (" Created Braket quantum task {}" , taskArn);
58
+
59
+ Aws::Braket::Model::GetQuantumTaskRequest req;
60
+ req.SetQuantumTaskArn (taskArn);
61
+ auto getResponse = braketClient.GetQuantumTask (req);
62
+ if (!getResponse.IsSuccess ()) {
63
+ throw std::runtime_error (getResponse.GetError ().GetMessage ());
64
+ }
65
+ auto taskStatus = getResponse.GetResult ().GetStatus ();
66
+ while (
67
+ taskStatus != Aws::Braket::Model::QuantumTaskStatus::COMPLETED &&
68
+ taskStatus != Aws::Braket::Model::QuantumTaskStatus::FAILED &&
69
+ taskStatus != Aws::Braket::Model::QuantumTaskStatus::CANCELLED) {
70
+ std::this_thread::sleep_for (pollingInterval);
71
+
72
+ getResponse = braketClient.GetQuantumTask (req);
73
+ if (!getResponse.IsSuccess ()) {
74
+ throw std::runtime_error (getResponse.GetError ().GetMessage ());
75
+ }
76
+ taskStatus = getResponse.GetResult ().GetStatus ();
77
+ }
78
+
79
+ auto getResult = getResponse.GetResult ();
80
+ if (taskStatus != Aws::Braket::Model::QuantumTaskStatus::COMPLETED) {
81
+ // Task terminated without results
82
+ throw std::runtime_error (
83
+ fmt::format (" Braket task {} terminated without results. {}" ,
84
+ taskArn, getResult.GetFailureReason ()));
85
+ }
86
+
87
+ std::string outBucket = getResult.GetOutputS3Bucket ();
88
+ std::string outPrefix = getResult.GetOutputS3Directory ();
89
+
90
+ cudaq::info (" Fetching braket quantum task {} results from "
91
+ " s3://{}/{}/results.json" ,
92
+ taskArn, outBucket, outPrefix);
93
+
94
+ Aws::S3Crt::Model::GetObjectRequest resultsJsonRequest;
95
+ resultsJsonRequest.SetBucket (outBucket);
96
+ resultsJsonRequest.SetKey (fmt::format (" {}/results.json" , outPrefix));
97
+ auto s3Response = s3Client.GetObject (resultsJsonRequest);
98
+ if (!s3Response.IsSuccess ()) {
99
+ throw std::runtime_error (s3Response.GetError ().GetMessage ());
100
+ }
101
+ auto resultsJson = nlohmann::json::parse (
102
+ s3Response.GetResultWithOwnership ().GetBody ());
103
+ auto c = serverHelper->processResults (resultsJson, taskArn);
104
+
105
+ for (auto ®Name : c.register_names ()) {
106
+ results.emplace_back (c.to_map (regName), regName);
107
+ results.back ().sequentialData = c.sequential_data (regName);
108
+ }
109
+ }
110
+
111
+ return sample_result (results);
112
+ },
113
+ std::move (createOutcomes));
114
+ };
115
+ } // namespace cudaq
116
+
117
+ CUDAQ_REGISTER_TYPE (cudaq::Executor, cudaq::BraketExecutor, braket);
0 commit comments