Source code for mrmustard.training.progress_bar
# Copyright 2023 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A module containing classes and methods for progress bars."""
from rich.progress import BarColumn, Progress, TextColumn, TimeElapsedColumn, TimeRemainingColumn
from mrmustard import settings
__all__ = ["ProgressBar"]
[docs]
class ProgressBar:
r"""A spiffy loading bar to display the progress during an optimization."""
def __init__(self, max_steps: int):
self.taskID = None
if max_steps == 0:
self.bar = Progress(
TextColumn("Step {task.completed}/∞"),
BarColumn(),
TextColumn("Cost = {task.fields[loss]:.5f}"),
TextColumn(" | ⏱️ "),
TimeElapsedColumn(),
)
else:
self.bar = Progress(
TextColumn("Step {task.completed}/{task.total} | {task.fields[speed]:.1f} it/s"),
BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TextColumn("Cost = {task.fields[loss]:.5f} | ⏳ "),
TimeRemainingColumn(),
TextColumn(" | ⏱️ "),
TimeElapsedColumn(),
)
self.taskID = self.bar.add_task(
description="Optimizing...",
start=True,
speed=0.0,
total=max_steps or None,
loss=1.0,
refresh=True,
visible=settings.PROGRESSBAR,
)
[docs]
def step(self, loss):
r"""Update bar step and the loss information associated with it."""
speed = self.bar.tasks[0].speed or 0.0
self.bar.update(self.taskID, advance=1, refresh=True, speed=speed, loss=loss)
def __enter__(self):
return self.bar.__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
return self.bar.__exit__(exc_type, exc_val, exc_tb)
_modules/mrmustard/training/progress_bar
Download Python script
Download Notebook
View on GitHub