Source code for mrmustard.lab.samplers

# Copyright 2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Samplers for measurement devices."""

from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Sequence
from itertools import product
from typing import Any

import numpy as np

from mrmustard import math, settings

from .circuit_components import CircuitComponent
from .circuit_components_utils import BtoQ
from .states import Ket, Number, State

__all__ = ["HomodyneSampler", "PNRSampler", "Sampler"]


[docs] class Sampler(ABC): r"""A sampler for measurements of quantum circuits. Args: meas_outcomes: The measurement outcomes for this sampler. povms: The (optional) POVMs of this sampler. """ def __init__( self, meas_outcomes: Sequence[Any], povms: CircuitComponent | Sequence[CircuitComponent] | None = None, ): self._povms = povms self._meas_outcomes = meas_outcomes self._outcome_arg = None @property def povms(self) -> CircuitComponent | Sequence[CircuitComponent] | None: r"""The POVMs of this sampler.""" return self._povms @property def meas_outcomes(self) -> Sequence[Any]: r"""The measurement outcomes of this sampler.""" return self._meas_outcomes
[docs] @abstractmethod def probabilities(self, state: State, atol: float = 1e-4) -> Sequence[float]: r"""Returns the probability distribution of a state w.r.t. measurement outcomes. Args: state: The state to generate the probability distribution of. Note: the input state must be normalized. atol: The absolute tolerance used for validating that the computed probability distribution sums to ``1``. """
[docs] def sample(self, state: State, n_samples: int = 1000, seed: int | None = None) -> np.ndarray: r"""Returns an array of samples given a state. Args: state: The state to sample. n_samples: The number of samples to generate. seed: An optional seed for random sampling. Returns: An array of samples such that the shape is ``(n_samples, n_modes)``. """ if len(state.modes) == 1: return self.sample_prob_dist(state, n_samples, seed)[0] initial_mode = state.modes[0] initial_samples, probs = self.sample_prob_dist( state.get_modes(initial_mode), n_samples, seed ) unique_samples, idxs, counts = np.unique( initial_samples, return_index=True, return_counts=True, ) ret = [] for unique_sample, idx, count in zip(unique_samples, idxs, counts): meas_op = self._get_povm(unique_sample, initial_mode).dual prob = probs[idx] norm = math.sqrt(prob) if isinstance(state, Ket) else prob reduced_state = (state >> meas_op) / norm samples = self.sample(reduced_state, count) ret.extend(np.append([unique_sample], sample) for sample in samples) return np.array(ret)
[docs] def sample_prob_dist( self, state: State, n_samples: int = 1000, seed: int | None = None, ) -> tuple[np.ndarray, np.ndarray]: r"""Samples a state by computing the probability distribution. Args: state: The state to sample. n_samples: The number of samples to generate. seed: An optional seed for random sampling. Returns: A tuple of the generated samples and the probability of obtaining the sample. """ rng = settings.get_rng(seed) probs = self.probabilities(state) meas_outcomes = list(product(self.meas_outcomes, repeat=len(state.modes))) samples = rng.choice( a=meas_outcomes, p=probs, size=n_samples, ) return samples, np.array([probs[meas_outcomes.index(tuple(sample))] for sample in samples])
def _get_povm(self, meas_outcome: Any, mode: int) -> CircuitComponent: r"""Returns the POVM associated with a given outcome on a given mode. Args: meas_outcome: The measurement outcome. mode: The mode. Returns: The POVM circuit component. Raises: ValueError: If this sampler has no POVMs. """ if self._povms is None: raise ValueError("This sampler has no POVMs defined.") if isinstance(self.povms, CircuitComponent): kwargs = self.povms.parameters kwargs[self._outcome_arg] = meas_outcome return self.povms.__class__(mode, **kwargs) return self.povms[self.meas_outcomes.index(meas_outcome)].on([mode]) def _validate_probs(self, probs: Sequence[float], atol: float) -> Sequence[float]: r"""Validates that the given probability distribution sums to ``1`` within some tolerance and returns a renormalized probability distribution to account for small numerical errors. Args: probs: The probability distribution to validate. atol: The absolute tolerance to validate with. """ atol = atol or settings.ATOL probs = math.abs(probs) prob_sum = math.sum(probs) math.error_if( prob_sum, not math.allclose(prob_sum, 1, atol=atol), f"Probabilities sum to {prob_sum} and not 1.0.", ) return probs / prob_sum
[docs] class PNRSampler(Sampler): r"""A sampler for photon-number resolving (PNR) detectors. Args: cutoff: The photon number cutoff. """ def __init__(self, cutoff: int) -> None: super().__init__(list(range(cutoff + 1)), Number(0, 0, cutoff)) self._cutoff = cutoff self._outcome_arg = "n"
[docs] def probabilities(self, state, atol=1e-4): return self._validate_probs(state.fock_distribution(self._cutoff), atol)
[docs] class HomodyneSampler(Sampler): r"""A sampler for homodyne measurements. Args: phi: The quadrature angle where ``0`` corresponds to ``x`` and ``\pi/2`` to ``p``. bounds: The range of values to discretize over. num: The number of points to discretize over. """ def __init__( self, phi: float = 0, bounds: tuple[float, float] = (-10, 10), num: int = 1000, ) -> None: meas_outcomes, step = np.linspace(*bounds, num, retstep=True) super().__init__(list(meas_outcomes)) self._step = step self._phi = phi
[docs] def probabilities(self, state, atol=1e-4): probs = state.quadrature_distribution( math.astensor(self.meas_outcomes), phi=self._phi, # TODO: revisit meas_outcomes ) * self._step ** len(state.modes) return self._validate_probs(probs, atol)
[docs] def sample(self, state: State, n_samples: int = 1000, seed: int | None = None) -> np.ndarray: if len(state.modes) == 1: return self.sample_prob_dist(state, n_samples, seed)[0] initial_mode = state.modes[0] initial_samples, probs = self.sample_prob_dist( state.get_modes(initial_mode), n_samples, seed ) unique_samples, idxs, counts = np.unique( initial_samples, return_index=True, return_counts=True, ) btoq_ansatz = (state >> BtoQ([initial_mode], phi=self._phi)).ansatz ret = [] for unique_sample, idx, count in zip(unique_samples, idxs, counts): # Use partial_eval to evaluate the ansatz at the first mode only reduced_ansatz = btoq_ansatz(unique_sample) reduced_state = state.from_bargmann(state.modes[1:], reduced_ansatz.triple) prob = probs[idx] / self._step norm = math.sqrt(prob) if isinstance(state, Ket) else prob normalized_reduced_state = reduced_state / norm samples = self.sample(normalized_reduced_state, count) ret.extend(np.append([unique_sample], sample) for sample in samples) return np.array(ret)