Optimizer

Optimizer

A Jax based optimizer for any parametrized object.

Optimizer

class mrmustard.training.optimizer.Optimizer(euclidean_lr=0.001, symplectic_lr=0.001, unitary_lr=0.001, orthogonal_lr=0.1, siegel_lr=0.001, euclidean_optimizer=None, symplectic_optimizer=None, unitary_optimizer=None, orthogonal_optimizer=None, siegel_optimizer=None, stable_threshold=1e-06)[source]

A Jax based optimizer for any parametrized object.

Parameters:
  • euclidean_lr (float) – The learning rate for euclidean parameters.

  • symplectic_lr (float) – The learning rate for symplectic parameters.

  • unitary_lr (float) – The learning rate for unitary parameters.

  • orthogonal_lr (float) – The learning rate for orthogonal parameters.

  • siegel_lr (float) – The learning rate for Siegel-disk parameters.

  • euclidean_optimizer (type[GradientTransformation] | None) – The optax optimizer class for euclidean updates (default: optax.adabelief).

  • symplectic_optimizer (type[GradientTransformation] | None) – The optax optimizer class for symplectic updates (default: optax.adabelief).

  • unitary_optimizer (type[GradientTransformation] | None) – The optax optimizer class for unitary updates (default: optax.adabelief).

  • orthogonal_optimizer (type[GradientTransformation] | None) – The optax optimizer class for orthogonal updates (default: optax.adabelief).

  • siegel_optimizer (type[GradientTransformation] | None) – The optax optimizer class for Siegel-disk updates (default: optax.adabelief).

  • stable_threshold (float) – The threshold for the loss to be considered stable.

Raises:

ValueError – If the set backend is not “jax”.

Example

Using different optimizers for different parameter types:

>>> from mrmustard.training import Optimizer
>>> import optax
>>> opt = Optimizer(
...     euclidean_lr=0.01,
...     symplectic_lr=0.001,
...     euclidean_optimizer=optax.adamw,
...     symplectic_optimizer=optax.adam,
... )
make_step(optim, cost_fn, by_optimizing, opt_state)[source]

Make a step of the optimization.

Parameters:
  • optim (GradientTransformation) – The optimizer to use.

  • cost_fn (Callable) – The cost function to minimize.

  • by_optimizing (Sequence[Variable | CircuitComponent]) – The items to optimize.

  • opt_state (OptState) – The current state of the optimizer.

Returns:

The updated by_optimizing, the updated optimizer state, and the loss value.

Return type:

tuple[Sequence[Variable | CircuitComponent], OptState, float]

minimize(cost_fn, by_optimizing, max_steps=1000, parameter_history=False)[source]

Minimizes the given cost function by optimizing Variables. If a ParameterDict is provided, it will return a ParameterDict with the optimized variables.

Parameters:
  • cost_fn (Callable) – A function that will be executed in a differentiable context in order to compute gradients as needed.

  • by_optimizing (Sequence[Variable] | ParameterDict) – A list of Variables to optimize.

  • max_steps (int) – The minimization keeps going until the loss is stable or max_steps are reached (if max_steps=0, it will only stop when the loss is stable).

  • parameter_history (bool) – If True, store intermediate variable values at each step in self.parameter_history, a dict keyed by variable name with numpy arrays of shape (n_steps+1, *var_shape).

Returns:

The list of optimized Variables or the ParameterDict with the optimized variables.

Return type:

Sequence[Variable] | ParameterDict

should_stop(max_steps)[source]

Returns a boolean indicating whether the optimization should stop. An optimization should stop either because the loss is stable or because the maximum number of steps is reached.

Parameters:

max_steps (int) – The maximum number of steps to run.

Returns:

A boolean indicating whether the optimization should stop.

Return type:

bool