trainer

This module contains the implementation of distributed training utilities for parallelized optimization of MrMustard circuits/devices through the function map_trainer().

This module requires extra dependencies, to install:

git clone https://github.com/XanaduAI/MrMustard
cd MrMustard
pip install -e .[ray]

User-provided Wrapper Functions

To distribute your optimization workflow, two user-defined functions are needed for wrapping up user logic:

  • A device_factory that wraps around the logic for making your circuits/states to be optimized; it is expected to return a single, or list of, :class:`Circuit`(s).

  • A cost_fn that takes the circuits made and additional keyword arguments and returns a backprop-able scalar cost.

Separating the circuit-making logic from the cost calculation logic has the benefit of returning the optimized circuit in the result dict for further inspection. One can also pass extra metric_fns to directly extract info from the circuit.

Examples:

from mrmustard.lab import Vacuum, Dgate, Ggate, Gaussian
from mrmustard.physics import fidelity
from mrmustard.training.trainer import map_trainer

def make_circ(x=0.):
    return Ggate(num_modes=1, symplectic_trainable=True) >> Dgate(x=x, x_trainable=True, y_trainable=True)

def cost_fn(circ=make_circ(0.1), y_targ=0.):
    target = Gaussian(1) >> Dgate(-1.5, y_targ)
    s = Vacuum(1) >> circ
    return -fidelity(s, target)

# Use case 0: Calculate the cost of a randomly initialized circuit 5 times without optimizing it.
results_0 = map_trainer(
    cost_fn=cost_fn,
    tasks=5,
)

# Use case 1: Run circuit optimization 5 times on randomly initialized circuits.
results_1 = map_trainer(
    cost_fn=cost_fn,
    device_factory=make_circ,
    tasks=5,
    max_steps=50,
    symplectic_lr=0.05,
)

# Use case 2: Run 2 sets of circuit optimization with custom parameters passed as list.
results_2 = map_trainer(
    cost_fn=cost_fn,
    device_factory=make_circ,
    tasks=[
        {'x': 0.1, 'euclidean_lr': 0.005, 'max_steps': 50},
        {'x': -0.7, 'euclidean_lr': 0.1, 'max_steps': 2},
    ],
    y_targ=0.35,
    symplectic_lr=0.05,
    AUTOCUTOFF_MAX_CUTOFF=7,
)

# Use case 3: Run 2 sets of circuit optimization with custom parameters passed as dict with extra metric functions for evaluating the final optimized circuit.
results_3 = map_trainer(
cost_fn=cost_fn,
device_factory=make_circ,
tasks={
    'my-job': {'x': 0.1, 'euclidean_lr': 0.005, 'max_steps': 50},
    'my-other-job': {'x': -0.7, 'euclidean_lr': 0.1, 'max_steps': 2},
},
y_targ=0.35,
symplectic_lr=0.05,
metric_fns={
    'is_gaussian': lambda c: c.is_gaussian,
    'foo': lambda _: 17.
},

)

Functions

curry_pop(fn, *args, **kwargs)

A poor man's reader monad bind function.

kwargs_of(fn)

Gets the kwarg signature of a callable.

map_trainer([trainer, tasks, pbar, unblock, ...])

Maps multiple training tasks across multiple workers using ray.

partial_pop(fn, *args, **kwargs)

Partially applies known kwargs to fn and returns the rest.

signature(obj, *[, follow_wrapped, globals, ...])

Get a signature object for the passed callable.

track(sequence[, description, total, ...])

Track progress by iterating over a sequence.

train_device(cost_fn[, device_factory, ...])

A general and flexible training loop for circuit optimizations with configurations adjustable through kwargs.

update_pop(obj, **kwargs)

Updates an object/dict while popping keys out and returns the updated dict and remaining kwargs.