Source code for mrmustard.physics.utils

# Copyright 2025 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This module contains the utility functions used by the classes in ``mrmustard.physics``."""

from __future__ import annotations

from collections.abc import Sequence
from types import EllipsisType

import numpy as np
from numpy.typing import ArrayLike

from mrmustard import math, settings
from mrmustard.utils.typing import ComplexMatrix, ComplexScalar, ComplexTensor, ComplexVector

__all__ = [
    "IndexerType",
    "batch_indexer_info",
    "generate_batch_str",
    "join_Abc",
    "lin_sup_batch_str",
    "outer_product_batch_str",
    "random_Abc",
    "reshape_args_to_batch_string",
    "verify_triple",
    "zip_batch_strings",
]

#  ~~~~~~~~~~~~~~~~
#  Helper functions
#  ~~~~~~~~~~~~~~~~

_IndexerTypes = int | slice | None | EllipsisType
IndexerType = _IndexerTypes | Sequence[_IndexerTypes]


def _join_Ab(
    A1: ComplexMatrix,
    b1: ComplexVector,
    A2: ComplexMatrix,
    b2: ComplexVector,
    m1: int,
    m2: int,
) -> tuple[ComplexMatrix, ComplexVector]:
    r"""Joins two (A, b) pairs by block-diagonal concatenation and reordering.

    This helper function creates a block-diagonal A matrix from A1 and A2, concatenates
    the b vectors, and reorders rows/columns to group core variables first, then derived variables.

    Args:
        A1: First A matrix with shape ``(*batch, n1, n1)``
        b1: First b vector with shape ``(*batch, n1)``
        A2: Second A matrix with shape ``(*batch, n2, n2)``
        b2: Second b vector with shape ``(*batch, n2)``
        m1: Number of derived variables in first triple
        m2: Number of derived variables in second triple

    Returns:
        Tuple of (A, b) with joined and reordered matrices/vectors
    """
    output_batch_shape = A1.shape[:-2]
    nA1 = A1.shape[-1]
    nA2 = A2.shape[-1]
    n1 = nA1 - m1
    n2 = nA2 - m2

    # Block-diagonal concatenation of A matrices: [[A1, 0], [0, A2]]
    zeros_top = math.zeros((*output_batch_shape, nA1, nA2), dtype=math.complex128)
    zeros_bottom = math.zeros((*output_batch_shape, nA2, nA1), dtype=math.complex128)
    A_top = math.concat([A1, zeros_top], axis=-1)
    A_bottom = math.concat([zeros_bottom, A2], axis=-1)
    A_joined = math.concat([A_top, A_bottom], axis=-2)

    # Reorder rows to group core and derived variables: [core1, core2, derived1, derived2]
    row_indices = (
        list(range(n1))  # core1
        + list(range(nA1, nA1 + n2))  # core2
        + list(range(n1, nA1))  # derived1
        + list(range(nA1 + n2, nA1 + nA2))  # derived2
    )
    A = math.gather(A_joined, row_indices, axis=-2)
    A = math.gather(A, row_indices, axis=-1)

    # Concatenate b vectors and reorder
    b_joined = math.concat([b1, b2], axis=-1)
    b = math.gather(b_joined, row_indices, axis=-1)

    return A, b


