Parameter Updates¶
Transforms Euclidean gradients to Riemannian gradients (Lie Algebra elements). |
|
Lie-group retraction |
|
Creates an optax GradientTransformation for orthogonal parameter updates using Riemannian optimization. |
|
Creates an optax GradientTransformation for symplectic parameter updates using Riemannian optimization. |
|
Creates an optax GradientTransformation for unitary parameter updates using Riemannian optimization. |
riemannian_gradient¶
riemannian_retraction¶
update_orthogonal¶
- mrmustard.training.parameter_update.update_orthogonal(orthogonal_lr, optimizer_cls=<function adabelief>)[source]¶
Creates an optax GradientTransformation for orthogonal parameter updates using Riemannian optimization.
- Parameters:
orthogonal_lr (float | Callable[[int], float]) – The learning rate for orthogonal updates.
optimizer_cls (Callable) – The optimizer class to use (default: optax.adabelief).
- Returns:
An optax.GradientTransformation for orthogonal updates.
update_symplectic¶
- mrmustard.training.parameter_update.update_symplectic(symplectic_lr, optimizer_cls=<function adabelief>)[source]¶
Creates an optax GradientTransformation for symplectic parameter updates using Riemannian optimization.
- Parameters:
symplectic_lr (float | Callable[[int], float]) – The learning rate for symplectic updates.
optimizer_cls (Callable) – The optimizer class to use (default: optax.adabelief).
- Returns:
An optax.GradientTransformation for symplectic updates.
update_unitary¶
- mrmustard.training.parameter_update.update_unitary(unitary_lr, optimizer_cls=<function adabelief>)[source]¶
Creates an optax GradientTransformation for unitary parameter updates using Riemannian optimization.
- Parameters:
unitary_lr (float | Callable[[int], float]) – The learning rate for unitary updates.
optimizer_cls (Callable) – The optimizer class to use (default: optax.adabelief).
- Returns:
An optax.GradientTransformation for unitary updates.