Source code for mrmustard.lab.transformations.base
# Copyright 2023 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains the base classes for the available unitaries and channels on quantum states.
In the docstrings defining the available unitaries we provide a definition in terms of
the symplectic matrix :math:`S` and the real vector :math:`d`. For deterministic Gaussian channels,
we use the two matrices :math:`X` and :math:`Y` and the vector :math:`d`. Additionally, we
provide the ``(A, b, c)`` triples that define the transformation in the Fock Bargmann
representation.
"""
from __future__ import annotations
from abc import abstractmethod
from collections.abc import Sequence
from typing import Self
from mrmustard import math, settings
from mrmustard.physics.ansatz import ArrayAnsatz, PolyExpAnsatz
from mrmustard.physics.ansatz_factory import AnsatzFactory
from mrmustard.physics.bargmann_utils import XY_of_channel, au2Symplectic, symplectic2Au
from mrmustard.physics.triples import XY_to_channel_Abc
from mrmustard.physics.wires import ReprEnum, Wires
from mrmustard.utils.typing import ComplexMatrix, ComplexTensor, RealMatrix, Scalar, Vector
from ..circuit_components import CircuitComponent
__all__ = ["Channel", "Map", "Operation", "Transformation", "Unitary"]
[docs]
class Transformation(CircuitComponent):
r"""Base class for all transformations."""
[docs]
@classmethod
@abstractmethod
def from_ansatz(
cls,
modes_out: Sequence[int],
modes_in: Sequence[int],
ansatz: PolyExpAnsatz | ArrayAnsatz | None = None,
name: str | None = None,
) -> Self:
r"""Initializes a transformation of type ``cls`` given modes and an ansatz.
Args:
modes_out: The output modes of this transformation.
modes_in: The input modes of this transformation.
ansatz: The ansatz of this transformation.
name: The name of this transformation.
Returns:
A transformation.
"""
[docs]
@classmethod
def from_bargmann(
cls,
modes_out: Sequence[int],
modes_in: Sequence[int],
triple: tuple,
name: str | None = None,
) -> Self:
r"""Initialize a Transformation from the given Bargmann triple (A,b,c)
which parametrizes the Bargmann function of the transformation as
:math:`c * exp(0.5*z^T A z + b^T z)`.
"""
return cls.from_ansatz(modes_out, modes_in, PolyExpAnsatz(*triple), name)
[docs]
@classmethod
def from_fock(
cls,
modes_out: Sequence[int],
modes_in: Sequence[int],
array: ComplexTensor,
batch_dims: int = 0,
name: str | None = None,
) -> Self:
r"""Initializes a transformation of type ``cls`` given modes and a fock array.
Args:
modes_out: The output modes of this transformation.
modes_in: The input modes of this transformation.
array: The fock array of this transformation.
batch_dims: The number of batch dimensions in the given array.
name: The name of this transformation.
Returns:
A transformation in the Fock representation.
"""
return cls.from_ansatz(modes_in, modes_out, ArrayAnsatz(array, batch_dims), name)
[docs]
@classmethod
def from_quadrature(
cls,
modes_out: Sequence[int],
modes_in: Sequence[int],
triple: tuple,
phi: float = 0,
name: str | None = None,
) -> Self:
r"""Initialize a Transformation from the given quadrature triple (A, b, c).
The triple parametrizes the quadrature representation of the transformation as
:math:`c * exp(0.5*x^T A x + b^T x)`.
"""
from ..circuit_components_utils.b_to_q import BtoQ # noqa: PLC0415
QtoB_out = BtoQ(modes_out, phi).inverse()
QtoB_in = BtoQ(modes_in, phi).inverse().dual
QQ = cls.from_ansatz(modes_out, modes_in, PolyExpAnsatz(*triple))
BB = QtoB_in >> QQ >> QtoB_out
return cls.from_ansatz(modes_out, modes_in, BB.ansatz, name)
[docs]
def inverse(self) -> Transformation:
r"""Returns the mathematical inverse of the transformation, if it exists.
Note that it can be unphysical, for example when the original is not unitary.
>>> from mrmustard.lab import GaussianDM, Identity, Operation
>>> rho = GaussianDM.random(modes=0, seed=1)
>>> rho_as_operator = Operation.from_bargmann([0], [0], rho.ansatz.triple)
>>> assert rho_as_operator >> rho_as_operator.inverse() == Identity([0])
Returns:
The inverse of the transformation.
Raises:
NotImplementedError: If the input and output wires have different lengths.
NotImplementedError: If the transformation is not in the Bargmann representation.
"""
if not len(self.wires.input) == len(self.wires.output):
raise NotImplementedError(
"Only Transformations with the same number of input and output wires are supported.",
)
if not isinstance(self.ansatz, PolyExpAnsatz): # pragma: no cover
raise NotImplementedError("Only Bargmann representation is supported.")
A_orig, b_orig, c_orig = self.ansatz.triple
A, b, _ = self.dual.to_standard_order().ansatz.conj.triple
A_inv = math.inv(A)
b_of_inverse = math.einsum("...ij,...j->...i", -math.inv(A), b)
in_idx = self.wires.input.indices
out_idx = self.wires.output.indices
A_orig_out = A_orig[..., out_idx, :][..., :, out_idx]
A_inv_in = A_inv[..., in_idx, :][..., :, in_idx]
b_orig_out = b_orig[..., out_idx]
b_of_inverse_in = b_of_inverse[..., in_idx]
m = A.shape[-1] // 2
Im = math.broadcast_to(math.eye(m, dtype=math.complex128), (*A.shape[:-2], m, m))
M = math.block([[A_orig_out, -Im], [-Im, A_inv_in]])
combined_b = math.concat([b_orig_out, b_of_inverse_in], axis=-1)
c_of_inverse = (
1
/ c_orig
* math.sqrt(math.cast(math.det(1j * M), dtype=math.complex128))
* math.exp(
0.5 * math.einsum("...i,...ij,...j->...", combined_b, math.inv(M), combined_b),
)
)
name = self.name.removesuffix("_inv") if self.name.endswith("_inv") else self.name + "_inv"
return self._from_attributes(
PolyExpAnsatz(
A_inv,
b_of_inverse,
c_of_inverse,
),
self.wires.copy(new_ids=True),
name,
)
[docs]
class Operation(Transformation):
r"""A CircuitComponent with input and output wires on the ket side. Operation are allowed
to have a different number of input and output wires.
"""
short_name = "Op"
[docs]
@classmethod
def from_ansatz(
cls,
modes_out: Sequence[int],
modes_in: Sequence[int],
ansatz: PolyExpAnsatz | ArrayAnsatz | None = None,
name: str | None = None,
) -> Operation:
if not isinstance(modes_out, set) and sorted(modes_out) != list(modes_out):
raise ValueError(f"Output modes must be sorted. got {modes_out}")
if not isinstance(modes_in, set) and sorted(modes_in) != list(modes_in):
raise ValueError(f"Input modes must be sorted. got {modes_in}")
wires = Wires(set(), set(), set(modes_out), set(modes_in))
if ansatz is not None:
ansatz_factory, representation = AnsatzFactory.from_ansatz(ansatz)
if representation == ReprEnum.FOCK:
for w in wires.quantum:
w.repr = ReprEnum.FOCK
w.fock_shape = ansatz.core_shape[w.index]
else:
ansatz_factory = None
return Operation(
ansatz_factory=ansatz_factory,
wires=wires,
name=name,
)
[docs]
class Unitary(Operation):
r"""Base class for all unitary transformations."""
short_name = "U"
@property
def symplectic(self):
r"""Returns the symplectic matrix that corresponds to this unitary."""
return au2Symplectic(self.ansatz.A)
[docs]
@classmethod
def from_ansatz(
cls,
modes_out: Sequence[int],
modes_in: Sequence[int],
ansatz: PolyExpAnsatz | ArrayAnsatz | None = None,
name: str | None = None,
) -> Unitary:
if not isinstance(modes_out, set) and sorted(modes_out) != list(modes_out):
raise ValueError(f"Output modes must be sorted. got {modes_out}")
if not isinstance(modes_in, set) and sorted(modes_in) != list(modes_in):
raise ValueError(f"Input modes must be sorted. got {modes_in}")
wires = Wires(set(), set(), set(modes_out), set(modes_in))
if ansatz is not None:
ansatz_factory, representation = AnsatzFactory.from_ansatz(ansatz)
if representation == ReprEnum.FOCK:
for w in wires.quantum:
w.repr = ReprEnum.FOCK
w.fock_shape = ansatz.core_shape[w.index]
else:
ansatz_factory = None
return Unitary(
ansatz_factory=ansatz_factory,
wires=wires,
name=name,
)
[docs]
@classmethod
def from_symplectic(
cls, modes: Sequence[int], S: RealMatrix, name: str | None = None
) -> Unitary:
r"""A method for constructing a ``Unitary`` from its symplectic representation.
>>> from mrmustard import math
>>> from mrmustard.lab import Unitary, Identity
>>> S = math.eye(2)
>>> U = Unitary.from_symplectic([0], S)
>>> assert U == Identity([0])
Args:
modes: the modes that we want the unitary to act on (should be a list of int)
S: the symplectic representation (in XXPP order)
name: A name for the unitary. If not provided, the class name will be used.
"""
m = len(modes)
batch_shape = S.shape[:-2]
A = symplectic2Au(S)
b = math.zeros((*batch_shape, 2 * m), dtype="complex128")
A_inin = A[..., m:, m:]
c = ((-1) ** m * math.det(A_inin @ math.conj(A_inin) - math.eye_like(A_inin))) ** 0.25
return Unitary.from_bargmann(modes, modes, (A, b, c), name=name)
[docs]
@classmethod
def random(
cls,
modes: int | tuple[int, ...],
max_r: float = 1.0,
seed: int | None = None,
name: str | None = None,
) -> Unitary:
r"""Returns a random unitary.
>>> from mrmustard.lab import Unitary
>>> U = Unitary.random((0, 1, 2), max_r=1.2)
>>> assert U.modes == (0,1,2)
Args:
modes: The modes of the unitary.
max_r: The maximum squeezing parameter.
seed: The random seed. If ``None``, the global seed is used.
name: A name for the unitary. If not provided, the class name will be used.
Returns:
The random Unitary.
Raises:
ValueError: if ``modes`` is an empty tuple.
"""
modes = (modes,) if isinstance(modes, int) else modes
if len(modes) == 0:
raise ValueError("Cannot create a random unitary with no modes.")
m = len(modes)
S = math.random_symplectic(m, max_r, seed=seed)
return Unitary.from_symplectic(modes, S, name=name)
[docs]
def inverse(self) -> Unitary:
r"""Returns the inverse of the unitary.
>>> from mrmustard.lab import Unitary, Identity
>>> u = Unitary.random((0, 1, 2))
>>> assert u >> u.inverse() == Identity(u.modes)
"""
unitary_dual = self.dual
name = self.name.removesuffix("_inv") if self.name.endswith("_inv") else self.name + "_inv"
ret = Unitary(
ansatz_factory=unitary_dual.ansatz_factory,
wires=unitary_dual.wires,
name=name,
)
for param in self.parameters:
ret.parameters[param] = self.parameters[param]
return ret
def __rshift__(self, other: Scalar | CircuitComponent) -> Scalar | CircuitComponent:
ret = super().__rshift__(other)
if isinstance(other, Unitary):
return Unitary(ansatz_factory=ret.ansatz_factory, wires=ret.wires)
if isinstance(other, Channel):
return Channel(ansatz_factory=ret.ansatz_factory, wires=ret.wires)
return ret
[docs]
class Map(Transformation):
r"""A ``CircuitComponent`` more general than ``Channel``, which are CPTP maps."""
short_name = "Map"
[docs]
@classmethod
def from_ansatz(
cls,
modes_out: Sequence[int],
modes_in: Sequence[int],
ansatz: PolyExpAnsatz | ArrayAnsatz | None = None,
name: str | None = None,
) -> Map:
if not isinstance(modes_out, set) and sorted(modes_out) != list(modes_out):
raise ValueError(f"Output modes must be sorted. got {modes_out}")
if not isinstance(modes_in, set) and sorted(modes_in) != list(modes_in):
raise ValueError(f"Input modes must be sorted. got {modes_in}")
wires = Wires(set(modes_out), set(modes_in), set(modes_out), set(modes_in))
if ansatz is not None:
ansatz_factory, representation = AnsatzFactory.from_ansatz(ansatz)
if representation == ReprEnum.FOCK:
for w in wires.quantum:
w.repr = ReprEnum.FOCK
w.fock_shape = ansatz.core_shape[w.index]
else:
ansatz_factory = None
return Map(
ansatz_factory=ansatz_factory,
wires=wires,
name=name,
)
[docs]
class Channel(Map):
r"""Base class for all CPTP channels."""
short_name = "Ch"
@property
def is_CP(self) -> bool:
r"""Whether this channel is completely positive (CP).
>>> from mrmustard.lab import Channel
>>> channel = Channel.random((0, 1, 2))
>>> assert channel.is_CP
"""
if self.ansatz._lin_sup:
raise NotImplementedError(
"Physicality conditions are not implemented for a mixture of states.",
)
if self.ansatz.num_derived_vars > 0:
raise NotImplementedError(
"Physicality conditions are not implemented for derived variables.",
)
if isinstance(self.ansatz, ArrayAnsatz):
raise NotImplementedError(
"Physicality conditions are not implemented for states with ArrayAnsatz.",
)
A = self.ansatz.A
m = A.shape[-1] // 2
gamma_A = A[..., :m, m:]
if (
math.real(math.norm(gamma_A - math.conj(math.einsum("...ij->...ji", gamma_A))))
> settings.ATOL
): # checks if gamma_A is Hermitian
return False
return math.all(math.real(math.eigvals(gamma_A)) > -settings.ATOL)
@property
def is_TP(self) -> bool:
r"""Whether this channel is trace preserving (TP).
>>> from mrmustard.lab import Channel
>>> channel = Channel.random((0, 1, 2))
>>> assert channel.is_TP
"""
if self.ansatz._lin_sup:
raise NotImplementedError(
"Physicality conditions are not implemented for a mixture of states.",
)
if self.ansatz.num_derived_vars > 0:
raise NotImplementedError(
"Physicality conditions are not implemented for derived variables.",
)
if isinstance(self.ansatz, ArrayAnsatz):
raise NotImplementedError(
"Physicality conditions are not implemented for states with ArrayAnsatz.",
)
A = self.ansatz.A
m = A.shape[-1] // 2
gamma_A = A[..., :m, m:]
lambda_A = A[..., m:, m:]
temp_A = (
gamma_A
+ math.conj(math.einsum("...ij->...ji", lambda_A))
@ math.inv(math.eye(m) - math.einsum("...ij->...ji", gamma_A))
@ lambda_A
)
return math.all(math.real(math.norm(temp_A - math.eye(m))) < settings.ATOL)
@property
def is_physical(self) -> bool:
r"""Whether this channel is physical (i.e. CPTP).
>>> from mrmustard.lab import Channel
>>> channel = Channel.random((0, 1, 2))
>>> assert channel.is_physical
"""
return self.is_CP and self.is_TP
@property
def XY(self) -> tuple[ComplexMatrix, ComplexMatrix]:
r"""Returns the X and Y matrix corresponding to the channel.
>>> from mrmustard.lab import Channel
>>> channel = Channel.random((0, 1))
>>> X, Y = channel.XY
>>> assert X.shape == (4, 4)
>>> assert Y.shape == (4, 4)
"""
return XY_of_channel(self.ansatz.A)
[docs]
@classmethod
def from_ansatz(
cls,
modes_out: Sequence[int],
modes_in: Sequence[int],
ansatz: PolyExpAnsatz | ArrayAnsatz | None = None,
name: str | None = None,
) -> Channel:
if not isinstance(modes_out, set) and sorted(modes_out) != list(modes_out):
raise ValueError(f"Output modes must be sorted. got {modes_out}")
if not isinstance(modes_in, set) and sorted(modes_in) != list(modes_in):
raise ValueError(f"Input modes must be sorted. got {modes_in}")
wires = Wires(set(modes_out), set(modes_in), set(modes_out), set(modes_in))
if ansatz is not None:
ansatz_factory, representation = AnsatzFactory.from_ansatz(ansatz)
if representation == ReprEnum.FOCK:
for w in wires.quantum:
w.repr = ReprEnum.FOCK
w.fock_shape = ansatz.core_shape[w.index]
else:
ansatz_factory = None
return Channel(
ansatz_factory=ansatz_factory,
wires=wires,
name=name,
)
[docs]
@classmethod
def from_XY(
cls,
modes_out: Sequence[int],
modes_in: Sequence[int],
X: RealMatrix,
Y: RealMatrix,
d: Vector | None = None,
name: str | None = None,
) -> Channel:
r"""Initialize a Channel from its XY representation.
>>> from mrmustard.lab import Attenuator, Channel
>>> X = math.eye(2)
>>> Y = math.zeros((2,2))
>>> channel = Channel.from_XY([0], [0], X,Y)
>>> assert channel == Attenuator(0, transmissivity=1)
Args:
modes_out: The output modes of the channel.
modes_in: The input modes of the channel.
X: The X matrix of the channel.
Y: The Y matrix of the channel.
d: The displacement vector of the channel.
name: A name for the channel. If not provided, the class name will be used.
Raises:
ValueError: If the dimensions of the X,Y matrices and the number of modes don't match.
.. details::
Each Gaussian channel transforms a state with covarince matrix :math:`\Sigma` and mean :math:`\mu`
into a state with covariance matrix :math:`X \Sigma X^T + Y` and vector of means :math:`X\mu + d`.
This channel has a Bargmann triple that is computed in https://arxiv.org/pdf/2209.06069. We borrow
the formulas from the paper to implement the corresponding channel.
"""
if X.shape[-2:] != (2 * len(modes_out), 2 * len(modes_in)) or Y.shape[-2:] != (
2 * len(modes_out),
2 * len(modes_out),
):
raise ValueError(
f"The dimension of XY matrices ({X.shape}, {Y.shape}) and number of modes ({len(modes_in), len(modes_out)}) don't match.",
)
return Channel.from_bargmann(modes_out, modes_in, XY_to_channel_Abc(X, Y, d), name=name)
[docs]
@classmethod
def random(
cls,
modes: int | tuple[int, ...],
max_r: float = 1.0,
seed: int | None = None,
name: str | None = None,
) -> Channel:
r"""A random channel without displacement.
>>> from mrmustard.lab import Channel
>>> channel = Channel.random((0, 1, 2), max_r=1.2)
>>> assert channel.modes == (0, 1, 2)
Args:
modes: The modes of the channel.
max_r: The maximum squeezing parameter.
seed: The random seed. If ``None``, the global seed is used.
name: A name for the channel. If not provided, the class name will be used.
Returns:
The random channel.
Raises:
ValueError: if ``modes`` is an empty tuple.
"""
modes = (modes,) if isinstance(modes, int) else modes
if len(modes) == 0:
raise ValueError("Cannot create a random channel with no modes.")
from mrmustard.lab.states import Vacuum # noqa: PLC0415
m = len(modes)
U = Unitary.random(range(3 * m), max_r, seed=seed)
u_psi = Vacuum(range(2 * m)) >> U
ansatz = u_psi.ansatz
kraus = ansatz.conj.contract(ansatz, idxs=(list(range(2 * m)), list(range(2 * m))))
return Channel.from_bargmann(modes, modes, kraus.triple, name=name)
def __rshift__(self, other: Scalar | CircuitComponent) -> Scalar | CircuitComponent:
ret = super().__rshift__(other)
if isinstance(other, (Channel, Unitary)):
return Channel(ansatz_factory=ret.ansatz_factory, wires=ret.wires)
return ret
_modules/mrmustard/lab/transformations/base
Download Python script
Download Notebook
View on GitHub