Source code for mrmustard.lab.states.base

# Copyright 2023 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""The base for the ``State`` class."""

from __future__ import annotations

from abc import abstractmethod
from collections.abc import Sequence
from itertools import product
from typing import Self

import numpy as np
import plotly.graph_objects as go
from IPython.display import display
from plotly.subplots import make_subplots

from mrmustard import math, settings
from mrmustard.mathlib.lattice.autoshape import autoshape_numba
from mrmustard.mathlib.lattice.strategies.wormhole import (
    wormhole_1leftover_dm,
    wormhole_1leftover_ket,
)
from mrmustard.physics.ansatz import ArrayAnsatz, PolyExpAnsatz
from mrmustard.physics.bargmann_utils import bargmann_Abc_to_phasespace_cov_means
from mrmustard.physics.fock_utils import quadrature_distribution
from mrmustard.physics.gaussian import von_neumann_entropy
from mrmustard.physics.wigner import wigner_discretized
from mrmustard.physics.wires import Wires
from mrmustard.utils.typing import ComplexTensor, Matrix, RealVector, Scalar, Vector

from ..circuit_components import CircuitComponent
from ..circuit_components_utils import BtoChar, BtoPS, BtoQ
from ..transformations import Transformation

__all__ = ["State"]


# ~~~~~~~
# Classes
# ~~~~~~~


