Parameter Updates

riemannian_gradient

Transforms Euclidean gradients to Riemannian gradients (Lie Algebra elements).

riemannian_retraction

Lie-group retraction S_new = S @ expm(update) (used for U(N), O(N), Sp).

update_orthogonal

Creates an optax GradientTransformation for orthogonal parameter updates using Riemannian optimization.

update_symplectic

Creates an optax GradientTransformation for symplectic parameter updates using Riemannian optimization.

update_unitary

Creates an optax GradientTransformation for unitary parameter updates using Riemannian optimization.

riemannian_gradient

mrmustard.training.parameter_update.riemannian_gradient(method)[source]

Transforms Euclidean gradients to Riemannian gradients (Lie Algebra elements).

Parameters:

method (str)

riemannian_retraction

mrmustard.training.parameter_update.riemannian_retraction()[source]

Lie-group retraction S_new = S @ expm(update) (used for U(N), O(N), Sp).

update_orthogonal

mrmustard.training.parameter_update.update_orthogonal(orthogonal_lr, optimizer_cls=<function adabelief>)[source]

Creates an optax GradientTransformation for orthogonal parameter updates using Riemannian optimization.

Parameters:
  • orthogonal_lr (float | Callable[[int], float]) – The learning rate for orthogonal updates.

  • optimizer_cls (Callable) – The optimizer class to use (default: optax.adabelief).

Returns:

An optax.GradientTransformation for orthogonal updates.

update_symplectic

mrmustard.training.parameter_update.update_symplectic(symplectic_lr, optimizer_cls=<function adabelief>)[source]

Creates an optax GradientTransformation for symplectic parameter updates using Riemannian optimization.

Parameters:
  • symplectic_lr (float | Callable[[int], float]) – The learning rate for symplectic updates.

  • optimizer_cls (Callable) – The optimizer class to use (default: optax.adabelief).

Returns:

An optax.GradientTransformation for symplectic updates.

update_unitary

mrmustard.training.parameter_update.update_unitary(unitary_lr, optimizer_cls=<function adabelief>)[source]

Creates an optax GradientTransformation for unitary parameter updates using Riemannian optimization.

Parameters:
  • unitary_lr (float | Callable[[int], float]) – The learning rate for unitary updates.

  • optimizer_cls (Callable) – The optimizer class to use (default: optax.adabelief).

Returns:

An optax.GradientTransformation for unitary updates.