Skip to content

Commit b41c8ed

Browse files
authored
Enable more config inputs to gqe. (#11)
Signed-off-by: Alex McCaskey <[email protected]>
1 parent c59e951 commit b41c8ed

File tree

4 files changed

+32
-16
lines changed

4 files changed

+32
-16
lines changed

cudaqlib/operators/chemistry/drivers/pyscf/ExternalPySCFDriver.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ class external_pyscf
8888
// Import all the data we need from the execution.
8989
std::ifstream f(metadataFile);
9090
auto metadata = nlohmann::json::parse(f);
91-
printf("TEST\n%s\n", metadata.dump(4).c_str());
91+
9292
// Get the energy, num orbitals, and num qubits
9393
std::unordered_map<std::string, double> energies;
9494
for (auto &[energyName, E] : metadata["energies"].items())

python/cudaqlib/algorithms/gqe.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -223,11 +223,16 @@ def __internal_run_gqe(temperature_scheduler: TemperatureScheduler,
223223
fabric.log('circuit', json.dumps(min_indices))
224224
return min_energy, min_indices
225225

226-
227226
def gqe(cost, pool, config=None, **kwargs):
228-
cfg = get_default_config() if config == None else config
227+
cfg = get_default_config()
228+
if config == None:
229+
[setattr(cfg, a, kwargs[a]) for a in dir(cfg) if not a.startswith('_') and a in kwargs]
230+
else:
231+
cfg = config
232+
233+
# Don't let someone override the vocab_size
229234
cfg.vocab_size = len(pool)
230-
if 'max_iters' in kwargs: cfg.max_iters = kwargs['max_iters']
235+
231236
model = Transformer(
232237
cfg, cost, loss='exp') if 'model' not in kwargs else kwargs['model']
233238
optimizer = torch.optim.AdamW(

python/cudaqlib/algorithms/transformer.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from lightning import LightningModule
55
from .loss import ExpLogitMatching, GFlowLogitMatching
66

7+
78
def get_device():
89
if torch.cuda.is_available():
910
return 'cuda'
@@ -20,16 +21,18 @@ def __init__(self, **kwargs):
2021

2122
class Transformer(LightningModule):
2223

23-
def __init__(self, cfg, cost, loss="exp"):
24+
def __init__(self, cfg, cost, loss="exp"):
2425
super().__init__()
2526
self._label = 'label_stand_in'
2627
self.cfg = cfg
2728
gpt2cfg = GPT2Config(
28-
**{k: cfg[k] for k in GPT2Config().to_dict().keys() & cfg.keys()})
29+
**{k: cfg[k]
30+
for k in GPT2Config().to_dict().keys() & cfg.keys()})
2931
if cfg.small:
30-
gpt2cfg = SmallConfig(
31-
**
32-
{k: cfg[k] for k in GPT2Config().to_dict().keys() & cfg.keys()})
32+
gpt2cfg = SmallConfig(**{
33+
k: cfg[k]
34+
for k in GPT2Config().to_dict().keys() & cfg.keys()
35+
})
3336
self.transformer = GPT2LMHeadModel(gpt2cfg).to(get_device())
3437
self.ngates = cfg.ngates
3538
self.num_samples = cfg.num_samples
@@ -57,22 +60,31 @@ def set_cost(self, cost):
5760

5861
def gather(self, idx, logits_base):
5962
b_size = idx.shape[0]
60-
return torch.gather(logits_base, 2, idx.reshape(b_size, -1,
61-
1)).reshape(b_size, -1)
63+
return torch.gather(logits_base, 2,
64+
idx.reshape(b_size, -1, 1)).reshape(b_size, -1)
6265

6366
def computeCost(self, idx_output, pool, **kwargs):
64-
return torch.tensor([self._cost([pool[i] for i in row]) for row in idx_output],
65-
dtype=torch.float)
67+
return torch.tensor(
68+
[self._cost([pool[i] for i in row]) for row in idx_output],
69+
dtype=torch.float)
6670

67-
def train_step(self, pool, indices=None, energies=None, numQPUs=None, comm=None):
71+
def train_step(self,
72+
pool,
73+
indices=None,
74+
energies=None,
75+
numQPUs=None,
76+
comm=None):
6877
log_values = {}
6978
if energies is not None:
7079
assert indices is not None
7180
idx_output = indices[:, 1:]
7281
logits_base = self.generate_logits(idx_output)
7382
else:
7483
idx_output, logits_base = self.generate()
75-
energies = self.computeCost(idx_output, pool, numQPUs=numQPUs, comm=comm)
84+
energies = self.computeCost(idx_output,
85+
pool,
86+
numQPUs=numQPUs,
87+
comm=comm)
7688
logits_tensor = self.gather(idx_output, logits_base)
7789
allLogits = logits_tensor
7890

tools/chemistry/cudaq-pyscf.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ def iter_namespace(ns_pkg):
9494
for (k, v) in vars(args).items() if k not in filterArgs
9595
}
9696
res = hamiltonianGenerator.generate(args.xyz, args.basis, **filteredArgs)
97-
print(res)
9897

9998
exit(0)
10099

0 commit comments

Comments
 (0)