mm.training.callbacks.Callback¶
- class mrmustard.training.callbacks.Callback(tag=None, steps_per_call=1)[source]¶
Bases:
object
Base callback class for optimizers. Users can inherit from this class and define the following custom logic:
- .trigger:
Custom triggering condition, other than the regular schedule set by step_per_call.
- .call:
The main routine to be customized.
- .update_cost_fn:
The custom cost_fn updater, which is expected to return a new cost_fn callable to replace the original one passed to the optimizer.
- .update_grads:
The custom grads modifyer, which is expected to return a list of parameter gradients after modification, to be applied to the parameters.
- .update_optimizer:
The custom optimizer updater, which is expected to modify the optimizer inplace for things like scheduling learning rates.
Attributes
Sets calling frequency of this callback.
Custom tag for a callback instance to be used as keys in Optimizer.callback_history.
-
steps_per_call:
int
= 1¶ Sets calling frequency of this callback. Defaults to once per optimization step. Use higher values to reduce overhead.
-
tag:
str
= None¶ Custom tag for a callback instance to be used as keys in Optimizer.callback_history. Defaults to the class name.
Methods
__call__
(**kwargs)Call self as a function.
call
(**kwargs)User implemented main callback logic.
get_opt_step
(optimizer, **kwargs)Gets current step from optimizer.
trigger
(**kwargs)User implemented custom trigger conditions.
update_cost_fn
(**kwargs)User implemented cost_fn modifier.
update_grads
(**kwargs)User implemented gradient modifier.
update_optimizer
(optimizer, **kwargs)User implemented optimizer update scheduler.