[docs] class State(CircuitComponent): r"""Base class for all states.""" @property def is_pure(self): r"""Whether this state is pure.""" return math.allclose(self.purity, 1.0) @property def is_separable(self): r"""Check if a multi-mode quantum state is separable based on the von Neumann entropy. Returns: Whether the state is separable. Raises: NotImplementedError: If the state is a linear superposition. """ if self.ansatz._lin_sup: raise NotImplementedError("Separation of linear superpositions is not implemented.") if self.n_modes == 1: return True rho = self.dm() cov_full, _, _ = rho.phase_space(s=0) S_total = von_neumann_entropy(cov_full) entropy_diff = -S_total for mode in self.modes: rho_reduced = rho.get_modes(mode) cov_reduced, _, _ = rho_reduced.phase_space(s=0) entropy = von_neumann_entropy(cov_reduced) entropy_diff += entropy return math.allclose(entropy_diff, 0) @property def L2_norm(self) -> float: r"""The `L2` norm squared of a ``Ket``, or the Hilbert-Schmidt norm of a ``DM``. >>> from mrmustard import math >>> from mrmustard.lab import GaussianKet >>> state = GaussianKet.random([0]) >>> assert math.allclose(state.L2_norm, 1.0) """ state = self if isinstance(state.ansatz, PolyExpAnsatz) and state.ansatz.num_derived_vars > 0: state = state.to_fock() return math.real(state.contract(state.dual).ansatz.scalar) @property @abstractmethod def probability(self) -> float: r"""Returns :math:`\langle\psi|\psi\rangle` for ``Ket`` states :math:`|\psi\rangle` and :math:`\text{Tr}(\rho)` for ``DM`` states :math:`\rho`. """ @property @abstractmethod def purity(self) -> float: r"""The purity of this state.""" @property def wigner(self): r"""Returns the Wigner function of this state in phase space as an ``Ansatz``. >>> import numpy as np >>> from mrmustard.lab import GaussianKet >>> state = GaussianKet.random([0]) >>> x = np.linspace(-5, 5, 100) >>> assert np.all(state.wigner(x,0).real >= 0) """ if isinstance(self.ansatz, PolyExpAnsatz): return (self >> BtoPS(self.modes, s=0)).ansatz.PS raise ValueError( "Wigner ansatz not implemented for Fock states. Consider calling ``.to_bargmann()`` first.", )
[docs] @classmethod def from_bargmann( cls, modes: Sequence[int], triple: tuple[Matrix, Vector, Scalar], name: str | None = None, lin_sup: bool = False, ) -> Self: r"""Initializes a state of type ``cls`` from an ``(A, b, c)`` triple parametrizing the Ansatz in Bargmann representation. >>> from mrmustard.physics.ansatz import PolyExpAnsatz >>> from mrmustard.physics.triples import coherent_state_Abc >>> from mrmustard.lab.states.ket import Ket >>> modes = (0,) >>> triple = coherent_state_Abc(alpha=0.1) >>> coh = Ket.from_bargmann(modes, triple) >>> assert coh.modes == modes >>> assert coh.ansatz == PolyExpAnsatz(*triple) >>> assert isinstance(coh, Ket) Args: modes: The modes of this state. triple: The ``(A, b, c)`` triple. name: The name of this state. lin_sup: Whether to include linear superposition axes in the batch dimensions. Returns: A ``State``. Raises: ValueError: If the ``A`` or ``b`` have a shape that is inconsistent with the number of modes. """ return cls.from_ansatz(modes, PolyExpAnsatz(*triple, lin_sup=lin_sup), name)
[docs] @classmethod def from_fock( cls, modes: Sequence[int], array: ComplexTensor, name: str | None = None, batch_dims: int = 0, ) -> Self: r"""Initializes a state of type ``cls`` from an array parametrizing the state in Fock representation. >>> from mrmustard.physics.ansatz import ArrayAnsatz >>> from mrmustard.lab import Coherent, Ket >>> array = Coherent(mode=0, alpha=0.1).to_fock().ansatz.array >>> coh = Ket.from_fock((0,), array, batch_dims=0) >>> assert coh.modes == (0,) >>> assert coh.ansatz == ArrayAnsatz(array, batch_dims=0) >>> assert isinstance(coh, Ket) Args: modes: The modes of this state. array: The Fock array. name: The name of this state. batch_dims: The number of batch dimensions in the given array. Returns: A ``State``. Raises: ValueError: If the given array has a shape that is inconsistent with the number of modes. """ return cls.from_ansatz(modes, ArrayAnsatz(array, batch_dims=batch_dims), name)
[docs] @classmethod @abstractmethod def from_ansatz( cls, modes: Sequence[int], ansatz: PolyExpAnsatz | ArrayAnsatz | None = None, name: str | None = None, ) -> Self: r"""Initializes a state of type ``cls`` given modes and an ansatz. >>> from mrmustard import math >>> from mrmustard.lab import Ket >>> from mrmustard.physics.ansatz import PolyExpAnsatz >>> A = math.astensor([[0,.5], [.5,0]]) >>> b = math.astensor([2-1j,2+1j]) >>> c = 1 >>> psi = Ket.from_ansatz([0,1], PolyExpAnsatz(A,b,c)) >>> assert isinstance(psi, Ket) Args: modes: The modes of this state. ansatz: The ansatz of this state. name: The name of this state. Returns: A state. """
[docs] @classmethod @abstractmethod def from_phase_space( cls, modes: Sequence[int], triple: tuple[Matrix, Vector, Scalar], name: str | None = None, atol_purity: float | None = None, ) -> Self: r"""Initializes a state from the covariance matrix and the vector of means of a state in phase space. >>> from mrmustard import math >>> from mrmustard.lab import Ket, Vacuum >>> assert Ket.from_phase_space([0], (math.eye(2)/2, [0,0], 1)) == Vacuum([0]) Note: If the given covariance matrix and vector of means are consistent with a pure state, a ``Ket`` is returned. Otherwise, a ``DM`` is returned. One can skip this check by setting ``atol_purity`` to ``None`` (``atol_purity`` defaults to ``None``). Args: modes: The modes of this states. triple: A covariance matrix, vector of means, and constant multiple triple. name: The name of this state. atol_purity: If ``atol_purity`` is given, the purity of the state is computed, and an error is raised if its value is smaller than ``1-atol_purity`` or larger than ``1+atol_purity``. If ``None``, this check is skipped. Returns: A ``State``. Raises: ValueError: If the given ``cov`` and ``means`` have shapes that are inconsistent with the number of modes. ValueError: If ``atol_purity`` is not ``None`` and the purity of the returned state is smaller than ``1-atol_purity`` or larger than ``1+atol_purity``. """
[docs] @classmethod def from_quadrature( cls, modes: Sequence[int], triple: tuple[Matrix, Vector, Scalar], phi: float = 0.0, name: str | None = None, ) -> Self: r"""Initializes a state from a triple (A,b,c) that parametrizes the wavefunction as `c * exp(0.5 z^T A z + b^T z)` in the quadrature representation. Args: modes: The modes of this state. triple: The ``(A, b, c)`` triple. phi: The angle of the quadrature. 0 corresponds to the x quadrature (default). name: The name of this state. Returns: A state of type ``cls``. Raises: ValueError: If the given triple has shapes that are inconsistent with the number of modes. """ QtoB = BtoQ(modes, phi).inverse() Q = cls.from_ansatz(modes, PolyExpAnsatz(*triple)) return cls.from_ansatz(modes, (Q >> QtoB).ansatz, name)
[docs] def auto_shape( self, max_prob=None, max_shape=None, min_shape=None, respect_manual_shape=True, ) -> tuple[int, ...]: r"""Generates an estimate for the Fock shape. If the state is in Fock the core shape is used. If in Bargmann, the shape is computed as the shape that captures at least ``settings.AUTOSHAPE_PROBABILITY`` of the probability mass of each single-mode marginal (default 99.9%) so long as the state has no derived variables and is unbatched. Otherwise, defaults to ``settings.DEFAULT_FOCK_SIZE``. If ``respect_manual_shape`` is ``True``, the non-None values in ``self.manual_shape`` are used to override the shape. >>> from mrmustard import math >>> from mrmustard.lab import Vacuum >>> assert math.allclose(Vacuum([0]).fock_array(), 1) Note: If jitted, the shape will default to ``settings.DEFAULT_FOCK_SIZE``. Args: max_prob: The maximum probability mass to capture in the shape. Default is ``settings.AUTOSHAPE_PROBABILITY``. max_shape: The maximum shape cutoff. Default is ``settings.AUTOSHAPE_MAX``. min_shape: The minimum shape cutoff. Default is ``settings.AUTOSHAPE_MIN``. respect_manual_shape: Whether to respect the non-None values in ``manual_shape``. Default is ``True``. Returns: The Fock shape of this component. """ try: shape = self.ansatz.core_shape except AttributeError: if self.ansatz.num_derived_vars == 0 and self.ansatz.batch_dims == 0: if not self.wires.ket or not self.wires.bra: ansatz = self.ansatz.conj & self.ansatz else: ansatz = self.ansatz A, b, c = ansatz.triple try: shape = autoshape_numba( math.asnumpy(A), math.asnumpy(b), math.asnumpy(c), max_prob or settings.AUTOSHAPE_PROBABILITY, max_shape or settings.AUTOSHAPE_MAX, min_shape or settings.AUTOSHAPE_MIN, ) # covers the case where auto_shape is jitted except math.BackendError: # pragma: no cover shape = super().auto_shape() if self.wires.ket and self.wires.bra: shape = tuple(shape) + tuple(shape) else: shape = super().auto_shape() if respect_manual_shape: return tuple(c or s for c, s in zip(self.manual_shape, shape)) return tuple(shape)
[docs] def fock_distribution(self, cutoff: int) -> ComplexTensor: r"""Returns the Fock distribution of the state up to some cutoff. Args: cutoff: The photon cutoff (maximum photon number). Returns: The Fock distribution including states from :math:`|0\rangle` to :math:`|\text{cutoff}\rangle`. """ batch_shape = self.ansatz.batch_shape batch_dim = self.ansatz.batch_dims fock_array = self.fock_array(cutoff + 1) if not self.wires.ket or not self.wires.bra: return math.reshape(math.abs(fock_array) ** 2, (*batch_shape, -1)) n_modes = self.n_modes if self.is_separable: for i in range(n_modes): fock_array = math.diagonal(fock_array, axis1=-n_modes - 1, axis2=-(1 + i)) return math.reshape(math.abs(fock_array), (*batch_shape, -1)) indices_list = [(...,) + ns * 2 for ns in product(list(range(cutoff + 1)), repeat=n_modes)] return math.stack([fock_array[indices] for indices in indices_list], axis=batch_dim)
[docs] def get_modes(self, modes: int | Sequence[int]) -> State: r"""Reduced density matrix obtained by tracing out all the modes except those in ``modes``. Note that the result is returned with modes in increasing order. Args: modes: The modes to keep. Returns: A ``State`` object with the remaining modes. Raises: ValueError: If the modes to keep are not a subset of the modes of the state. """ if not self.wires.ket or not self.wires.bra: return self.dm().get_modes(modes) keep = {modes} if isinstance(modes, int) else set(modes) if not keep.issubset(self.modes): raise ValueError(f"Expected a subset of ``{self.modes}``, found ``{keep}``.") idxz = [i for i, m in enumerate(self.modes) if m not in keep] idxz_conj = [i + len(self.modes) for i, m in enumerate(self.modes) if m not in keep] ansatz = self.ansatz.trace(idxz, idxz_conj) return self._from_attributes( ansatz, Wires(modes_out_bra=keep, modes_out_ket=keep), self.name )
[docs] def wormhole_1mode( self, pnr_outcomes: tuple[int, ...] | list[tuple[int, ...]], output_cutoff: int, leftover_mode: int, ) -> dict[tuple[int, ...], State]: r"""Compute conditional single-mode states given PNR measurements. Uses the wormhole algorithm to efficiently compute the conditional state of one "leftover" mode given photon number resolving (PNR) measurements on all other modes. This is much more efficient than computing the full Fock tensor and slicing for high photon counts. Supports both single and batched states. For batched states, the computation is parallelized across batch elements for improved performance. Args: pnr_outcomes: Either a single tuple or a list of tuples specifying photon counts for the measured modes. Each tuple has length n_modes - 1, with photon counts in mode order (excluding the leftover mode). output_cutoff: Maximum photon number for the output state. leftover_mode: The mode to keep unmeasured. Must be one of this state's modes. Returns: Dict mapping PNR tuples to State objects (same type as self). For batched input states, the returned states will also be batched. Raises: ValueError: If this state has fewer than 2 modes. ValueError: If leftover_mode is not one of this state's modes. ValueError: If pnr_outcomes tuples have wrong length. Example: >>> import numpy as np >>> from mrmustard.lab import GaussianKet >>> ket = GaussianKet.random([0, 1], seed=42) >>> pnr = (2,) >>> cutoff = 5 >>> results = ket.wormhole_1mode(pnr, output_cutoff=cutoff, leftover_mode=0) >>> cond_ket = results[pnr] >>> cond_ket.fock_array().shape (6,) Note: The wormhole algorithm is particularly efficient for high photon counts where computing the full Fock tensor would be prohibitively expensive. Memory scaling depends on state type: O(2^M × cutoff) for Ket, O(4^M × cutoff²) for DM, where M is the number of measured modes. """ if self.n_modes < 2: raise ValueError("wormhole_1mode requires at least 2 modes") if isinstance(pnr_outcomes, tuple): pnr_outcomes = [pnr_outcomes] modes = sorted(self.modes) if leftover_mode not in self.modes: raise ValueError(f"leftover_mode {leftover_mode} is not in this state's modes {modes}") expected_len = self.n_modes - 1 for pnr in pnr_outcomes: if len(pnr) != expected_len: raise ValueError( f"Each PNR tuple must have length {expected_len} (n_modes - 1), got {len(pnr)}" ) A, b, c = self.bargmann_triple() leftover_mode_idx = modes.index(leftover_mode) if not self.wires.ket or not self.wires.bra: # Ket case results_arrays = wormhole_1leftover_ket( A, b, c, output_cutoff=output_cutoff, pnr_outcomes=pnr_outcomes, leftover_mode=leftover_mode_idx, ) else: # DM case results_arrays = wormhole_1leftover_dm( A, b, c, output_cutoff=output_cutoff, pnr_outcomes=pnr_outcomes, leftover_mode=leftover_mode_idx, ) # Determine batch dimensions from input state batch_dims = self.ansatz.batch_dims results: dict[tuple[int, ...], State] = {} for pnr_tuple, state_array in results_arrays.items(): state_obj = self.__class__.from_ansatz( modes=[leftover_mode], ansatz=ArrayAnsatz(state_array, batch_dims=batch_dims), ) results[pnr_tuple] = state_obj return results
[docs] @abstractmethod def formal_stellar_decomposition( self, core_modes: Sequence[int], ) -> tuple[State, Transformation]: r"""Applies the formal stellar decomposition. Args: core_modes: The set of modes defining core variables. Returns: A tuple containing the core state and the Gaussian transformation performing the stellar decomposition. """
[docs] def normalize(self) -> State: r"""Returns a rescaled version of the state such that its probability is 1.""" probability = self.probability if not self.wires.ket or not self.wires.bra: return self / math.sqrt(probability) return self / probability
[docs] def phase_space(self, s: float) -> tuple: r"""Returns the phase space parametrization of a state, consisting in a covariance matrix, a vector of means and a scaling coefficient. When a state is a linear superposition of Gaussians, each of cov, means, coeff are arranged in a batch. Phase space representations are labelled by an ``s`` parameter (float) which modifies the exponent of :math:`D_s(\gamma) = e^{\frac{s}{2}|\gamma|^2}D(\gamma)`, which is the operator basis used to expand phase space density matrices. The ``s`` parameter typically takes the values of -1, 0, 1 to indicate Glauber/Wigner/Husimi functions. Args: s: The phase space parameter Returns: The covariance matrix, the mean vector and the coefficient of the state in s-parametrized phase space. """ if not isinstance(self.ansatz, PolyExpAnsatz): raise ValueError("Can calculate phase space only for Bargmann states.") if not self.wires.ket or not self.wires.bra: new_state = self.adjoint.contract(self.contract(BtoChar(self.modes, s=s))) else: new_state = self.contract(BtoChar(self.modes, s=s)) return bargmann_Abc_to_phasespace_cov_means(*new_state.bargmann_triple())
[docs] @abstractmethod def physical_stellar_decomposition( self, core_modes: Sequence[int], ) -> tuple[State, Transformation]: r"""Applies the physical stellar decomposition. Args: core_modes: The set of modes defining core variables. Returns: A tuple containing the core state and the Gaussian transformation performing the stellar decomposition. """
[docs] def quadrature_distribution(self, *quad: RealVector, phi: float = 0.0) -> ComplexTensor: r"""The (discretized) quadrature distribution of the ``State``. Args: quad: the discretized quadrature axis over which the distribution is computed. phi: The quadrature angle. ``0`` corresponds to the x quadrature, ``pi/2`` to the p quadrature. Returns: The quadrature distribution. """ if len(quad) != 1 and len(quad) != self.n_modes: raise ValueError( f"Expected {self.n_modes} or ``1`` quadrature vectors, got {len(quad)}.", ) if len(quad) == 1: quad = quad * self.n_modes if not self.wires.ket or not self.wires.bra: return math.abs(self.quadrature(*quad, phi=phi)) ** 2 return math.abs(self.quadrature(*(quad * 2), phi=phi))
[docs] def visualize_2d( self, xbounds: tuple[int, int] = (-6, 6), pbounds: tuple[int, int] = (-6, 6), resolution: int = 200, colorscale: str = "RdBu", return_fig: bool = False, min_shape: int = 50, ) -> go.Figure | None: r"""2D visualization of the Wigner function of this state. Plots the Wigner function on a heatmap, alongside the probability distributions on the two quadrature axis. >>> from mrmustard.lab import Coherent >>> state = Coherent(0, alpha=1) / 2**0.5 + Coherent(0, alpha=-1) / 2**0.5 >>> # state.visualize_2d() Args: xbounds: The range of the `x` axis. pbounds: The range of the `p` axis. resolution: The number of bins on each axes. colorscale: A colorscale. Must be one of ``Plotly``\'s built-in continuous color scales. return_fig: Whether to return the ``Plotly`` figure. min_shape: The minimum fock shape to use for the Wigner function plot. Returns: A ``Plotly`` figure representing the state in 2D. Raises: ValueError: If this state is a multi-mode state. """ if self.n_modes > 1: raise ValueError("2D visualization not available for multi-mode states.") if self.ansatz.batch_dims > 1: raise NotImplementedError("2D visualization not implemented for batched states.") shape = [max(min_shape, d) for d in self.auto_shape()] dm = self.to_fock(tuple(shape)).dm().ansatz.array x, prob_x = quadrature_distribution(dm) p, prob_p = quadrature_distribution(dm, np.pi / 2) mask_x = math.asnumpy([xi >= xbounds[0] and xi <= xbounds[1] for xi in x]) x = x[mask_x] prob_x = prob_x[mask_x] mask_p = math.asnumpy([pi >= pbounds[0] and pi <= pbounds[1] for pi in p]) p = p[mask_p] prob_p = prob_p[mask_p] xvec = np.linspace(*xbounds, resolution) pvec = np.linspace(*pbounds, resolution) z, xs, ps = wigner_discretized(dm, xvec, pvec) xs = xs[:, 0] ps = ps[0, :] fig = make_subplots( rows=2, cols=2, column_widths=[5, 3], row_heights=[1, 2], vertical_spacing=0.05, horizontal_spacing=0.05, shared_xaxes="columns", shared_yaxes="rows", ) # X-P plot # note: heatmaps revert the y axes, which is why the minus in `y=-ps` is required fig_21 = go.Heatmap( x=xs, y=-ps, z=math.transpose(z), coloraxis="coloraxis", name="Wigner function", autocolorscale=False, ) fig.add_trace(fig_21, row=2, col=1) fig.update_traces(row=2, col=1) fig.update_xaxes(range=xbounds, title_text="x", row=2, col=1) fig.update_yaxes(range=pbounds, title_text="p", row=2, col=1) # X quadrature probability distribution fig_11 = go.Scatter(x=x, y=prob_x, line={"color": "steelblue", "width": 2}, name="Prob(x)") fig.add_trace(fig_11, row=1, col=1) fig.update_xaxes(range=xbounds, row=1, col=1, showticklabels=False) fig.update_yaxes(title_text="Prob(x)", range=(0, max(prob_x)), row=1, col=1) # P quadrature probability distribution fig_22 = go.Scatter(x=prob_p, y=p, line={"color": "steelblue", "width": 2}, name="Prob(p)") fig.add_trace(fig_22, row=2, col=2) fig.update_xaxes(title_text="Prob(p)", range=(0, max(prob_p)), row=2, col=2) fig.update_yaxes(range=pbounds, row=2, col=2, showticklabels=False) fig.update_layout( height=500, width=580, plot_bgcolor="aliceblue", margin={"l": 20, "r": 20, "t": 30, "b": 20}, showlegend=False, coloraxis={"colorscale": colorscale, "cmid": 0}, ) fig.update_xaxes( showline=True, linewidth=1, linecolor="black", mirror=True, tickfont_family="Arial Black", ) fig.update_yaxes( showline=True, linewidth=1, linecolor="black", mirror=True, tickfont_family="Arial Black", ) if return_fig: return fig display(fig) return None
[docs] def visualize_2d_with_arrows( self, arrows: np.ndarray, xbounds: tuple[int, int] = (-6, 6), pbounds: tuple[int, int] = (-6, 6), resolution: int = 200, colorscale: str = "RdBu", return_fig: bool = False, min_shape: int = 50, ) -> go.Figure | None: r"""Plot the state Wigner function and q/p marginals along with arrows from the origin of the Wigner function. Useful for, e.g., visualizing the stabilizer arguments of a GKP state. Args: arrows: 1D numpy array of complex numbers representing arrow end-points. xbounds: The range of the `x` axis. pbounds: The range of the `p` axis. resolution: The number of bins on each axes. colorscale: A colorscale. Must be one of ``Plotly``'s built-in continuous color scales. return_fig: Whether to return the ``Plotly`` figure. min_shape: The minimum fock shape to use for the Wigner function plot. Returns: A ``Plotly`` figure. """ def _plot_arrow(x: float, y: float) -> go.Scatter: return go.Scatter( x=[0, x], y=[0, y], line={"color": "black", "width": 2}, marker={"symbol": "arrow", "angleref": "previous", "size": 15}, ) fig = self.visualize_2d( xbounds=xbounds, pbounds=pbounds, resolution=resolution, colorscale=colorscale, return_fig=True, min_shape=min_shape, ) assert fig is not None stabilizer_arrows = [ _plot_arrow( alpha.real * np.sqrt(2 * settings.HBAR), alpha.imag * np.sqrt(2 * settings.HBAR), ) for alpha in arrows ] for arrow in stabilizer_arrows: fig.add_trace(arrow, row=2, col=1) # row=2 col=1 is the Wigner plot of the figure if return_fig: return fig display(fig) return None
[docs] def visualize_3d( self, xbounds: tuple[int, int] = (-6, 6), pbounds: tuple[int, int] = (-6, 6), resolution: int = 200, colorscale: str = "RdBu", return_fig: bool = False, min_shape: int = 50, ) -> go.Figure | None: r"""3D visualization of the Wigner function of this state on a surface plot. Args: xbounds: The range of the `x` axis. pbounds: The range of the `p` axis. resolution: The number of bins on each axes. colorscale: A colorscale. Must be one of ``Plotly``\'s built-in continuous color scales. return_fig: Whether to return the ``Plotly`` figure. min_shape: The minimum fock shape to use for the Wigner function plot. Returns: A ``Plotly`` figure representing the state in 3D. Raises: ValueError: If this state is a multi-mode state. """ if self.n_modes != 1: raise ValueError("3D visualization not available for multi-mode states.") if self.ansatz.batch_dims > 1: raise NotImplementedError("3D visualization not implemented for batched states.") shape = [max(min_shape, d) for d in self.auto_shape()] dm = self.to_fock(tuple(shape)).dm().ansatz.array xvec = np.linspace(*xbounds, resolution) pvec = np.linspace(*pbounds, resolution) z, xs, ps = wigner_discretized(dm, xvec, pvec) xs = xs[:, 0] ps = ps[0, :] fig = go.Figure( data=go.Surface( x=xs, y=ps, z=z, coloraxis="coloraxis", hovertemplate="x: %{x:.3f}<br>p: %{y:.3f}<br>W(x, p): %{z:.3f}<extra></extra>", ), ) fig.update_layout( autosize=False, width=500, height=500, margin={"l": 0, "r": 0, "b": 0, "t": 0}, scene_camera_eye={"x": -2.1, "y": 0.88, "z": 0.64}, coloraxis={"colorscale": colorscale, "cmid": 0}, ) fig.update_traces( contours_z={ "show": True, "usecolormap": True, "highlightcolor": "limegreen", "project_z": False, }, ) fig.update_traces( contours_y={ "show": True, "usecolormap": True, "highlightcolor": "red", "project_y": False, }, ) fig.update_traces( contours_x={ "show": True, "usecolormap": True, "highlightcolor": "yellow", "project_x": False, }, ) fig.update_scenes( xaxis_title_text="x", yaxis_title_text="p", zaxis_title_text="Wigner function", ) fig.update_xaxes(title_text="x") fig.update_yaxes(title="p") if return_fig: return fig display(fig) return None
[docs] def visualize_dm( self, cutoff: int | None = None, return_fig: bool = False, ) -> go.Figure | None: r"""Plots the absolute value :math:`abs(\rho)` of the density matrix :math:`\rho` of this state on a heatmap. Args: cutoff: The desired cutoff. Defaults to the value of auto_shape. return_fig: Whether to return the ``Plotly`` figure. Returns: A ``Plotly`` figure representing absolute value of the density matrix of this state. Raises: ValueError: If this state is a multi-mode state. """ if self.n_modes != 1: raise ValueError("DM visualization not available for multi-mode states.") if self.ansatz.batch_dims > 1: raise NotImplementedError("DM visualization not implemented for batched states.") dm = self.to_fock(cutoff).dm().ansatz.array fig = go.Figure( data=go.Heatmap(z=abs(dm), colorscale="viridis", name="abs(ρ)", showscale=False), ) fig.update_yaxes(autorange="reversed") fig.update_layout( height=257, width=257, margin={"l": 30, "r": 30, "t": 30, "b": 20}, ) fig.update_xaxes(title_text=f"abs(ρ), cutoff={dm.shape[0]}") if return_fig: return fig display(fig) return None