Source code for mrmustard.physics.fock_utils

# Copyright 2021 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This module contains functions for performing calculations on objects in the Fock representations."""

from __future__ import annotations

from collections.abc import Iterable, Sequence
from functools import lru_cache

import numpy as np
from scipy.special import comb, factorial

from mrmustard import math, settings
from mrmustard.utils.typing import Scalar, Tensor, Vector

__all__ = [
    "c_in_PS",
    "c_ps_matrix",
    "estimate_dx",
    "estimate_quadrature_axis",
    "estimate_xmax",
    "fidelity",
    "fock_state",
    "gamma_matrix",
    "oscillator_eigenstate",
    "quadrature_basis",
    "quadrature_distribution",
]


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~ static functions ~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~


[docs] def fock_state(n: int | Sequence[int], cutoff: int | None = None) -> Tensor: r"""The Fock array of a batchable single-mode ``Number`` state. Args: n: The photon number of the number state. Can be a single integer or a batch of integers. cutoff: The cutoff of the Fock array. Returns: The Fock array of a batchable single-mode ``Number`` state. """ n = math.astensor(n, dtype=math.int64) cutoff = cutoff or int(math.max(n)) return math.eye(cutoff + 1)[n]
[docs] def fidelity(dm_a, dm_b) -> Scalar: r"""Computes the fidelity between two states in Fock representation.""" # Richard Jozsa (1994) Fidelity for Mixed Quantum States, # Journal of Modern Optics, 41:12, 2315-2323, DOI: 10.1080/09500349414552171 sqrt_dm_a = math.sqrtm(dm_a) return math.abs(math.trace(math.sqrtm(math.matmul(sqrt_dm_a, dm_b, sqrt_dm_a))) ** 2)
[docs] def oscillator_eigenstate(q: Vector, cutoff: int) -> Tensor: r"""Harmonic oscillator eigenstate wavefunction `\psi_n(q) = <n|q>`. Args: q: array of q points at which the function is evaluated (units of \sqrt{\hbar}). Can have any shape; the result has shape ``(cutoff, *q.shape)``. cutoff (int): Fock space dimension (shape). Note: despite the parameter name, this is the shape (cutoff + 1), not the max photon number. Callers pass shape values here. Returns: Tensor: shape ``(cutoff, *q.shape)``. Entry ``[n, ...]`` is :math:`\psi_n` evaluated at the corresponding points of *q*. .. details:: .. admonition:: Definition :class: defn The q-quadrature eigenstates are defined as .. math:: \psi_n(x) = 1/sqrt[2^n n!](\frac{\omega}{\pi \hbar})^{1/4} \exp{-\frac{\omega}{2\hbar} x^2} H_n(\sqrt{\frac{\omega}{\pi}} x) where :math:`H_n(x)` is the (physicists) `n`-th Hermite polynomial. """ hbar = settings.HBAR q = np.asarray(q) batch_shape = q.shape x = math.cast(q.ravel() / np.sqrt(hbar), math.complex128) # unit-less vector # prefactor term (\Omega/\hbar \pi)**(1/4) * 1 / sqrt(2**n) prefactor = math.cast( (np.pi * hbar) ** (-0.25) * math.pow(0.5, math.arange(0, cutoff) / 2), math.complex128, ) # Renormalized physicist hermite polys: Hn / sqrt(n!) R = -np.array([[2 + 0j]]) # to get the physicist polys hermite_polys = math.hermite_renormalized(R, 2 * x[..., None], 1 + 0j, (cutoff,)) # (real) wavefunction — shape (cutoff, *batch_shape) result = math.exp(-(x**2 / 2)) * math.transpose(prefactor * hermite_polys) return math.reshape(result, (cutoff, *batch_shape))
[docs] @lru_cache def estimate_dx(cutoff, period_resolution=20): r"""Estimates a suitable quadrature discretization interval `dx`. Uses the fact that Fock state `n` oscillates with angular frequency :math:`\sqrt{2(n + 1)}`, which follows from the relation. .. math:: \psi^{[n]}'(q) = q - sqrt(2*(n + 1))*\psi^{[n+1]}(q) by setting q = 0, and approximating the oscillation amplitude by `\psi^{[n+1]}(0)`. Ref: https://en.wikipedia.org/wiki/Hermite_polynomials#Hermite_functions Args: cutoff (int): Fock cutoff period_resolution (int): Number of points used to sample one Fock wavefunction oscillation. Larger values yields better approximations and thus smaller `dx`. Returns: (float): discretization value of quadrature """ fock_cutoff_frequency = np.sqrt(2 * (cutoff + 1)) fock_cutoff_period = 2 * np.pi / fock_cutoff_frequency return fock_cutoff_period / period_resolution
[docs] @lru_cache def estimate_xmax(cutoff, minimum=5): r"""Estimates a suitable quadrature axis length. Args: cutoff (int): Fock cutoff minimum (float): Minimum value of the returned xmax Returns: (float): maximum quadrature value """ if cutoff == 0: xmax_estimate = 3 else: # maximum q for a classical particle with energy n=cutoff classical_endpoint = np.sqrt(2 * cutoff) # approximate probability of finding particle outside classical region excess_probability = 1 / (7.464 * cutoff ** (1 / 3)) # Emperical factor that yields reasonable results A = 5 xmax_estimate = classical_endpoint * (1 + A * excess_probability) return max(minimum, xmax_estimate)
[docs] @lru_cache def estimate_quadrature_axis(cutoff, minimum=5, period_resolution=20): r"""Generates a suitable quadrature axis. Args: cutoff (int): Fock cutoff minimum (float): Minimum value of the returned xmax period_resolution (int): Number of points used to sample one Fock wavefunction oscillation. Larger values yields better approximations and thus smaller dx. Returns: (array): quadrature axis """ xmax = estimate_xmax(cutoff, minimum=minimum) dx = estimate_dx(cutoff, period_resolution=period_resolution) xaxis = np.arange(-xmax, xmax, dx) xaxis = np.append(xaxis, xaxis[-1] + dx) return xaxis - np.mean(xaxis) # center around 0
[docs] def quadrature_basis( fock_array: Tensor, quad: Vector, conjugates: bool | list[bool] = False, phi: Scalar = 0.0, ): r"""Given the Fock basis representation return the quadrature basis representation. Args: fock_array (Tensor): fock tensor amplitudes quad (Batch[Vector]): points at which the quadrature basis is evaluated conjugates (list[bool]): which dimensions of the array to conjugate based on whether it is a bra or a ket phi (float): angle of the quadrature basis vector Returns: tuple(Tensor): quadrature basis representation at the points in quad """ dims = len(fock_array.shape) if quad.shape[-1] != dims: raise ValueError( f"Input fock array has dimension {dims} whereas ``quad`` has {quad.shape[-1]}.", ) conjugates = conjugates if isinstance(conjugates, Iterable) else [conjugates] * dims # construct quadrature basis vectors shapes = fock_array.shape quad_basis_vecs = [] for dim in range(dims): q_to_n = oscillator_eigenstate(quad[..., dim], shapes[dim]) if not np.isclose(phi, 0.0): theta = -math.arange(shapes[dim]) * phi Ur = math.make_complex(math.cos(theta), math.sin(theta)) q_to_n = math.einsum("n,nq->nq", Ur, q_to_n) if conjugates[dim]: q_to_n = math.conj(q_to_n) quad_basis_vecs += [math.cast(q_to_n, "complex128")] # Convert each dimension to quadrature fock_string = "".join([chr(i) for i in range(98, 98 + dims)]) #'bcd....' q_string = "".join([fock_string[i] + "a," for i in range(dims - 1)] + [fock_string[-1] + "a"]) return math.einsum(fock_string + "," + q_string + "->" + "a", fock_array, *quad_basis_vecs)
[docs] def quadrature_distribution( state: Tensor, quadrature_angle: float = 0.0, x: Vector | None = None, ): r"""Given the ket or density matrix of a single-mode state, it generates the probability density distribution :math:`\tr [ \rho |x_\phi><x_\phi| ]` where ``\rho`` is the density matrix of the state and ``|x_\phi>`` the quadrature eigenvector with angle ``\phi`` equal to ``quadrature_angle``. Args: state: A single mode state ket or density matrix. quadrature_angle: The angle of the quadrature basis vector. x: The points at which the quadrature distribution is evaluated. Returns: The coordinates at which the pdf is evaluated and the probability distribution. """ shape = state.shape[0] cutoff = shape - 1 if x is None: x = np.sqrt(settings.HBAR) * estimate_quadrature_axis(cutoff) dims = len(state.shape) is_dm = dims == 2 quad = math.transpose(math.astensor([x] * dims)) conjugates = [True, False] if is_dm else [False] quad_basis = quadrature_basis(state, quad, conjugates, quadrature_angle) pdf = quad_basis if is_dm else math.abs(quad_basis) ** 2 return x, math.real(pdf)
[docs] def c_ps_matrix(m, n, alpha): """Helper function for ``c_in_PS``.""" mu_range = range(max(0, alpha - n), min(m, alpha) + 1) tmp = [comb(m, mu) * comb(n, alpha - mu) * (1j) ** (m - n - 2 * mu + alpha) for mu in mu_range] return np.sum(tmp)
[docs] def gamma_matrix(c): """Helper function for ``c_in_PS``. constructs the matrix transformation that helps transforming ``c``. ``c`` here must be 2-dimensional. """ M = c.shape[0] + c.shape[1] - 1 Gamma = np.zeros((M**2, c.shape[0] * c.shape[1]), dtype=np.complex128) for m in range(c.shape[0]): for n in range(c.shape[1]): for alpha in range(m + n + 1): factor = math.sqrt( factorial(m) * factorial(n) / (factorial(alpha) * factorial(m + n - alpha)), ) value = c_ps_matrix(m, n, alpha) * math.sqrt(settings.HBAR / 2) ** (m + n) row = alpha * M + (m + n - alpha) col = m * c.shape[0] + n Gamma[row, col] = value / factor return Gamma
[docs] def c_in_PS(c): """Transforms the ``c`` matrix of a ``DM`` object from bargmann to phase-space. Args: c (Tensor): the 2-dimensional ``c`` matrix of the ``DM`` object """ M = c.shape[0] + c.shape[1] - 1 return np.reshape(gamma_matrix(c) @ np.reshape(c, (c.shape[0] * c.shape[1], 1)), (M, M))