Source code for mrmustard.parameters.parameter_dict
# Copyright 2023 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains the classes to describe a dictionary of parameters."""
from __future__ import annotations
import io
from collections import UserDict
from copy import copy as _copy
from copy import deepcopy as _deepcopy
from typing import Any
import numpy as np
from rich.console import Console
from rich.table import Table
from mrmustard.math.backend_manager import BackendManager
from .parameters import Constant, Variable, format_dtype, format_value
math = BackendManager()
__all__ = ["ParameterDict"]
[docs]
class ParameterDict(UserDict):
r"""A dictionary-like class for storing parameters.
>>> c1 = Constant(1.2345, "const1")
>>> c2 = Constant(2.3456, "const2")
>>> v1 = Variable(3.4567, "var1")
>>> pd = ParameterDict(c1, c2, some_name=v1)
>>> assert pd.names == ['const1', 'const2', 'some_name']
>>> assert pd.constants == {"const1": c1, "const2": c2}
>>> assert pd.variables == {"some_name": v1}
Args:
*args: Constant or Variable parameters.
**kwargs: Constant or Variable parameters by name.
"""
def __init__(self, *args: Constant | Variable, **kwargs: Constant | Variable) -> None:
super().__init__()
for arg in args:
self.data[arg.name] = arg
if not isinstance(arg, (Constant, Variable)):
raise ValueError(f"Argument {arg} is not a Constant or Variable")
for key, value in kwargs.items():
self.data[key] = value
def __deepcopy__(self, memo: dict[int, Any]) -> ParameterDict:
new = type(self)()
new.data = _deepcopy(self.data, memo)
return new
def copy(self) -> ParameterDict:
return _copy(self)
def __setitem__(self, key: str, value: Constant | Variable) -> None:
self.data[key] = value
def __getattr__(self, key: str) -> Constant | Variable:
return self.data[key]
def __hash__(self) -> int:
return hash(tuple(self.data.values()))
def __tuple__(self) -> tuple[Constant | Variable, ...]:
return tuple(self.data.values())
@property
def names(self) -> list[str]:
return list(self.data.keys())
@property
def constants(self) -> ParameterDict:
r"""Returns a ParameterDict of constant parameters in this ParameterDict."""
return ParameterDict(
**{name: param for name, param in self.data.items() if isinstance(param, Constant)}
)
@property
def variables(self) -> ParameterDict:
r"""Returns a ParameterDict of variable parameters in this ParameterDict."""
return ParameterDict(
**{name: param for name, param in self.data.items() if isinstance(param, Variable)}
)
[docs]
def to_string(self, decimals: int) -> str:
r"""Returns a string representation of the parameter values, separated by commas and rounded
to the specified number of decimals.
Args:
decimals (int): number of decimals to round to
Returns:
str: string representation of the parameter values
"""
strings = []
for name, param in self.data.items():
if len(param.value.shape) == 0: # don't show arrays
string = str(np.round(math.asnumpy(param.value), decimals))
else:
string = f"{name}"
strings.append(string)
return ", ".join(strings)
def __repr__(self) -> str:
r"""Returns a rich-formatted string representation of this parameter set."""
if not self:
return "ParameterDict()"
table = Table(title=f"ParameterDict ({len(self.names)} parameters)", show_header=True)
table.add_column("Name", style="#FFB3B3", header_style="#FFB3B3", no_wrap=True)
table.add_column("Type", style="#FFCC99", header_style="#FFCC99", width=9)
table.add_column("Value", style="#FFFFBA", header_style="#FFFFBA")
table.add_column("Dtype", style="#BAFFC9", header_style="#BAFFC9", width=10)
table.add_column("Shape", style="#E1BAFF", header_style="#E1BAFF")
for name in self.names:
param = self.data[name]
param_type = "Constant" if isinstance(param, Constant) else "Variable"
value_str, shape_str = format_value(param)
dtype_str = format_dtype(param)
table.add_row(name, param_type, value_str, dtype_str, shape_str)
with io.StringIO() as string_buffer:
console = Console(file=string_buffer, width=100, legacy_windows=False)
console.print(table)
return string_buffer.getvalue().strip()
_modules/mrmustard/parameters/parameter_dict
Download Python script
Download Notebook
View on GitHub