Source code for mrmustard.physics.representations

# Copyright 2023 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""
This module contains the classes for the available representations.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Iterable, Union
import os
from matplotlib import colors
import matplotlib.pyplot as plt
import numpy as np

from IPython.display import display, HTML
from mako.template import Template

from mrmustard import math, settings
from mrmustard.physics.gaussian_integrals import (
    contract_two_Abc,
    reorder_abc,
    complex_gaussian_integral,
)
from mrmustard.physics.ansatze import Ansatz, PolyExpAnsatz, ArrayAnsatz
from mrmustard.utils.typing import (
    Batch,
    ComplexMatrix,
    ComplexTensor,
    ComplexVector,
    Scalar,
    Tensor,
)

__all__ = ["Representation", "Bargmann", "Fock"]


[docs] class Representation(ABC): r""" A base class for representations. Representations can be initialized using the ``from_ansatz`` method, which automatically equips them with all the functionality required to perform mathematical operations, such as equality, multiplication, subtraction, etc. """ @property @abstractmethod def ansatz(self) -> Ansatz: r""" The ansatz of the representation. """
[docs] @abstractmethod def reorder(self, order: tuple[int, ...] | list[int]) -> Representation: r""" Reorders the representation indices. """
[docs] @abstractmethod def from_ansatz(cls, ansatz: Ansatz) -> Representation: # pragma: no cover r""" Returns a representation from an ansatz. """
def __eq__(self, other: Representation) -> bool: r""" Whether this representation is equal to another. """ return self.ansatz == other.ansatz def __add__(self, other: Representation) -> Representation: r""" Adds this representation to another. """ if self.__class__.__name__ != other.__class__.__name__: msg = f"Cannot add ``{self.__class__.__name__}`` representation to " msg += f"``{other.__class__.__name__}`` representation." raise ValueError(msg) return self.from_ansatz(self.ansatz + other.ansatz) def __sub__(self, other) -> Representation: r""" Subtracts another representation from this one. """ return self.from_ansatz(self.ansatz - other.ansatz) def __mul__(self, other: Representation | Scalar) -> Representation: r""" Multiplies this representation by another or by a scalar. """ try: return self.from_ansatz(self.ansatz * other.ansatz) except AttributeError: return self.from_ansatz(self.ansatz * other) def __rmul__(self, other: Representation | Scalar) -> Representation: r""" Multiplies this representation by another or by a scalar on the right. """ return self.__mul__(other) def __truediv__(self, other: Representation | Scalar) -> Representation: r""" Divides this representation by another or by a scalar. """ try: return self.from_ansatz(self.ansatz / other.ansatz) except AttributeError: return self.from_ansatz(self.ansatz / other) def __rtruediv__(self, other: Representation | Scalar) -> Representation: r""" Divides this representation by another or by a scalar on the right. """ return self.from_ansatz(other / self.ansatz) def __and__(self, other: Representation) -> Representation: r""" Takes the outer product of this representation with another. """ return self.from_ansatz(self.ansatz & other.ansatz) def __getitem__(self, idx: int | tuple[int, ...]) -> Representation: r""" Stores the indices for contraction. """ raise NotImplementedError
[docs] class Bargmann(Representation): r""" The Fock-Bargmann representation of a broad class of quantum states, transformations, measurements, channels, etc. The ansatz available in this representation is a linear combination of exponentials of bilinear forms with a polynomial part: .. math:: F(z) = \sum_i \textrm{poly}_i(z) \textrm{exp}(z^T A_i z / 2 + z^T b_i) This function allows for vector space operations on Bargmann objects including linear combinations (``+``), outer product (``&``), and inner product (``@``). .. code-block :: >>> from mrmustard.physics.representations import Bargmann >>> from mrmustard.physics.triples import displacement_gate_Abc, vacuum_state_Abc >>> # bargmann representation of one-mode vacuum >>> rep_vac = Bargmann(*vacuum_state_Abc(1)) >>> # bargmann representation of one-mode dgate with gamma=1+0j >>> rep_dgate = Bargmann(*displacement_gate_Abc(1)) The inner product is defined as the contraction of two Bargmann objects across marked indices. Indices are marked using ``__getitem__``. Once the indices are marked for contraction, they are be used the next time the inner product (``@``) is called. For example: .. code-block :: >>> import numpy as np >>> # mark indices for contraction >>> idx_vac = [0] >>> idx_rep = [1] >>> # bargmann representation of coh = vacuum >> dgate >>> rep_coh = rep_vac[idx_vac] @ rep_dgate[idx_rep] >>> assert np.allclose(rep_coh.A, [[0,],]) >>> assert np.allclose(rep_coh.b, [1,]) >>> assert np.allclose(rep_coh.c, 0.6065306597126334) This can also be used to contract existing indices in a single Bargmann object, e.g. to implement the partial trace. .. code-block :: >>> trace = (rep_coh @ rep_coh.conj()).trace([0], [1]) >>> assert np.allclose(trace.A, 0) >>> assert np.allclose(trace.b, 0) >>> assert trace.c == 1 The ``A``, ``b``, and ``c`` parameters can be batched to represent superpositions. .. code-block :: >>> # bargmann representation of one-mode coherent state with gamma=1+0j >>> A_plus = [[0,],] >>> b_plus = [1,] >>> c_plus = 0.6065306597126334 >>> # bargmann representation of one-mode coherent state with gamma=-1+0j >>> A_minus = [[0,],] >>> b_minus = [-1,] >>> c_minus = 0.6065306597126334 >>> # bargmann representation of a superposition of coherent states >>> A = [A_plus, A_minus] >>> b = [b_plus, b_minus] >>> c = [c_plus, c_minus] >>> rep_coh_sup = Bargmann(A, b, c) Note that the operations that change the shape of the ansatz (outer product and inner product) do not automatically modify the ordering of the combined or leftover indices. However, the ``reordering`` method allows reordering the representation after the products have been carried out. Args: A: A batch of quadratic coefficient :math:`A_i`. b: A batch of linear coefficients :math:`b_i`. c: A batch of arrays :math:`c_i`. Note: The args can be passed non-batched, as they will be automatically broadcasted to the correct batch shape. """ def __init__( self, A: Batch[ComplexMatrix], b: Batch[ComplexVector], c: Batch[ComplexTensor] = 1.0, ): self._contract_idxs: tuple[int, ...] = () self._ansatz = PolyExpAnsatz(A, b, c) @property def ansatz(self) -> PolyExpAnsatz: r""" The ansatz of the representation. """ return self._ansatz
[docs] @classmethod def from_ansatz(cls, ansatz: PolyExpAnsatz) -> Bargmann: # pylint: disable=arguments-differ r""" Returns a Bargmann object from an ansatz object. """ return cls(ansatz.A, ansatz.b, ansatz.c)
@property def A(self) -> Batch[ComplexMatrix]: r""" The batch of quadratic coefficient :math:`A_i`. """ return self.ansatz.A @property def b(self) -> Batch[ComplexVector]: r""" The batch of linear coefficients :math:`b_i` """ return self.ansatz.b @property def c(self) -> Batch[ComplexTensor]: r""" The batch of arrays :math:`c_i`. """ return self.ansatz.c @property def triple( self, ) -> tuple[Batch[ComplexMatrix], Batch[ComplexVector], Batch[ComplexTensor]]: r""" The batch of triples :math:`(A_i, b_i, c_i)`. """ return self.A, self.b, self.c
[docs] def conj(self): r""" The conjugate of this Bargmann object. """ new = self.__class__(math.conj(self.A), math.conj(self.b), math.conj(self.c)) new._contract_idxs = self._contract_idxs # pylint: disable=protected-access return new
[docs] def trace(self, idx_z: tuple[int, ...], idx_zconj: tuple[int, ...]) -> Bargmann: r""" The partial trace over the given index pairs. Args: idx_z: The first part of the pairs of indices to trace over. idx_zconj: The second part. Returns: Bargmann: the ansatz with the given indices traced over """ if self.ansatz.degree > 0: raise NotImplementedError( "Partial trace is only supported for ansatze with polynomial of degree ``0``." ) A, b, c = [], [], [] for Abci in zip(self.A, self.b, self.c): Aij, bij, cij = complex_gaussian_integral(Abci, idx_z, idx_zconj, measure=-1.0) A.append(Aij) b.append(bij) c.append(cij) return Bargmann(A, b, c)
[docs] def reorder(self, order: tuple[int, ...] | list[int]) -> Bargmann: r""" Reorders the indices of the ``A`` matrix and ``b`` vector of the ``(A, b, c)`` triple in this Bargmann object. .. code-block:: >>> from mrmustard.physics.representations import Bargmann >>> from mrmustard.physics.triples import displacement_gate_Abc >>> rep_dgate1 = Bargmann(*displacement_gate_Abc([0.1, 0.2, 0.3])) >>> rep_dgate2 = Bargmann(*displacement_gate_Abc([0.2, 0.3, 0.1])) >>> assert rep_dgate1.reorder([1, 2, 0, 4, 5, 3]) == rep_dgate2 Args: order: The new order. Returns: The reordered Bargmann object. """ A, b, c = reorder_abc((self.A, self.b, self.c), order) return self.__class__(A, b, c)
[docs] def plot( self, just_phase: bool = False, with_measure: bool = False, log_scale: bool = False, xlim=(-2 * np.pi, 2 * np.pi), ylim=(-2 * np.pi, 2 * np.pi), ) -> tuple[plt.figure.Figure, plt.axes.Axes]: # pragma: no cover r""" Plots the Bargmann function :math:`F(z)` on the complex plane. Phase is represented by color, magnitude by brightness. The function can be multiplied by :math:`exp(-|z|^2)` to represent the Bargmann function times the measure function (for integration). Args: just_phase: Whether to plot only the phase of the Bargmann function. with_measure: Whether to plot the bargmann function times the measure function :math:`exp(-|z|^2)`. log_scale: Whether to plot the log of the Bargmann function. xlim: The `x` limits of the plot. ylim: The `y` limits of the plot. Returns: The figure and axes of the plot """ # eval F(z) on a grid of complex numbers X, Y = np.mgrid[xlim[0] : xlim[1] : 400j, ylim[0] : ylim[1] : 400j] Z = (X + 1j * Y).T f_values = self(Z[..., None]) if log_scale: f_values = np.log(np.abs(f_values)) * np.exp(1j * np.angle(f_values)) if with_measure: f_values = f_values * np.exp(-(np.abs(Z) ** 2)) # Get phase and magnitude of F(z) phases = np.angle(f_values) / (2 * np.pi) % 1 magnitudes = np.abs(f_values) magnitudes_scaled = magnitudes / np.max(magnitudes) # Convert to RGB hsv_values = np.zeros(f_values.shape + (3,)) hsv_values[..., 0] = phases hsv_values[..., 1] = 1 hsv_values[..., 2] = 1 if just_phase else magnitudes_scaled rgb_values = colors.hsv_to_rgb(hsv_values) # Plot the image fig, ax = plt.subplots() ax.imshow(rgb_values, origin="lower", extent=[xlim[0], xlim[1], ylim[0], ylim[1]]) ax.set_xlabel("$Re(z)$") ax.set_ylabel("$Im(z)$") name = "F_{" + self.ansatz.name + "}(z)" name = f"\\arg({name})\\log|{name}|" if log_scale else name title = name + "e^{-|z|^2}" if with_measure else name title = f"\\arg({name})" if just_phase else title ax.set_title(f"${title}$") plt.show(block=False) return fig, ax
[docs] def __call__(self, z: ComplexTensor) -> ComplexTensor: r""" Evaluates the Bargmann function at the given array of points. Args: z: The array of points. Returns: The value of the Bargmann function at ``z``. """ return self.ansatz(z)
def __getitem__(self, idx: int | tuple[int, ...]) -> Bargmann: r""" A copy of self with the given indices marked for contraction. """ idx = (idx,) if isinstance(idx, int) else idx for i in idx: if i >= self.ansatz.dim: raise IndexError( f"Index {i} out of bounds for ansatz {self.ansatz.__class__.__qualname__} of dimension {self.ansatz.dim}." ) new = self.__class__(self.A, self.b, self.c) new._contract_idxs = idx return new def __matmul__(self, other: Union[Bargmann, Fock]) -> Union[Bargmann, Fock]: r""" The inner product of ansatze across the marked indices. If ``other`` is ``Fock``, then ``self`` is converted to ``Fock`` before the contraction. Args: other: Another representation. Returns: A ``Bargmann`` representation if ``other`` is ``Bargmann``, or a ``Fock``representation if ``other`` is ``Fock``. """ idx_s = self._contract_idxs idx_o = other._contract_idxs # if ``other`` is ``Fock``, convert ``self`` to ``Fock`` if isinstance(other, Fock): from .converters import to_fock # pylint: disable=import-outside-toplevel # set same shape along the contracted axes, and default shape along the # axes that are not being contracted shape = [settings.AUTOCUTOFF_MAX_CUTOFF for _ in range(len(self.b[0]))] for i, j in zip(idx_s, idx_o): shape[i] = other.array.shape[1:][j] return to_fock(self, shape=shape)[idx_s] @ other[idx_o] if self.ansatz.degree > 0 or other.ansatz.degree > 0: raise NotImplementedError( "Inner product of ansatze is only supported for ansatze with polynomial of degree 0." ) Abc = [] if settings.UNSAFE_ZIP_BATCH: if self.ansatz.batch_size != other.ansatz.batch_size: raise ValueError( f"Batch size of the two ansatze must match since the settings.UNSAFE_ZIP_BATCH is {settings.UNSAFE_ZIP_BATCH}." ) for (A1, b1, c1), (A2, b2, c2) in zip( zip(self.A, self.b, self.c), zip(other.A, other.b, other.c) ): Abc.append(contract_two_Abc((A1, b1, c1), (A2, b2, c2), idx_s, idx_o)) else: for A1, b1, c1 in zip(self.A, self.b, self.c): for A2, b2, c2 in zip(other.A, other.b, other.c): Abc.append(contract_two_Abc((A1, b1, c1), (A2, b2, c2), idx_s, idx_o)) A, b, c = zip(*Abc) return Bargmann(A, b, c) def _repr_html_(self): # pragma: no cover template = Template(filename=os.path.dirname(__file__) + "/assets/bargmann.txt") display(HTML(template.render(rep=self)))
[docs] class Fock(Representation): r""" The Fock representation of a broad class of quantum states, transformations, measurements, channels, etc. The ansatz available in this representation is ``ArrayAnsatz``. This function allows for vector space operations on Fock objects including linear combinations, outer product (``&``), and inner product (``@``). .. code-block:: >>> from mrmustard.physics.representations import Fock >>> # initialize Fock objects >>> array1 = np.random.random((5,7,8)) >>> array2 = np.random.random((5,7,8)) >>> array3 = np.random.random((3,5,7,8)) # where 3 is the batch. >>> fock1 = Fock(array1) >>> fock2 = Fock(array2) >>> fock3 = Fock(array3, batched=True) >>> # linear combination can be done with the same batch dimension >>> fock4 = 1.3 * fock1 - fock2 * 2.1 >>> # division by a scalar >>> fock5 = fock1 / 1.3 >>> # inner product by contracting on marked indices >>> fock6 = fock1[2] @ fock3[2] >>> # outer product (tensor product) >>> fock7 = fock1 & fock3 >>> # conjugation >>> fock8 = fock1.conj() Args: array: the (batched) array in Fock representation. batched: whether the array input has a batch dimension. Note: The args can be passed non-batched, as they will be automatically broadcasted to the correct batch shape. """ def __init__(self, array: Batch[Tensor], batched=False): self._contract_idxs: tuple[int, ...] = () array = math.astensor(array) if not batched: array = array[None, ...] self._ansatz = ArrayAnsatz(array=array) @property def ansatz(self) -> ArrayAnsatz: r""" The ansatz of the representation. """ return self._ansatz
[docs] @classmethod def from_ansatz(cls, ansatz: ArrayAnsatz) -> Fock: # pylint: disable=arguments-differ r""" Returns a Fock object from an ansatz object. """ return cls(ansatz.array, batched=True)
@property def array(self) -> Batch[Tensor]: r""" The array from the ansatz. """ return self.ansatz.array
[docs] def conj(self): r""" The conjugate of this Fock object. """ new = self.from_ansatz(self.ansatz.conj) new._contract_idxs = self._contract_idxs # pylint: disable=protected-access return new
def __getitem__(self, idx: int | tuple[int, ...]) -> Fock: r""" Returns a copy of self with the given indices marked for contraction. """ idx = (idx,) if isinstance(idx, int) else idx for i in idx: if i >= len(self.array.shape): raise IndexError( f"Index {i} out of bounds for ansatz {self.ansatz.__class__.__qualname__} of dimension {self.ansatz.dim}." ) new = self.from_ansatz(self.ansatz) new._contract_idxs = idx return new def __matmul__(self, other: Union[Bargmann, Fock]) -> Fock: r""" Implements the inner product of ansatze across the marked indices. If ``other`` is ``Fock``, the two representations are automatically reduced before being contracted. The array of the returned representation has the largest possible dimension along every axis. If ``other`` is ``Bargmann``, it is converted to ``Fock`` before the contraction. Args: other: Another representation. Returns: A ``Fock``representation. """ idx_s = list(self._contract_idxs) idx_o = list(other._contract_idxs) # if ``other`` is ``Bargmann``, convert it to ``Fock`` if isinstance(other, Bargmann): from .converters import to_fock # pylint: disable=import-outside-toplevel # set same shape along the contracted axes, and default shape along the # axes that are not being contracted shape = [settings.AUTOCUTOFF_MAX_CUTOFF for _ in range(len(other.b[0]))] for i, j in zip(idx_s, idx_o): shape[j] = self.array.shape[1:][i] return self[idx_s] @ to_fock(other, shape=shape)[idx_o] # the number of batches in self and other n_batches_s = self.array.shape[0] n_batches_o = other.array.shape[0] # the shapes each batch in self and other shape_s = self.array.shape[1:] shape_o = other.array.shape[1:] # the shapes of the axes being contracted shape_s_contr = [shape_s[i] for i in idx_s] shape_o_contr = [shape_o[i] for i in idx_o] # compare the shapes along the axes being contracted if shape_o_contr != shape_s_contr: # calculate new shapes that maintain the largest possible dimension # along each of the contracted axes shape = [min(s, o) for s, o in zip(shape_s_contr, shape_o_contr)] new_shape_s = [n_batches_s] new_shape_s += [ shape[idx_s.index(i)] if i in idx_s else idx for i, idx in enumerate(shape_s) ] new_shape_o = [n_batches_o] new_shape_o += [ shape[idx_o.index(i)] if i in idx_o else idx for i, idx in enumerate(shape_o) ] return self.reduce(new_shape_s)[idx_s] @ other.reduce(new_shape_o)[idx_o] axes = [list(idx_s), list(idx_o)] new_array = [] for i in range(n_batches_s): for j in range(n_batches_o): new_array.append(math.tensordot(self.array[i], other.array[j], axes)) return self.from_ansatz(ArrayAnsatz(new_array))
[docs] def trace(self, idxs1: tuple[int, ...], idxs2: tuple[int, ...]) -> Fock: r"""Implements the partial trace over the given index pairs. Args: idxs1: The first part of the pairs of indices to trace over. idxs2: The second part. Returns: The traced-over Fock object. """ if len(idxs1) != len(idxs2) or not set(idxs1).isdisjoint(idxs2): raise ValueError("idxs must be of equal length and disjoint") order = ( [0] + [i + 1 for i in range(len(self.array.shape) - 1) if i not in idxs1 + idxs2] + [i + 1 for i in idxs1] + [i + 1 for i in idxs2] ) new_array = math.transpose(self.array, order) n = np.prod(new_array.shape[-len(idxs2) :]) new_array = math.reshape(new_array, new_array.shape[: -2 * len(idxs1)] + (n, n)) trace = math.trace(new_array) return self.from_ansatz(ArrayAnsatz([trace] if trace.shape == () else trace))
[docs] def reorder(self, order: tuple[int, ...] | list[int]) -> Fock: r""" Reorders the indices of the array with the given order. Args: order: The order. Does not need to refer to the batch dimension. Returns: The reordered Fock. """ return self.from_ansatz( ArrayAnsatz(math.transpose(self.array, [0] + [i + 1 for i in order])) )
[docs] def reduce(self, shape: Union[int, Iterable[int]]) -> Fock: r""" Returns a new ``Fock`` with a sliced array. .. code-block:: >>> from mrmustard import math >>> from mrmustard.physics.representations import Fock >>> array1 = math.arange(27).reshape((3, 3, 3)) >>> fock1 = Fock(array1) >>> fock2 = fock1.reduce(3) >>> assert fock1 == fock2 >>> fock3 = fock1.reduce(2) >>> array3 = [[[0, 1], [3, 4]], [[9, 10], [12, 13]]] >>> assert fock3 == Fock(array3) >>> fock4 = fock1.reduce((2, 1, 3, 1)) >>> array4 = [[[0], [3], [6]]] >>> assert fock4 == Fock(array4) Args: shape: The shape of the array of the returned ``Fock``. """ length = len(self.array.shape) shape = (shape,) * length if isinstance(shape, int) else shape if len(shape) != length: msg = f"Expected ``shape`` of lenght {length}, " msg += f"found shape of lenght {len(shape)}." raise ValueError(msg) ret = self.array for i, s in enumerate(shape): slc = (slice(None),) * i + (slice(0, s),) + (slice(None),) * (length - i - 1) ret = ret[slc] return Fock(array=ret, batched=True)
def _repr_html_(self): # pragma: no cover template = Template(filename=os.path.dirname(__file__) + "/assets/fock.txt") display(HTML(template.render(rep=self)))
[docs] def sum_batch(self) -> Fock: r""" Sums over the batch dimension of the array. Turns an object with any batch size to a batch size of 1. Returns: The collapsed Fock object. """ return self.from_ansatz(ArrayAnsatz(math.expand_dims(math.sum(self.array, axes=[0]), 0)))