def _join_c(
    c1: ComplexTensor,
    c2: ComplexTensor,
    batch_dim1: int,
    batch_dim2: int,
    return_log: bool = False,
) -> ComplexTensor:
    r"""Joins two c tensors by computing their outer product.

    This helper function computes the outer product of c1 and c2 in log space for
    numerical stability, then reshapes to the final output shape.

    Args:
        c1: First c tensor with shape ``(*batch1, *poly_shape1)``
        c2: Second c tensor with shape ``(*batch2, *poly_shape2)``
        batch_dim1: Number of batch dimensions in c1
        batch_dim2: Number of batch dimensions in c2
        return_log: If ``True``, returns ``log(c)`` instead of ``c``

    Returns:
        Joined c tensor with shape ``(*batch_out, *poly_shape1, *poly_shape2)``
    """
    batch1 = c1.shape[:batch_dim1]
    batch2 = c2.shape[:batch_dim2]
    poly_shape1 = c1.shape[batch_dim1:] if c1.ndim > batch_dim1 else ()
    poly_shape2 = c2.shape[batch_dim2:] if c2.ndim > batch_dim2 else ()

    # Determine output batch shape using standard broadcasting
    output_batch_shape = np.broadcast_shapes(batch1, batch2)

    # Flatten polynomial dimensions for easier manipulation
    c1_flat_size = int(np.prod(poly_shape1)) if poly_shape1 else 1
    c2_flat_size = int(np.prod(poly_shape2)) if poly_shape2 else 1
    c1_flat = math.reshape(c1, (*batch1, c1_flat_size))
    c2_flat = math.reshape(c2, (*batch2, c2_flat_size))

    # Broadcast to output batch shape
    c1_bc = math.broadcast_to(c1_flat, (*output_batch_shape, c1_flat_size))
    c2_bc = math.broadcast_to(c2_flat, (*output_batch_shape, c2_flat_size))

    # Compute outer product in log space for numerical stability
    c1_exp = c1_bc[..., :, None]  # Add axis for outer product
    c2_exp = c2_bc[..., None, :]  # Add axis for outer product
    log_c = math.log(math.cast(c1_exp, "complex128")) + math.log(math.cast(c2_exp, "complex128"))

    # Reshape to final shape: batch + poly_shape1 + poly_shape2
    final_c_shape = output_batch_shape + poly_shape1 + poly_shape2
    log_c = math.reshape(log_c, final_c_shape)

    if return_log:
        return log_c
    return math.exp(log_c)


#  ~~~~~~~~~
#  Utilities
#  ~~~~~~~~~


