Source code for mrmustard.training.trainer

# Copyright 2022 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This module contains the implementation of distributed training utilities for parallelized
optimization of MrMustard circuits/devices through the function :meth:`map_trainer`.

This module requires extra dependencies, to install:

.. code-block:: bash

    git clone https://github.com/XanaduAI/MrMustard
    cd MrMustard
    pip install -e .[ray]


User-provided Wrapper Functions
===============================
To distribute your optimization workflow, two user-defined functions are needed for wrapping up user logic:

* A `device_factory` that wraps around the logic for making your circuits/states to be optimized; it is expected to return a single, or list of, :class:`Circuit`(s).
* A `cost_fn` that takes the circuits made and additional keyword arguments and returns a backprop-able scalar cost.

Separating the circuit-making logic from the cost calculation logic has the benefit of returning the optimized circuit in the result dict for further inspection. One can also pass extra `metric_fns` to directly extract info from the circuit.


Examples:
=========

.. code-block::

    from mrmustard.lab import Vacuum, Dgate, Ggate, Gaussian
    from mrmustard.physics import fidelity
    from mrmustard.training.trainer import map_trainer

    def make_circ(x=0.):
        return Ggate(num_modes=1, symplectic_trainable=True) >> Dgate(x=x, x_trainable=True, y_trainable=True)

    def cost_fn(circ=make_circ(0.1), y_targ=0.):
        target = Gaussian(1) >> Dgate(-1.5, y_targ)
        s = Vacuum(1) >> circ
        return -fidelity(s, target)

    # Use case 0: Calculate the cost of a randomly initialized circuit 5 times without optimizing it.
    results_0 = map_trainer(
        cost_fn=cost_fn,
        tasks=5,
    )

    # Use case 1: Run circuit optimization 5 times on randomly initialized circuits.
    results_1 = map_trainer(
        cost_fn=cost_fn,
        device_factory=make_circ,
        tasks=5,
        max_steps=50,
        symplectic_lr=0.05,
    )

    # Use case 2: Run 2 sets of circuit optimization with custom parameters passed as list.
    results_2 = map_trainer(
        cost_fn=cost_fn,
        device_factory=make_circ,
        tasks=[
            {'x': 0.1, 'euclidean_lr': 0.005, 'max_steps': 50},
            {'x': -0.7, 'euclidean_lr': 0.1, 'max_steps': 2},
        ],
        y_targ=0.35,
        symplectic_lr=0.05,
        AUTOCUTOFF_MAX_CUTOFF=7,
    )

    # Use case 3: Run 2 sets of circuit optimization with custom parameters passed as dict with extra metric functions for evaluating the final optimized circuit.
    results_3 = map_trainer(
    cost_fn=cost_fn,
    device_factory=make_circ,
    tasks={
        'my-job': {'x': 0.1, 'euclidean_lr': 0.005, 'max_steps': 50},
        'my-other-job': {'x': -0.7, 'euclidean_lr': 0.1, 'max_steps': 2},
    },
    y_targ=0.35,
    symplectic_lr=0.05,
    metric_fns={
        'is_gaussian': lambda c: c.is_gaussian,
        'foo': lambda _: 17.
    },
)


