Source code for mrmustard.lab_dev.states.base

# Copyright 2023 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: disable=abstract-method, chained-comparison, use-dict-literal, protected-access, inconsistent-return-statements

"""
This module contains the base classes for the available quantum states.

In the docstrings defining the available states we provide a definition in terms of
the covariance matrix :math:`V` and the vector of means :math:`r`. Additionally, we
provide the ``(A, b, c)`` triples that define the states in the Fock Bargmann
representation.
"""

from __future__ import annotations

from typing import Optional, Sequence, Union
import os

from enum import Enum
from IPython.display import display, HTML
from mako.template import Template
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import numpy as np

from mrmustard import math, settings
from mrmustard.math.parameters import Variable
from mrmustard.physics.fock import quadrature_distribution
from mrmustard.physics.wigner import wigner_discretized
from mrmustard.utils.typing import (
    ComplexMatrix,
    ComplexTensor,
    ComplexVector,
    RealVector,
)
from mrmustard.physics.bargmann import wigner_to_bargmann_psi, wigner_to_bargmann_rho
from mrmustard.physics.converters import to_fock
from mrmustard.physics.gaussian import purity
from mrmustard.physics.gaussian_integrals import join_Abc_real, real_gaussian_integral
from mrmustard.physics.representations import Bargmann, Fock
from mrmustard.lab_dev.utils import shape_check
from mrmustard.physics.ansatze import (
    bargmann_Abc_to_phasespace_cov_means,
)
from ..circuit_components_utils import DsMap, BtoQMap
from ..circuit_components import CircuitComponent
from ..circuit_components_utils import TraceOut
from ..wires import Wires

__all__ = ["State", "DM", "Ket"]

# ~~~~~~~
# Helpers
# ~~~~~~~


class OperatorType(Enum):
    r"""
    A convenience Enum class used to tag the type operators in the ``expectation`` method
    of ``Ket``\s and ``DM``\s.
    """

    KET_LIKE = 1
    DM_LIKE = 2
    UNITARY_LIKE = 3
    INVALID_TYPE = 4


def _validate_operator(operator: CircuitComponent) -> tuple[OperatorType, str]:
    r"""
    A function used to validate an operator inside the ``expectation`` method of ``Ket`` and
    ``DM``.

    If ``operator`` is ket-like, density matrix-like, or unitary-like, returns the corresponding
    ``OperatorType`` and an empty string. Otherwise, it returns ``INVALID_TYPE`` and an error
    message.
    """
    w = operator.wires

    # check if operator is ket-like
    if w.ket.output and not w.ket.input and not w.bra:
        return (
            OperatorType.KET_LIKE,
            "",
        )

    # check if operator is density matrix-like
    if w.ket.output and w.bra.output and not w.ket.input and not w.bra.input:
        if not w.ket.output.modes == w.bra.output.modes:
            msg = "Found DM-like operator with different modes for ket and bra wires."
            return OperatorType.INVALID_TYPE, msg
        return OperatorType.DM_LIKE, ""

    # check if operator is unitary-like
    if w.ket.input and w.ket.output and not w.bra.input and not w.bra.input:
        if not w.ket.input.modes == w.ket.output.modes:
            msg = "Found unitary-like operator with different modes for input and output wires."
            return OperatorType.INVALID_TYPE, msg
        return OperatorType.UNITARY_LIKE, ""

    msg = "Cannot calculate the expectation value of the given ``operator``."
    return OperatorType.INVALID_TYPE, msg


# ~~~~~~~
# Classes
# ~~~~~~~


