mm.training.callbacks.Callback

class mrmustard.training.callbacks.Callback(tag=None, steps_per_call=1)[source]

Bases: object

Base callback class for optimizers. Users can inherit from this class and define the following custom logic:

  • .trigger:

    Custom triggering condition, other than the regular schedule set by step_per_call.

  • .call:

    The main routine to be customized.

  • .update_cost_fn:

    The custom cost_fn updater, which is expected to return a new cost_fn callable to replace the original one passed to the optimizer.

  • .update_grads:

    The custom grads modifyer, which is expected to return a list of parameter gradients after modification, to be applied to the parameters.

  • .update_optimizer:

    The custom optimizer updater, which is expected to modify the optimizer inplace for things like scheduling learning rates.

steps_per_call

Sets calling frequency of this callback.

tag

Custom tag for a callback instance to be used as keys in Optimizer.callback_history.

steps_per_call: int = 1

Sets calling frequency of this callback. Defaults to once per optimization step. Use higher values to reduce overhead.

tag: str = None

Custom tag for a callback instance to be used as keys in Optimizer.callback_history. Defaults to the class name.

__call__(**kwargs)

Call self as a function.

call(**kwargs)

User implemented main callback logic.

get_opt_step(optimizer, **kwargs)

Gets current step from optimizer.

trigger(**kwargs)

User implemented custom trigger conditions.

update_cost_fn(**kwargs)

User implemented cost_fn modifier.

update_grads(**kwargs)

User implemented gradient modifier.

update_optimizer(optimizer, **kwargs)

User implemented optimizer update scheduler.

__call__(**kwargs)[source]

Call self as a function.

call(**kwargs)[source]

User implemented main callback logic.

Return type:

Optional[Mapping]

get_opt_step(optimizer, **kwargs)[source]

Gets current step from optimizer.

trigger(**kwargs)[source]

User implemented custom trigger conditions.

Return type:

bool

update_cost_fn(**kwargs)[source]

User implemented cost_fn modifier.

Return type:

Optional[Callable]

update_grads(**kwargs)[source]

User implemented gradient modifier.

Return type:

Optional[Sequence]

update_optimizer(optimizer, **kwargs)[source]

User implemented optimizer update scheduler.