Optimizer¶
A Jax based optimizer for any parametrized object. |
Optimizer¶
- class mrmustard.training.optimizer.Optimizer(euclidean_lr=0.001, symplectic_lr=0.001, unitary_lr=0.001, orthogonal_lr=0.1, siegel_lr=0.001, euclidean_optimizer=None, symplectic_optimizer=None, unitary_optimizer=None, orthogonal_optimizer=None, siegel_optimizer=None, stable_threshold=1e-06)[source]¶
A Jax based optimizer for any parametrized object.
- Parameters:
euclidean_lr (float) – The learning rate for euclidean parameters.
symplectic_lr (float) – The learning rate for symplectic parameters.
unitary_lr (float) – The learning rate for unitary parameters.
orthogonal_lr (float) – The learning rate for orthogonal parameters.
siegel_lr (float) – The learning rate for Siegel-disk parameters.
euclidean_optimizer (type[GradientTransformation] | None) – The optax optimizer class for euclidean updates (default:
optax.adabelief).symplectic_optimizer (type[GradientTransformation] | None) – The optax optimizer class for symplectic updates (default:
optax.adabelief).unitary_optimizer (type[GradientTransformation] | None) – The optax optimizer class for unitary updates (default:
optax.adabelief).orthogonal_optimizer (type[GradientTransformation] | None) – The optax optimizer class for orthogonal updates (default:
optax.adabelief).siegel_optimizer (type[GradientTransformation] | None) – The optax optimizer class for Siegel-disk updates (default:
optax.adabelief).stable_threshold (float) – The threshold for the loss to be considered stable.
- Raises:
ValueError – If the set backend is not “jax”.
Example
Using different optimizers for different parameter types:
>>> from mrmustard.training import Optimizer >>> import optax >>> opt = Optimizer( ... euclidean_lr=0.01, ... symplectic_lr=0.001, ... euclidean_optimizer=optax.adamw, ... symplectic_optimizer=optax.adam, ... )
- make_step(optim, cost_fn, by_optimizing, opt_state)[source]¶
Make a step of the optimization.
- Parameters:
optim (GradientTransformation) – The optimizer to use.
cost_fn (Callable) – The cost function to minimize.
by_optimizing (Sequence[Variable | CircuitComponent]) – The items to optimize.
opt_state (OptState) – The current state of the optimizer.
- Returns:
The updated by_optimizing, the updated optimizer state, and the loss value.
- Return type:
tuple[Sequence[Variable | CircuitComponent], OptState, float]
- minimize(cost_fn, by_optimizing, max_steps=1000, parameter_history=False)[source]¶
Minimizes the given cost function by optimizing
Variables. If aParameterDictis provided, it will return aParameterDictwith the optimized variables.- Parameters:
cost_fn (Callable) – A function that will be executed in a differentiable context in order to compute gradients as needed.
by_optimizing (Sequence[Variable] | ParameterDict) – A list of
Variables to optimize.max_steps (int) – The minimization keeps going until the loss is stable or max_steps are reached (if
max_steps=0, it will only stop when the loss is stable).parameter_history (bool) – If True, store intermediate variable values at each step in
self.parameter_history, a dict keyed by variable name with numpy arrays of shape(n_steps+1, *var_shape).
- Returns:
The list of optimized
Variables or theParameterDictwith the optimized variables.- Return type:
Sequence[Variable] | ParameterDict
- should_stop(max_steps)[source]¶
Returns a boolean indicating whether the optimization should stop. An optimization should stop either because the loss is stable or because the maximum number of steps is reached.
- Parameters:
max_steps (int) – The maximum number of steps to run.
- Returns:
A boolean indicating whether the optimization should stop.
- Return type:
bool