"""

import warnings
from functools import partial
from inspect import Parameter, signature
from typing import Mapping, Sequence

import numpy as np
from rich.progress import track

from mrmustard import settings
from .optimizer import Optimizer


def _apply_partial_cost(device, cost_fn, **kwargs):
    """Helper partial cost fn maker."""
    if isinstance(device, Sequence):
        cost_fn, kwargs = partial_pop(cost_fn, *device, **kwargs)
        optimized = device
    elif isinstance(device, Mapping):
        cost_fn, kwargs = partial_pop(cost_fn, **device, **kwargs)
        optimized = list(device.values())
    return cost_fn, kwargs, optimized


[docs] def train_device( cost_fn, device_factory=None, metric_fns=None, return_kwargs=True, skip_opt=False, tag=None, **kwargs, ): """A general and flexible training loop for circuit optimizations with configurations adjustable through kwargs. Args: cost_fn (callable): The optimized cost function to be distributed. It's expected to accept the output of `device_factory` as *args as well as user-defined **kwargs, and returns a scalar cost. Its user-defined **kwargs will be passed from this function's **kwargs which must include all its required arguments. device_factory (callable): Function that (partially) takes `kwargs` and returns a device, or list/dict of devices. If None, `cost_fn` will be assumed to take no positional argument (for example, when device-making is contained in `cost_fn`). Defaults to None. metric_fns (Union[Sequence[callable], Mapping[callable], callable]): Optional collection of functions that takes the output of `device_factory` after optimization and returns arbitrary evaluation/information. return_kwargs (bool): Whether to include input config `kwargs` in the output dict. Defualts to True. skip_opt (bool): Whether to skip the optimization and directly calculate cost. tag (str): Optional label of the training task associated with the `kwargs` to be included in the output dict. kwargs: Dict containing all arguments to any of the functions below: - `cost_fn`: exluding the output of `device_factory`. - `device_factory`: e.g. `x`, `r`, `theta`, etc. - `Optimizer`: e.g. `euclidean_lr`. - `Optimizer.minimize`: excluding `cost_fn` and `by_optimizing`, e.g. `max_steps`. Returns: dict: A result dict summarizing the optimized circuit, cost, metrics and/or input configs. """ setting_updates, kwargs = update_pop(settings, **kwargs) input_kwargs = kwargs.copy() if return_kwargs else {} device, kwargs = ( curry_pop(device_factory, **kwargs) if callable(device_factory) else ([], kwargs) ) device = [device] if not isinstance(device, (Sequence, Mapping)) else device cost_fn, kwargs, optimized = _apply_partial_cost(device, cost_fn, **kwargs) opt = None if optimized and not skip_opt: opt, kwargs = curry_pop(Optimizer, **kwargs) _, kwargs = curry_pop( opt.minimize, **{"cost_fn": cost_fn, "by_optimizing": optimized}, **kwargs ) if kwargs: warnings.warn(f"Unused kwargs: {kwargs}") final_cost = cost_fn() results = { "cost": np.array(final_cost).item(), "device": device, "optimizer": opt, } if callable(metric_fns): results["metrics"] = metric_fns(*device) elif isinstance(metric_fns, Sequence): results["metrics"] = [f(*device) for f in metric_fns if callable(f)] elif isinstance(metric_fns, Mapping): results = { **results, **{k: f(*device) for k, f in metric_fns.items() if callable(f)}, } return { **({"tag": tag} if tag is not None else {}), **results, **input_kwargs, **setting_updates, }
def _iter_futures(futures): """Make ray futures iterable for easy passing to a progress bar. Hacky: https://github.com/ray-project/ray/issues/5554 """ import ray # pylint: disable=import-outside-toplevel while futures: done, futures = ray.wait(futures) yield ray.get(done[0])
[docs] def map_trainer(trainer=train_device, tasks=1, pbar=True, unblock=False, num_cpus=None, **kwargs): """Maps multiple training tasks across multiple workers using `ray`. In practice, the most common use case is to ignore the keywords `trainer` (as it defaults to :meth:`train_device`), `pbar`, `unblock`, etc. and just concentrate on `tasks` and `**kwargs` which passes arguments to the wrapper functions that contain the task execution logic, as well as the :class:`Optimizer` and its :meth:`Optimizer.minimize`. For example, with the default `trainer` :meth:`train_device`, two user-defined functions are used for wrapping up user logic: * A `device_factory` (optional) that wraps around the logic for making circuits/states to be optimized; it is expected to return a single, or list of, :class:`Circuit`(s). * A `cost_fn` (required) that takes the circuits made and additional keyword arguments and returns a backprop-able scalar cost. Refer to the `kwargs` section below for more available options. Args: trainer (callable): The function containing the training loop to be distributed, whose fixed arguments are to be passed by `**kwargs` and task-specific arguments iterated through `tasks`. Provide only when custom evaluation/training logic is needed. Defaults to :meth:`train_device`. tasks (Union[int, Sequence, Mapping]): Number of repeats or collection of task-specific training config arguments feeding into :meth:`train_device`. Refer to `kwargs` below for the available options. Defaults to 1 which runs `trainer` exactly once. pbar (bool): Whether to show a progress bar, available only in blocking mode (i.e. `unblock==False`). Defaults to True. unblock (bool): Whether to unblock the process and returns a getter function returning the available results. Defaults to False. num_cpus (int): Number of cpu workers to initialize ray. Defaults to the number of virtual cores. kwargs: Additional arguments containing fixed training config kwargs feeding into `trainer`. For the default `trainer` :meth:`train_device`, available options are: - cost_fn (callable): The optimized cost function to be distributed. It's expected to accept the output of `device_factory` as *args as well as user-defined **kwargs, and returns a scalar cost. Its user-defined **kwargs will be passed from this function's **kwargs which must include all its required arguments. - device_factory (callable): Function that (partially) takes `kwargs` and returns a device, or list/dict of devices. If None, `cost_fn` will be assumed to take no positional argument (for example, when device-making is contained in `cost_fn`). Defaults to None. - metric_fns (Union[Sequence[callable], Mapping[callable], callable]): Optional collection of functions that takes the output of `device_factory` after optimization and returns arbitrary evaluation/information. - return_kwargs (bool): Whether to include input config `kwargs` in the output dict. Defualts to True. - skip_opt (bool): Whether to skip the optimization and directly calculate cost. - tag (str): Optional label of the training task associated with the `kwargs` to be included in the output dict. - any kwargs to `cost_fn`: exluding the output of `device_factory`. - any kwargs to `device_factory`: e.g. `x`, `r`, `theta`, etc. - any kwargs to `Optimizer`: e.g. `euclidean_lr`. - any kwargs to `Optimizer.minimize`: excluding `cost_fn` and `by_optimizing`, e.g. `max_steps`. Returns Union[List, Dict]: The collection of results from each training task. Returns - a list if `tasks` is provided as an int or a list; or - a dict with the same keys if `tasks` is provided as a dict. Examples: ========= .. code-block:: from mrmustard.lab import Vacuum, Dgate, Ggate, Gaussian from mrmustard.physics import fidelity from mrmustard.training.trainer import map_trainer def make_circ(x=0.): return Ggate(num_modes=1, symplectic_trainable=True) >> Dgate(x=x, x_trainable=True, y_trainable=True) def cost_fn(circ=make_circ(0.1), y_targ=0.): target = Gaussian(1) >> Dgate(-1.5, y_targ) s = Vacuum(1) >> circ return -fidelity(s, target) # Use case 0: Calculate the cost of a randomly initialized circuit 5 times without optimizing it. results_0 = map_trainer( cost_fn=cost_fn, tasks=5, ) # Use case 1: Run circuit optimization 5 times on randomly initialized circuits. results_1 = map_trainer( cost_fn=cost_fn, device_factory=make_circ, tasks=5, max_steps=50, symplectic_lr=0.05, ) # Use case 2: Run 2 sets of circuit optimization with custom parameters passed as list. results_2 = map_trainer( cost_fn=cost_fn, device_factory=make_circ, tasks=[ {'x': 0.1, 'euclidean_lr': 0.005, 'max_steps': 50}, {'x': -0.7, 'euclidean_lr': 0.1, 'max_steps': 2}, ], y_targ=0.35, symplectic_lr=0.05, AUTOCUTOFF_MAX_CUTOFF=7, ) # Use case 3: Run 2 sets of circuit optimization with custom parameters passed as dict with extra metric functions for evaluating the final optimized circuit. results_3 = map_trainer( cost_fn=cost_fn, device_factory=make_circ, tasks={ 'my-job': {'x': 0.1, 'euclidean_lr': 0.005, 'max_steps': 50}, 'my-other-job': {'x': -0.7, 'euclidean_lr': 0.1, 'max_steps': 2}, }, y_targ=0.35, symplectic_lr=0.05, metric_fns={ 'is_gaussian': lambda c: c.is_gaussian, 'foo': lambda _: 17. }, ) """ try: import ray # pylint: disable=import-outside-toplevel except ImportError as e: raise ImportError( "Failed to import `ray` which is an extra dependency. Please install with `pip install -e .[ray]`." ) from e if not ray.is_initialized(): # pragma: no cover ray.init(num_cpus=num_cpus) return_dict = False if isinstance(tasks, int): tasks = [{} for _ in range(tasks)] elif isinstance(tasks, Mapping): return_dict = True tasks = [{"tag": tag, **task} for tag, task in tasks.items()] remote_trainer, kwargs = partial_pop( ray.remote(trainer).remote, **kwargs, ) if isinstance(tasks, Sequence): promises = [ curry_pop( remote_trainer, **task_kwargs, )[0] for task_kwargs in tasks if isinstance(task_kwargs, Mapping) ] else: raise ValueError( f"`tasks` is expected to be of type int, list, or dict. got {type(tasks)}: {tasks}" ) if not unblock: # blocks and wait till all tasks complete to return the end results. if pbar: results = list( track( _iter_futures(promises), description=f"{len(promises)} tasks running...", total=len(promises), ) ) else: results = ray.get(promises) if return_dict: return {r["tag"]: r for r in results} else: return results else: # does not block and returns a getter function that returns the available results so far. def get_avail_results(): results, running_tasks = ray.wait( # pylint: disable=unused-variable promises, num_returns=len(promises) ) if return_dict: return {r["tag"]: r for r in ray.get(results)} else: return ray.get(results) return get_avail_results
[docs] def kwargs_of(fn): """Gets the kwarg signature of a callable.""" params = signature(fn).parameters kwarg_kinds = [Parameter.POSITIONAL_OR_KEYWORD, Parameter.KEYWORD_ONLY] keywords = [k for k, p in params.items() if p.kind in kwarg_kinds] has_var_keyword = any(p.kind is Parameter.VAR_KEYWORD for p in params.values()) return keywords, has_var_keyword
[docs] def partial_pop(fn, *args, **kwargs): """Partially applies known kwargs to fn and returns the rest.""" keywords, has_var_keyword = kwargs_of(fn) known_kwargs = {k: kwargs.pop(k) for k in set(kwargs).intersection(keywords)} partial_fn = partial(fn, *args, **known_kwargs, **(kwargs if has_var_keyword else {})) return partial_fn, kwargs
[docs] def curry_pop(fn, *args, **kwargs): """A poor man's reader monad bind function.""" partial_fn, kwargs = partial_pop(fn, *args, **kwargs) return partial_fn(), kwargs
[docs] def update_pop(obj, **kwargs): """Updates an object/dict while popping keys out and returns the updated dict and remaining kwargs.""" updated = {} if isinstance(obj, Mapping): for k in set(kwargs).intersection(obj): obj[k] = kwargs.pop(k) updated[k] = obj[k] else: for k in set(kwargs).intersection(dir(obj)): setattr(obj, k, kwargs.pop(k)) updated[k] = getattr(obj, k) return updated, kwargs