mm.training.trainer.train_device¶
- mrmustard.training.trainer.train_device(cost_fn, device_factory=None, metric_fns=None, return_kwargs=True, skip_opt=False, tag=None, **kwargs)[source]¶
A general and flexible training loop for circuit optimizations with configurations adjustable through kwargs.
- Parameters:
cost_fn (callable) – The optimized cost function to be distributed. It’s expected to accept the output of device_factory as *args as well as user-defined **kwargs, and returns a scalar cost. Its user-defined **kwargs will be passed from this function’s **kwargs which must include all its required arguments.
device_factory (callable) – Function that (partially) takes kwargs and returns a device, or list/dict of devices. If None, cost_fn will be assumed to take no positional argument (for example, when device-making is contained in cost_fn). Defaults to None.
metric_fns (Union[Sequence[callable], Mapping[callable], callable]) – Optional collection of functions that takes the output of device_factory after optimization and returns arbitrary evaluation/information.
return_kwargs (bool) – Whether to include input config kwargs in the output dict. Defualts to True.
skip_opt (bool) – Whether to skip the optimization and directly calculate cost.
tag (str) – Optional label of the training task associated with the kwargs to be included in the output dict.
kwargs –
- Dict containing all arguments to any of the functions below:
cost_fn: exluding the output of device_factory.
device_factory: e.g. x, r, theta, etc.
Optimizer: e.g. euclidean_lr.
Optimizer.minimize: excluding cost_fn and by_optimizing, e.g. max_steps.
- Returns:
A result dict summarizing the optimized circuit, cost, metrics and/or input configs.
- Return type:
dict