callbacks

This module contains the implementation of callback functionalities for optimizations.

Callbacks allow users to have finer control over the optimization process by executing predefined routines as optimization progresses. Even though the Optimizer.minimize() accepts Callable functions directly, the Callback class modularizes the logic and makes it easier for users to inherit from it and come up with their own custom callbacks.

Things you can do with callbacks:

  • Logging custom metrics.

  • Tracking parameters and costs with Tensorboard.

  • Scheduling learning rates.

  • Modifying the gradient update that gets applied.

  • Updating cost_fn to alter the optimization landscape in our favour.

  • Adding some RL into the optimizer.

Builtin callbacks:

  • Callback: The base class, to be used for building custom callbacks.

  • TensorboardCallback: Tracks costs, parameter values and gradients in Tensorboard.

Examples:

import numpy as np
from mrmustard.training import Optimizer, TensorboardCallback

def cost_fn():
    ...

def as_dB(cost):
    delta = np.sqrt(np.log(1 / (abs(cost) ** 2)) / (2 * np.pi))
    cost_dB = -10 * np.log10(delta**2)
    return cost_dB

tb_cb = TensorboardCallback(cost_converter=as_dB, track_grads=True)

def rolling_cost_cb(optimizer, cost, **kwargs):
    return {
        'rolling_cost': np.mean(optimizer.opt_history[-10:] + [cost]),
    }

opt = Optimizer(euclidean_lr = 0.001);
opt.minimize(cost_fn, max_steps=200, by_optimizing=[...], callbacks=[tb_cb, rolling_cost_cb])

# VScode can be used to open the Tensorboard frontend for live monitoring.

opt.callback_history['TensorboardCallback']
opt.callback_history['rolling_cost_cb']

Classes

Callback([tag, steps_per_call])

Base callback class for optimizers.

Path(*args, **kwargs)

PurePath subclass that can make system calls.

TensorboardCallback([tag, steps_per_call, ...])

Callback for enabling Tensorboard tracking of optimization progresses.

datetime(year, month, day[, hour[, minute[, ...)

The year, month and day arguments are required.

Class Inheritance Diagram

digraph inheritancee87112a861 { bgcolor=transparent; rankdir=LR; size="8.0, 12.0"; "Callback" [URL="../api/mrmustard.training.callbacks.Callback.html#mrmustard.training.callbacks.Callback",color=lightskyblue1,fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style=filled,target="_top",tooltip="Base callback class for optimizers. Users can inherit from this class and define the"]; "TensorboardCallback" [URL="../api/mrmustard.training.callbacks.TensorboardCallback.html#mrmustard.training.callbacks.TensorboardCallback",color=lightskyblue1,fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style=filled,target="_top",tooltip="Callback for enabling Tensorboard tracking of optimization progresses."]; "Callback" -> "TensorboardCallback" [arrowsize=0.5,style="setlinewidth(0.5)"]; }