Skip to content

Commit

Permalink
added sample test file
Browse files Browse the repository at this point in the history
  • Loading branch information
mjk committed Jan 29, 2023
1 parent 6304f8a commit a72ee6c
Showing 1 changed file with 172 additions and 0 deletions.
172 changes: 172 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import cirq
import numpy as np
import pickle
import json
import os
import sys
from collections import Counter
from sklearn.metrics import mean_squared_error

if len(sys.argv) > 1:
data_path = sys.argv[1]
else:
data_path = '.'

#define utility functions

def simulate(circuit: cirq.Circuit) -> dict:
"""This function simulates a Cirq circuit (without measurement) and outputs results in the format of histogram.
"""
simulator = cirq.Simulator()
result = simulator.simulate(circuit)

state_vector=result.final_state_vector

histogram = dict()
for i in range(len(state_vector)):
population = abs(state_vector[i]) ** 2
if population > 1e-9:
histogram[i] = population

return histogram


def histogram_to_category(histogram):
"""This function takes a histogram representation of circuit execution results, and processes into labels as described in
the problem description."""
assert abs(sum(histogram.values())-1)<1e-8
positive=0
for key in histogram.keys():
digits = bin(int(key))[2:].zfill(20)
if digits[-1]=='0':
positive+=histogram[key]

return positive

def count_gates(circuit: cirq.Circuit):
"""Returns the number of 1-qubit gates, number of 2-qubit gates, number of 3-qubit gates...."""
counter=Counter([len(op.qubits) for op in circuit.all_operations()])

#feel free to comment out the following two lines. But make sure you don't have k-qubit gates in your circuit
#for k>2
for i in range(2,20):
assert counter[i]==0

return counter

def image_mse(image1,image2):
# Using sklearns mean squared error:
# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_squared_error.html
return mean_squared_error(255*image1,255*image2)

def test():
#load the actual hackthon data (fashion-mnist)
images=np.load(data_path+'/images.npy')
labels=np.load(data_path+'/labels.npy')

#test part 1

n=len(images)
mse=0
gatecount=0

for image in images:
#encode image into circuit
circuit,image_re=run_part1(image)

#count the number of 2qubit gates used
gatecount+=count_gates(circuit)[2]

#calculate mse
mse+=image_mse(image,image_re)

#fidelity of reconstruction
f=1-mse/n
gatecount=gatecount/n

#score for part1
score_part1=f*(0.999**gatecount)

#test part 2

score=0
gatecount=0
n=len(images)

for i in range(n):
#run part 2
circuit,label=run_part2(images[i])

#count the gate used in the circuit for score calculation
gatecount+=count_gates(circuit)[2]

#check label
if label==labels[i]:
score+=1
#score
score=score/n
gatecount=gatecount/n

score_part2=score*(0.999**gatecount)

print(score_part1, ",", score_part2, ",", data_path, sep="")

############################
# YOUR CODE HERE #
############################
def encode(image):
circuit=cirq.Circuit()
if image[0][0]==0:
circuit.append(cirq.rx(np.pi).on(cirq.LineQubit(0)))
return circuit

def decode(histogram):
if 1 in histogram.keys():
image=np.array([[0,0],[0,0]])
else:
image=np.array([[1,1],[1,1]])
return image

def run_part1(image):
#encode image into a circuit
circuit=encode(image)

#simulate circuit
histogram=simulate(circuit)

#reconstruct the image
image_re=decode(histogram)

return circuit,image_re

def run_part2(image):
# load the quantum classifier circuit
with open('quantum_classifier.pickle', 'rb') as f:
classifier=pickle.load(f)

#encode image into circuit
circuit=encode(image)

#append with classifier circuit

circuit.append(classifier)

#simulate circuit
histogram=simulate(circuit)

#convert histogram to category
label=histogram_to_category(histogram)

#thresholding the label, any way you want
if label>0.5:
label=1
else:
label=0

return circuit,label

############################
# END YOUR CODE #
############################

test()

0 comments on commit a72ee6c

Please sign in to comment.