[docs] def batch_indexer_info( indexer: IndexerType, batch_dims: int ) -> tuple[tuple[int | slice, ...], int, int]: r"""Process the indexer passed to a __getitem__ call and return the batch index and the number of removed and inserted batch dimensions. Args: indexer: The indexer passed to a __getitem__ call. batch_dims: The number of batch dimensions of the ansatz (including eventual linear superposition dimension) Returns: A tuple containing the final index, the number of removed batch dimensions, and the number of inserted batch dimensions. Raises: TypeError: If the indexer contains any types other than int, slice, None, or Ellipsis. IndexError: If the indexer contains more than one Ellipsis. IndexError: If the indexer consumes more batch dimensions than the ansatz has. IndexError: If the indexer contains an Ellipsis in the middle of the indexer. """ indexer = (indexer,) if not isinstance(indexer, tuple) else indexer def is_consuming(x: _IndexerTypes) -> bool: return isinstance(x, (int, slice)) if any((x is not None) and (x is not Ellipsis) and not is_consuming(x) for x in indexer): raise TypeError("Only int, slice, None, and Ellipsis are supported for batch indexing.") if indexer.count(Ellipsis) > 1: raise IndexError("Only a single ellipsis is allowed.") if indexer.count(Ellipsis) == 1 and indexer[0] != Ellipsis and indexer[-1] != Ellipsis: raise IndexError("Ellipsis must be at the beginning or end of the indexer.") explicit_consuming = sum(1 for x in indexer if is_consuming(x)) if explicit_consuming > batch_dims: raise IndexError("Too many indices.") expanded: list[_IndexerTypes] = [] for x in indexer: if x is Ellipsis: expanded.extend([slice(None)] * (batch_dims - explicit_consuming)) else: expanded.append(x) return _batch_indexer_info_return_values(expanded)
def _batch_indexer_info_return_values( expanded: list[_IndexerTypes], ) -> tuple[list[_IndexerTypes], int, int]: """Helper function to return the final index, the number of removed batch dimensions, and the number of inserted batch dimensions.""" final_index: list[_IndexerTypes] = [] removed_by_int = 0 inserted_by_none = 0 for x in expanded: if x is None: final_index.append(None) inserted_by_none += 1 else: final_index.append(x) if isinstance(x, int): removed_by_int += 1 final_index.append(Ellipsis) # preserve core axes return tuple(final_index), removed_by_int, inserted_by_none
[docs] def generate_batch_str(batch_dim: int, offset: int = 0) -> str: r"""Generate a string of characters to represent the batch dimensions. Args: batch_dim: The number of batch dimensions. offset: The offset of the characters. Returns: A string of characters to represent the batch dimensions. """ return "".join([chr(97 + i) for i in range(offset, offset + batch_dim)])
def join_Abc( A1: ComplexMatrix, b1: ComplexVector, c1: ComplexTensor, A2: ComplexMatrix, b2: ComplexVector, c2: ComplexTensor, return_log_c: bool = False, ) -> tuple[ComplexMatrix, ComplexVector, ComplexScalar]: r"""Joins two ``(A,b,c)`` triples into a single ``(A,b,c)`` by block-diagonal concatenation of the A matrices and concatenation of the b vectors. The c tensor is computed as the outer product of the c tensors of the two input triples. This function combines two Bargmann representations by placing them in a block-diagonal structure for the A matrices, concatenating the b vectors, and computing the outer product of the c arrays. Batch dimensions are automatically broadcast using standard NumPy/JAX broadcasting semantics. The function handles "derived variables" (polynomial indices in c) by reordering the rows and columns of A and b to keep core variables first, followed by derived variables. Args: A1: First A matrix with shape ``(*batch1, n1, n1)`` b1: First b vector with shape ``(*batch1, n1)`` c1: First c array with shape ``(*batch1, *poly_shape1)`` A2: Second A matrix with shape ``(*batch2, n2, n2)`` b2: Second b vector with shape ``(*batch2, n2)`` c2: Second c array with shape ``(*batch2, *poly_shape2)`` return_log_c: If ``True``, returns ``log(c)`` instead of ``c`` for numerical stability Returns: Joined ``(A, b, c)`` triple with broadcasted batch dimensions and concatenated core dimensions: - A: shape ``(*batch_out, n1+n2, n1+n2)`` - b: shape ``(*batch_out, n1+n2)`` - c: shape ``(*batch_out, *poly_shape1, *poly_shape2)`` Examples: >>> # Non-batched join >>> A1 = math.astensor([[1, 2], [3, 4]]) # shape (2, 2) >>> b1 = math.astensor([5, 6]) # shape (2,) >>> c1 = math.astensor(7) # shape () >>> A2 = math.astensor([[8, 9], [10, 11]]) # shape (2, 2) >>> b2 = math.astensor([12, 13]) # shape (2,) >>> c2 = math.astensor(10) # shape () >>> A, b, c = join_Abc(A1, b1, c1, A2, b2, c2) >>> A.shape # (4, 4) - block diagonal of two 2x2 matrices (4, 4) >>> b.shape # (4,) - concatenation of two length-2 vectors (4,) >>> c.shape # () - product of two scalars () >>> # Batched join with automatic broadcasting >>> A1 = math.astensor([[[1, 2], [3, 4]]]) # shape (1, 2, 2) >>> b1 = math.astensor([[5, 6]]) # shape (1, 2) >>> c1 = math.astensor([7]) # shape (1,) >>> A2 = math.astensor([[[8, 9], [10, 11]], [[12, 13], [14, 15]]]) # shape (2, 2, 2) >>> b2 = math.astensor([[12, 13], [14, 15]]) # shape (2, 2) >>> c2 = math.astensor([10, 100]) # shape (2,) >>> A, b, c = join_Abc(A1, b1, c1, A2, b2, c2) >>> A.shape # (2, 4, 4) - broadcasts (1,) and (2,) to (2,) (2, 4, 4) >>> b.shape # (2, 4) (2, 4) >>> c.shape # (2,) (2,) """ # Verify consistency of batch dimensions within each triple verify_triple(A1, b1, c1) verify_triple(A2, b2, c2) # Extract batch dimensions batch1 = A1.shape[:-2] batch2 = A2.shape[:-2] batch_dim1 = len(batch1) batch_dim2 = len(batch2) # Determine polynomial shapes (derived variables from c shape) poly_shape1 = c1.shape[batch_dim1:] if c1.ndim > batch_dim1 else () poly_shape2 = c2.shape[batch_dim2:] if c2.ndim > batch_dim2 else () # Number of derived variables (from polynomial dimensions of c) m1 = len(poly_shape1) m2 = len(poly_shape2) # Get dimensions for broadcasting nA1, mA1 = A1.shape[-2:] nA2, mA2 = A2.shape[-2:] # Determine output batch shape using standard broadcasting output_batch_shape = np.broadcast_shapes(batch1, batch2) # Broadcast A and b to common batch shape before joining A1_bc = math.broadcast_to(A1, (*output_batch_shape, nA1, mA1)) A2_bc = math.broadcast_to(A2, (*output_batch_shape, nA2, mA2)) b1_bc = math.broadcast_to(b1, (*output_batch_shape, nA1)) b2_bc = math.broadcast_to(b2, (*output_batch_shape, nA2)) # Join A and b using helper function A, b = _join_Ab(A1_bc, b1_bc, A2_bc, b2_bc, m1, m2) # Join c using helper function c = _join_c(c1, c2, batch_dim1, batch_dim2, return_log=return_log_c) return A, b, c
[docs] def lin_sup_batch_str(batch_str: str) -> str: r"""Given a batch string, appends the linear superposition batch dimension to the end. Args: batch_str: The batch string to append the linear superposition batch dimension to. Returns: The batch string with the linear superposition batch dimension appended to the end. """ input_str, output_str = batch_str.split("->") inputs = input_str.split(",") max_char = max(ord(i) for i in batch_str) lin_sups = [chr(max_char + offset) for offset in range(1, len(inputs) + 1)] new_input = ",".join([ipt + lin_sup for ipt, lin_sup in zip(inputs, lin_sups)]) new_output = output_str + "".join(lin_sups) return f"{new_input}->{new_output}"
[docs] def outer_product_batch_str(*batch_dims: int, lin_sup: tuple[int, ...] | None = None) -> str: r"""Creates the einsum string for the outer product of the given tuple of dimensions. E.g. for (2,1,3) it returns ab,c,def->abcdef. If lin_sup is provided, the linear superposition dimensions are moved to the end. E.g. for (2,1,3) and lin_sup=(0,1) it returns ab,c,def->adefbc, as b and c are the linear superposition dimensions of the 0th and 1st tensors. """ strs = [] offset = 0 for batch_dim in batch_dims: strs.append(generate_batch_str(batch_dim, offset)) offset += batch_dim orig_strs = strs.copy() # keep original for input part of einsum string if lin_sup is not None: lin_sup_chars = [] for idx in lin_sup: lin_sup_chars.append(strs[idx][-1]) strs[idx] = strs[idx][:-1] output = "".join(strs) + "".join(lin_sup_chars) else: output = "".join(strs) return ",".join(orig_strs) + "->" + output
def random_Abc( core_vars: int, batch: tuple[int, ...] = (), derived: tuple[int, ...] = (), seed: int | None = None, ) -> tuple[ComplexMatrix, ComplexVector, ComplexScalar]: r"""Generate a random ``(A, b, c)`` triple for testing purposes. This function creates random complex-valued (A, b, c) triples with specified batch dimensions, core variables, and derived variables (polynomial dimensions). Useful for testing functions that operate on Fock-Bargmann triples. Args: core_vars: Number of core variables (determines the size of A and b before derived variables) batch: Batch shape for the triple derived: Derived variable shape (polynomial dimensions in c) seed: Random seed Returns: Tuple of (A, b, c) where: - A: Complex symmetric matrix of shape ``(*batch, n+m, n+m)`` - b: Complex vector of shape ``(*batch, n+m)`` - c: Complex tensor of shape ``(*batch, *derived)`` where n = core_vars and m = len(derived) Examples: >>> # Simple non-batched triple with 3 core variables >>> A, b, c = random_Abc(3) >>> A.shape, b.shape, c.shape ((3, 3), (3,), ()) >>> # Batched triple with derived variables >>> A, b, c = random_Abc(2, batch=(5,), derived=(4,)) >>> A.shape, b.shape, c.shape ((5, 3, 3), (5, 3), (5, 4)) """ m = len(derived) n = core_vars min_magnitude = 1e-9 max_magnitude = 1 rng = settings.get_rng(seed) # Complex symmetric matrix A A = rng.uniform(min_magnitude, max_magnitude, (*batch, n + m, n + m)) + 1.0j * rng.uniform( min_magnitude, max_magnitude, (*batch, n + m, n + m), ) A = 0.5 * (A + np.swapaxes(A, -2, -1)) # make it symmetric # Complex vector b b = rng.uniform(min_magnitude, max_magnitude, (*batch, n + m)) + 1.0j * rng.uniform( min_magnitude, max_magnitude, (*batch, n + m), ) # Complex scalar/tensor c c = rng.uniform(min_magnitude, max_magnitude, (*batch, *derived)) + 1.0j * rng.uniform( min_magnitude, max_magnitude, (*batch, *derived), ) return A, b, c
[docs] def reshape_args_to_batch_string( args: list[ArrayLike], batch_string: str, ) -> tuple[list[ArrayLike], tuple[int, ...]]: r"""Reshapes arguments to match the batch string by inserting singleton dimensions where needed so that they are broadcastable. E.g. given two arrays of shape (2,7) and (3,7) and string ab,cb->abc, it reshapes them to shape (2,7,1) and (1,7,3). """ # Parse the batch string input_specs, output_spec = batch_string.split("->") input_specs = input_specs.split(",") if len(input_specs) != len(args): raise ValueError( f"Number of input specifications ({len(input_specs)}) does not match number of arguments ({len(args)})", ) args = [math.astensor(arg) for arg in args] # Determine the size of each dimension in the output dim_sizes = {} for arg, spec in zip(args, input_specs): for dim, label in zip(arg.shape, spec): if label in dim_sizes and dim_sizes[label] != dim: raise ValueError( f"Dimension {label} has inconsistent sizes: got {dim_sizes[label]} and {dim}", ) dim_sizes[label] = dim reshaped = [] for arg, spec in zip(args, input_specs): new_shape = [dim_sizes[label] if label in spec else 1 for label in output_spec] reshaped.append(math.reshape(arg, new_shape)) return reshaped
def verify_triple( A: ComplexMatrix, b: ComplexVector, c: ComplexScalar, ) -> None: r"""Verify that both the batch and core dimensions of the ``(A, b, c)`` triple are consistent. ``A`` and ``b`` must have core dimensions ``(N, N)`` and ``(N,)``, respectively. Args: A: The matrix of the quadratic form. b: The vector of the linear form. c: The scalar of the quadratic form. Raises: ValueError: If any (batch / core) dimensions of the ``(A, b, c)`` triple are inconsistent. """ batch = A.shape[:-2] batch_dim = len(batch) if batch != b.shape[:batch_dim] or (len(c.shape) != 0 and batch != c.shape[:batch_dim]): raise ValueError( f"Batch dimensions of the triple ({batch}, {b.shape[:batch_dim]}, {c.shape[:batch_dim]}) are inconsistent.", ) A_core_shape = A.shape[batch_dim:] b_core_shape = b.shape[batch_dim:] if A_core_shape[0] != A_core_shape[1]: raise ValueError(f"Core dimensions of A are inconsistent: {A_core_shape}.") if A_core_shape[0] != b_core_shape[0]: raise ValueError( f"Core dimensions of A and b are inconsistent: {A_core_shape} and {b_core_shape}, " "respectively." )
[docs] def zip_batch_strings(*batch_dims: int) -> str: r"""Creates a batch string for zipping over the batch dimensions.""" input_str = ",".join([generate_batch_str(batch_dim) for batch_dim in batch_dims]) return input_str + "->" + generate_batch_str(max(batch_dims))