Source code for mrmustard.lab_dev.wires
# Copyright 2023 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" ``Wires`` class for supporting tensor network functionalities."""
from __future__ import annotations
from functools import cached_property
from typing import Optional
import os
import numpy as np
from IPython.display import display, HTML
from mako.template import Template
__all__ = ["Wires"]
[docs]
class Wires:
r"""
A class with wire functionality for tensor network applications.
In MrMustard, instances of ``CircuitComponent`` have a ``Wires`` attribute.
The wires describe how they connect with the surrounding components in a tensor network picture,
where states flow from left to right. ``CircuitComponent``\s can have wires on the
bra and/or on the ket side. Here are some examples for the types of components available on
``mrmustard.lab_dev``:
.. code-block::
A channel acting on mode ``1`` has input and output wires on both ket and bra sides:
┌──────┐ 1 ╔═════════╗ 1 ┌───────┐
│Bra in│─────▶║ ║─────▶│Bra out│
└──────┘ ║ Channel ║ └───────┘
┌──────┐ 1 ║ ║ 1 ┌───────┐
│Ket in│─────▶║ ║─────▶│Ket out│
└──────┘ ╚═════════╝ └───────┘
A unitary acting on mode ``2`` has input and output wires on the ket side:
┌──────┐ 2 ╔═════════╗ 2 ┌───────┐
│Ket in│─────▶║ Unitary ║─────▶│Ket out│
└──────┘ ╚═════════╝ └───────┘
A density matrix representing the state of mode ``0`` has only output wires:
╔═════════╗ 0 ┌───────┐
║ ║─────▶│Bra out│
║ Density ║ └───────┘
║ Matrix ║ 0 ┌───────┐
║ ║─────▶│Ket out│
╚═════════╝ └───────┘
Also a ket representing the state of mode ``1`` has only output wires:
╔═════════╗ 1 ┌───────┐
║ Ket ║─────▶│Ket out│
╚═════════╝ └───────┘
The ``Wires`` class can then be used to create subsets of wires:
.. code-block::
>>> from mrmustard.lab_dev.wires import Wires
>>> modes_out_bra={0, 1}
>>> modes_in_bra={1, 2}
>>> modes_out_ket={0, 13}
>>> modes_in_ket={1, 2, 13}
>>> w = Wires(modes_out_bra, modes_in_bra, modes_out_ket, modes_in_ket)
>>> # all the modes
>>> modes = w.modes
>>> assert w.modes == {0, 1, 2, 13}
>>> # input/output modes
>>> assert w.input.modes == {1, 2, 13}
>>> assert w.output.modes == {0, 1, 13}
>>> # get ket/bra modes
>>> assert w.ket.modes == {0, 1, 2, 13}
>>> assert w.bra.modes == {0, 1, 2}
>>> # combined subsets
>>> assert w.output.ket.modes == {0, 13}
>>> assert w.input.bra.modes == {1, 2}
Here's a diagram of the original ``Wires`` object in the example above,
with the indices of the wires (the number in parenthesis) given in the "standard" order
(``bra_out``, ``bra_in``, ``ket_out``, ``ket_in``, and the modes in sorted increasing order):
.. code-block::
╔═════════╗
1 (2) ─────▶ ║ ║─────▶ 0 (0)
2 (3) ─────▶ ║ ║─────▶ 1 (1)
║ ║
║ ``Wires`` ║
1 (6) ─────▶ ║ ║
2 (7) ─────▶ ║ ║─────▶ 0 (4)
13 (8) ─────▶ ║ ║─────▶ 13 (5)
╚═════════╝
To access the index of a subset of wires in standard order we can use the ``indices``
property:
.. code-block::
>>> assert w.indices == (0,1,2,3,4,5,6,7,8)
>>> assert w.input.indices == (2,3,6,7,8)
Another important application of the ``Wires`` class is to contract the wires of two components.
This is done using the ``@`` operator. The result is a new ``Wires`` object that combines the wires
of the two components. Here's an example of a contraction of a single-mode density matrix going
into a single-mode channel:
.. code-block::
>>> rho = Wires(modes_out_bra={0}, modes_in_bra={0})
>>> Phi = Wires(modes_out_bra={0}, modes_in_bra={0}, modes_out_ket={0}, modes_in_ket={0})
>>> rho_out, perm = rho @ Phi
>>> assert rho_out.modes == {0}
Here's a diagram of the result of the contraction:
.. code-block::
╔═══════╗ ╔═══════╗
║ ║─────▶║ ║─────▶ 0
║ rho ║ ║ Phi ║
║ ║─────▶║ ║─────▶ 0
╚═══════╝ ╚═══════╝
The permutation that takes the contracted representations to the standard order is also returned.
Args:
modes_out_bra: The output modes on the bra side.
modes_in_bra: The input modes on the bra side.
modes_out_ket: The output modes on the ket side.
modes_in_ket: The input modes on the ket side.
"""
def __init__(
self,
modes_out_bra: Optional[set[int]] = None,
modes_in_bra: Optional[set[int]] = None,
modes_out_ket: Optional[set[int]] = None,
modes_in_ket: Optional[set[int]] = None,
) -> None:
self.args: tuple[set, ...] = (
modes_out_bra or set(),
modes_in_bra or set(),
modes_out_ket or set(),
modes_in_ket or set(),
)
# The "parent" wires object, if any. This is ``None`` for freshly initialized
# wires objects, and ``not None`` for subsets.
self._original = None
# Adds elements to the cache when calling ``__getitem__``
self._mode_cache = {}
@cached_property
def id(self) -> int:
r"""
A numerical identifier for this ``Wires`` object.
The ``id`` are random and unique, and are preserved when taking subsets.
"""
if self.original:
return self.original.id
return np.random.randint(0, 2**32)
@cached_property
def ids(self) -> list[int]:
r"""
A list of numerical identifier for the wires in this ``Wires`` object, in
the standard order.
The ``ids`` are derived incrementally from the ``id`` and are unique.
.. code-block::
>>> w = Wires(modes_in_ket = {0,1}, modes_out_ket = {0,1})
>>> id = w.id
>>> ids = w.ids
>>> assert ids == [id, id+1, id+2, id+3]
"""
if self.original:
return [self.original.ids[i] for i in self.indices]
return [id for d in self.ids_dicts for id in d.values()]
@cached_property
def indices(self) -> tuple[int, ...]:
r"""
The array of indices of this ``Wires`` in the standard order.
When a subset is selected (e.g. ``.ket``), it doesn't include wires that do not belong
to the subset, but it still counts them because indices refer to the original modes.
.. code-block::
>>> w = Wires(modes_in_ket = {0,1}, modes_out_ket = {0,1})
>>> assert w.indices == (0,1,2,3)
>>> assert w.input.indices == (2,3)
"""
return tuple(
self.index_dicts[t][m] for t, modes in enumerate(self.sorted_args) for m in modes
)
@cached_property
def index_dicts(self) -> list[dict[int, int]]:
r"""
A list of dictionary mapping modes to indices, one for each of the subsets
(``output.bra``, ``input.bra``, ``output.ket``, and ``input.ket``).
If subsets are taken, ``index_dicts`` refers to the parent object rather than to the
child.
"""
if self.original:
return self.original.index_dicts
return [
{m: i + sum(len(s) for s in self.args[:t]) for i, m in enumerate(lst)}
for t, lst in enumerate(self.sorted_args)
]
@cached_property
def ids_dicts(self) -> list[dict[int, int]]:
r"""
A list of dictionary mapping modes to ``ids``, one for each of the subsets
(``output.bra``, ``input.bra``, ``output.ket``, and ``input.ket``).
If subsets are taken, ``ids_dicts`` refers to the parent object rather than to the
child.
"""
if self.original:
return self.original.ids_dicts
return [{m: i + self.id for m, i in d.items()} for d in self.index_dicts]
@cached_property
def sorted_args(self) -> tuple[list[int], ...]:
r"The sorted arguments. Allows to sort them only once."
return tuple(sorted(s) for s in self.args)
@property
def original(self):
r"""
The parent wire, if any.
"""
return self._original
@cached_property
def modes(self) -> set[int]:
r"The modes spanned by the wires."
return set.union(*self.args)
@cached_property
def input(self) -> Wires:
r"New ``Wires`` object without output wires."
ret = Wires(set(), self.args[1], set(), self.args[3])
ret._original = self.original or self # pylint: disable=protected-access
return ret
@cached_property
def output(self) -> Wires:
r"New ``Wires`` object without input wires."
ret = Wires(self.args[0], set(), self.args[2], set())
ret._original = self.original or self # pylint: disable=protected-access
return ret
@cached_property
def ket(self) -> Wires:
r"New ``Wires`` object without bra wires."
ret = Wires(set(), set(), self.args[2], self.args[3])
ret._original = self.original or self # pylint: disable=protected-access
return ret
@cached_property
def bra(self) -> Wires:
r"New ``Wires`` object without ket wires."
ret = Wires(self.args[0], self.args[1], set(), set())
ret._original = self.original or self # pylint: disable=protected-access
return ret
@cached_property
def adjoint(self) -> Wires:
r"New ``Wires`` object obtained by swapping ket and bra wires."
return Wires(self.args[2], self.args[3], self.args[0], self.args[1])
@cached_property
def dual(self) -> Wires:
r"New ``Wires`` object obtained by swapping input and output wires."
return Wires(self.args[1], self.args[0], self.args[3], self.args[2])
def __getitem__(self, modes: tuple[int, ...] | int) -> Wires:
r"New ``Wires`` object with wires only on the given modes."
modes_set = {modes} if isinstance(modes, int) else set(modes)
if modes not in self._mode_cache:
w = Wires(*(self.args[t] & modes_set for t in (0, 1, 2, 3)))
w._original = self.original or self
self._mode_cache[modes] = w
return self._mode_cache[modes]
def __add__(self, other: Wires) -> Wires:
r"""
New ``Wires`` object that combines the wires of ``self`` and those of ``other``.
Raises:
ValueError: If any leftover wires would overlap.
"""
new_args = []
for t, (m1, m2) in enumerate(zip(self.args, other.args)):
if m := m1 & m2:
raise ValueError(f"{t}-type wires overlap at mode {m}")
new_args.append(m1 | m2)
return Wires(*new_args)
def __bool__(self) -> bool:
r"Returns ``True`` if this ``Wires`` object has any wires, ``False`` otherwise."
return any(self.args)
def __eq__(self, other) -> bool:
return self.args == other.args
def __matmul__(self, other: Wires) -> tuple[Wires, list[int]]:
r"""
Returns the wires of the circuit composition of self and other without adding missing
adjoints. It also returns the permutation that takes the contracted representations
to the standard order. An exception is raised if any leftover wires would overlap.
Consider the following example:
.. code-block::
╔═══════╗ ╔═══════╗
B───║ self ║───A D───║ other ║───C
b───║ ║───a d───║ ║───c
╚═══════╝ ╚═══════╝
B and D-A must not overlap, same for b and d-a, etc. The result is a new ``Wires`` object
.. code-block::
╔═══════╗
B|(D-A)────║self @ ║────C|(A-D)
b|(d-a)────║ other ║────c|(a-d)
╚═══════╝
In comparison, contracting the representations rather than the wires corresponds to
an order where we start from juxtaposing the objects and then removing pairs of contracted
indices, i.e. A-D, B, C, D-A and then the same for a-d, b, c, d-a. The returned permutation
is the one that takes the result of multiplying representations to the standard order.
Args:
other: The wires of the other circuit component.
Returns:
The wires of the circuit composition and the permutation.
Raises:
ValueError: If any leftover wires would overlap.
"""
if self.original or other.original:
raise ValueError("cannot contract a subset of wires")
A, B, a, b = self.args
C, D, c, d = other.args
sets = (A - D, B, a - d, b, C, D - A, c, d - a)
if m := sets[0] & sets[4]:
raise ValueError(f"output bra modes {m} overlap")
if m := sets[1] & sets[5]:
raise ValueError(f"input bra modes {m} overlap")
if m := sets[2] & sets[6]:
raise ValueError(f"output ket modes {m} overlap")
if m := sets[3] & sets[7]:
raise ValueError(f"input ket modes {m} overlap")
bra_out = sets[0] | sets[4] # (self.output.bra - other.input.bra) | other.output.bra
bra_in = sets[1] | sets[5] # self.input.bra | (other.input.bra - self.output.bra)
ket_out = sets[2] | sets[6] # (self.output.ket - other.input.ket) | other.output.ket
ket_in = sets[3] | sets[7] # self.input.ket | (other.input.ket - self.output.ket)
w = Wires(bra_out, bra_in, ket_out, ket_in)
# preserve ids
for t in (0, 1, 2, 3):
for m in w.args[t]:
w.ids_dicts[t][m] = self.ids_dicts[t][m] if m in sets[t] else other.ids_dicts[t][m]
# calculate permutation
result_ids = [id for d in w.ids_dicts for id in d.values()]
self_other_ids = [self.ids_dicts[t][m] for t in (0, 1, 2, 3) for m in sorted(sets[t])] + [
other.ids_dicts[t][m] for t in (0, 1, 2, 3) for m in sorted(sets[t + 4])
]
perm = [self_other_ids.index(id) for id in result_ids]
return w, perm
def __repr__(self) -> str:
return f"Wires{self.args}"
def _repr_html_(self): # pragma: no cover
template = Template(filename=os.path.dirname(__file__) + "/assets/wires.txt")
display(HTML(template.render(wires=self)))
_modules/mrmustard/lab_dev/wires
Download Python script
Download Notebook
View on GitHub