[docs] class State(CircuitComponent): r""" Base class for all states. """
[docs] @classmethod def from_bargmann( cls, modes: Sequence[int], triple: tuple[ComplexMatrix, ComplexVector, complex], name: Optional[str] = None, ) -> State: r""" Initializes a state from an ``(A, b, c)`` triple defining a Bargmann representation. .. code-block:: >>> from mrmustard.physics.representations import Bargmann >>> from mrmustard.physics.triples import coherent_state_Abc >>> from mrmustard.lab_dev import Ket >>> modes = [0, 1] >>> triple = coherent_state_Abc(x=[0.1, 0.2]) >>> coh = Ket.from_bargmann(modes, triple) >>> assert coh.modes == modes >>> assert coh.representation == Bargmann(*triple) >>> assert isinstance(coh, Ket) Args: modes: The modes of this states. triple: The ``(A, b, c)`` triple. name: The name of this state. Returns: A state. Raises: ValueError: If the ``A`` or ``b`` have a shape that is inconsistent with the number of modes. """ raise NotImplementedError
[docs] @classmethod def from_fock( cls, modes: Sequence[int], array: ComplexTensor, name: Optional[str] = None, batched: bool = False, ) -> State: r""" Initializes a state from an array describing the state in the Fock representation. .. code-block:: >>> from mrmustard.physics.representations import Fock >>> from mrmustard.physics.triples import coherent_state_Abc >>> from mrmustard.lab_dev import Coherent, Ket >>> modes = [0] >>> array = Coherent(modes, x=0.1).to_fock_component().representation.array >>> coh = Ket.from_fock(modes, array, batched=True) >>> assert coh.modes == modes >>> assert coh.representation == Fock(array) >>> assert isinstance(coh, Ket) Args: modes: The modes of this states. array: The Fock array. name: The name of this state. batched: Whether the given array is batched. Returns: A state. Raises: ValueError: If the given array has a shape that is inconsistent with the number of modes. """ raise NotImplementedError
[docs] @classmethod def from_phase_space( cls, modes: Sequence[int], cov: ComplexMatrix, means: ComplexMatrix, name: Optional[str] = None, atol_purity: Optional[float] = 1e-3, ) -> State: # pylint: disable=abstract-method r""" Initializes a state from the covariance matrix and the vector of means of a state in phase space. Args: cov: The covariance matrix. means: The vector of means. modes: The modes of this states. name: The name of this state. atol_purity: If ``atol_purity`` is given, the purity of the state is computed, and an error is raised if its value is smaller than ``1-atol_purity`` or larger than ``1+atol_purity``. If ``None``, this check is skipped. Returns: A state. Raises: ValueError: If the given ``cov`` and ``means`` have shapes that are inconsistent with the number of modes. ValueError: If ``atol_purity`` is not ``None`` and the purity of the returned state is smaller than ``1-atol_purity`` or larger than ``1+atol_purity``. """ raise NotImplementedError
[docs] @classmethod def from_quadrature( cls, modes: Sequence[int], triple: tuple[ComplexMatrix, ComplexVector, complex], name: Optional[str] = None, ) -> State: r""" Initializes a state from quadrature with a ABC Ansatz Gaussian exponential form. Args: modes: The modes of this states. triple: The ``(A, b, c)`` triple. name: The name of this state. Returns: A state. Raises: ValueError: If the given triple have shapes that are inconsistent with the number of modes. """ raise NotImplementedError
@property def _L2_norms(self) -> RealVector: r""" The `L2` norm (squared) of a ``Ket``, or the Hilbert-Schmidt norm of a ``DM``, element-wise along the batch dimension. """ settings.UNSAFE_ZIP_BATCH = True rep = (self >> self.dual).representation settings.UNSAFE_ZIP_BATCH = False return math.real(rep.c if isinstance(rep, Bargmann) else rep.array) @property def L2_norm(self) -> float: r""" The `L2` norm (squared) of a ``Ket``, or the Hilbert-Schmidt norm of a ``DM``. """ rep = (self >> self.dual).representation return math.sum(math.real(rep.c if isinstance(rep, Bargmann) else rep.array), axes=[0]) @property def probability(self) -> float: r""" Returns :math:`\langle\psi|\psi\rangle` for ``Ket`` states :math:`|\psi\rangle` and :math:`\text{Tr}(\rho)` for ``DM`` states :math:`\rho`. """ raise NotImplementedError @property def purity(self) -> float: r""" The purity of this state. """ raise NotImplementedError @property def is_pure(self): r""" Whether this state is pure. """ return math.allclose(self.purity, 1.0)
[docs] def fock_array(self, shape: Optional[Union[int, Sequence[int]]] = None) -> ComplexTensor: r""" The array that describes this state in the Fock representation. Uses the :meth:`mrmustard.physics.converters.to_fock` method to convert the internal representation into a ``Fock`` object. Args: shape: The shape of the returned array. If ``shape``is given as an ``int``, it is broadcasted to all the dimensions. If ``None``, it defaults to the value of ``AUTOCUTOFF_MAX_CUTOFF`` in the settings. Returns: The array that describes this state in the Fock representation. """ return to_fock(self.representation, shape).array
[docs] def phase_space(self, s: float) -> tuple[ComplexMatrix, ComplexVector, complex]: r""" Returns the phase space parametrization of a state, consisting in a covariance matrix, a vector of means and a scaling coefficient. When a state is a linear superposition of Gaussians each of cov, means, coeff are arranged in a batch. Phase space representations are labelled by an ``s`` parameter (float) which modifies the exponent of :math:`D_s(\gamma) = e^{\frac{s}{2}|\gamma|^2}D(\gamma)`, which is the operator basis used to expand phase space density matrices. The ``s`` parameter typically takes the values of -1, 0, 1 to indicate Glauber/Wigner/Husimi functions. Note that the same ``(cov, means, coeff)`` triple can be used to parametrize the characteristic functions as well. Args: s: The phase space parameter Returns: The covariance matrix, the mean vector and the coefficient of the state in s-parametrized phase space. """ if not isinstance(self.representation, Bargmann): raise ValueError(f"Can not calculate phase space for ``{self.name}`` object.") new_state = self >> DsMap(self.modes, s=s) # pylint: disable=protected-access return bargmann_Abc_to_phasespace_cov_means( new_state.representation.ansatz.A, new_state.representation.ansatz.b, new_state.representation.ansatz.c, )
[docs] def visualize_2d( self, xbounds: tuple[int] = (-6, 6), pbounds: tuple[int] = (-6, 6), resolution: int = 200, colorscale: str = "viridis", return_fig: bool = False, ) -> Union[go.Figure, None]: r""" 2D visualization of the Wigner function of this state. Plots the Wigner function on a heatmap, alongside the probability distributions on the two quadrature axis. .. code-block:: >>> from mrmustard.lab_dev import Coherent >>> state = Coherent([0], x=1) / 2**0.5 + Coherent([0], x=-1) / 2**0.5 >>> # state.visualize_2d() Args: xbounds: The range of the `x` axis. pbounds: The range of the `p` axis. resolution: The number of bins on each axes. colorscale: A colorscale. Must be one of ``Plotly``\'s built-in continuous color scales. return_fig: Whether to return the ``Plotly`` figure. Returns: A ``Plotly`` figure representing the state in 2D. Raises: ValueError: If this state is a multi-mode state. """ if self.n_modes > 1: raise ValueError("2D visualization not available for multi-mode states.") state = self.to_fock_component(settings.AUTOCUTOFF_MAX_CUTOFF) state = state if isinstance(state, DM) else state.dm() dm = math.sum(state.representation.array, axes=[0]) x, prob_x = quadrature_distribution(dm) # TODO: replace with new MM methods p, prob_p = quadrature_distribution(dm, np.pi / 2) mask_x = math.asnumpy([xi >= xbounds[0] and xi <= xbounds[1] for xi in x]) x = x[mask_x] prob_x = prob_x[mask_x] mask_p = math.asnumpy([pi >= pbounds[0] and pi <= pbounds[1] for pi in p]) p = p[mask_p] prob_p = prob_p[mask_p] xvec = np.linspace(*xbounds, resolution) pvec = np.linspace(*pbounds, resolution) z, xs, ps = wigner_discretized(dm, xvec, pvec) xs = xs[:, 0] ps = ps[0, :] fig = make_subplots( rows=2, cols=2, column_widths=[2, 1], row_heights=[1, 2], vertical_spacing=0.05, horizontal_spacing=0.05, shared_xaxes="columns", shared_yaxes="rows", ) # X-P plot # note: heatmaps revert the y axes, which is why the minus in `y=-ps` is required fig_21 = go.Heatmap( x=xs, y=-ps, z=math.transpose(z), colorscale=colorscale, name="Wigner function", ) fig.add_trace(fig_21, row=2, col=1) fig.update_traces(row=2, col=1, showscale=False) fig.update_xaxes(range=xbounds, title_text="x", row=2, col=1) fig.update_yaxes(range=pbounds, title_text="p", row=2, col=1) # X quadrature probability distribution fig_11 = go.Scatter(x=x, y=prob_x, line=dict(color="steelblue", width=2), name="Prob(x)") fig.add_trace(fig_11, row=1, col=1) fig.update_xaxes(range=xbounds, row=1, col=1, showticklabels=False) fig.update_yaxes(title_text="Prob(x)", range=(0, max(prob_x)), row=1, col=1) # P quadrature probability distribution fig_22 = go.Scatter(x=prob_p, y=-p, line=dict(color="steelblue", width=2), name="Prob(p)") fig.add_trace(fig_22, row=2, col=2) fig.update_xaxes(title_text="Prob(p)", range=(0, max(prob_p)), row=2, col=2) fig.update_yaxes(range=pbounds, row=2, col=2, showticklabels=False) fig.update_layout( height=500, width=500, plot_bgcolor="aliceblue", margin=dict(l=20, r=20, t=30, b=20), showlegend=False, ) fig.update_xaxes( showline=True, linewidth=1, linecolor="black", mirror=True, tickfont_family="Arial Black", ) fig.update_yaxes( showline=True, linewidth=1, linecolor="black", mirror=True, tickfont_family="Arial Black", ) if return_fig: return fig html = fig.to_html(full_html=False, include_plotlyjs="cdn") # pragma: no cover display(HTML(html)) # pragma: no cover
[docs] def visualize_3d( self, xbounds: tuple[int] = (-6, 6), pbounds: tuple[int] = (-6, 6), resolution: int = 200, colorscale: str = "viridis", return_fig: bool = False, ) -> Union[go.Figure, None]: r""" 3D visualization of the Wigner function of this state on a surface plot. Args: xbounds: The range of the `x` axis. pbounds: The range of the `p` axis. resolution: The number of bins on each axes. colorscale: A colorscale. Must be one of ``Plotly``\'s built-in continuous color scales. return_fig: Whether to return the ``Plotly`` figure. Returns: A ``Plotly`` figure representing the state in 3D. Raises: ValueError: If this state is a multi-mode state. """ if self.n_modes != 1: raise ValueError("3D visualization not available for multi-mode states.") state = self.to_fock_component(settings.AUTOCUTOFF_MAX_CUTOFF) state = state if isinstance(state, DM) else state.dm() dm = math.sum(state.representation.array, axes=[0]) xvec = np.linspace(*xbounds, resolution) pvec = np.linspace(*pbounds, resolution) z, xs, ps = wigner_discretized(dm, xvec, pvec) xs = xs[:, 0] ps = ps[0, :] fig = go.Figure( data=go.Surface( x=xs, y=ps, z=z, colorscale=colorscale, hovertemplate="x: %{x:.3f}" + "<br>p: %{y:.3f}" + "<br>W(x, p): %{z:.3f}<extra></extra>", ) ) fig.update_layout( autosize=False, width=500, height=500, margin=dict(l=0, r=0, b=0, t=0), scene_camera_eye=dict(x=-2.1, y=0.88, z=0.64), ) fig.update_traces( contours_z=dict( show=True, usecolormap=True, highlightcolor="limegreen", project_z=False ) ) fig.update_traces( contours_y=dict(show=True, usecolormap=True, highlightcolor="red", project_y=False) ) fig.update_traces( contours_x=dict(show=True, usecolormap=True, highlightcolor="yellow", project_x=False) ) fig.update_scenes( xaxis_title_text="x", yaxis_title_text="p", zaxis_title_text="Wigner function", ) fig.update_xaxes(title_text="x") fig.update_yaxes(title="p") if return_fig: return fig html = fig.to_html(full_html=False, include_plotlyjs="cdn") # pragma: no cover display(HTML(html)) # pragma: no cover
[docs] def visualize_dm( self, cutoff: Optional[int] = None, return_fig: bool = False, ) -> Union[go.Figure, None]: r""" Plots the absolute value :math:`abs(\rho)` of the density matrix :math:`\rho` of this state on a heatmap. Args: cutoff: The desired cutoff. Defaults to the value of ``AUTOCUTOFF_MAX_CUTOFF`` in the settings. return_fig: Whether to return the ``Plotly`` figure. Returns: A ``Plotly`` figure representing absolute value of the density matrix of this state. Raises: ValueError: If this state is a multi-mode state. """ if self.n_modes != 1: raise ValueError("DM visualization not available for multi-mode states.") state = self.to_fock_component(cutoff or settings.AUTOCUTOFF_MAX_CUTOFF) state = state if isinstance(state, DM) else state.dm() dm = math.sum(state.representation.array, axes=[0]) fig = go.Figure( data=go.Heatmap(z=abs(dm), colorscale="viridis", name="abs(ρ)", showscale=False) ) fig.update_yaxes(autorange="reversed") fig.update_layout( height=257, width=257, margin=dict(l=20, r=20, t=30, b=20), ) fig.update_xaxes(title_text=f"abs(ρ), cutoff={dm.shape[0]}") if return_fig: return fig html = fig.to_html(full_html=False, include_plotlyjs="cdn") # pragma: no cover display(HTML(html)) # pragma: no cover
def _repr_html_(self): # pragma: no cover template = Template(filename=os.path.dirname(__file__) + "/assets/states.txt") display(HTML(template.render(state=self))) def _getitem_builtin_state(self, modes: set[int]): r""" A convenience function to slice built-in states. Built-in states come with a parameter set. To slice them, we simply slice the parameter set, and then used the sliced parameter set to re-initialize them. This approach avoids computing the representation, which may be expensive. Additionally, it allows returning trainable states. """ # slice the parameter set items = [i for i, m in enumerate(self.modes) if m in modes] kwargs = {} for name, param in self._parameter_set[items].all_parameters.items(): kwargs[name] = param.value if isinstance(param, Variable): kwargs[name + "_trainable"] = True kwargs[name + "_bounds"] = param.bounds # use `mro` to return the correct state return self.__class__(modes, **kwargs)
[docs] def quadrature(self) -> tuple[ComplexMatrix, ComplexVector, complex]: r""" The A matrix, b vector and c scalar that describe this state in the quadrature basis for all modes. """ if not isinstance(self.representation, Bargmann): raise ValueError( f"``{self.representation}`` is not available to calculate the quadrature representation." ) ret = self >> BtoQMap(self.modes) return ret.bargmann
[docs] class DM(State): r""" Base class for density matrices. Args: name: The name of this state. modes: The modes of this state. """ def __init__(self, name: Optional[str] = None, modes: tuple[int, ...] = ()): super().__init__( name or "DM" + "".join(str(m) for m in sorted(modes)), modes_out_bra=modes, modes_out_ket=modes, )
[docs] @classmethod def from_bargmann( cls, modes: Sequence[int], triple: tuple[ComplexMatrix, ComplexVector, complex], name: Optional[str] = None, ) -> DM: A = math.astensor(triple[0]) b = math.astensor(triple[1]) c = math.astensor(triple[2]) shape_check(A, b, 2 * len(modes), "Bargmann") ret = DM(name, modes) ret._representation = Bargmann(A, b, c) return ret
[docs] @classmethod def from_fock( cls, modes: Sequence[int], array: ComplexTensor, name: Optional[str] = None, batched: bool = False, ) -> DM: array = math.astensor(array) n_modes = len(modes) if len(array.shape) != 2 * n_modes + (1 if batched else 0): msg = f"Given array is inconsistent with modes=``{modes}``." raise ValueError(msg) ret = DM(name, modes) ret._representation = Fock(array, batched) return ret
[docs] @classmethod def from_phase_space( cls, modes: Sequence[int], cov: ComplexMatrix, means: ComplexMatrix, name: Optional[str] = None, atol_purity: Optional[float] = 1e-3, ) -> DM: cov = math.astensor(cov) means = math.astensor(means) shape_check(cov, means, 2 * len(modes), "Phase space") if atol_purity: p = purity(cov) if p < 1.0 - atol_purity: msg = f"Cannot initialize a ket: purity is {p:.3f} (must be 1.0)." raise ValueError(msg) ret = DM(name, modes) ret._representation = Bargmann(*wigner_to_bargmann_rho(cov, means)) return ret
[docs] @classmethod def from_quadrature( cls, modes: Sequence[int], triple: tuple[ComplexMatrix, ComplexVector, complex], name: Optional[str] = None, ) -> DM: # The representation change from quadrature into Bargmann is to use the BtoQMap.dual. # Plus this map is on a single wire, here for a DM, we need to add a adjoint wire as well. QtoBMap_CC = BtoQMap(modes).dual.adjoint @ BtoQMap(modes).dual QtoBMap_A, QtoBMap_b, QtoBMap_c = ( QtoBMap_CC.representation.A[0], QtoBMap_CC.representation.b[0], QtoBMap_CC.representation.c[0], ) full_order_list = math.arange(4 * len(modes)) bargmann_A, bargmann_b, bargmann_c = real_gaussian_integral( join_Abc_real( triple, (QtoBMap_A, QtoBMap_b, QtoBMap_c), idx1=list(math.cast(full_order_list[: 2 * len(modes)], math.int32)), idx2=list( math.cast( math.concat( [ full_order_list[len(modes) : 2 * len(modes)], full_order_list[3 * len(modes) :], ], axis=0, ), math.int32, ) ), ), idx=list(math.cast(full_order_list[: 2 * len(modes)], math.int32)), ) ret = DM(name, modes) ret._representation = Bargmann(bargmann_A, bargmann_b, bargmann_c) return ret
@property def _probabilities(self) -> RealVector: r"""Element-wise probabilities along the batch dimension of this DM. Useful for cases where the batch dimension does not mean a convex combination of states.""" idx_ket = self.wires.output.ket.indices idx_bra = self.wires.output.bra.indices rep = self.representation.trace(idx_ket, idx_bra) if isinstance(rep, Bargmann): return math.real(rep.c) return math.real(rep.array) @property def probability(self) -> float: r"""Probability of this DM, using the batch dimension of the Ansatz as a convex combination of states.""" return math.sum(self._probabilities) @property def _purities(self) -> RealVector: r"""Element-wise purities along the batch dimension of this DM. Useful for cases where the batch dimension does not mean a convex combination of states.""" return self._L2_norms / self._probabilities @property def purity(self) -> float: return self.L2_norm
[docs] def expectation(self, operator: CircuitComponent): r""" The expectation value of an operator calculated over this DM. Given the operator `O`, this function returns :math:`Tr\big(\rho O)`\, where :math:`\rho` is the density matrix of this state. The ``operator`` is expected to be a component with ket-like wires (i.e., output wires on the ket side), density matrix-like wires (output wires on both ket and bra sides), or unitary-like wires (input and output wires on the ket side). Args: operator: A ket-like, density-matrix like, or unitary-like circuit component. Raise: ValueError: If ``operator`` is not a ket-like, density-matrix like, or unitary-like component. ValueError: If ``operator`` is defined over a set of modes that is not a subset of the modes of this state. """ op_type, msg = _validate_operator(operator) if op_type is OperatorType.INVALID_TYPE: raise ValueError(msg) if not operator.wires.modes.issubset(self.wires.modes): msg = f"Expected an operator defined on a subset of modes `{self.modes}`, " msg += f"found one defined on `{operator.modes}.`" raise ValueError(msg) leftover_modes = self.wires.modes - operator.wires.modes if op_type is OperatorType.KET_LIKE: result = self @ operator.dual @ operator.dual.adjoint if leftover_modes: result >>= TraceOut(leftover_modes) elif op_type is OperatorType.DM_LIKE: result = self @ operator.dual if leftover_modes: result >>= TraceOut(leftover_modes) else: result = (self @ operator) >> TraceOut(self.modes) rep = result.representation return rep.array if isinstance(rep, Fock) else rep.c
def __rshift__(self, other: CircuitComponent) -> CircuitComponent: r""" Contracts ``self`` and ``other`` as it would in a circuit, adding the adjoints when they are missing. Returns a ``DM`` when the wires of the resulting components are compatible with those of a ``Ket``, a ``CircuitComponent`` otherwise. """ ret = super().__rshift__(other) if not ret.wires.input and ret.wires.bra.modes == ret.wires.ket.modes: return DM._from_attributes("", ret.representation, ret.wires) return ret def __repr__(self) -> str: return "" def __getitem__(self, modes: Union[int, Sequence[int]]) -> State: r""" Traces out all the modes, except those in the given ``modes``. """ if isinstance(modes, int): modes = [modes] modes = set(modes) if not modes.issubset(self.modes): msg = f"Expected a subset of `{self.modes}, found `{list(modes)}`." raise ValueError(msg) if self._parameter_set: # if ``self`` has a parameter set, it is a built-in state, and we slice the # parameters return self._getitem_builtin_state(modes) # if ``self`` has no parameter set, it is not a built-in state, and we must slice the # representation wires = Wires(modes_out_bra=modes, modes_out_ket=modes) idxz = [i for i, m in enumerate(self.modes) if m not in modes] idxz_conj = [i + len(self.modes) for i, m in enumerate(self.modes) if m not in modes] representation = self.representation.trace(idxz, idxz_conj) return self.__class__._from_attributes( self.name, representation, wires ) # pylint: disable=protected-access
[docs] class Ket(State): r""" Base class for all pure states, potentially unnormalized. Arguments: name: The name of this state. modes: The modes of this states. """ def __init__(self, name: Optional[str] = None, modes: tuple[int, ...] = ()): super().__init__( name or "Ket" + "".join(str(m) for m in sorted(modes)), modes_out_ket=modes )
[docs] @classmethod def from_bargmann( cls, modes: Sequence[int], triple: tuple[ComplexMatrix, ComplexVector, complex], name: Optional[str] = None, ) -> Ket: A = math.astensor(triple[0]) b = math.astensor(triple[1]) c = math.astensor(triple[2]) shape_check(A, b, len(modes), "Bargmann") ret = Ket(name, modes) ret._representation = Bargmann(A, b, c) return ret
[docs] @classmethod def from_fock( cls, modes: Sequence[int], array: ComplexTensor, name: Optional[str] = None, batched: bool = False, ) -> Ket: array = math.astensor(array) n_modes = len(modes) if len(array.shape) != n_modes + (1 if batched else 0): msg = f"Given array is inconsistent with modes=``{modes}``." raise ValueError(msg) ret = Ket(name, modes) ret._representation = Fock(array, batched) return ret
[docs] @classmethod def from_phase_space( cls, modes: Sequence[int], cov: ComplexMatrix, means: ComplexMatrix, name: Optional[str] = None, atol_purity: Optional[float] = 1e-3, ) -> Ket: cov = math.astensor(cov) means = math.astensor(means) shape_check(cov, means, 2 * len(modes), "Phase space") if atol_purity: p = purity(cov) if p < 1.0 - atol_purity: msg = f"Cannot initialize a ket: purity is {p:.3f} (must be at least 1.0-atol)." raise ValueError(msg) ret = Ket(name, modes) ret._representation = Bargmann(*wigner_to_bargmann_psi(cov, means)) return ret
[docs] @classmethod def from_quadrature( cls, modes: Sequence[int], triple: tuple[ComplexMatrix, ComplexVector, complex], name: Optional[str] = None, ) -> Ket: QtoBMap_CC = BtoQMap(modes).dual QtoBMap_A, QtoBMap_b, QtoBMap_c = ( QtoBMap_CC.representation.A[0], QtoBMap_CC.representation.b[0], QtoBMap_CC.representation.c[0], ) joinedA, joinedb, joinedc = join_Abc_real( triple, (QtoBMap_A, QtoBMap_b, QtoBMap_c), idx1=list(np.arange(len(modes))), idx2=list(np.arange(len(modes), 2 * len(modes))), ) bargmann_A, bargmann_b, bargmann_c = real_gaussian_integral( (joinedA, joinedb, joinedc), idx=list(np.arange(len(modes))), ) ret = Ket(name, modes) ret._representation = Bargmann(bargmann_A, bargmann_b, bargmann_c) return ret
@property def _probabilities(self) -> RealVector: r"""Element-wise probabilities along the batch dimension of this Ket. Useful for cases where the batch dimension does not mean a linear combination of states.""" return self._L2_norms @property def probability(self) -> float: r"""Probability of this state, where the batch dimension of the Ansatz means a linear combination of states.""" return self.L2_norm @property def _purities(self) -> float: r"""Purity of each state in the batch.""" return math.ones((self.representation.ansatz.batch_size,), math.float64) @property def purity(self) -> float: return 1.0
[docs] def dm(self) -> DM: r""" The ``DM`` object obtained from this ``Ket``. """ dm = self @ self.adjoint return DM._from_attributes(self.name, dm.representation, dm.wires)
[docs] def expectation(self, operator: CircuitComponent): r""" The expectation value of an operator calculated over this Ket. Given the operator `O`, this function returns :math:`Tr\big(|\psi\rangle\langle\psi| O)`\, where :math:`|\psi\rangle` is the vector representing this state. The ``operator`` is expected to be a component with ket-like wires (i.e., output wires on the ket side), density matrix-like wires (output wires on both ket and bra sides), or unitary-like wires (input and output wires on the ket side). Args: operator: A ket-like, density-matrix like, or unitary-like circuit component. Raise: ValueError: If ``operator`` is not a ket-like, density-matrix like, or unitary-like component. ValueError: If ``operator`` is defined over a set of modes that is not a subset of the modes of this state. """ op_type, msg = _validate_operator(operator) if op_type is OperatorType.INVALID_TYPE: raise ValueError(msg) if not operator.wires.modes.issubset(self.wires.modes): msg = f"Expected an operator defined on a subset of modes `{self.modes}`, " msg += f"found one defined on `{operator.modes}.`" raise ValueError(msg) leftover_modes = self.wires.modes - operator.wires.modes if op_type is OperatorType.KET_LIKE: result = self @ operator.dual result = result >> TraceOut(leftover_modes) if leftover_modes else result @ result.dual elif op_type is OperatorType.DM_LIKE: result = self @ (self.adjoint @ operator.dual) if leftover_modes: result >>= TraceOut(leftover_modes) else: result = self @ operator @ self.dual rep = result.representation return rep.array if isinstance(rep, Fock) else rep.c
def __getitem__(self, modes: Union[int, Sequence[int]]) -> State: r""" Traces out all the modes, except those in the given ``modes``. """ if isinstance(modes, int): modes = [modes] modes = set(modes) if not modes.issubset(self.modes): msg = f"Expected a subset of `{self.modes}, found `{list(modes)}`." raise ValueError(msg) if self._parameter_set: # if ``self`` has a parameter set, it is a built-in state, and we slice the # parameters return self._getitem_builtin_state(modes) # if ``self`` has no parameter set, it is not a built-in state. # we must turn it into a density matrix and slice the representation return self.dm()[modes] def __rshift__(self, other: CircuitComponent) -> CircuitComponent: r""" Contracts ``self`` and ``other`` as it would in a circuit, adding the adjoints when they are missing. Returns a ``DM`` or a ``Ket`` when the wires of the resulting components are compatible with those of a ``DM`` or of a ``Ket``, a ``CircuitComponent`` otherwise. """ ret = super().__rshift__(other) if not ret.wires.input: if not ret.wires.bra: return Ket._from_attributes("", ret.representation, ret.wires) if ret.wires.bra.modes == ret.wires.ket.modes: return DM._from_attributes("", ret.representation, ret.wires) return ret def __repr__(self) -> str: return ""