Source code for mrmustard.physics.ansatz.polyexp_ansatz

# Copyright 2023 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This module contains the PolyExp ansatz."""

from __future__ import annotations

import itertools
from collections.abc import Sequence
from typing import Any

import numpy as np
from IPython.display import display
from numpy.typing import ArrayLike

from mrmustard import math, settings, widgets
from mrmustard.physics.fock_utils import c_in_PS
from mrmustard.physics.utils import (
    batch_indexer_info,
    generate_batch_str,
    join_Abc,
    outer_product_batch_str,
    reshape_args_to_batch_string,
    verify_triple,
)
from mrmustard.utils.argsort import argsort_gen
from mrmustard.utils.typing import (
    ComplexMatrix,
    ComplexScalar,
    ComplexTensor,
    ComplexVector,
    Matrix,
    Scalar,
    Tensor,
    Vector,
)

from .base import Ansatz

__all__ = ["PolyExpAnsatz"]


[docs] class PolyExpAnsatz(Ansatz): r"""This class represents the ansatz of a polynomial exponential function. Namely, :math:`F^{(i)}(z) = \sum_k c^{(i)}_{k} \partial_y^k \textrm{exp}(\frac{1}{2}(z,y)^T A^{(i)} (z,y) + (z,y)^T b^{(i)})|_{y=0}` with ``k`` and ``i`` multi-indices. The ``i`` multi-index is a batch index of shape ``L`` that can be used for linear superposition or batching purposes. Each of the ``c^{(i)}_k`` tensors are contracted with the array of derivatives :math:`\partial_y^k` to form polynomials of derivatives. The tensors :math:`c^{(i)}_{k}` contain the coefficients of the polynomial of derivatives and have shape ``(*L, *derived)``, where ``*derived`` is the shape of the derived variables, which implies ``len(c.shape[1:]) = m``. The matrices :math:`A^{(i)}` and vectors :math:`b^{(i)}` are the parameters of the exponential terms in the ansatz, with :math:`z\in\mathbb{C}^{n}` and :math:`y\in\mathbb{C}^{m}`. ``A`` and ``b`` have shape ``(*L, n+m, n+m)`` and ``(*L, n+m)``, respectively. Args: A: A batch of quadratic coefficient :math:`A^{(i)}`. b: A batch of linear coefficients :math:`b^{(i)}`. c: A batch of arrays :math:`c^{(i)}`. name: The name of the ansatz. lin_sup: Whether to include linear superposition axes in the batch dimensions. >>> from mrmustard.physics.ansatz import PolyExpAnsatz >>> import numpy as np >>> A = np.random.random((3,3)) # no batch >>> b = np.random.random((3,)) >>> c = np.random.random() >>> F = PolyExpAnsatz(A, b, c) >>> assert F(1.0, 2.0, 3.0).shape == () >>> A = np.random.random((10,3,3)) # batch of 10 >>> b = np.random.random((10,3)) >>> c = np.random.random((10,)) >>> F = PolyExpAnsatz(A, b, c) >>> assert F(1.0, 2.0, 3.0).shape == (10,) >>> A = np.random.random((10,3,3)) # batch of 10 >>> b = np.random.random((10,3)) >>> c = np.random.random((10,7)) >>> F = PolyExpAnsatz(A, b, c) >>> assert F(1.0, 2.0).shape == (10,) # two CV variables, one derived >>> A = np.random.random((10,3,3)) # batch of 10 >>> b = np.random.random((10,3)) >>> c = np.random.random((10,7,5)) >>> F = PolyExpAnsatz(A, b, c) >>> assert F(1.0).shape == (10,) # one CV variable, two derived >>> assert F([1.0, 2.0, 3.0]).shape == (3,10) # batch of 3 inputs """ def __init__( self, A: Matrix, b: Vector, c: Tensor, name: str = "", lin_sup: bool = False, ): super().__init__() self.name = name self._simplified = False self._lin_sup = lin_sup self._A = math.astensor(A, dtype=math.complex128) self._b = math.astensor(b, dtype=math.complex128) self._c = math.astensor(c, dtype=math.complex128) verify_triple(self._A, self._b, self._c) self._batch_shape = tuple(self._A.shape[:-2]) @property def A(self) -> ComplexMatrix: r"""The batch of quadratic coefficient :math:`A^{(i)}`.""" return self._A @property def b(self) -> ComplexVector: r"""The batch of linear coefficients :math:`b^{(i)}`.""" return self._b @property def batch_dims(self) -> int: return len(self.batch_shape) @property def batch_shape(self) -> tuple[int, ...]: return self._batch_shape @property def batch_size(self) -> int: return math.prod(self.batch_shape) if self.batch_shape else 0 @property def c(self) -> ComplexTensor: r"""The batch of polynomial coefficients :math:`c^{(i)}_{k}`.""" return self._c @property def conj(self) -> PolyExpAnsatz: return PolyExpAnsatz( math.conj(self.A), math.conj(self.b), math.conj(self.c), lin_sup=self._lin_sup, ) @property def core_dims(self) -> int: r"""The number of core variables of the ansatz. Equivalent to ``self.num_CV_vars``.""" return self.num_CV_vars @property def data( self, ) -> tuple[ComplexMatrix, ComplexVector, ComplexTensor]: r"""Returns the triple, which is necessary to reinstantiate the ansatz.""" return self.triple @property def num_CV_vars(self) -> int: r"""The number of continuous variables that remain after the polynomial of derivatives is applied. This is the number of continuous variables of the Ansatz function itself. """ return self.num_vars - self.num_derived_vars @property def num_derived_vars(self) -> int: r"""The number of derived variables that are derived by the polynomial of derivatives.""" return len(self.shape_derived_vars) @property def num_vars(self): return self.A.shape[-1] @property def PS(self) -> PolyExpAnsatz: r"""The ansatz defined using real (i.e., phase-space) variables.""" n = self.A.shape[-1] if n % 2: raise ValueError( f"A phase space ansatz must have even number of indices. (n={n} is odd)", ) if self.num_derived_vars == 0: W = math.conj(math.rotmat(n // 2)) / math.sqrt(settings.HBAR, dtype=math.complex128) A = math.einsum("ji,...jk,kl->...il", W, self.A, W) b = math.einsum("ij,...j->...i", W, self.b) c = self.c / (2 * settings.HBAR) ** (n // 2) return PolyExpAnsatz(A, b, c, lin_sup=self._lin_sup) if self.num_derived_vars != 2: raise ValueError("This transformation supports 2 core and 0 or 2 derived variables") A_tmp = self.A A_tmp = A_tmp[..., [0, 2, 1, 3], :][..., [0, 2, 1, 3]] b = self.b[..., [0, 2, 1, 3]] c = c_in_PS(self.c) # implements PS transformations on ``c`` W = math.conj(math.rotmat(n // 2)) / math.sqrt(settings.HBAR, dtype=math.complex128) A = math.einsum("ji,...jk,kl->...il", W, A_tmp, W) b = math.einsum("ij,...j->...i", W, b) c = c / (2 * settings.HBAR) A_out = A[..., [0, 2, 1, 3], :][..., :, [0, 2, 1, 3]] b_out = b[..., [0, 2, 1, 3]] return PolyExpAnsatz(A_out, b_out, c, lin_sup=self._lin_sup) @property def scalar(self) -> Scalar: r"""The scalar part of the ansatz, i.e. F(0).""" if self.num_CV_vars == 0 and self.num_derived_vars == 0: ret = math.einsum("...a->...", self.c) if self._lin_sup else self.c elif self.num_CV_vars == 0: ret = self() else: ret = self(*math.zeros(self.num_CV_vars)) return ret @property def shape_derived_vars(self) -> tuple[int, ...]: r"""The shape of the coefficients of the polynomial of derivatives.""" return tuple(self.c.shape[self.batch_dims :]) @property def triple( self, ) -> tuple[ComplexMatrix, ComplexVector, ComplexTensor]: r"""Returns the triple of parameters of the exponential part of the ansatz.""" return self.A, self.b, self.c
[docs] @classmethod def from_dict(cls, data: dict[str, ArrayLike]) -> PolyExpAnsatz: r"""Creates an ansatz from a dictionary. For deserialization purposes.""" return cls(**data)
[docs] def concat(self, other: Ansatz, axis: int = 0) -> PolyExpAnsatz: r"""Concatenates two PolyExpAnsatz objects along the specified batch axis. All batch axes except the concatenation axis must agree in size. The lin_sup status must match. Core dimensions (num_CV_vars and num_derived_vars) must also match. Args: other: The other PolyExpAnsatz to concatenate with. axis: The batch axis along which to concatenate (default 0). Returns: A new PolyExpAnsatz with the concatenated batch dimensions. Raises: TypeError: If ``other`` is not a ``PolyExpAnsatz``. ValueError: If batch dimensions don't match (except at axis), if lin_sup status doesn't match, or if core dimensions don't match. >>> from mrmustard.physics.ansatz import PolyExpAnsatz >>> import numpy as np >>> A1 = np.random.random((2, 3, 3)) >>> b1 = np.random.random((2, 3)) >>> c1 = np.random.random((2,)) >>> ansatz1 = PolyExpAnsatz(A1, b1, c1) >>> A2 = np.random.random((3, 3, 3)) >>> b2 = np.random.random((3, 3)) >>> c2 = np.random.random((3,)) >>> ansatz2 = PolyExpAnsatz(A2, b2, c2) >>> concatenated = ansatz1.concat(ansatz2, axis=0) >>> assert concatenated.batch_shape == (5,) """ if not isinstance(other, PolyExpAnsatz): raise TypeError(f"Cannot concatenate with ansatz of type {type(other)}!") if self._lin_sup != other._lin_sup: raise ValueError( f"Cannot concat ansatze with different lin_sup status: {self._lin_sup} vs {other._lin_sup}" ) if self.num_CV_vars != other.num_CV_vars: raise ValueError( f"Cannot concat ansatze with different num_CV_vars: {self.num_CV_vars} vs {other.num_CV_vars}" ) if self.num_derived_vars != other.num_derived_vars: raise ValueError( f"Cannot concat ansatze with different num_derived_vars: {self.num_derived_vars} vs {other.num_derived_vars}" ) # Normalize negative axis if axis < 0: axis = self.batch_dims + axis if axis < 0 or axis >= self.batch_dims: raise ValueError(f"axis {axis} is out of range for batch_dims {self.batch_dims}") # Check that all other batch axes match for i in range(self.batch_dims): if i != axis and self.batch_shape[i] != other.batch_shape[i]: raise ValueError( f"Batch shapes must match except at axis {axis}. " f"Got {self.batch_shape} and {other.batch_shape}" ) # Concatenate the triple components along the batch axis A_concat = math.concat([self.A, other.A], axis=axis) b_concat = math.concat([self.b, other.b], axis=axis) c_concat = math.concat([self.c, other.c], axis=axis) return PolyExpAnsatz(A_concat, b_concat, c_concat, lin_sup=self._lin_sup)
[docs] def contract( self, other: Ansatz, idxs: tuple[Sequence[int], Sequence[int]], ) -> PolyExpAnsatz: r"""Contract along specified core (CV) axes, broadcasting batch dimensions and taking the kronecker product of linear-superposition dimensions. Note: - Batch dims broadcast automatically (no batch string needed). - Negative indices are normalized relative to each operand's CV count. Args: other: The other PolyExpAnsatz to contract with. idxs: Tuple ``(idx_self, idx_other)`` of CV-axis indices to integrate over (0-based relative to the CV variables). The two sequences must have the same length. Negative indices allowed. Returns: The contracted PolyExpAnsatz with kept core axes ordered as ``[self non-contracted] + [other non-contracted]`` and derived axes as ``[self derived] + [other derived]``. Raises: TypeError: If ``other`` is not a ``PolyExpAnsatz``. """ if not isinstance(other, PolyExpAnsatz): raise TypeError(f"Cannot contract with ansatz of type {type(other)}!") # Validate indices idx1, idx2 = idxs n1, n2 = self.num_CV_vars, other.num_CV_vars m1, m2 = self.num_derived_vars, other.num_derived_vars if len(idx1) != len(idx2): raise ValueError( f"idxs must have sequences of equal length, got {len(idx1)} and {len(idx2)}." ) if self._lin_sup and other._lin_sup: # kron lin sup axes A1 = math.expand_dims(self.A, axis=-3) b1 = math.expand_dims(self.b, axis=-2) c1 = math.expand_dims(self.c, axis=-m1 - 1) A2 = math.expand_dims(other.A, axis=-4) b2 = math.expand_dims(other.b, axis=-3) c2 = math.expand_dims(other.c, axis=-m2 - 2) elif self._lin_sup or other._lin_sup: # zip lin sup axes A1, b1, c1 = self._ensure_lin_sup_axis().triple A2, b2, c2 = other._ensure_lin_sup_axis().triple else: A1 = self.A b1 = self.b c1 = self.c A2 = other.A b2 = other.b c2 = other.c A_post, b_post, log_c_factor = math.complex_gaussian_integral_2(A1, b1, A2, b2, idx1, idx2) # Reorder core indices to [self CV remaining, other CV remaining, self derived, other derived] k = len(idx1) s1, s2, s3, s4 = n1 - k, m1, n2 - k, m2 order = ( list(range(s1)) + list(range(s1 + s2, s1 + s2 + s3)) + list(range(s1, s1 + s2)) + list(range(s1 + s2 + s3, s1 + s2 + s3 + s4)) ) A_post = math.gather(math.gather(A_post, order, axis=-1), order, axis=-2) b_post = math.gather(b_post, order, axis=-1) # Combine polynomial parts: outer product of c's times the scalar (batched) c_factor if m1 or m2: poly1 = "".join(chr(97 + i) for i in range(m1)) poly2 = "".join(chr(97 + m1 + j) for j in range(m2)) log_c12 = math.log(math.einsum(f"...{poly1},...{poly2}->...{poly1}{poly2}", c1, c2)) c_out = math.exp(log_c_factor[..., *((None,) * (m1 + m2))] + log_c12) else: c_out = math.exp(log_c_factor + math.log(c1) + math.log(c2)) # If both have linear-superposition dimensions, collapse them into one if self._lin_sup and other._lin_sup: batch_shape = A_post.shape[:-2] # collapse the last two batch axes new_bs = (*batch_shape[:-2], batch_shape[-2] * batch_shape[-1]) A_post = math.reshape(A_post, new_bs + A_post.shape[-2:]) b_post = math.reshape(b_post, new_bs + b_post.shape[-1:]) c_out = math.reshape(c_out, new_bs + c_out.shape[len(batch_shape) :]) return PolyExpAnsatz(A_post, b_post, c_out, lin_sup=self._lin_sup or other._lin_sup)
[docs] def decompose_ansatz(self) -> PolyExpAnsatz: r"""This method decomposes a PolyExp ansatz to make it more efficient to evaluate. An ansatz with ``n`` CV variables and ``m`` derived variables has parameters with the following shapes: ``A=(batch;n+m,n+m)``, ``b=(batch;n+m)``, ``c = (batch;k_1,k_2,...,k_m;j_1,...,j_d)``, where ``d`` is the number of discrete variables, i.e. axes of the array of values that we get when we evaluate the ansatz at a point. This can be rewritten as an ansatz of dimension ``A=(*batch;2n,2n)``, ``b=(*batch;2n)``, ``c = (*batch;l_1,l_2,...,l_n;j_1,...,j_d)``, with ``l_i = sum_j k_j``. This means that the number of continuous variables remains ``n``, the number of derived variables decreases from ``m`` to ``n``, and the number of discrete variables remains ``d``. The price we pay is that the order of the derivatives is larger (the order of each derivative is the sum of all the orders of the initial derivatives). This decomposition is typically favourable if ``m > n`` and the sum of the elements in ``c.shape[1:]`` is not too large. This method will actually decompose the ansatz only if ``m > n`` and return the original ansatz otherwise. """ if self.num_derived_vars < self.num_CV_vars: return self n = self.num_CV_vars A, b, c = self.triple pulled_out_input_shape = ( int(math.sum(self.shape_derived_vars)), ) * n # cast to int for jax poly_shape = pulled_out_input_shape + self.shape_derived_vars batch_shape = A.shape[:-2] A_core = math.block( [ [math.zeros((*batch_shape, n, n), dtype=A.dtype), A[..., :n, n:]], [A[..., n:, :n], A[..., n:, n:]], ], ) b_core = math.concat((math.zeros((*batch_shape, n), dtype=b.dtype), b[..., n:]), axis=-1) poly_core = math.hermite_renormalized( A_core, b_core, math.ones(self.batch_shape, dtype=math.complex128), shape=poly_shape, ) derived_vars_size = int(math.prod(self.shape_derived_vars)) poly_core = math.reshape( poly_core, batch_shape + pulled_out_input_shape + (derived_vars_size,), ) batch_str = generate_batch_str(len(batch_shape)) c_prime = math.einsum( f"{batch_str}...k,{batch_str}...k->{batch_str}...", poly_core, c.reshape((*batch_shape, derived_vars_size)), ) block = A[..., :n, :n] I_matrix = math.broadcast_to(math.eye_like(block), block.shape) A_decomp = math.block([[block, I_matrix], [I_matrix, math.zeros_like(block)]]) b_decomp = math.concat((b[..., :n], math.zeros((*batch_shape, n), dtype=b.dtype)), axis=-1) return PolyExpAnsatz(A_decomp, b_decomp, c_prime, lin_sup=self._lin_sup)
def display(self): if widgets.IN_INTERACTIVE_SHELL: print(repr(self)) return display(widgets.bargmann(self))
[docs] def eval( self, *z: Vector | None, batch_string: str | None = None, ) -> Scalar | ArrayLike | PolyExpAnsatz: r"""Evaluates the ansatz at given points or returns a partially evaluated ansatz. This method supports passing an einsum-style batch string to specify how the batch dimensions of the arguments should be handled. For example, for two arguments, "i,j->ij" means to take the outer product of the batch dimensions. The batch dimensions of the ansatz itself are not part of the batch string and are placed after the output batch dimensions in the output. 1. Partial evaluation: If any argument is None, it returns a new ansatz with those arguments unevaluated. For example, if F(z1, z2, z3) is called as F(1.0, None, 3.0), it returns a new ansatz G(z2) with z1 and z3 fixed at 1.0 and 3.0. 2. Full evaluation: If all arguments are provided, it returns the value of the ansatz at those points. The returned shape depends on the batch shape of the ansatz itself and the batch dimensions of the inputs and the batch string. Args: z: points in C where the function is (partially) evaluated or None if the variable is not evaluated. batch_string: like einsum string for batch dimensions of the inputs, e.g. "i,j->ij" Returns: The value of the ansatz or a new ansatz if partial evaluation is performed. """ if len(z) > self.num_CV_vars: raise ValueError( f"The ansatz was called with {len(z)} variables, " f"but it only has {self.num_CV_vars} CV variables.", ) evaluated_indices = [i for i, zi in enumerate(z) if zi is not None] only_z = [math.astensor(zi) for zi in z if zi is not None] if batch_string is None: # Generate default batch string if none provided batch_string = outer_product_batch_str(*[zi.ndim for zi in only_z]) reshaped_z = reshape_args_to_batch_string(only_z, batch_string) broadcasted_z = math.broadcast_arrays(*reshaped_z) if len(evaluated_indices) == self.num_CV_vars: # Full evaluation: all CV vars specified return self(*broadcasted_z) # Partial evaluation: some CV variables are not provided combined_z = math.stack(broadcasted_z, axis=-1) return self._partial_eval(combined_z, tuple(evaluated_indices))
[docs] def reorder(self, order: Sequence[int]) -> PolyExpAnsatz: r"""Reorders the CV indices of an (A,b,c) triple. The length of ``order`` must equal the number of CV variables. This method returns a new ansatz object. """ if len(order) != self.num_CV_vars: raise ValueError(f"order must have length {self.num_CV_vars}, got {len(order)}") # Add derived variable indices after CV indices order = list(order) + list( range(self.num_CV_vars, self.num_CV_vars + self.num_derived_vars), ) A = math.gather(math.gather(self.A, order, axis=-1), order, axis=-2) b = math.gather(self.b, order, axis=-1) return PolyExpAnsatz(A, b, self.c, lin_sup=self._lin_sup)
[docs] def reorder_batch( self, order: Sequence[int] ) -> PolyExpAnsatz: # TODO: omit last batch index if lin_sup if len(order) != self.batch_dims: raise ValueError( f"order must have length {self.batch_dims} (number of batch dimensions), got {len(order)}", ) core_dims_indices_A = range(self.batch_dims, self.batch_dims + 2) core_dims_indices_b = range(self.batch_dims, self.batch_dims + 1) core_dims_indices_c = range(self.batch_dims, self.batch_dims + self.num_derived_vars) new_A = math.transpose(self.A, list(order) + list(core_dims_indices_A)) new_b = math.transpose(self.b, list(order) + list(core_dims_indices_b)) new_c = math.transpose(self.c, list(order) + list(core_dims_indices_c)) return PolyExpAnsatz(new_A, new_b, new_c, lin_sup=self._lin_sup)
# TODO: this should be moved to classes responsible for interpreting a batch dimension as a sum
[docs] def simplify(self) -> PolyExpAnsatz: r"""Returns a new ansatz simplified by combining terms that have the same exponential part, i.e. two components of the batch are considered equal if their ``A`` and ``b`` are equal. In this case only one is kept and the corresponding ``c`` are added. Will return immediately if the ansatz has already been simplified, so it is safe to re-call. Raises: NotImplementedError: If the ``PolyExpAnsatz`` is batched (w/o linear superposition). """ if self._simplified or not self._lin_sup: return self batch_shape = self.batch_shape[:-1] if self._lin_sup else self.batch_shape if batch_shape: raise NotImplementedError("Batched simplify is not implemented.") (A, b, c), to_keep = self._find_unique_terms_sorted() A = math.gather(A, to_keep, axis=0) b = math.gather(b, to_keep, axis=0) c = math.gather(c, to_keep, axis=0) # already added A = math.reshape(A, (len(to_keep), self.num_vars, self.num_vars)) b = math.reshape(b, (len(to_keep), self.num_vars)) c = math.reshape(c, (len(to_keep), *self.shape_derived_vars)) new_ansatz = PolyExpAnsatz(A, b, c, lin_sup=self._lin_sup) new_ansatz._simplified = True return new_ansatz
[docs] def to_dict(self) -> dict[str, ArrayLike]: r"""Returns a dictionary representation of the ansatz. For serialization purposes.""" return {"A": self.A, "b": self.b, "c": self.c}
[docs] def trace(self, idx_z: tuple[int, ...], idx_zconj: tuple[int, ...], measure: float = -1.0): r"""Computes the trace of the ansatz across the specified pairs of CV variables. Args: idx_z: The indices indicating which CV variables to integrate over. idx_zconj: The indices indicating which conjugate CV variables to integrate over. measure: The measure to use in the complex Gaussian integral. Returns: A new ansatz with the specified indices traced out. """ if len(idx_z) != len(idx_zconj): raise ValueError("idx_z and idx_zconj must have the same length.") if len(set(idx_z + idx_zconj)) != len(idx_z) + len(idx_zconj): raise ValueError( f"Indices must be unique: {set(idx_z).intersection(idx_zconj)} are repeated.", ) if any(i >= self.num_CV_vars for i in idx_z) or any( i >= self.num_CV_vars for i in idx_zconj ): raise ValueError( f"All indices must be between 0 and {self.num_CV_vars - 1}. Got {idx_z} and {idx_zconj}.", ) A_in, b_in, c_in = self.triple log_c_in = math.log(math.cast(c_in, "complex128")) A, b, log_c = math.complex_gaussian_integral_1(A_in, b_in, idx_z + idx_zconj) # broadcast log_c and log_c_in to the same shape if log_c_in.shape != log_c.shape: log_c = math.reshape(log_c, log_c.shape + (1,) * (log_c_in.ndim - log_c.ndim)) return PolyExpAnsatz(A, b, math.exp(log_c_in + log_c), lin_sup=self._lin_sup)
def _combine_exp_and_poly( self, exp_sum: ComplexTensor, poly: ComplexTensor, c: ComplexTensor, ) -> ComplexTensor: r"""Combines exponential and polynomial parts using einsum. Needed in ``__call__``.""" poly_string = "".join(chr(i) for i in range(97, 97 + len(self.shape_derived_vars))) return math.einsum(f"...,...{poly_string},...{poly_string}->...", exp_sum, c, poly) def _compute_exp_part( self, z: Vector, A: ComplexMatrix, b: ComplexVector, log_c: ComplexScalar, ) -> Scalar: r"""Computes the exponential part of the ansatz evaluation. Needed in ``__call__``. The exponential part is given by: .. math:: \exp\left(\frac{1}{2} z^T A z + b^T z + \log c\right) where :math:`A` is the matrix of the quadratic part of the ansatz, :math:`b` is the vector of the linear part that correspond to the vector of given CV variables, and :math:`\log c` is a scalar folded into the exponent to prevent overflow when :math:`c` is very small and the quadratic/linear terms are very large. """ n = self.num_CV_vars A_part = math.einsum("...a,...b,...ab->...", z, z, A[..., :n, :n]) b_part = math.einsum("...a,...a->...", z, b[..., :n]) return math.exp(1 / 2 * A_part + b_part + log_c) def _compute_polynomial_part( self, z: Vector, A: ComplexMatrix, b: ComplexVector, ) -> Scalar: r"""Computes the polynomial part of the ansatz evaluation. Needed in ``__call__``.""" n = self.num_CV_vars batch_shape = z.shape[:-1] b_poly = math.einsum("...ab,...a->...b", A[..., :n, n:], z) + b[..., n:] return math.hermite_renormalized( A[..., n:, n:], b_poly, math.ones(batch_shape, dtype=math.complex128), shape=self.shape_derived_vars, ) def _ensure_lin_sup_axis(self) -> PolyExpAnsatz: if self._lin_sup: return PolyExpAnsatz(self._A, self._b, self._c, lin_sup=True) return PolyExpAnsatz( math.expand_dims(self._A, axis=-3), math.expand_dims(self._b, axis=-2), math.expand_dims(self._c, axis=-self.num_derived_vars - 1), lin_sup=True, ) def _find_unique_terms_sorted( self, ) -> tuple[tuple[ComplexMatrix, ComplexVector, ComplexTensor], list[int]]: r"""Finds unique terms by first sorting the batch dimension and adds the corresponding c values. Returns: The updated vectorized (A,b,c) triple and a list of indices to keep after simplification. """ A, b, c = self._order_batch() to_keep = [d0 := 0] mat, vec = A[d0], b[d0] for d in range(1, self.batch_size): if not ( math.allclose(mat, A[d], atol=settings.ATOL) and math.allclose(vec, b[d], atol=settings.ATOL) ): to_keep.append(d) d0 = d mat, vec = A[d0], b[d0] else: d0r = np.unravel_index(d0, self.batch_shape) dr = np.unravel_index(d, self.batch_shape) c = math.update_add_tensor(c, [d0r], [c[dr]]) return (A, b, c), to_keep def _order_batch( self, ) -> tuple[ComplexMatrix, ComplexVector, ComplexTensor]: r"""This method orders the batch dimension by the lexicographical order of the flattened arrays (A, b, c). This is a very cheap way to enforce an ordering of the batch dimension, which is useful for simplification and for determining (in)equality between two PolyExp ansatz. Returns: The ordered vectorized (A, b, c) triple. """ if not self.batch_shape: return self.A, self.b, self.c A_vectorized = math.reshape(self.A, (self.batch_size, self.num_vars, self.num_vars)) b_vectorized = math.reshape(self.b, (self.batch_size, self.num_vars)) c_vectorized = math.reshape(self.c, (self.batch_size, *self.shape_derived_vars)) generators = [ itertools.chain( math.asnumpy(b_vectorized[i]).flat, math.asnumpy(A_vectorized[i]).flat, math.asnumpy(c_vectorized[i]).flat, ) for i in range(self.batch_size) ] sorted_indices = argsort_gen(generators) A = math.gather(A_vectorized, sorted_indices, axis=0) b = math.gather(b_vectorized, sorted_indices, axis=0) c = math.gather(c_vectorized, sorted_indices, axis=0) return A, b, c def _partial_eval(self, z: ArrayLike, indices: tuple[int, ...]) -> PolyExpAnsatz: r"""Partially evaluates the ansatz by fixing some of its variables to specific values. This method creates a new ansatz with fewer variables by substituting the specified variables with their given values. The remaining variables keep their original order in the function signature. Example: If this ansatz represents F(z0,z1,z2), then: ``` new_ansatz = self._partial_eval([2.0,3.0], indices=(0,2)) ``` returns a new ansatz G(z1) equal to F(2.0, z1, 3.0). Args: z: Values for the variables being fixed. Can be: - Shape (r,): A single vector of values for r variables - Shape (*b, r): Batch of r values of shape *b Where r is the number of indices in `indices` indices: Indices of the variables to be fixed to the values in z Returns: A new PolyExpAnsatz with fewer variables. If the original ansatz has batch dimensions *L and z has batch dimensions *b, the resulting ansatz will have batch dimensions (*b, *L). """ if len(indices) >= self.num_CV_vars: raise ValueError( "The number of variables and indices must not exceed the number of CV " f"variables {self.num_CV_vars}. Use the eval() or __call__() method instead.", ) z_batch_shape = z.shape[:-1] # evaluated, remaining and derived indices e = indices r = [i for i in range(self.num_CV_vars) if i not in indices] d = list(range(self.num_CV_vars, self.num_vars)) ansatz_batch_idxs = tuple(range(self.batch_dims)) z_batch_idxs = tuple(range(self.batch_dims, self.batch_dims + len(z_batch_shape))) z = math.transpose( math.broadcast_to(z, self.batch_shape + z.shape), z_batch_idxs + ansatz_batch_idxs + (len(z_batch_idxs + ansatz_batch_idxs),), ) A = math.broadcast_to(self.A, z_batch_shape + self.A.shape) b = math.broadcast_to(self.b, z_batch_shape + self.b.shape) c = math.broadcast_to(self.c, z_batch_shape + self.c.shape) new_A = math.gather(math.gather(A, r + d, axis=-1), r + d, axis=-2) A_er = math.gather(math.gather(A, e, axis=-2), r, axis=-1) b_r = math.einsum("...er,...e->...r", A_er, z) if len(d) > 0: A_ed = math.gather(math.gather(A, e, axis=-2), d, axis=-1) b_d = math.einsum("...ed,...e->...d", A_ed, z) new_b = math.gather(b, r + d, axis=-1) + math.concat((b_r, b_d), axis=-1) else: new_b = math.gather(b, r, axis=-1) + b_r A_ee = math.gather(math.gather(A, e, axis=-2), e, axis=-1) A_part = math.einsum("...e,...f,...ef->...", z, z, A_ee) b_part = math.einsum("...e,...e->...", z, math.gather(b, e, axis=-1)) exp_sum = math.exp(1 / 2 * A_part + b_part) poly_string = "".join(chr(i) for i in range(97, 97 + len(self.shape_derived_vars))) new_c2 = math.einsum(f"...,...{poly_string}->...{poly_string}", exp_sum, c) return PolyExpAnsatz( new_A, new_b, new_c2, lin_sup=self._lin_sup, ) def _squeeze_lin_sup_axis(self) -> PolyExpAnsatz: if not self._lin_sup: return PolyExpAnsatz(self._A, self._b, self._c, lin_sup=False) if self._A.shape[-3] == 1: return PolyExpAnsatz( math.squeeze(self._A, axis=-3), math.squeeze(self._b, axis=-2), math.squeeze(self._c, axis=-self.num_derived_vars - 1), lin_sup=False, ) raise ValueError("Cannot squeeze lin sup axis if it is not of length 1") def __add__(self, other: Ansatz) -> PolyExpAnsatz: r"""Adds two ``PolyExpAnsatz`` together. This is equivalent to stacking their respective triples along a batch dimension, which is to be interpreted to mean a linear superposition. In order to use the ``__add__`` method, the ansatze must have the same number of CV variables, and zero or one batch dimensions. The reason for this restriction on the number of batch dimensions is that if there are multiple batch dimensions, it is not clear which one is used as meaning "linear superposition". In that case, the stacking of the triples should be done by the user. Args: other: The other ansatz to add. Raises: TypeError: If ``other`` is not a ``PolyExpAnsatz``. ValueError: If the number of CV variables doesn't match. ValueError: If the batch dimensions are incompatible. """ if not isinstance(other, PolyExpAnsatz): raise TypeError(f"Cannot add ansatz of type {type(other)}!") if self.num_CV_vars != other.num_CV_vars: raise ValueError( f"The number of CV variables must match. Got {self.num_CV_vars} and {other.num_CV_vars}.", ) if (self.batch_shape and not self._lin_sup) or (other.batch_shape and not other._lin_sup): raise ValueError( f"Cannot add PolyExpAnsatz with batch dimensions {self.batch_shape} and {other.batch_shape}.", ) A_self = self.A if self.batch_dims == 1 else math.expand_dims(self.A, axis=0) b_self = self.b if self.batch_dims == 1 else math.expand_dims(self.b, axis=0) c_self = self.c if self.batch_dims == 1 else math.expand_dims(self.c, axis=0) A_other = other.A if other.batch_dims == 1 else math.expand_dims(other.A, axis=0) b_other = other.b if other.batch_dims == 1 else math.expand_dims(other.b, axis=0) c_other = other.c if other.batch_dims == 1 else math.expand_dims(other.c, axis=0) def pad_arrays(array1, array2): shape1 = array1.shape[1:] shape2 = array2.shape[1:] max_shapes = tuple(map(max, zip(shape1, shape2))) pad_widths1 = [(0, 0)] + [(0, m - s) for m, s in zip(max_shapes, shape1)] pad_widths2 = [(0, 0)] + [(0, m - s) for m, s in zip(max_shapes, shape2)] padded_array1 = math.pad(array1, pad_widths1, "constant") padded_array2 = math.pad(array2, pad_widths2, "constant") return padded_array1, padded_array2 def pad_and_combine_arrays(array1, array2): padded_array1, padded_array2 = pad_arrays(array1, array2) return math.concat([padded_array1, padded_array2], axis=0) def pad_and_combine_Ab(Ab1, Ab2): padded_Ab1, padded_Ab2 = pad_arrays(Ab1, Ab2) return math.concat([padded_Ab1, padded_Ab2], axis=0) n_derived_vars = max(self.num_derived_vars, other.num_derived_vars) combined_matrices = pad_and_combine_Ab(A_self, A_other) combined_vectors = pad_and_combine_Ab(b_self, b_other) combined_arrays = pad_and_combine_arrays( math.atleast_nd(c_self, n_derived_vars + 1), math.atleast_nd(c_other, n_derived_vars + 1), ) return PolyExpAnsatz( combined_matrices, combined_vectors, combined_arrays, lin_sup=True, ) def __and__(self, other: Ansatz) -> PolyExpAnsatz: r"""Tensor product of this PolyExpAnsatz with another. Equivalent to :math:`H(a,b) = F(a) * G(b)`. As it distributes over addition on both self and other, the batch shape of the result is the outer product of the batch shapes of this ansatz and the other one. Use with moderation. Args: other: Another PolyExpAnsatz. Returns: The tensor product of this PolyExpAnsatz and other. Raises: TypeError: If ``other`` is not a ``PolyExpAnsatz``. """ if not isinstance(other, PolyExpAnsatz): raise TypeError(f"Cannot tensor product ansatz of type {type(other)}!") A1, b1, c1 = self.triple A2, b2, c2 = other.triple # Split batch dimensions into regular and linear superposition num_reg1 = self.batch_dims - (1 if self._lin_sup else 0) num_reg2 = other.batch_dims - (1 if other._lin_sup else 0) # For outer product: (r1, [l1]) & (r2, [l2]) -> (r1, r2, l1, l2) # We insert singleton dims at specific positions to enable broadcasting # Expand A1/b1/c1: insert num_reg2 singletons after regular batch dims # If other has lin_sup, also insert singleton after self's lin_sup (if any) for _ in range(num_reg2): A1 = math.expand_dims(A1, axis=num_reg1) b1 = math.expand_dims(b1, axis=num_reg1) c1 = math.expand_dims(c1, axis=num_reg1) if other._lin_sup: # Insert after: reg1 + inserted_reg2 + [linsup1 if exists] pos = num_reg1 + num_reg2 + (1 if self._lin_sup else 0) A1 = math.expand_dims(A1, axis=pos) b1 = math.expand_dims(b1, axis=pos) c1 = math.expand_dims(c1, axis=pos) # Expand A2/b2/c2: insert num_reg1 singletons at the beginning # If self has lin_sup, also insert singleton after those leading ones for _ in range(num_reg1): A2 = math.expand_dims(A2, axis=0) b2 = math.expand_dims(b2, axis=0) c2 = math.expand_dims(c2, axis=0) if self._lin_sup: # Insert after the leading ones, before other's regular batch + lin_sup A2 = math.expand_dims(A2, axis=num_reg1 + num_reg2) b2 = math.expand_dims(b2, axis=num_reg1 + num_reg2) c2 = math.expand_dims(c2, axis=num_reg1 + num_reg2) # Join the triples (broadcasting will handle the outer product) As, bs, cs = join_Abc(A1, b1, c1, A2, b2, c2) # If both have lin_sup, merge the two lin_sup dimensions into one if self._lin_sup and other._lin_sup: total_reg = num_reg1 + num_reg2 # Shape is (...reg_batch, lin1, lin2, ...) -> (...reg_batch, lin1*lin2, ...) As = math.reshape(As, (*As.shape[:total_reg], -1, *As.shape[total_reg + 2 :])) bs = math.reshape(bs, (*bs.shape[:total_reg], -1, *bs.shape[total_reg + 2 :])) cs = math.reshape(cs, (*cs.shape[:total_reg], -1)) return PolyExpAnsatz(As, bs, cs, lin_sup=self._lin_sup or other._lin_sup) def __call__(self, *z_inputs: ArrayLike | None) -> ComplexTensor: r"""Evaluates the ansatz at the given batch of points. Each point can have arbitray batch dimensions, as long as they are broadcastable. If some of the points are not specified (None), the result will be a partially evaluated ansatz. If the combined shape of the inputs is ``(*b, n)`` where ``n`` is the number of CV variables in the ansatz and ``*b`` is the batch dimensions of the combined inputs, then the output will have shape ``(*b, *L)`` where ``*L`` is the batch shape of the ansatz itself. Args: *z_inputs: A batch of points where the function is evaluated (or None). The shape of each point can be arbitrary, as long as they are broadcastable. Returns: The evaluated function with shape ``(*b, *L)`` where: - ``*b`` are the batch dimensions of the combined inputs. - ``*L`` is the batch shape of the ansatz. Raises: ValueError: If the number of CV variables is not equal to the number of input points. """ z_only = [math.cast(arr, dtype=math.complex128) for arr in z_inputs if arr is not None] broadcasted_z = math.broadcast_arrays(*z_only) z = ( math.stack(broadcasted_z, axis=-1) if broadcasted_z else math.astensor([], dtype=math.complex128) ) if len(z_only) < self.num_CV_vars: indices = tuple(i for i, arr in enumerate(z_inputs) if arr is not None) return self._partial_eval(z, indices) z_batch_shape, z_dim = z.shape[:-1], z.shape[-1] if z_dim != self.num_CV_vars: raise ValueError( f"The last dimension of `z` must equal the number of CV variables {self.num_CV_vars}, got {z_dim}.", ) ansatz_batch_idxs = tuple(range(self.batch_dims)) z_batch_idxs = tuple(range(self.batch_dims, self.batch_dims + len(z_batch_shape))) z = math.transpose( math.broadcast_to(z, self.batch_shape + z.shape), z_batch_idxs + ansatz_batch_idxs + (len(z_batch_idxs + ansatz_batch_idxs),), ) A = math.broadcast_to(self.A, z_batch_shape + self.A.shape) b = math.broadcast_to(self.b, z_batch_shape + self.b.shape) c = math.broadcast_to(self.c, z_batch_shape + self.c.shape) if self.num_derived_vars == 0: # purely gaussian # Fold log(c) into the exponent to prevent overflow: # c * exp(exponent) = exp(exponent + log(c)) log_c = math.log(c) exp_sum = self._compute_exp_part(z, A, b, log_c) ret = exp_sum else: # For the polynomial case, extract a global scale to stabilise the exp: # c * exp(exponent) = (c / c_scale) * exp(exponent + log(c_scale)) # Flatten derived-var dims and take a global max of |c| as scale. c_flat = math.reshape(math.abs(c), (*c.shape[: -self.num_derived_vars], -1)) # c_flat has shape (...batch, prod_derived); take max over last axis c_scale = math.max(c_flat) # global scalar scale c_scale = math.maximum(c_scale, 1e-300) log_c_scale = math.log(c_scale) exp_sum = self._compute_exp_part(z, A, b, log_c_scale) c_normalized = c / c_scale poly = self._compute_polynomial_part(z, A, b) ret = self._combine_exp_and_poly(exp_sum, poly, c_normalized) return math.sum(ret, axis=-1) if self._lin_sup else ret def __eq__(self, other: object) -> bool: if not isinstance(other, PolyExpAnsatz): return False if self.num_CV_vars != other.num_CV_vars or self.num_derived_vars != other.num_derived_vars: return False self_A, self_b, self_c = self._order_batch() other_A, other_b, other_c = other._order_batch() return ( math.allclose(self_A, other_A, atol=settings.ATOL) and math.allclose(self_b, other_b, atol=settings.ATOL) and math.allclose(self_c, other_c, atol=settings.ATOL) ) def __getitem__(self, indexer: Any) -> PolyExpAnsatz: r"""Batch-only indexing. Supports integers, slices, None (newaxis), and Ellipsis. Indexing is restricted to the first ``batch_dims`` axes. The linear-superposition axis (if present) sits immediately before core axes and is not directly indexable. The returned object preserves the lin-sup axis position; integers remove batch axes, None inserts batch axes. """ batch_index, _, _ = batch_indexer_info(indexer, self.batch_dims) A = self.A[batch_index] b = self.b[batch_index] c = self.c[batch_index] return PolyExpAnsatz(A, b, c, lin_sup=self._lin_sup) def __mul__(self, other: Scalar) -> PolyExpAnsatz: return PolyExpAnsatz(self.A, self.b, self.c * other, lin_sup=self._lin_sup) def __neg__(self) -> PolyExpAnsatz: return PolyExpAnsatz(self.A, self.b, -self.c, lin_sup=self._lin_sup) def __repr__(self) -> str: r"""Returns a string representation of the PolyExpAnsatz object.""" # Create a descriptive name display_name = f'"{self.name}"' if self.name else "unnamed" # Build the representation string repr_str = [ f"PolyExpAnsatz({display_name})", f" Batch shape: {self.batch_shape}", f" Linear superposition: {self._lin_sup}", f" Variables: {self.num_CV_vars} CV + {self.num_derived_vars} derived = {self.num_vars} total", " Parameter shapes:", f" A: {self.A.shape}", f" b: {self.b.shape}", f" c: {self.c.shape}", ] # Add information about simplification status if self._simplified: repr_str.append(" Status: simplified") return "\n".join(repr_str) def __str__(self) -> str: return f"PolyExpAnsatz(batch_shape={self.batch_shape}, lin_sup={self._lin_sup}, num_CV_vars={self.num_CV_vars}, num_derived_vars={self.num_derived_vars})" def __truediv__(self, other: Scalar) -> PolyExpAnsatz: # handle the case where other is a batched scalar shape = math.shape(other) if shape != (): delta = len(self.c.shape) - len(shape) other = math.reshape(other, shape + (1,) * delta) return PolyExpAnsatz(self.A, self.b, self.c / other, lin_sup=self._lin_sup)