-
Notifications
You must be signed in to change notification settings - Fork 239
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
172 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
import cirq | ||
import numpy as np | ||
import pickle | ||
import json | ||
import os | ||
import sys | ||
from collections import Counter | ||
from sklearn.metrics import mean_squared_error | ||
|
||
if len(sys.argv) > 1: | ||
data_path = sys.argv[1] | ||
else: | ||
data_path = '.' | ||
|
||
#define utility functions | ||
|
||
def simulate(circuit: cirq.Circuit) -> dict: | ||
"""This function simulates a Cirq circuit (without measurement) and outputs results in the format of histogram. | ||
""" | ||
simulator = cirq.Simulator() | ||
result = simulator.simulate(circuit) | ||
|
||
state_vector=result.final_state_vector | ||
|
||
histogram = dict() | ||
for i in range(len(state_vector)): | ||
population = abs(state_vector[i]) ** 2 | ||
if population > 1e-9: | ||
histogram[i] = population | ||
|
||
return histogram | ||
|
||
|
||
def histogram_to_category(histogram): | ||
"""This function takes a histogram representation of circuit execution results, and processes into labels as described in | ||
the problem description.""" | ||
assert abs(sum(histogram.values())-1)<1e-8 | ||
positive=0 | ||
for key in histogram.keys(): | ||
digits = bin(int(key))[2:].zfill(20) | ||
if digits[-1]=='0': | ||
positive+=histogram[key] | ||
|
||
return positive | ||
|
||
def count_gates(circuit: cirq.Circuit): | ||
"""Returns the number of 1-qubit gates, number of 2-qubit gates, number of 3-qubit gates....""" | ||
counter=Counter([len(op.qubits) for op in circuit.all_operations()]) | ||
|
||
#feel free to comment out the following two lines. But make sure you don't have k-qubit gates in your circuit | ||
#for k>2 | ||
for i in range(2,20): | ||
assert counter[i]==0 | ||
|
||
return counter | ||
|
||
def image_mse(image1,image2): | ||
# Using sklearns mean squared error: | ||
# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_squared_error.html | ||
return mean_squared_error(255*image1,255*image2) | ||
|
||
def test(): | ||
#load the actual hackthon data (fashion-mnist) | ||
images=np.load(data_path+'/images.npy') | ||
labels=np.load(data_path+'/labels.npy') | ||
|
||
#test part 1 | ||
|
||
n=len(images) | ||
mse=0 | ||
gatecount=0 | ||
|
||
for image in images: | ||
#encode image into circuit | ||
circuit,image_re=run_part1(image) | ||
|
||
#count the number of 2qubit gates used | ||
gatecount+=count_gates(circuit)[2] | ||
|
||
#calculate mse | ||
mse+=image_mse(image,image_re) | ||
|
||
#fidelity of reconstruction | ||
f=1-mse/n | ||
gatecount=gatecount/n | ||
|
||
#score for part1 | ||
score_part1=f*(0.999**gatecount) | ||
|
||
#test part 2 | ||
|
||
score=0 | ||
gatecount=0 | ||
n=len(images) | ||
|
||
for i in range(n): | ||
#run part 2 | ||
circuit,label=run_part2(images[i]) | ||
|
||
#count the gate used in the circuit for score calculation | ||
gatecount+=count_gates(circuit)[2] | ||
|
||
#check label | ||
if label==labels[i]: | ||
score+=1 | ||
#score | ||
score=score/n | ||
gatecount=gatecount/n | ||
|
||
score_part2=score*(0.999**gatecount) | ||
|
||
print(score_part1, ",", score_part2, ",", data_path, sep="") | ||
|
||
############################ | ||
# YOUR CODE HERE # | ||
############################ | ||
def encode(image): | ||
circuit=cirq.Circuit() | ||
if image[0][0]==0: | ||
circuit.append(cirq.rx(np.pi).on(cirq.LineQubit(0))) | ||
return circuit | ||
|
||
def decode(histogram): | ||
if 1 in histogram.keys(): | ||
image=np.array([[0,0],[0,0]]) | ||
else: | ||
image=np.array([[1,1],[1,1]]) | ||
return image | ||
|
||
def run_part1(image): | ||
#encode image into a circuit | ||
circuit=encode(image) | ||
|
||
#simulate circuit | ||
histogram=simulate(circuit) | ||
|
||
#reconstruct the image | ||
image_re=decode(histogram) | ||
|
||
return circuit,image_re | ||
|
||
def run_part2(image): | ||
# load the quantum classifier circuit | ||
with open('quantum_classifier.pickle', 'rb') as f: | ||
classifier=pickle.load(f) | ||
|
||
#encode image into circuit | ||
circuit=encode(image) | ||
|
||
#append with classifier circuit | ||
|
||
circuit.append(classifier) | ||
|
||
#simulate circuit | ||
histogram=simulate(circuit) | ||
|
||
#convert histogram to category | ||
label=histogram_to_category(histogram) | ||
|
||
#thresholding the label, any way you want | ||
if label>0.5: | ||
label=1 | ||
else: | ||
label=0 | ||
|
||
return circuit,label | ||
|
||
############################ | ||
# END YOUR CODE # | ||
############################ | ||
|
||
test() |