Source code for mrmustard.lab_dev.circuit_components

# Copyright 2023 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
A base class for the components of quantum circuits.
"""

# pylint: disable=super-init-not-called, protected-access

from __future__ import annotations

from typing import Iterable, Optional, Sequence, Union

import os
import numpy as np

from IPython.display import display, HTML
from mako.template import Template

from ..utils.typing import Scalar
from ..physics.converters import to_fock
from ..physics.representations import Representation, Bargmann, Fock
from ..math.parameter_set import ParameterSet
from ..math.parameters import Constant, Variable
from .wires import Wires

__all__ = ["CircuitComponent", "AdjointView", "DualView"]


[docs] class CircuitComponent: r""" A base class for the components (states, transformations, and measurements, or potentially unphysical ``wired'' objects) that can be placed in Mr Mustard's quantum circuits. Args: name: The name of this component. representation: A representation for this circuit component. modes_out_bra: The output modes on the bra side of this component. modes_in_bra: The input modes on the bra side of this component. modes_out_ket: The output modes on the ket side of this component. modes_in_ket: The input modes on the ket side of this component. """ def __init__( self, name: Optional[str] = None, representation: Optional[Bargmann | Fock] = None, modes_out_bra: Optional[Sequence[int]] = None, modes_in_bra: Optional[Sequence[int]] = None, modes_out_ket: Optional[Sequence[int]] = None, modes_in_ket: Optional[Sequence[int]] = None, ) -> None: modes_out_bra = modes_out_bra or () modes_in_bra = modes_in_bra or () modes_out_ket = modes_out_ket or () modes_in_ket = modes_in_ket or () self._wires = Wires( set(modes_out_bra), set(modes_in_bra), set(modes_out_ket), set(modes_in_ket) ) self._name = name or "CC" + "".join(str(m) for m in sorted(self.wires.modes)) self._parameter_set = ParameterSet() self._representation = representation # handle out-of-order modes ob = tuple(sorted(modes_out_bra)) ib = tuple(sorted(modes_in_bra)) ok = tuple(sorted(modes_out_ket)) ik = tuple(sorted(modes_in_ket)) if ob != modes_out_bra or ib != modes_in_bra or ok != modes_out_ket or ik != modes_in_ket: offsets = [len(ob), len(ob) + len(ib), len(ob) + len(ib) + len(ok)] perm = ( tuple(np.argsort(modes_out_bra)) + tuple(np.argsort(modes_in_bra) + offsets[0]) + tuple(np.argsort(modes_out_ket) + offsets[1]) + tuple(np.argsort(modes_in_ket) + offsets[2]) ) if self._representation: self._representation = self._representation.reorder(tuple(perm)) @classmethod def _from_attributes( cls, name: str, representation: Representation, wires: Wires ) -> CircuitComponent: r""" Initializes a circuit component from its attributes (a name, a ``Wires``, and a ``Representation``). If the Method Resolution Order (MRO) of ``cls`` contains one between ``Ket``, ``DM``, ``Unitary``, and ``Channel``, then the returned component is of that type. Otherwise, it is of type ``CircuitComponent``. This function needs to be used with caution, as it does not check that the attributes provided are consistent with the type of the returned component. If used improperly it may initialize, e.g., ``Ket``s with both input and output wires or ``Unitary``s with wires on the bra side. Args: name: The name of this component. representation: A representation for this circuit component. wires: The wires of this component. Returns: A circuit component of type ``cls`` with the given attributes. """ types = {"Ket", "DM", "Unitary", "Channel"} for tp in cls.mro(): if tp.__name__ in types: ret = tp() break else: ret = CircuitComponent() ret._name = name ret._representation = representation ret._wires = wires return ret def _add_parameter(self, parameter: Union[Constant, Variable]): r""" Adds a parameter to this circuit component. Args: parameter: The parameter to add. Raises: ValueError: If the length of the given parameter is incompatible with the number of modes. """ if parameter.value.shape != (): if len(parameter.value) != 1 and len(parameter.value) != len(self.modes): msg = f"Length of ``{parameter.name}`` must be 1 or {len(self.modes)}." raise ValueError(msg) self.parameter_set.add_parameter(parameter) self.__dict__[parameter.name] = parameter @property def bargmann(self) -> tuple: r""" The Bargmann parametrization of this circuit component, if available. """ if not isinstance(self.representation, Bargmann): raise ValueError( f"Cannot compute triple from representation of type ``{self.representation.__class__.__qualname__}``." ) return self.representation.triple @property def representation(self) -> Representation | None: r""" A representation of this circuit component. """ return self._representation @property def modes(self) -> list[int]: r""" The sorted list of modes of this component. """ return sorted(self.wires.modes) @property def n_modes(self) -> list[int]: r""" The number of modes in this component. """ return len(self.modes) @property def name(self) -> str: r""" The name of this component. """ return self._name @property def parameter_set(self) -> ParameterSet: r""" The set of parameters characterizing this component. """ return self._parameter_set @property def wires(self) -> Wires: r""" The wires of this component. """ return self._wires @property def adjoint(self) -> AdjointView: r""" The ``AdjointView`` of this component. """ return AdjointView(self) @property def dual(self) -> DualView: r""" The ``DualView`` of this component. """ return DualView(self)
[docs] def light_copy(self) -> CircuitComponent: r""" Creates a copy of this component by copying every data stored in memory for it by reference, except for its wires, which are copied by value. """ instance = super().__new__(self.__class__) instance.__dict__ = self.__dict__.copy() instance.__dict__["_wires"] = Wires(*self.wires.args) return instance
[docs] def on(self, modes: Sequence[int]) -> CircuitComponent: r""" Creates a copy of this component that acts on the given ``modes`` instead of on the original modes. Args: modes: The new modes that this component acts on. Returns: The component acting on the specified modes. Raises: ValueError: If ``modes`` contains more or less modes than the original component. """ modes = set(modes) ob = self.wires.output.bra ib = self.wires.input.bra ok = self.wires.output.ket ik = self.wires.input.ket for subset in [ob, ib, ok, ik]: if subset and len(subset.modes) != len(modes): msg = f"Expected ``{len(modes)}`` modes, found ``{len(subset.modes)}``." raise ValueError(msg) wires = Wires( modes_out_bra=modes if ob else set(), modes_in_bra=modes if ib else set(), modes_out_ket=modes if ok else set(), modes_in_ket=modes if ik else set(), ) ret = self.light_copy() ret._wires = wires return ret
[docs] def to_fock_component( self, shape: Optional[Union[int, Iterable[int]]] = None ) -> CircuitComponent: r""" Returns a circuit component with the same attributes as this component, but with ``Fock`` representation. Uses the :meth:`mrmustard.physics.converters.to_fock` method to convert the internal representation. .. code-block:: >>> from mrmustard.physics.converters import to_fock >>> from mrmustard.lab_dev import Dgate >>> d = Dgate([1], x=0.1, y=0.1) >>> d_fock = d.to_fock_component(shape=3) >>> assert d_fock.name == d.name >>> assert d_fock.wires == d.wires >>> assert d_fock.representation == to_fock(d.representation, shape=3) Args: shape: The shape of the returned representation. If ``shape``is given as an ``int``, it is broadcasted to all the dimensions. If ``None``, it defaults to the value of ``AUTOCUTOFF_MAX_CUTOFF`` in the settings. """ return self.__class__._from_attributes( self.name, to_fock(self.representation, shape=shape), self.wires, )
def __add__(self, other: CircuitComponent) -> CircuitComponent: r""" Implements the addition between circuit components. """ if self.wires != other.wires: msg = "Cannot add components with different wires." raise ValueError(msg) rep = self.representation + other.representation name = self.name if self.name == other.name else "" return self._from_attributes(name, rep, self.wires) def __sub__(self, other: CircuitComponent) -> CircuitComponent: r""" Implements the subtraction between circuit components. """ if self.wires != other.wires: msg = "Cannot subtract components with different wires." raise ValueError(msg) rep = self.representation - other.representation name = self.name if self.name == other.name else "" return self._from_attributes(name, rep, self.wires) def __mul__(self, other: Scalar) -> CircuitComponent: r""" Implements the multiplication by a scalar on the right. """ return self._from_attributes(self.name, self.representation * other, self.wires) def __rmul__(self, other: Scalar) -> CircuitComponent: r""" Implements the multiplication by a scalar on the left. """ return self.__mul__(other) def __truediv__(self, other: Scalar) -> CircuitComponent: r""" Implements the division by a scalar for circuit components. """ return self._from_attributes(self.name, self.representation / other, self.wires) def __eq__(self, other) -> bool: r""" Whether this component is equal to another component. Compares representations and wires, but not the other attributes (including name and parameter set). """ return self.representation == other.representation and self.wires == other.wires def _matmul_indices(self, other: CircuitComponent) -> tuple[tuple[int, ...], tuple[int, ...]]: r""" Finds the indices of the wires being contracted on the bra and ket sides of the components. """ # find the indices of the wires being contracted on the bra side bra_modes = tuple(self.wires.bra.output.modes & other.wires.bra.input.modes) idx_z = self.wires.bra.output[bra_modes].indices idx_zconj = other.wires.bra.input[bra_modes].indices # find the indices of the wires being contracted on the ket side ket_modes = tuple(self.wires.ket.output.modes & other.wires.ket.input.modes) idx_z += self.wires.ket.output[ket_modes].indices idx_zconj += other.wires.ket.input[ket_modes].indices return idx_z, idx_zconj def __matmul__(self, other: CircuitComponent) -> CircuitComponent: r""" Contracts ``self`` and ``other``, without adding adjoints. """ wires_ret, perm = self.wires @ other.wires idx_z, idx_zconj = self._matmul_indices(other) representation_ret = self.representation[idx_z] @ other.representation[idx_zconj] representation_ret = representation_ret.reorder(perm) if perm else representation_ret return CircuitComponent._from_attributes(None, representation_ret, wires_ret) def __rshift__(self, other: CircuitComponent) -> CircuitComponent: r""" Contracts ``self`` and ``other`` as it would in a circuit, adding the adjoints when they are missing. """ msg = f"``>>`` not supported between {self} and {other}, use ``@``." wires_s = self.wires wires_o = other.wires if wires_s.ket and wires_s.bra: if wires_o.ket and wires_o.bra: return self @ other return self @ other @ other.adjoint if wires_s.ket: if wires_o.ket and wires_o.bra: return self @ self.adjoint @ other if wires_o.ket: return self @ other raise ValueError(msg) if wires_s.bra: if wires_o.ket and wires_o.bra: return self @ self.adjoint @ other if wires_o.bra: return self @ other raise ValueError(msg) raise ValueError(msg) def __repr__(self) -> str: return f"CircuitComponent(name={self.name or None}, modes={self.modes})" def _repr_html_(self): # pragma: no cover temp = Template(filename=os.path.dirname(__file__) + "/assets/circuit_components.txt") wires_temp = Template(filename=os.path.dirname(__file__) + "/assets/wires.txt") wires_temp_uni = wires_temp.render_unicode(wires=self.wires) wires_temp_uni = ( wires_temp_uni.replace("<body>", "").replace("</body>", "").replace("h1", "h3") ) rep_temp = ( Template(filename=os.path.dirname(__file__) + "/../physics/assets/fock.txt") if isinstance(self.representation, Fock) else Template(filename=os.path.dirname(__file__) + "/../physics/assets/bargmann.txt") ) rep_temp_uni = rep_temp.render_unicode(rep=self.representation) rep_temp_uni = rep_temp_uni.replace("<body>", "").replace("</body>", "").replace("h1", "h3") display(HTML(temp.render(comp=self, wires=wires_temp_uni, rep=rep_temp_uni)))
class CCView(CircuitComponent): r"""A base class for views of circuit components. Args: component: The circuit component to take the view of. """ def __init__(self, component: CircuitComponent) -> None: self.__dict__ = component.__dict__.copy() self._component = component.light_copy() def __getattr__(self, name): r"""send calls to the component""" return getattr(self._component, name) def __repr__(self) -> str: return repr(self._component)
[docs] class AdjointView(CCView): r""" Adjoint view of a circuit component. Args: component: The circuit component to take the view of. """ @property def adjoint(self) -> CircuitComponent: r""" Returns a light-copy of the component that was used to generate the view. """ return self._component.light_copy() @property def representation(self): r""" A representation of this circuit component. """ bras = self._component.wires.bra.indices kets = self._component.wires.ket.indices return self._component.representation.reorder(kets + bras).conj() @property def wires(self): r""" The ``Wires`` in this component. """ return self._component.wires.adjoint
[docs] class DualView(CCView): r""" Dual view of a circuit component. Args: component: The circuit component to take the view of. """ @property def dual(self) -> CircuitComponent: r""" Returns a light-copy of the component that was used to generate the view. """ return self._component.light_copy() @property def representation(self): r""" A representation of this circuit component. """ ok = self._component.wires.ket.output.indices ik = self._component.wires.ket.input.indices ib = self._component.wires.bra.input.indices ob = self._component.wires.bra.output.indices return self._component.representation.reorder(ib + ob + ik + ok).conj() @property def wires(self): r""" The ``Wires`` in this component. """ return self._component.wires.dual