Source code for mrmustard.training.parameter_update
# Copyright 2025 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Custom optax ``GradientTransformation``\ s for non-euclidean parameter updates."""
from collections.abc import Callable
import jax
import optax
from mrmustard import math
__all__ = [
"riemannian_gradient",
"riemannian_retraction",
"siegel_retract",
"siegel_retraction",
"update_orthogonal",
"update_siegel",
"update_symplectic",
"update_unitary",
]
def _hpsd_eigh(H):
r"""Eigendecomposition of a batched Hermitian positive-semidefinite matrix
with eigenvalues clipped to ``1e-12`` to absorb tiny numerical negatives
and to keep negative powers well-defined.
"""
evals, V = math.eigh(H)
evals = math.maximum(math.real(evals), 1e-12)
return evals, V
def _hpsd_power_from_eigh(evals, V, power: float):
r"""Reassemble ``V diag(evals**power) V^H`` from a precomputed clipped
eigendecomposition (see :func:`_hpsd_eigh`).
"""
pow_evals = math.cast(evals**power, V.dtype)
return math.einsum("...ij,...j,...kj->...ik", V, pow_evals, math.conj(V))
def _hpsd_power(H, power: float):
r"""Compute ``H ** power`` for a batched Hermitian positive-semidefinite matrix
via eigendecomposition.
"""
evals, V = _hpsd_eigh(H)
return _hpsd_power_from_eigh(evals, V, power)
def siegel_retract(Z, eta):
r"""Bergman-geodesic retraction on the Siegel disk :math:`\mathcal{D}_g`.
Follows the Bergman geodesic for unit time from :math:`Z` in direction ``eta``
(complex symmetric) via three steps: pull ``eta`` back to the origin to obtain
:math:`\xi_0 = A^{-1/2}\,\eta\,B^{-1/2}` with :math:`A = I - ZZ^*`,
:math:`B = I - Z^*Z`; apply the origin geodesic
:math:`V = U\,\mathrm{diag}(\tanh\sigma)\,U^T` where
:math:`\xi_0 = U\,\mathrm{diag}(\sigma)\,U^T` is the Takagi decomposition;
transvect the result back to :math:`Z` via the Möbius map
.. math::
\phi_Z(V) = A^{-1/2}\,(V + Z)\,(I + Z^* V)^{-1}\,B^{1/2}.
The origin geodesic is computed without an explicit Takagi factorization via
the identity :math:`U\,\tanh(S)\,U^T = h(\xi_0\xi_0^*)\,\xi_0` with
:math:`h(x) = \tanh(\sqrt{x})/\sqrt{x}`. The Hermitian function
:math:`h(\xi_0\xi_0^*)` is invariant under any unitary rotation within an
eigenspace of :math:`\xi_0\xi_0^*`, so the result is correct even when the
Takagi singular values of :math:`\xi_0` are degenerate.
Args:
Z: A point in :math:`\mathcal{D}_g` (complex symmetric).
eta: A tangent vector at ``Z`` (complex symmetric).
Returns:
A complex symmetric matrix in :math:`\mathcal{D}_g`.
"""
Z_H = math.conj(math.swapaxes(Z, -1, -2))
eye = math.eye(Z.shape[-1], dtype=Z.dtype)
A_neg_half = _hpsd_power(eye - math.matmul(Z, Z_H), -0.5)
B_evals, B_V = _hpsd_eigh(eye - math.matmul(Z_H, Z))
B_half = _hpsd_power_from_eigh(B_evals, B_V, 0.5)
B_neg_half = _hpsd_power_from_eigh(B_evals, B_V, -0.5)
xi0 = math.matmul(A_neg_half, math.matmul(eta, B_neg_half))
xi0 = 0.5 * (xi0 + math.swapaxes(xi0, -1, -2))
# Origin geodesic V = U tanh(S) U^T computed as h(xi0 xi0^*) @ xi0,
# where h(x) = tanh(sqrt(x)) / sqrt(x). The eigenvectors of xi0 xi0^*
# at eigenvalue 0 are orthogonal to range(xi0), so any finite value
# assigned to h at sqrt(x) = 0 contributes exactly 0 to h_mat @ xi0;
# we therefore clamp the denominator with a tiny constant rather than
# branching on the Takagi limit.
xi0_H = math.conj(math.swapaxes(xi0, -1, -2))
sigma_sq, M_V = math.eigh(math.matmul(xi0, xi0_H))
sigma = math.sqrt(math.maximum(math.real(sigma_sq), 0.0))
# Clip tanh below 1 so the next iteration's (I - ZZ*)^{-1/2} stays
# well-defined in float64 when the optimizer produces a very long step.
tanh_sigma = math.minimum(math.tanh(sigma), 1.0 - 1e-7)
h_diag = math.cast(tanh_sigma / math.maximum(sigma, 1e-30), M_V.dtype)
h_mat = math.einsum("...ij,...j,...kj->...ik", M_V, h_diag, math.conj(M_V))
V = math.matmul(h_mat, xi0)
V = 0.5 * (V + math.swapaxes(V, -1, -2))
right = math.matmul(math.inv(eye + math.matmul(Z_H, V)), B_half)
out = math.matmul(math.matmul(A_neg_half, V + Z), right)
return 0.5 * (out + math.swapaxes(out, -1, -2))
[docs]
def riemannian_gradient(method: str):
r"""Transforms Euclidean gradients to Riemannian gradients (Lie Algebra elements)."""
def init_fn(params):
return optax.EmptyState()
def update_fn(grads, state, params):
if params is None:
return grads, state
if method == "symplectic":
updates = jax.tree_util.tree_map(
math.euclidean_to_symplectic,
params,
grads,
)
elif method == "unitary":
updates = jax.tree_util.tree_map(
math.euclidean_to_unitary,
params,
grads,
)
elif method == "orthogonal":
updates = jax.tree_util.tree_map(
lambda p, g: math.euclidean_to_unitary(p, math.real(g)),
params,
grads,
)
elif method == "siegel":
updates = jax.tree_util.tree_map(
math.euclidean_to_siegel,
params,
grads,
)
else:
updates = grads
return updates, state
return optax.GradientTransformation(init_fn, update_fn)
def _retraction(step_fn: Callable):
r"""Builds an optax ``GradientTransformation`` that applies a per-leaf retraction.
``step_fn(p, u)`` should return the new point on the manifold given the current
point ``p`` and the tangent-space step ``u``; this transformation converts it to
the additive update ``p_new - p`` that optax composes downstream.
"""
def init_fn(params):
return optax.EmptyState()
def update_fn(updates, state, params):
new_updates = jax.tree_util.tree_map(lambda p, u: step_fn(p, u) - p, params, updates)
return new_updates, state
return optax.GradientTransformation(init_fn, update_fn)
[docs]
def riemannian_retraction():
r"""Lie-group retraction ``S_new = S @ expm(update)`` (used for U(N), O(N), Sp)."""
return _retraction(lambda p, u: math.matmul(p, math.expm(u)))
def siegel_retraction():
r"""Bergman-geodesic retraction on the Siegel disk."""
return _retraction(siegel_retract)
def _riemannian_update(
method: str,
lr: float | Callable[[int], float],
optimizer_cls: Callable,
retraction: Callable[[], optax.GradientTransformation],
) -> optax.GradientTransformation:
r"""Compose ``riemannian_gradient(method) → optimizer_cls(lr) → retraction()``."""
return optax.chain(
riemannian_gradient(method),
optimizer_cls(learning_rate=lr),
retraction(),
)
[docs]
def update_orthogonal(
orthogonal_lr: float | Callable[[int], float], optimizer_cls: Callable = optax.adabelief
):
r"""Creates an optax GradientTransformation for orthogonal parameter updates using Riemannian optimization.
Args:
orthogonal_lr: The learning rate for orthogonal updates.
optimizer_cls: The optimizer class to use (default: optax.adabelief).
Returns:
An optax.GradientTransformation for orthogonal updates.
"""
return _riemannian_update("orthogonal", orthogonal_lr, optimizer_cls, riemannian_retraction)
[docs]
def update_symplectic(
symplectic_lr: float | Callable[[int], float], optimizer_cls: Callable = optax.adabelief
):
r"""Creates an optax GradientTransformation for symplectic parameter updates using Riemannian optimization.
Args:
symplectic_lr: The learning rate for symplectic updates.
optimizer_cls: The optimizer class to use (default: optax.adabelief).
Returns:
An optax.GradientTransformation for symplectic updates.
"""
return _riemannian_update("symplectic", symplectic_lr, optimizer_cls, riemannian_retraction)
[docs]
def update_unitary(
unitary_lr: float | Callable[[int], float], optimizer_cls: Callable = optax.adabelief
):
r"""Creates an optax GradientTransformation for unitary parameter updates using Riemannian optimization.
Args:
unitary_lr: The learning rate for unitary updates.
optimizer_cls: The optimizer class to use (default: optax.adabelief).
Returns:
An optax.GradientTransformation for unitary updates.
"""
return _riemannian_update("unitary", unitary_lr, optimizer_cls, riemannian_retraction)
def update_siegel(
siegel_lr: float | Callable[[int], float], optimizer_cls: Callable = optax.adabelief
):
r"""Creates an optax GradientTransformation for Siegel-disk parameter updates using
Riemannian optimization with the Bergman metric.
Args:
siegel_lr: The learning rate for Siegel-disk updates.
optimizer_cls: The optimizer class to use (default: optax.adabelief).
Returns:
An optax.GradientTransformation for Siegel-disk updates.
"""
return _riemannian_update("siegel", siegel_lr, optimizer_cls, siegel_retraction)
_modules/mrmustard/training/parameter_update
Download Python script
Download Notebook
View on GitHub