mm.training.trainer.train_device

mrmustard.training.trainer.train_device(cost_fn, device_factory=None, metric_fns=None, return_kwargs=True, skip_opt=False, tag=None, **kwargs)[source]

A general and flexible training loop for circuit optimizations with configurations adjustable through kwargs.

Parameters:
  • cost_fn (callable) – The optimized cost function to be distributed. It’s expected to accept the output of device_factory as *args as well as user-defined **kwargs, and returns a scalar cost. Its user-defined **kwargs will be passed from this function’s **kwargs which must include all its required arguments.

  • device_factory (callable) – Function that (partially) takes kwargs and returns a device, or list/dict of devices. If None, cost_fn will be assumed to take no positional argument (for example, when device-making is contained in cost_fn). Defaults to None.

  • metric_fns (Union[Sequence[callable], Mapping[callable], callable]) – Optional collection of functions that takes the output of device_factory after optimization and returns arbitrary evaluation/information.

  • return_kwargs (bool) – Whether to include input config kwargs in the output dict. Defualts to True.

  • skip_opt (bool) – Whether to skip the optimization and directly calculate cost.

  • tag (str) – Optional label of the training task associated with the kwargs to be included in the output dict.

  • kwargs

    Dict containing all arguments to any of the functions below:
    • cost_fn: exluding the output of device_factory.

    • device_factory: e.g. x, r, theta, etc.

    • Optimizer: e.g. euclidean_lr.

    • Optimizer.minimize: excluding cost_fn and by_optimizing, e.g. max_steps.

Returns:

A result dict summarizing the optimized circuit, cost, metrics and/or input configs.

Return type:

dict