Source code for mrmustard.lab.abstract.measurement
# Copyright 2021 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains the implementation of the class :class:`FockMeasurement`."""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Iterable, Sequence, Union
from mrmustard import math, settings
from mrmustard.math.parameter_set import ParameterSet
from mrmustard.math.parameters import Constant, Variable
from mrmustard.utils.typing import Tensor
from .state import State
[docs]
class Measurement(ABC):
"""this is an abstract class holding the common methods and properties that any measurement should
implement
Args:
outcome (optional, List[float] or Tensor): the result of the measurement
modes (List[int]): the modes on which the measurement is acting on
"""
def __init__(self, outcome: Tensor, modes: Sequence[int]) -> None:
super().__init__()
if modes is None:
raise ValueError(f"Modes not defined for {self.__class__.__name__}.")
self._modes = modes
self._is_postselected = False if outcome is None else True
"""used to evaluate if the measurement outcome should be
sampled or is already defined by the user (postselection)"""
self._parameter_set = ParameterSet()
def _add_parameter(self, parameter: Union[Constant, Variable]):
r"""
Adds a parameter to a transformation.
Args:
parameter: The parameter to add.
"""
self.parameter_set.add_parameter(parameter)
self.__dict__[parameter.name] = parameter
@property
def parameter_set(self):
r"""
The set of parameters for this transformation.
"""
return self._parameter_set
@property
def modes(self):
r"""returns the modes being measured"""
return self._modes
@property
def num_modes(self):
r"""returns the number of modes being measured"""
return len(self.modes)
@property
def postselected(self):
r"""returns whether the measurement is postselected, i.e, a outcome has been provided"""
return self._is_postselected
@property
@abstractmethod
def outcome(self):
"""Returns outcome of the measurement. If no measurement has been carried out returns `None`."""
...
@abstractmethod
def _measure_fock(self, other: State) -> Union[State, float]: ...
@abstractmethod
def _measure_gaussian(self, other: State) -> Union[State, float]: ...
[docs]
def primal(self, other: State) -> Union[State, float]:
"""performs the measurement procedure according to the representation of the incoming state"""
if other.is_gaussian:
return self._measure_gaussian(other)
return self._measure_fock(other)
def __lshift__(self, other) -> Union[State, float]:
if isinstance(other, State):
self.primal(other)
raise TypeError(
f"Cannot apply Measurement '{self.__qualname__}' to '{other.__qualname__}'."
)
def __getitem__(self, items) -> Measurement:
"""Assign modes via the getitem syntax: allows measurements to be used as
``output = meas[0,1](input)``, e.g. measuring modes 0 and 1.
"""
if isinstance(items, int):
modes = [items]
elif isinstance(items, slice):
modes = list(range(items.start, items.stop, items.step))
elif isinstance(items, (Sequence, Iterable)):
modes = list(items)
else:
raise ValueError(f"{items} is not a valid slice or list of modes.")
self._modes = modes
return self
[docs]
class FockMeasurement(Measurement):
"""A Fock measurement projecting onto a Fock measurement pattern.
It works by representing the state in the Fock basis and then applying a stochastic channel
matrix ``P(meas|n)`` to the Fock probabilities (belief propagation).
It outputs the measurement probabilities and the remaining post-measurement state (if any)
in the Fock basis.
"""
def __init__(self, outcome: Tensor, modes: Sequence[int], cutoffs: Sequence[int]) -> None:
self._cutoffs = cutoffs or [settings.PNR_INTERNAL_CUTOFF] * len(modes)
super().__init__(outcome, modes)
@property
def outcome(self):
raise NotImplementedError
def _measure_gaussian(self, other: State) -> Union[State, float]:
return self._measure_fock(other)
def _measure_fock(self, other: State) -> Union[State, float]:
r"""
Returns a tensor representing the post-measurement state in the unmeasured modes in the Fock basis.
The first `N` indices of the returned tensor correspond to the Fock measurements of the `N` modes that
the detector is measuring. The remaining indices correspond to the density matrix of the unmeasured modes.
Args
other (State): the quantum state
Returns
Tensor: a tensor representing the post-measurement state
"""
cutoffs = []
for mode in other.modes:
if mode in self._modes:
cutoffs.append(
max(settings.PNR_INTERNAL_CUTOFF, other.cutoffs[other.indices(mode)])
)
else:
cutoffs.append(other.cutoffs[other.indices(mode)])
if self.should_recompute_stochastic_channel() or math.any(
[c > settings.PNR_INTERNAL_CUTOFF for c in other.cutoffs]
):
self.recompute_stochastic_channel(cutoffs)
dm = other.dm(cutoffs)
for k, (mode, stoch) in enumerate(zip(self._modes, self._internal_stochastic_channel)):
# move the mode indices to the end
last = [mode - k, mode + other.num_modes - 2 * k]
perm = [m for m in range(dm.ndim) if m not in last] + last
dm = math.transpose(dm, perm)
# compute sum_m P(meas|m)rho_mm
dm = math.diag_part(dm)
dm = math.tensordot(dm, stoch[: self._cutoffs[k], : dm.shape[-1]], [[-1], [1]])
# put back the last len(self.modes) modes at the beginning
output = math.transpose(
dm,
list(range(dm.ndim - len(self._modes), dm.ndim))
+ list(range(dm.ndim - len(self._modes))),
)
if len(output.shape) == len(self._modes): # all modes are measured
output = math.real(output) # return probabilities
return output
# pylint: disable=no-self-use
[docs]
def should_recompute_stochastic_channel(self) -> bool: # override in subclasses
"""Returns `True` if the stochastic channel has to be recomputed.
This method should be overriden by subclasses as needed.
"""
return False
[docs]
def recompute_stochastic_channel(self, cutoffs: Sequence[int]) -> None:
"""Recomputes the stochastic channel.
This method should be overriden by subclasses as needed.
"""
raise NotImplementedError
_modules/mrmustard/lab/abstract/measurement
Download Python script
Download Notebook
View on GitHub