Source code for mrmustard.training.optimizer

# Copyright 2025 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A Jax based optimizer for any parametrized object."""

from __future__ import annotations

from collections.abc import Callable, Sequence

import numpy as np

from mrmustard import math, settings
from mrmustard.lab import CircuitComponent
from mrmustard.parameters import ParameterDict, Variable
from mrmustard.training.progress_bar import ProgressBar
from mrmustard.utils.logger import create_logger

try:
    import equinox as eqx
    import jax
    import optax
    from optax import GradientTransformation, OptState, multi_transform

    from mrmustard.training.parameter_update import (
        update_orthogonal,
        update_siegel,
        update_symplectic,
        update_unitary,
    )

except ImportError:
    raise ImportError(
        "Optimizer only supports the Jax backend. Please install the `jax_backend` group using `uv pip install -g jax_backend` and set the backend to Jax using `math.change_backend('jax')`."
    ) from None

__all__ = ["Optimizer"]


[docs] class Optimizer: r"""A Jax based optimizer for any parametrized object. Args: euclidean_lr: The learning rate for euclidean parameters. symplectic_lr: The learning rate for symplectic parameters. unitary_lr: The learning rate for unitary parameters. orthogonal_lr: The learning rate for orthogonal parameters. siegel_lr: The learning rate for Siegel-disk parameters. euclidean_optimizer: The optax optimizer class for euclidean updates (default: ``optax.adabelief``). symplectic_optimizer: The optax optimizer class for symplectic updates (default: ``optax.adabelief``). unitary_optimizer: The optax optimizer class for unitary updates (default: ``optax.adabelief``). orthogonal_optimizer: The optax optimizer class for orthogonal updates (default: ``optax.adabelief``). siegel_optimizer: The optax optimizer class for Siegel-disk updates (default: ``optax.adabelief``). stable_threshold: The threshold for the loss to be considered stable. Raises: ValueError: If the set backend is not "jax". Example: Using different optimizers for different parameter types: >>> from mrmustard.training import Optimizer >>> import optax >>> opt = Optimizer( ... euclidean_lr=0.01, ... symplectic_lr=0.001, ... euclidean_optimizer=optax.adamw, ... symplectic_optimizer=optax.adam, ... ) """ def __init__( self, euclidean_lr: float = 0.001, symplectic_lr: float = 0.001, unitary_lr: float = 0.001, orthogonal_lr: float = 0.1, siegel_lr: float = 0.001, euclidean_optimizer: type[GradientTransformation] | None = None, symplectic_optimizer: type[GradientTransformation] | None = None, unitary_optimizer: type[GradientTransformation] | None = None, orthogonal_optimizer: type[GradientTransformation] | None = None, siegel_optimizer: type[GradientTransformation] | None = None, stable_threshold: float = 1e-6, ): if math.backend_name != "jax": raise ValueError( "Optimizer only supports the Jax backend. Please set the backend to Jax using `math.change_backend('jax')`.", ) self.euclidean_lr = euclidean_lr self.symplectic_lr = symplectic_lr self.unitary_lr = unitary_lr self.orthogonal_lr = orthogonal_lr self.siegel_lr = siegel_lr self.euclidean_optimizer = euclidean_optimizer or optax.adabelief self.symplectic_optimizer = symplectic_optimizer or optax.adabelief self.unitary_optimizer = unitary_optimizer or optax.adabelief self.orthogonal_optimizer = orthogonal_optimizer or optax.adabelief self.siegel_optimizer = siegel_optimizer or optax.adabelief self.opt_history = [0] self.parameter_history: dict[str, np.ndarray] = {} self.log = create_logger(__name__) self.stable_threshold = stable_threshold
[docs] @eqx.filter_jit def make_step( self, optim: GradientTransformation, cost_fn: Callable, by_optimizing: Sequence[Variable | CircuitComponent], opt_state: OptState, ) -> tuple[Sequence[Variable | CircuitComponent], OptState, float]: r"""Make a step of the optimization. Args: optim: The optimizer to use. cost_fn: The cost function to minimize. by_optimizing: The items to optimize. opt_state: The current state of the optimizer. Returns: The updated by_optimizing, the updated optimizer state, and the loss value. """ loss_value, grads = jax.value_and_grad(cost_fn, argnums=tuple(range(len(by_optimizing))))( *by_optimizing, ) conj_grads = jax.tree.map(jax.numpy.conj, grads) updates, opt_state = optim.update(conj_grads, opt_state, by_optimizing) by_optimizing = eqx.apply_updates(by_optimizing, updates) return by_optimizing, opt_state, loss_value
[docs] def minimize( self, cost_fn: Callable, by_optimizing: Sequence[Variable] | ParameterDict, max_steps: int = 1000, parameter_history: bool = False, ) -> Sequence[Variable] | ParameterDict: r"""Minimizes the given cost function by optimizing ``Variable``\ s. If a ``ParameterDict`` is provided, it will return a ``ParameterDict`` with the optimized variables. Args: cost_fn: A function that will be executed in a differentiable context in order to compute gradients as needed. by_optimizing: A list of ``Variable``\ s to optimize. max_steps: The minimization keeps going until the loss is stable or max_steps are reached (if ``max_steps=0``\ , it will only stop when the loss is stable). parameter_history: If True, store intermediate variable values at each step in ``self.parameter_history``, a dict keyed by variable name with numpy arrays of shape ``(n_steps+1, *var_shape)``. Returns: The list of optimized ``Variable``\ s or the ``ParameterDict`` with the optimized variables. """ if isinstance(by_optimizing, ParameterDict): return_param_dict = True by_optimizing = by_optimizing.variables.values() else: return_param_dict = False by_optimizing = tuple(by_optimizing) self.opt_history = [0] self.parameter_history = {} if settings.PROGRESSBAR: progress_bar = ProgressBar(max_steps) with progress_bar: by_optimizing = self._optimization_loop( cost_fn, by_optimizing, max_steps=max_steps, parameter_history=parameter_history, progress_bar=progress_bar, ) else: by_optimizing = self._optimization_loop( cost_fn, by_optimizing, max_steps=max_steps, parameter_history=parameter_history, ) return ParameterDict(*by_optimizing) if return_param_dict else by_optimizing
[docs] def should_stop(self, max_steps: int) -> bool: r"""Returns a boolean indicating whether the optimization should stop. An optimization should stop either because the loss is stable or because the maximum number of steps is reached. Args: max_steps: The maximum number of steps to run. Returns: A boolean indicating whether the optimization should stop. """ if max_steps != 0 and len(self.opt_history) > max_steps: return True # if cost varies less than threshold over 20 steps if ( len(self.opt_history) > 20 and sum(abs(self.opt_history[-i - 1] - self.opt_history[-i]) for i in range(1, 20)) < self.stable_threshold ): self.log.info("Loss looks stable, stopping here.") return True return False
def _optimization_loop( self, cost_fn: Callable, by_optimizing: Sequence[Variable], max_steps: int, parameter_history: bool = False, progress_bar: ProgressBar | None = None, ) -> Sequence[Variable]: r"""The core optimization loop.""" by_optimizing = tuple(by_optimizing) labels_pytree = jax.tree_util.tree_map( lambda node: str(node.update_fn), by_optimizing, is_leaf=lambda n: isinstance(n, Variable), ) optim = multi_transform( { "update_euclidean": self.euclidean_optimizer(learning_rate=self.euclidean_lr), "update_symplectic": update_symplectic( self.symplectic_lr, self.symplectic_optimizer ), "update_unitary": update_unitary(self.unitary_lr, self.unitary_optimizer), "update_orthogonal": update_orthogonal( self.orthogonal_lr, self.orthogonal_optimizer ), "update_siegel": update_siegel(self.siegel_lr, self.siegel_optimizer), }, labels_pytree, ) opt_state = optim.init(by_optimizing) snapshots: dict[str, list[np.ndarray]] | None = ( {v.name: [math.asnumpy(v.value)] for v in by_optimizing} if parameter_history else None ) # optimize try: while not self.should_stop(max_steps): by_optimizing, opt_state, loss_value = self.make_step( optim, cost_fn, by_optimizing, opt_state, ) if snapshots is not None: for v in by_optimizing: snapshots[v.name].append(math.asnumpy(v.value)) self.opt_history.append(loss_value) if progress_bar is not None: progress_bar.step(math.asnumpy(loss_value)) except KeyboardInterrupt: self.log.info( "Optimization interrupted at step %d. Returning current values.", len(self.opt_history) - 1, ) if snapshots is not None: self.parameter_history = {name: np.stack(vals) for name, vals in snapshots.items()} return by_optimizing