Source code for mrmustard.parameters.parameters
# Copyright 2023 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains the classes to describe constant and variable parameters used in Mr Mustard."""
from __future__ import annotations
from typing import Literal, cast
import numpy as np
from numpy.typing import ArrayLike, DTypeLike
from mrmustard import settings
from mrmustard.math.backend_manager import BackendManager
from mrmustard.utils.typing import RealTensor, Tensor
math = BackendManager()
__all__ = ["Constant", "Parameter", "Variable"]
UpdateFn = Literal[
"update_euclidean",
"update_orthogonal",
"update_siegel",
"update_symplectic",
"update_unitary",
]
# ~~~~~~~~~
# Functions
# ~~~~~~~~~
def format_dtype(param: Constant | Variable) -> str:
r"""Format parameter dtype string.
Args:
param: The parameter to format.
Returns:
A string representation of the parameter dtype.
"""
return param.value.dtype.name
def format_value(param: Constant | Variable) -> tuple[str, str]:
r"""Format parameter value and shape strings.
Args:
param: The parameter to format.
Returns:
A tuple of strings representing the parameter value and shape.
"""
value = math.asnumpy(param.value)
# Handle arrays
if hasattr(param.value, "shape") and param.value.shape != ():
shape_str = str(param.value.shape)
int_like = isinstance(value, np.integer)
flat = value.flat
if len(flat) <= 3:
# Small arrays: preserve structure, format integers appropriately
value_str = str(value.astype(int).tolist()) if int_like else str(value.tolist())
else:
# Large arrays: show preview with ellipsis
preview = (
[str(int(x)) for x in flat[:3]] if int_like else [f"{x:.3g}" for x in flat[:3]]
)
value_str = f"[{', '.join(preview)}, ...]"
return value_str, shape_str
# Handle scalars
if isinstance(value, np.integer):
value_str = str(int(value))
elif math.iscomplexobj(value) or math.issubdtype(value.dtype, np.complexfloating):
# Format complex numbers with g format for both real and imaginary parts
real_part = f"{value.real:.6g}"
imag_part = f"{value.imag:.6g}"
value_str = f"{real_part}+{imag_part}j" if value.imag >= 0 else f"{real_part}{imag_part}j"
else:
value_str = f"{float(value):.6g}"
return value_str, "scalar"
# ~~~~~~~
# Classes
# ~~~~~~~
[docs]
class Parameter:
r"""Superclass for Constant and Variable."""
def __init__(
self,
value: ArrayLike | Parameter | None = None,
name: str | None = None,
dtype: str | None = None,
*,
update_fn: UpdateFn = "update_euclidean", # accepted for classmethod API; only Variable uses it
) -> None:
"""Common initialization for value and name. Subclasses set defaults and extra attributes."""
if value is None:
value = np.asarray(None) # type: ignore[arg-type]
if isinstance(value, Parameter):
self._value = math.astensor(value.value, dtype=dtype or value.value.dtype)
self.name = name if name is not None else value.name
else:
self._value = math.astensor(value, dtype=dtype or getattr(value, "dtype", None))
self.name = name
if self.name is None:
prefix = getattr(type(self), "_default_name_prefix", None)
if prefix is not None:
self.name = f"{prefix}_{id(self)}"
@property
def value(self) -> Tensor:
r"""Returns the value of the parameter."""
return self._value
@value.setter
def value(self, value: Tensor) -> None:
self._value = value
@property
def dtype(self) -> DTypeLike:
r"""Returns the dtype of the parameter."""
return self.value.dtype
@property
def shape(self) -> tuple[int, ...]:
r"""Returns the shape of the parameter."""
return self.value.shape
[docs]
@classmethod
def orthogonal(
cls,
name: str = "orthogonal",
N: int = 1,
seed: int | None = None,
batch_shape: tuple[int, ...] = (),
) -> Parameter:
r"""Initializes a parameter in O(N) with ``update_fn`` for orthogonal optimization.
Args:
name: The name of the returned parameter.
N: The dimension of the random orthogonal matrix.
seed: The seed for the random number generator.
batch_shape: The batch shape for generating multiple random matrices.
Returns:
A variable with ``update_fn`` for orthogonal optimization or a constant.
"""
return cls(
value=math.random_orthogonal(N, seed=seed, batch_shape=batch_shape),
name=name,
update_fn="update_orthogonal",
)
[docs]
@classmethod
def symplectic(
cls,
name: str = "symplectic",
N: int = 1,
max_r: float = 1.0,
seed: int | None = None,
batch_shape: tuple[int, ...] = (),
) -> Parameter:
r"""Initializes a parameter in SP(2N, R) with ``update_fn`` for symplectic optimization.
Args:
name: The name of the returned parameter.
N: (half) the dimension of the random symplectic matrix.
max_r: The maximum squeezing value sampled uniformly.
seed: The seed for the random number generator.
batch_shape: The batch shape for generating multiple random matrices.
Returns:
A variable with ``update_fn`` for symplectic optimization or a constant.
"""
return cls(
value=math.random_symplectic(N, max_r, seed=seed, batch_shape=batch_shape),
name=name,
update_fn="update_symplectic",
)
[docs]
@classmethod
def unitary(
cls,
name: str = "unitary",
N: int = 1,
seed: int | None = None,
batch_shape: tuple[int, ...] = (),
) -> Parameter:
r"""Initializes a parameter in U(N) with ``update_fn`` for unitary optimization.
Args:
name: The name of the returned parameter.
N: The dimension of the random unitary matrix.
seed: The seed for the random number generator.
batch_shape: The batch shape for generating multiple random matrices.
Returns:
A variable with ``update_fn`` for unitary optimization or a constant.
"""
return cls(
value=math.random_unitary(N, seed=seed, batch_shape=batch_shape),
name=name,
update_fn="update_unitary",
)
[docs]
@classmethod
def siegel(
cls,
name: str = "siegel",
n: int = 1,
max_r: float = 0.9,
seed: int | None = None,
batch_shape: tuple[int, ...] = (),
) -> Parameter:
r"""Initializes a parameter in the open Siegel disk
:math:`\mathcal{D}_n = \{Z \in \mathbb{C}^{n\times n} : Z = Z^T,\ I - Z^* Z \succ 0\}`
with ``update_fn`` for Siegel-disk Riemannian optimization (Bergman metric).
The matrix is sampled with Haar-random unitary eigenbasis and Takagi values
drawn uniformly in :math:`[0, \mathtt{max\_r})`, ensuring
:math:`\|Z\|_\mathrm{op} < 1`. All eigenvalues of :math:`Z` lie in the open
unit disk.
Args:
name: The name of the returned parameter.
n: The dimension of the matrix.
max_r: The maximum Takagi value of the initial matrix. Must satisfy
``0 <= max_r < 1`` to start strictly inside the disk.
seed: The seed for the random number generator.
batch_shape: The batch shape for generating multiple random matrices.
Returns:
A variable with ``update_fn`` for Siegel-disk optimization or a constant.
"""
return cls(
value=math.random_siegel(n, max_r=max_r, seed=seed, batch_shape=batch_shape),
name=name,
update_fn="update_siegel",
)
[docs]
@classmethod
def complex_normal(
cls,
name: str = "complex",
variance: float | RealTensor | Parameter = 1.0,
mean: complex | Tensor | Parameter = 0.0 + 0.0j,
seed: int | None = None,
batch_shape: tuple[int, ...] = (),
) -> Parameter:
r"""Initializes a parameter with a circular complex normal distribution.
This samples complex numbers where the real and imaginary parts are independent
Gaussian random variables, each with variance ``variance/2``. This ensures the
total variance of ``|z|²`` is ``variance``, which is the standard convention
for complex normal distributions.
Args:
name: The name of the returned parameter.
variance: The variance of the complex distribution. The real and imaginary
parts each have variance ``variance/2``. Can be a scalar or array
for element-wise variance.
mean: The mean of the complex normal distribution. Can be a scalar or array.
seed: The seed for the random number generator.
batch_shape: The batch shape for generating multiple random values.
Returns:
A parameter with complex normal distribution.
"""
rng = settings.get_rng(seed)
variance_value = variance.value if isinstance(variance, Parameter) else variance
mean_value = mean.value if isinstance(mean, Parameter) else mean
std = cast(float | RealTensor, math.sqrt(variance_value / 2.0))
real = rng.normal(loc=math.real(mean_value), scale=std, size=batch_shape)
imag = rng.normal(loc=math.imag(mean_value), scale=std, size=batch_shape)
return cls(value=real + 1j * imag, name=name, update_fn="update_euclidean")
[docs]
@classmethod
def complex_uniform(
cls,
name: str = "complex",
max_r: float = 1.0,
min_r: float = 0.0,
seed: int | None = None,
batch_shape: tuple[int, ...] = (),
) -> Parameter:
r"""Initializes a parameter with a complex uniform distribution in the unit disk,
i.e. a complex number with radius between ``min_r`` and ``max_r`` and angle between 0 and
2*pi.
Args:
name: The name of the returned parameter.
max_r: The maximum radius of the complex uniform distribution.
min_r: The minimum radius of the complex uniform distribution.
seed: The seed for the random number generator.
batch_shape: The batch shape for generating multiple random matrices.
Returns:
A variable with ``update_fn`` for complex uniform optimization or a constant.
"""
rng = settings.get_rng(seed)
r = rng.uniform(low=min_r, high=max_r, size=batch_shape)
phi = rng.uniform(low=0, high=2 * np.pi, size=batch_shape)
return cls(value=r * math.exp(1j * phi), name=name, update_fn="update_euclidean")
[docs]
@classmethod
def real_normal(
cls,
name: str = "real",
variance: float = 1.0,
mean: float = 0.0,
seed: int | None = None,
batch_shape: tuple[int, ...] = (),
) -> Parameter:
r"""Initializes a parameter with a real normal distribution,
i.e. a real number drawn from a normal distribution with mean ``mean`` and variance
``variance``.
Args:
name: The name of the returned parameter.
variance: The variance of the real normal distribution.
mean: The mean of the real normal distribution.
seed: The seed for the random number generator.
batch_shape: The batch shape for generating multiple random matrices.
"""
rng = settings.get_rng(seed)
std = cast(RealTensor, math.sqrt(variance))
value = rng.normal(loc=mean, scale=std, size=batch_shape)
return cls(value=value, name=name, update_fn="update_euclidean")
[docs]
@classmethod
def real_uniform(
cls,
name: str = "real",
low: float = 0.0,
high: float = 1.0,
seed: int | None = None,
batch_shape: tuple[int, ...] = (),
) -> Parameter:
r"""Initializes a parameter with a real uniform distribution,
i.e. a real number drawn from a uniform distribution over the half-open interval
[min_r, max_r).
Args:
name: The name of the returned parameter.
low: The minimum of the real uniform distribution.
high: The maximum of the real uniform distribution.
seed: The seed for the random number generator.
batch_shape: The batch shape for generating multiple random matrices.
"""
rng = settings.get_rng(seed)
value = rng.uniform(low=low, high=high, size=batch_shape)
return cls(value=value, name=name, update_fn="update_euclidean")
[docs]
@classmethod
def from_cc_init(
cls, value: ArrayLike | Parameter, expected_dtype: str, name: str
) -> Constant | Variable:
r"""Raise error if Parameter has wrong dtype, or simply cast to expected dtype if not a
Parameter.
Args:
value: The input value. Can be an array, a scalar, a Parameter, or a nested list or
tuple.
expected_dtype: The expected dtype string (e.g. "float64", "complex128").
name: The name of the parameter for error messages.
Returns:
A Constant object with the value cast to the expected dtype (if raw value), or the
original Parameter (if dtype matches).
Raises:
ValueError: If a Parameter object has the wrong dtype.
"""
if isinstance(value, (Constant, Variable)):
dtype_name = value.value.dtype.name
if dtype_name == expected_dtype:
return value
raise ValueError(
f"Parameter {name} is a {type(value).__name__} with dtype {dtype_name}, expected "
f"{expected_dtype}."
)
return Constant(value=value, name=name, dtype=expected_dtype)
[docs]
class Constant(Parameter):
r"""A parameter with a constant, immutable value.
Args:
value: The value of this constant.
name: The name of this constant.
dtype: The dtype of this constant.
Example:
.. code::
my_const = Constant(1, "my_const")
"""
_default_name_prefix = "const"
@property
def value(self) -> Tensor:
return self._value
@value.setter
def value(self, value: Tensor) -> None:
raise AttributeError("Constant.value is immutable")
def __repr__(self) -> str:
return f"Constant(name={self.name}, value={format_value(self)[0]})"
def __mul__(self, value: Tensor | complex) -> Constant:
return Constant(value=value * self.value, name=self.name)
def __rmul__(self, value: Tensor | complex) -> Constant:
return Constant(value=self.value * value, name=self.name)
[docs]
class Variable(Parameter):
r"""A parameter whose value can change.
Args:
value: The value of this variable.
name: The name of this variable.
dtype: The dtype of this variable.
update_fn: The name of the function used to update this variable during training.
Example:
.. code::
my_var = Variable(1, "my_var")
"""
_default_name_prefix = "var"
_update_fn: UpdateFn
name: str # always set by Parameter.__init__ (via _default_name_prefix when None)
def __init__(
self,
value: Tensor | Parameter | None = None,
name: str | None = None,
dtype: str | None = None,
update_fn: UpdateFn = "update_euclidean",
) -> None:
super().__init__(value=value, name=name, dtype=dtype)
self._update_fn = update_fn
@property
def update_fn(self) -> UpdateFn:
r"""The name of the function used to update this variable during training."""
return self._update_fn
@update_fn.setter
def update_fn(self, value: UpdateFn) -> None:
self._update_fn = value
def __repr__(self) -> str:
return (
f"Variable(name={self.name}, value={format_value(self)[0]}, update_fn={self.update_fn})"
)
def __mul__(self, value: Tensor | complex) -> Variable:
return Variable(
value=value * self.value,
name=self.name,
update_fn=self.update_fn,
)
def __rmul__(self, value: Tensor | complex) -> Variable:
return Variable(
value=self.value * value,
name=self.name,
update_fn=self.update_fn,
)
# ~~~~~~
# PyTree
# ~~~~~~
@classmethod
def _tree_unflatten(
cls,
aux_data: tuple[str, UpdateFn],
children: tuple[Tensor, ...],
) -> Variable: # pragma: no cover
ret = object.__new__(cls)
ret.value = children[0]
ret.name, ret.update_fn = aux_data
return ret
def _tree_flatten(self) -> tuple[tuple[Tensor, ...], tuple[str, UpdateFn]]: # pragma: no cover
children = (self.value,)
aux_data = (self.name, self._update_fn)
return (children, aux_data)
_modules/mrmustard/parameters/parameters
Download Python script
Download Notebook
View on GitHub