Skip to content

Commit

Permalink
Add an example of running circuit on density device.
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhuoyang Ye committed Jan 20, 2024
1 parent 045a289 commit 3ac389f
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 23 deletions.
Empty file added examples/density/dexample.py
Empty file.
5 changes: 2 additions & 3 deletions torchquantum/density/density_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import torch
import numpy as np
import torchquantum as tq

import functools
from typing import Callable, Union, Optional, List, Dict
from ..macro import C_DTYPE, ABC, ABC_ARRAY, INV_SQRT2
from ..util.utils import pauli_eigs, diag
Expand Down Expand Up @@ -227,10 +227,9 @@ def apply_unitary_density_bmm(density, mat, wires):
bsz = permuted_dag.shape[0]
expand_shape = [bsz] + list(matdag.shape)
new_density = permuted_dag.bmm(matdag.expand(expand_shape))
_matrix = torch.reshape(new_density[0], [2 ** n_qubit] * 2)
new_density = new_density.view(original_shape).permute(permute_back_dag)
return new_density


def gate_wrapper(
name,
mat,
Expand Down
5 changes: 2 additions & 3 deletions torchquantum/device/noisedevices.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@

from torchquantum.macro import C_DTYPE
from torchquantum.functional import func_name_dict, func_name_dict_collect
from torchquantum.density import density_mat, density_func
from typing import Union

__all__ = ["NoiseDevice"]
Expand All @@ -38,7 +37,7 @@ class NoiseDevice(nn.Module):
def __init__(
self,
n_wires: int,
device_name: str = "default",
device_name: str = "noisedevice",
bsz: int = 1,
device: Union[torch.device, str] = "cpu",
record_op: bool = False,
Expand Down Expand Up @@ -80,7 +79,7 @@ def name(self):
return self.__class__.__name__

def __repr__(self):
return f" class: {self.name} \n device name: {self.device_name} \n number of qubits: {self.n_wires} \n batch size: {self.bsz} \n current computing device: {self.state.device} \n recording op history: {self.record_op} \n current states: {repr(self.get_states_1d().cpu().detach().numpy())}"
return f" class: {self.name} \n device name: {self.device_name} \n number of qubits: {self.n_wires} \n batch size: {self.bsz} \n current computing device: {self.density.device} \n recording op history: {self.record_op} \n current states: {repr(self.get_probs_1d().cpu().detach().numpy())}"


'''
Expand Down
225 changes: 208 additions & 17 deletions torchquantum/functional/gate_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from torchpack.utils.logging import logger
from torchquantum.util import normalize_statevector


if TYPE_CHECKING:
from torchquantum.device import QuantumDevice
from torchquantum.device import QuantumDevice, NoiseDevice
else:
QuantumDevice = None

Expand Down Expand Up @@ -58,7 +59,7 @@ def apply_unitary_einsum(state, mat, wires):

# All affected indices will be summed over, so we need the same number
# of new indices
new_indices = ABC[total_wires : total_wires + len(device_wires)]
new_indices = ABC[total_wires: total_wires + len(device_wires)]

# The new indices of the state are given by the old ones with the
# affected indices replaced by the new_indices
Expand Down Expand Up @@ -139,17 +140,198 @@ def apply_unitary_bmm(state, mat, wires):
return new_state


def apply_unitary_density_einsum(density, mat, wires):
"""Apply the unitary to the densitymatrix using torch.einsum method.
Args:
density (torch.Tensor): The densitymatrix.
mat (torch.Tensor): The unitary matrix of the operation.
wires (int or List[int]): Which qubit the operation is applied to.
Returns:
torch.Tensor: The new statevector.
"""

device_wires = wires
n_qubit = int((density.dim() - 1) / 2)

# minus one because of batch
total_wires = len(density.shape) - 1

if len(mat.shape) > 2:
is_batch_unitary = True
bsz = mat.shape[0]
shape_extension = [bsz]
else:
is_batch_unitary = False
shape_extension = []

"""
Compute U \rho
"""
mat = mat.view(shape_extension + [2] * len(device_wires) * 2)
mat = mat.type(C_DTYPE).to(density.device)
if len(mat.shape) > 2:
# both matrix and state are in batch mode
# matdag is the dagger of mat
matdag = torch.conj(mat.permute([0, 2, 1]))
else:
# matrix no batch, state in batch mode
matdag = torch.conj(mat.permute([1, 0]))

# Tensor indices of the quantum state
density_indices = ABC[:total_wires]
print("density_indices", density_indices)

# Indices of the quantum state affected by this operation
affected_indices = "".join(ABC_ARRAY[list(device_wires)].tolist())
print("affected_indices", affected_indices)

# All affected indices will be summed over, so we need the same number
# of new indices
new_indices = ABC[total_wires: total_wires + len(device_wires)]
print("new_indices", new_indices)

# The new indices of the state are given by the old ones with the
# affected indices replaced by the new_indices
new_density_indices = functools.reduce(
lambda old_string, idx_pair: old_string.replace(idx_pair[0], idx_pair[1]),
zip(affected_indices, new_indices),
density_indices,
)
print("new_density_indices", new_density_indices)

# Use the last literal as the indice of batch
density_indices = ABC[-1] + density_indices
new_density_indices = ABC[-1] + new_density_indices
if is_batch_unitary:
new_indices = ABC[-1] + new_indices

# We now put together the indices in the notation numpy einsum
# requires
einsum_indices = (
f"{new_indices}{affected_indices}," f"{density_indices}->{new_density_indices}"
)
print("einsum_indices", einsum_indices)

new_density = torch.einsum(einsum_indices, mat, density)

"""
Compute U \rho U^\dagger
"""
print("dagger")

# Tensor indices of the quantum state
density_indices = ABC[:total_wires]
print("density_indices", density_indices)

# Indices of the quantum state affected by this operation
affected_indices = "".join(
ABC_ARRAY[[x + n_qubit for x in list(device_wires)]].tolist()
)
print("affected_indices", affected_indices)

# All affected indices will be summed over, so we need the same number
# of new indices
new_indices = ABC[total_wires: total_wires + len(device_wires)]
print("new_indices", new_indices)

# The new indices of the state are given by the old ones with the
# affected indices replaced by the new_indices
new_density_indices = functools.reduce(
lambda old_string, idx_pair: old_string.replace(idx_pair[0], idx_pair[1]),
zip(affected_indices, new_indices),
density_indices,
)
print("new_density_indices", new_density_indices)

density_indices = ABC[-1] + density_indices
new_density_indices = ABC[-1] + new_density_indices
if is_batch_unitary:
new_indices = ABC[-1] + new_indices

# We now put together the indices in the notation numpy einsum
# requires
einsum_indices = (
f"{density_indices}," f"{affected_indices}{new_indices}->{new_density_indices}"
)
print("einsum_indices", einsum_indices)

new_density = torch.einsum(einsum_indices, density, matdag)

return new_density


def apply_unitary_density_bmm(density, mat, wires):
"""Apply the unitary to the DensityMatrix using torch.bmm method.
Args:
state (torch.Tensor): The statevector.
mat (torch.Tensor): The unitary matrix of the operation.
wires (int or List[int]): Which qubit the operation is applied to.
Returns:
torch.Tensor: The new statevector.
"""
device_wires = wires
n_qubit = density.dim() // 2
mat = mat.type(C_DTYPE).to(density.device)
"""
Compute U \rho
"""
devices_dims = [w + 1 for w in device_wires]
permute_to = list(range(density.dim()))
for d in sorted(devices_dims, reverse=True):
del permute_to[d]
permute_to = permute_to[:1] + devices_dims + permute_to[1:]
permute_back = list(np.argsort(permute_to))
original_shape = density.shape
permuted = density.permute(permute_to).reshape([original_shape[0], mat.shape[-1], -1])

if len(mat.shape) > 2:
# both matrix and state are in batch mode
new_density = mat.bmm(permuted)
else:
# matrix no batch, state in batch mode
bsz = permuted.shape[0]
expand_shape = [bsz] + list(mat.shape)
new_density = mat.expand(expand_shape).bmm(permuted)
new_density = new_density.view(original_shape).permute(permute_back)
"""
Compute \rho U^\dagger
"""
matdag = torch.conj(mat)
matdag = matdag.type(C_DTYPE).to(density.device)

devices_dims_dag = [n_qubit + w + 1 for w in device_wires]
permute_to_dag = list(range(density.dim()))
for d in sorted(devices_dims_dag, reverse=True):
del permute_to_dag[d]
permute_to_dag = permute_to_dag + devices_dims_dag
permute_back_dag = list(np.argsort(permute_to_dag))
permuted_dag = new_density.permute(permute_to_dag).reshape([original_shape[0], -1, matdag.shape[0]])

if len(matdag.shape) > 2:
# both matrix and state are in batch mode
new_density = permuted_dag.bmm(matdag)
else:
# matrix no batch, state in batch mode
bsz = permuted_dag.shape[0]
expand_shape = [bsz] + list(matdag.shape)
new_density = permuted_dag.bmm(matdag.expand(expand_shape))
new_density = new_density.view(original_shape).permute(permute_back_dag)
return new_density


def gate_wrapper(
name,
mat,
method,
q_device: QuantumDevice,
wires,
params=None,
n_wires=None,
static=False,
parent_graph=None,
inverse=False,
name,
mat,
method,
q_device: QuantumDevice,
wires,
params=None,
n_wires=None,
static=False,
parent_graph=None,
inverse=False,
):
"""Perform the phaseshift gate.
Expand Down Expand Up @@ -249,8 +431,17 @@ def gate_wrapper(
else:
matrix = matrix.permute(1, 0)
assert np.log2(matrix.shape[-1]) == len(wires)
state = q_device.states
if method == "einsum":
q_device.states = apply_unitary_einsum(state, matrix, wires)
elif method == "bmm":
q_device.states = apply_unitary_bmm(state, matrix, wires)
if q_device.device_name=="noisedevice":
density = q_device.densities
print(density.shape)
if method == "einsum":
return
elif method == "bmm":
q_device.densities = apply_unitary_density_bmm(density, matrix, wires)
else:
state = q_device.states
if method == "einsum":
q_device.states = apply_unitary_einsum(state, matrix, wires)
elif method == "bmm":
q_device.states = apply_unitary_bmm(state, matrix, wires)

0 comments on commit 3ac389f

Please sign in to comment.