diff --git a/source/pip/benchmarks/bench_qre.py b/source/pip/benchmarks/bench_qre.py index e236594921..536669a8aa 100644 --- a/source/pip/benchmarks/bench_qre.py +++ b/source/pip/benchmarks/bench_qre.py @@ -5,7 +5,7 @@ from dataclasses import dataclass, KW_ONLY, field from qsharp.qre import linear_function, generic_function from qsharp.qre._architecture import _make_instruction -from qsharp.qre.models import AQREGateBased, SurfaceCode +from qsharp.qre.models import GateBased, SurfaceCode from qsharp.qre._enumeration import _enumerate_instances @@ -37,10 +37,10 @@ def bench_enumerate_isas(): # Add the tests directory to sys.path to import test_qre # TODO: Remove this once the models in test_qre are moved to a proper module - sys.path.append(os.path.join(os.path.dirname(__file__), "../tests")) - from test_qre import ExampleLogicalFactory, ExampleFactory # type: ignore + sys.path.append(os.path.join(os.path.dirname(__file__), "../tests/qre/")) + from conftest import ExampleLogicalFactory, ExampleFactory # type: ignore - ctx = AQREGateBased(gate_time=50, measurement_time=100).context() + ctx = GateBased(gate_time=50, measurement_time=100).context() # Hierarchical factory using from_components query = SurfaceCode.q() * ExampleLogicalFactory.q( diff --git a/source/pip/qsharp/qre/__init__.py b/source/pip/qsharp/qre/__init__.py index a17bc2122c..6ba945acf1 100644 --- a/source/pip/qsharp/qre/__init__.py +++ b/source/pip/qsharp/qre/__init__.py @@ -3,13 +3,7 @@ from ._application import Application from ._architecture import Architecture -from ._estimation import ( - estimate, - EstimationTable, - EstimationTableColumn, - EstimationTableEntry, - plot_estimates, -) +from ._estimation import estimate from ._instruction import ( LOGICAL, PHYSICAL, @@ -37,6 +31,12 @@ property_name, property_name_to_key, ) +from ._results import ( + EstimationTable, + EstimationTableColumn, + EstimationTableEntry, + plot_estimates, +) from ._trace import LatticeSurgery, PSSPC, TraceQuery, TraceTransform # Extend Rust Python types with additional Python-side functionality diff --git a/source/pip/qsharp/qre/_architecture.py b/source/pip/qsharp/qre/_architecture.py index 9045caee73..1bfb3f29ff 100644 --- a/source/pip/qsharp/qre/_architecture.py +++ b/source/pip/qsharp/qre/_architecture.py @@ -10,7 +10,7 @@ from ._qre import ( ISA, _ProvenanceGraph, - _Instruction, + Instruction, _IntFunction, _FloatFunction, constant_function, @@ -25,7 +25,7 @@ class Architecture(ABC): @abstractmethod - def provided_isa(self, ctx: _Context) -> ISA: + def provided_isa(self, ctx: ISAContext) -> ISA: """ Creates the ISA provided by this architecture, adding instructions directly to the context's provenance graph. @@ -39,12 +39,12 @@ def provided_isa(self, ctx: _Context) -> ISA: """ ... - def context(self) -> _Context: + def context(self) -> ISAContext: """Create a new enumeration context for this architecture.""" - return _Context(self) + return ISAContext(self) -class _Context: +class ISAContext: """ Context passed through enumeration, holding shared state. """ @@ -58,7 +58,7 @@ def __init__(self, arch: Architecture): self._bindings: dict[str, ISA] = {} self._transforms: dict[int, Architecture | ISATransform] = {0: arch} - def _with_binding(self, name: str, isa: ISA) -> _Context: + def _with_binding(self, name: str, isa: ISA) -> ISAContext: """Return a new context with an additional binding (internal use).""" ctx = copy.copy(self) ctx._bindings = {**self._bindings, name: isa} @@ -71,7 +71,7 @@ def isa(self) -> ISA: def add_instruction( self, - id_or_instruction: int | _Instruction, + id_or_instruction: int | Instruction, encoding: Encoding = 0, # type: ignore *, arity: Optional[int] = 1, @@ -80,7 +80,7 @@ def add_instruction( length: Optional[int | _IntFunction] = None, error_rate: float | _FloatFunction = 0.0, transform: ISATransform | None = None, - source: list[_Instruction] | None = None, + source: list[Instruction] | None = None, **kwargs: int, ) -> int: """ @@ -93,7 +93,7 @@ def add_instruction( ctx.add_instruction(T, encoding=LOGICAL, time=1000, error_rate=1e-8) - 2. With a pre-existing ``_Instruction`` object (e.g. from + 2. With a pre-existing ``Instruction`` object (e.g. from ``with_id()``):: ctx.add_instruction(existing_instruction) @@ -107,26 +107,26 @@ def add_instruction( Args: id_or_instruction: Either an instruction ID (int) for creating - a new instruction, or an existing ``_Instruction`` object. + a new instruction, or an existing ``Instruction`` object. encoding: The instruction encoding (0 = Physical, 1 = Logical). - Ignored when passing an existing ``_Instruction``. + Ignored when passing an existing ``Instruction``. arity: The instruction arity. ``None`` for variable arity. - Ignored when passing an existing ``_Instruction``. + Ignored when passing an existing ``Instruction``. time: Instruction time in ns (or ``_IntFunction`` for variable - arity). Ignored when passing an existing ``_Instruction``. + arity). Ignored when passing an existing ``Instruction``. space: Instruction space in physical qubits (or ``_IntFunction`` for variable arity). Ignored when passing an existing - ``_Instruction``. + ``Instruction``. length: Arity including ancilla qubits. Ignored when passing an - existing ``_Instruction``. + existing ``Instruction``. error_rate: Instruction error rate (or ``_FloatFunction`` for variable arity). Ignored when passing an existing - ``_Instruction``. + ``Instruction``. transform: The ``ISATransform`` that produced the instruction. - source: List of source ``_Instruction`` objects consumed by the + source: List of source ``Instruction`` objects consumed by the transform. **kwargs: Additional properties (e.g. ``distance=9``). Ignored - when passing an existing ``_Instruction``. + when passing an existing ``Instruction``. Returns: The node index in the provenance graph. @@ -146,7 +146,7 @@ def add_instruction( **kwargs, ) - if isinstance(id_or_instruction, _Instruction): + if isinstance(id_or_instruction, Instruction): instr = id_or_instruction else: instr = _make_instruction( @@ -193,10 +193,10 @@ def _make_instruction( length: int | _IntFunction | None, error_rate: float | _FloatFunction, properties: dict[str, int], -) -> _Instruction: - """Build an ``_Instruction`` from keyword arguments.""" +) -> Instruction: + """Build an ``Instruction`` from keyword arguments.""" if arity is not None: - instr = _Instruction.fixed_arity( + instr = Instruction.fixed_arity( id, encoding, arity, @@ -215,7 +215,7 @@ def _make_instruction( if isinstance(error_rate, (int, float)): error_rate = constant_function(float(error_rate)) - instr = _Instruction.variable_arity( + instr = Instruction.variable_arity( id, encoding, time, diff --git a/source/pip/qsharp/qre/_estimation.py b/source/pip/qsharp/qre/_estimation.py index b49f92d60b..7f39fd1683 100644 --- a/source/pip/qsharp/qre/_estimation.py +++ b/source/pip/qsharp/qre/_estimation.py @@ -3,30 +3,20 @@ from __future__ import annotations -from dataclasses import dataclass, field -from typing import cast, Optional, Callable, Any, Iterable +from typing import cast, Optional, Any -import pandas as pd from ._application import Application -from ._architecture import Architecture, _Context +from ._architecture import Architecture from ._qre import ( _estimate_parallel, _estimate_with_graph, _EstimationCollection, Trace, - FactoryResult, - instruction_name, - EstimationResult, ) from ._trace import TraceQuery, PSSPC, LatticeSurgery -from ._instruction import InstructionSource from ._isa_enumeration import ISAQuery -from .property_keys import ( - PHYSICAL_COMPUTE_QUBITS, - PHYSICAL_MEMORY_QUBITS, - PHYSICAL_FACTORY_QUBITS, -) +from ._results import EstimationTable, EstimationTableEntry def estimate( @@ -136,12 +126,12 @@ def estimate( # trace, not on the ISA. trace_multipliers: dict[int, tuple[float, float]] = {} trace_sample_isa: dict[int, int] = {} - for t_idx, i_idx, _q, r in summaries: + for t_idx, isa_idx, _q, r in summaries: if t_idx not in trace_sample_isa: - trace_sample_isa[t_idx] = i_idx - for t_idx, i_idx in trace_sample_isa.items(): + trace_sample_isa[t_idx] = isa_idx + for t_idx, isa_idx in trace_sample_isa.items(): params, trace = params_and_traces[t_idx] - sample = trace.estimate(isas[i_idx], max_error) + sample = trace.estimate(isas[isa_idx], max_error) if sample is not None: pre_q = sample.qubits pre_r = sample.runtime @@ -150,12 +140,14 @@ def estimate( trace_multipliers[t_idx] = (pp.qubits / pre_q, pp.runtime / pre_r) # Phase 3: Estimate post-pp values and filter to Pareto candidates. - estimated_pp: list[tuple[int, int, int, int]] = [] # (t, i, q, est_r) - for t_idx, i_idx, q, r in summaries: + estimated_pp: list[tuple[int, int, int, int]] = ( + [] + ) # (t_idx, isa_idx, est_q, est_r) + for t_idx, isa_idx, q, r in summaries: mult_q, mult_r = trace_multipliers.get(t_idx, (0.0, 0.0)) est_q = int(q * mult_q) if mult_q > 0 else q est_r = int(r * mult_r) if mult_r > 0 else r - estimated_pp.append((t_idx, i_idx, est_q, est_r)) + estimated_pp.append((t_idx, isa_idx, est_q, est_r)) # Build approximate post-pp Pareto frontier to identify candidates. estimated_pp.sort(key=lambda x: (x[2], x[3])) # sort by qubits, then runtime @@ -168,9 +160,9 @@ def estimate( # Phase 4: Re-estimate and post-process only the Pareto candidates. pp_collection = _EstimationCollection() - for t_idx, i_idx, _q, _r in approx_pareto: + for t_idx, isa_idx, _q, _r in approx_pareto: params, trace = params_and_traces[t_idx] - result = trace.estimate(isas[i_idx], max_error) + result = trace.estimate(isas[isa_idx], max_error) if result is not None: pp_result = app_ctx.application.post_process(params, result) if pp_result is not None: @@ -222,355 +214,3 @@ def estimate( table.stats.pareto_results = len(collection) return table - - -class EstimationTable(list["EstimationTableEntry"]): - """A table of quantum resource estimation results. - - Extends ``list[EstimationTableEntry]`` and provides configurable columns for - displaying estimation data. By default the table includes *qubits*, - *runtime* (displayed as a ``pandas.Timedelta``), and *error* columns. - Additional columns can be added or inserted with :meth:`add_column` and - :meth:`insert_column`. - """ - - def __init__(self): - """Initialize an empty estimation table with default columns.""" - super().__init__() - - self.name: Optional[str] = None - self.stats = EstimationTableStats() - - self._columns: list[tuple[str, EstimationTableColumn]] = [ - ("qubits", EstimationTableColumn(lambda entry: entry.qubits)), - ( - "runtime", - EstimationTableColumn( - lambda entry: entry.runtime, - formatter=lambda x: pd.Timedelta(x, unit="ns"), - ), - ), - ("error", EstimationTableColumn(lambda entry: entry.error)), - ] - - def add_column( - self, - name: str, - function: Callable[[EstimationTableEntry], Any], - formatter: Optional[Callable[[Any], Any]] = None, - ) -> None: - """Adds a column to the estimation table. - - Args: - name (str): The name of the column. - function (Callable[[EstimationTableEntry], Any]): A function that - takes an EstimationTableEntry and returns the value for this - column. - formatter (Optional[Callable[[Any], Any]]): An optional function - that formats the output of `function` for display purposes. - """ - self._columns.append((name, EstimationTableColumn(function, formatter))) - - def insert_column( - self, - index: int, - name: str, - function: Callable[[EstimationTableEntry], Any], - formatter: Optional[Callable[[Any], Any]] = None, - ) -> None: - """Inserts a column at the specified index in the estimation table. - - Args: - index (int): The index at which to insert the column. - name (str): The name of the column. - function (Callable[[EstimationTableEntry], Any]): A function that - takes an EstimationTableEntry and returns the value for this - column. - formatter (Optional[Callable[[Any], Any]]): An optional function - that formats the output of `function` for display purposes. - """ - self._columns.insert(index, (name, EstimationTableColumn(function, formatter))) - - def add_qubit_partition_column(self) -> None: - self.add_column( - "physical_compute_qubits", - lambda entry: entry.properties.get(PHYSICAL_COMPUTE_QUBITS, 0), - ) - self.add_column( - "physical_factory_qubits", - lambda entry: entry.properties.get(PHYSICAL_FACTORY_QUBITS, 0), - ) - self.add_column( - "physical_memory_qubits", - lambda entry: entry.properties.get(PHYSICAL_MEMORY_QUBITS, 0), - ) - - def add_factory_summary_column(self) -> None: - """Adds a column to the estimation table that summarizes the factories used in the estimation.""" - - def summarize_factories(entry: EstimationTableEntry) -> str: - if not entry.factories: - return "None" - return ", ".join( - f"{factory_result.copies}×{instruction_name(id)}" - for id, factory_result in entry.factories.items() - ) - - self.add_column("factories", summarize_factories) - - def as_frame(self): - """Convert the estimation table to a :class:`pandas.DataFrame`. - - Each row corresponds to an :class:`EstimationTableEntry` and each - column is determined by the columns registered on this table. Column - formatters, when present, are applied to the values before they are - placed in the frame. - - Returns: - pandas.DataFrame: A DataFrame representation of the estimation - results. - """ - return pd.DataFrame( - [ - { - column_name: ( - column.formatter(column.function(entry)) - if column.formatter is not None - else column.function(entry) - ) - for column_name, column in self._columns - } - for entry in self - ] - ) - - def plot(self, **kwargs): - """Plot this table's results. - - Convenience wrapper around :func:`plot_estimates`. All keyword - arguments are forwarded. - - Returns: - matplotlib.figure.Figure: The figure containing the plot. - """ - return plot_estimates(self, **kwargs) - - -@dataclass(frozen=True, slots=True) -class EstimationTableColumn: - """Definition of a single column in an :class:`EstimationTable`. - - Attributes: - function: A callable that extracts the raw column value from an - :class:`EstimationTableEntry`. - formatter: An optional callable that transforms the raw value for - display purposes (e.g. converting nanoseconds to a - ``pandas.Timedelta``). - """ - - function: Callable[[EstimationTableEntry], Any] - formatter: Optional[Callable[[Any], Any]] = None - - -@dataclass(frozen=True, slots=True) -class EstimationTableEntry: - """A single row in an :class:`EstimationTable`. - - Each entry represents one Pareto-optimal estimation result for a - particular combination of application trace and architecture ISA. - - Attributes: - qubits: Total number of physical qubits required. - runtime: Total runtime of the algorithm in nanoseconds. - error: Total estimated error probability. - source: The instruction source derived from the architecture ISA used - for this estimation. - factories: A mapping from instruction id to the - :class:`FactoryResult` describing the magic-state factory used - and the number of copies required. - properties: Additional key-value properties attached to the - estimation result. - """ - - qubits: int - runtime: int - error: float - source: InstructionSource - factories: dict[int, FactoryResult] = field(default_factory=dict) - properties: dict[int, int | float | bool | str] = field(default_factory=dict) - - @classmethod - def from_result( - cls, result: EstimationResult, ctx: _Context - ) -> EstimationTableEntry: - return cls( - qubits=result.qubits, - runtime=result.runtime, - error=result.error, - source=InstructionSource.from_isa(ctx, result.isa), - factories=result.factories.copy(), - properties=result.properties.copy(), - ) - - -@dataclass(slots=True) -class EstimationTableStats: - num_traces: int = 0 - num_isas: int = 0 - total_jobs: int = 0 - successful_estimates: int = 0 - pareto_results: int = 0 - - -# Mapping from runtime unit name to its value in nanoseconds. -_TIME_UNITS: dict[str, float] = { - "ns": 1, - "µs": 1e3, - "us": 1e3, - "ms": 1e6, - "s": 1e9, - "min": 60e9, - "hours": 3600e9, - "days": 86_400e9, - "weeks": 604_800e9, - "months": 31 * 86_400e9, - "years": 365 * 86_400e9, - "decades": 10 * 365 * 86_400e9, - "centuries": 100 * 365 * 86_400e9, -} - -# Ordered subset of _TIME_UNITS used for default x-axis tick labels. -_TICK_UNITS: list[tuple[str, float]] = [ - ("1 ns", _TIME_UNITS["ns"]), - ("1 µs", _TIME_UNITS["µs"]), - ("1 ms", _TIME_UNITS["ms"]), - ("1 s", _TIME_UNITS["s"]), - ("1 min", _TIME_UNITS["min"]), - ("1 hour", _TIME_UNITS["hours"]), - ("1 day", _TIME_UNITS["days"]), - ("1 week", _TIME_UNITS["weeks"]), - ("1 month", _TIME_UNITS["months"]), - ("1 year", _TIME_UNITS["years"]), - ("1 decade", _TIME_UNITS["decades"]), - ("1 century", _TIME_UNITS["centuries"]), -] - - -def plot_estimates( - data: EstimationTable | Iterable[EstimationTable], - *, - runtime_unit: Optional[str] = None, - figsize: tuple[float, float] = (15, 8), - scatter_args: dict[str, Any] = {"marker": "x"}, -): - """Returns a plot of the estimates displaying qubits vs runtime. - - Creates a log-log scatter plot where the x-axis shows the total runtime and - the y-axis shows the total number of physical qubits. - - *data* may be a single `EstimationTable` or an iterable of tables. When - multiple tables are provided, each is plotted as a separate series. If a - table has a `EstimationTable.name` (set via the *name* parameter of - `estimate`), it is used as the legend label for that series. - - When *runtime_unit* is ``None`` (the default), the x-axis uses - human-readable time-unit tick labels spanning nanoseconds to centuries. - When a unit string is given (e.g. ``"hours"``), all runtimes are scaled to - that unit and the x-axis label includes the unit while the ticks are plain - numbers. - - Supported *runtime_unit* values: ``"ns"``, ``"µs"`` (or ``"us"``), ``"ms"``, - ``"s"``, ``"min"``, ``"hours"``, ``"days"``, ``"weeks"``, ``"months"``, - ``"years"``. - - Args: - data: A single EstimationTable or an iterable of - EstimationTable objects to plot. - runtime_unit: Optional time unit to scale the x-axis to. - figsize: Figure dimensions in inches as ``(width, height)``. - scatter_args: Additional keyword arguments to pass to - ``matplotlib.axes.Axes.scatter`` when plotting the points. - - Returns: - matplotlib.figure.Figure: The figure containing the plot. - - Raises: - ImportError: If matplotlib is not installed. - ValueError: If all tables are empty or *runtime_unit* is not - recognised. - """ - try: - import matplotlib.pyplot as plt - except ImportError: - raise ImportError( - "Missing optional 'matplotlib' dependency. To install run: " - "pip install matplotlib" - ) - - # Normalize to a list of tables - if isinstance(data, EstimationTable): - tables = [data] - else: - tables = list(data) - - if not tables or all(len(t) == 0 for t in tables): - raise ValueError("Cannot plot an empty EstimationTable.") - - if runtime_unit is not None and runtime_unit not in _TIME_UNITS: - raise ValueError( - f"Unknown runtime_unit {runtime_unit!r}. " - f"Supported units: {', '.join(_TIME_UNITS)}" - ) - - fig, ax = plt.subplots(figsize=figsize) - ax.set_ylabel("Physical qubits") - ax.set_xscale("log") - ax.set_yscale("log") - - all_xs: list[float] = [] - has_labels = False - - for table in tables: - if len(table) == 0: - continue - - ys = [entry.qubits for entry in table] - - if runtime_unit is not None: - scale = _TIME_UNITS[runtime_unit] - xs = [entry.runtime / scale for entry in table] - else: - xs = [float(entry.runtime) for entry in table] - - all_xs.extend(xs) - - label = table.name - if label is not None: - has_labels = True - - ax.scatter(x=xs, y=ys, label=label, **scatter_args) - - if runtime_unit is not None: - ax.set_xlabel(f"Runtime ({runtime_unit})") - else: - ax.set_xlabel("Runtime") - - time_labels, time_units = zip(*_TICK_UNITS) - - cutoff = ( - next( - (i for i, x in enumerate(time_units) if x > max(all_xs)), - len(time_units) - 1, - ) - + 1 - ) - - ax.set_xticks(time_units[:cutoff]) - ax.set_xticklabels(time_labels[:cutoff], rotation=90) - - if has_labels: - ax.legend() - - plt.close(fig) - - return fig diff --git a/source/pip/qsharp/qre/_instruction.py b/source/pip/qsharp/qre/_instruction.py index de54bfd657..ab3c176e69 100644 --- a/source/pip/qsharp/qre/_instruction.py +++ b/source/pip/qsharp/qre/_instruction.py @@ -10,7 +10,7 @@ import pandas as pd -from ._architecture import _Context, Architecture +from ._architecture import ISAContext, Architecture from ._enumeration import _enumerate_instances from ._isa_enumeration import ( ISA_ROOT, @@ -22,7 +22,7 @@ ISA, Constraint, ConstraintBound, - _Instruction, + Instruction, ISARequirements, instruction_name, property_name_to_key, @@ -97,7 +97,9 @@ def required_isa() -> ISARequirements: ... @abstractmethod - def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, None]: + def provided_isa( + self, impl_isa: ISA, ctx: ISAContext + ) -> Generator[ISA, None, None]: """ Yields ISAs provided by this transform given an implementation ISA. @@ -113,7 +115,7 @@ def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, Non def enumerate_isas( cls, impl_isa: ISA | Iterable[ISA], - ctx: _Context, + ctx: ISAContext, **kwargs, ) -> Generator[ISA, None, None]: """ @@ -178,7 +180,7 @@ class InstructionSource: roots: list[int] = field(default_factory=list, init=False) @classmethod - def from_isa(cls, ctx: _Context, isa: ISA) -> InstructionSource: + def from_isa(cls, ctx: ISAContext, isa: ISA) -> InstructionSource: """ Constructs an InstructionSource graph from an ISA. @@ -187,7 +189,7 @@ def from_isa(cls, ctx: _Context, isa: ISA) -> InstructionSource: transforms and architectures that generated them. Args: - ctx (_Context): The enumeration context containing the provenance graph. + ctx (ISAContext): The enumeration context containing the provenance graph. isa (ISA): Instructions in the ISA will serve as root nodes in the source graph. Returns: @@ -231,7 +233,7 @@ def add_root(self, node_id: int) -> None: def add_node( self, - instruction: _Instruction, + instruction: Instruction, transform: Optional[ISATransform | Architecture], children: list[int], ) -> int: @@ -311,7 +313,7 @@ def get( @dataclass(frozen=True, slots=True) class _InstructionSourceNode: - instruction: _Instruction + instruction: Instruction transform: Optional[ISATransform | Architecture] children: list[int] @@ -322,7 +324,7 @@ def __init__(self, graph: InstructionSource, node_id: int): self.node_id = node_id @property - def instruction(self) -> _Instruction: + def instruction(self) -> Instruction: return self.graph.nodes[self.node_id].instruction @property diff --git a/source/pip/qsharp/qre/_isa_enumeration.py b/source/pip/qsharp/qre/_isa_enumeration.py index 5cbb9fa187..c33fdac435 100644 --- a/source/pip/qsharp/qre/_isa_enumeration.py +++ b/source/pip/qsharp/qre/_isa_enumeration.py @@ -9,7 +9,7 @@ from dataclasses import dataclass, field from typing import Generator -from ._architecture import _Context +from ._architecture import ISAContext from ._enumeration import _enumerate_instances from ._qre import ISA @@ -25,7 +25,7 @@ class ISAQuery(ABC): """ @abstractmethod - def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: + def enumerate(self, ctx: ISAContext) -> Generator[ISA, None, None]: """ Yields all ISA instances represented by this enumeration node. @@ -38,7 +38,7 @@ def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: """ pass - def populate(self, ctx: _Context) -> int: + def populate(self, ctx: ISAContext) -> int: """ Populates the provenance graph with instructions from this node. @@ -47,7 +47,7 @@ def populate(self, ctx: _Context) -> int: requirements, and adds produced instructions directly to the graph. Args: - ctx (_Context): The enumeration context whose provenance graph + ctx (ISAContext): The enumeration context whose provenance graph will be populated. Returns: @@ -158,7 +158,7 @@ class RootNode(ISAQuery): Reads from the context instead of holding a reference. """ - def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: + def enumerate(self, ctx: ISAContext) -> Generator[ISA, None, None]: """ Yields the architecture ISA from the context. @@ -170,8 +170,8 @@ def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: """ yield ctx._isa - def populate(self, ctx: _Context) -> int: - """Architecture ISA is already in the graph from ``_Context.__init__``. + def populate(self, ctx: ISAContext) -> int: + """Architecture ISA is already in the graph from ``ISAContext.__init__``. Returns: int: 1, since architecture nodes start at index 1. @@ -203,7 +203,7 @@ class _ComponentQuery(ISAQuery): source: ISAQuery = field(default_factory=lambda: ISA_ROOT) kwargs: dict = field(default_factory=dict) - def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: + def enumerate(self, ctx: ISAContext) -> Generator[ISA, None, None]: """ Yields ISAs generated by the component from source ISAs. @@ -216,7 +216,7 @@ def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: for isa in self.source.enumerate(ctx): yield from self.component.enumerate_isas(isa, ctx, **self.kwargs) - def populate(self, ctx: _Context) -> int: + def populate(self, ctx: ISAContext) -> int: """ Populates the graph by querying matching instructions. @@ -253,7 +253,7 @@ class _ProductNode(ISAQuery): sources: list[ISAQuery] - def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: + def enumerate(self, ctx: ISAContext) -> Generator[ISA, None, None]: """ Yields ISAs formed by combining ISAs from all source nodes. @@ -269,7 +269,7 @@ def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: for isa_tuple in itertools.product(*source_generators) ) - def populate(self, ctx: _Context) -> int: + def populate(self, ctx: ISAContext) -> int: """Populates the graph from each source sequentially (no cross product). Returns: @@ -292,7 +292,7 @@ class _SumNode(ISAQuery): sources: list[ISAQuery] - def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: + def enumerate(self, ctx: ISAContext) -> Generator[ISA, None, None]: """ Yields ISAs from each source node in sequence. @@ -305,7 +305,7 @@ def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: for source in self.sources: yield from source.enumerate(ctx) - def populate(self, ctx: _Context) -> int: + def populate(self, ctx: ISAContext) -> int: """Populates the graph from each source sequentially. Returns: @@ -330,7 +330,7 @@ class ISARefNode(ISAQuery): name: str - def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: + def enumerate(self, ctx: ISAContext) -> Generator[ISA, None, None]: """ Yields the bound ISA from the context. @@ -347,7 +347,7 @@ def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: raise ValueError(f"Undefined component reference: '{self.name}'") yield ctx._bindings[self.name] - def populate(self, ctx: _Context) -> int: + def populate(self, ctx: ISAContext) -> int: """Instructions already in graph from the bound component. Returns: @@ -401,7 +401,7 @@ class _BindingNode(ISAQuery): component: ISAQuery node: ISAQuery - def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: + def enumerate(self, ctx: ISAContext) -> Generator[ISA, None, None]: """ Enumerates child nodes with the bound component in context. @@ -417,7 +417,7 @@ def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: new_ctx = ctx._with_binding(self.name, isa) yield from self.node.enumerate(new_ctx) - def populate(self, ctx: _Context) -> int: + def populate(self, ctx: ISAContext) -> int: """Populates the graph from both the component and the child node. Returns: diff --git a/source/pip/qsharp/qre/_qre.py b/source/pip/qsharp/qre/_qre.py index f724349388..2d1aaa7aa5 100644 --- a/source/pip/qsharp/qre/_qre.py +++ b/source/pip/qsharp/qre/_qre.py @@ -19,7 +19,7 @@ _FloatFunction, generic_function, instruction_name, - _Instruction, + Instruction, InstructionFrontier, _IntFunction, ISA, diff --git a/source/pip/qsharp/qre/_qre.pyi b/source/pip/qsharp/qre/_qre.pyi index e143333df4..7e9f92ddc8 100644 --- a/source/pip/qsharp/qre/_qre.pyi +++ b/source/pip/qsharp/qre/_qre.pyi @@ -33,7 +33,7 @@ class ISA: """ ... - def __getitem__(self, id: int) -> _Instruction: + def __getitem__(self, id: int) -> Instruction: """ Gets an instruction by its ID. @@ -41,23 +41,23 @@ class ISA: id (int): The instruction ID. Returns: - _Instruction: The instruction. + Instruction: The instruction. """ ... def get( - self, id: int, default: Optional[_Instruction] = None - ) -> Optional[_Instruction]: + self, id: int, default: Optional[Instruction] = None + ) -> Optional[Instruction]: """ Gets an instruction by its ID, or returns a default value if not found. Args: id (int): The instruction ID. - default (Optional[_Instruction]): The default value to return if the + default (Optional[Instruction]): The default value to return if the instruction is not found. Returns: - Optional[_Instruction]: The instruction, or the default value if not found. + Optional[Instruction]: The instruction, or the default value if not found. """ ... @@ -105,7 +105,7 @@ class ISA: """ ... - def __iter__(self) -> Iterator[_Instruction]: + def __iter__(self) -> Iterator[Instruction]: """ Returns an iterator over the instructions. @@ -113,7 +113,7 @@ class ISA: The order of instructions is not guaranteed. Returns: - Iterator[_Instruction]: The instruction iterator. + Iterator[Instruction]: The instruction iterator. """ ... @@ -178,7 +178,7 @@ class ISARequirements: """ ... -class _Instruction: +class Instruction: @staticmethod def fixed_arity( id: int, @@ -188,7 +188,7 @@ class _Instruction: space: Optional[int], length: Optional[int], error_rate: float, - ) -> _Instruction: + ) -> Instruction: """ Creates an instruction with a fixed arity. @@ -207,7 +207,7 @@ class _Instruction: error_rate (float): The instruction error rate. Returns: - _Instruction: The instruction. + Instruction: The instruction. """ ... @@ -219,7 +219,7 @@ class _Instruction: space_fn: _IntFunction, error_rate_fn: _FloatFunction, length_fn: Optional[_IntFunction], - ) -> _Instruction: + ) -> Instruction: """ Creates an instruction with variable arity. @@ -236,11 +236,11 @@ class _Instruction: If None, space_fn is used. Returns: - _Instruction: The instruction. + Instruction: The instruction. """ ... - def with_id(self, id: int) -> _Instruction: + def with_id(self, id: int) -> Instruction: """ Returns a copy of the instruction with the given ID. @@ -252,7 +252,7 @@ class _Instruction: id (int): The instruction ID. Returns: - _Instruction: A copy of the instruction with the given ID. + Instruction: A copy of the instruction with the given ID. """ ... @@ -702,7 +702,7 @@ class _ProvenanceGraph: """ def add_node( - self, instruction: _Instruction, transform_id: int, children: list[int] + self, instruction: Instruction, transform_id: int, children: list[int] ) -> int: """ Adds a node to the provenance graph. @@ -717,7 +717,7 @@ class _ProvenanceGraph: """ ... - def instruction(self, node_index: int) -> _Instruction: + def instruction(self, node_index: int) -> Instruction: """ Returns the instruction for a given node index. @@ -774,7 +774,7 @@ class _ProvenanceGraph: @overload def add_instruction( self, - instruction: _Instruction, + instruction: Instruction, ) -> int: ... @overload def add_instruction( @@ -791,7 +791,7 @@ class _ProvenanceGraph: ) -> int: ... def add_instruction( self, - id_or_instruction: int | _Instruction, + id_or_instruction: int | Instruction, encoding: int = 0, *, arity: Optional[int] = 1, @@ -805,20 +805,20 @@ class _ProvenanceGraph: Adds an instruction to the provenance graph with no transform or children. - Can be called with a pre-existing ``_Instruction`` or with keyword + Can be called with a pre-existing ``Instruction`` or with keyword args to create one inline. Args: - id_or_instruction: An instruction ID (int) or ``_Instruction``. - encoding: 0 = Physical, 1 = Logical. Ignored for ``_Instruction``. + id_or_instruction: An instruction ID (int) or ``Instruction``. + encoding: 0 = Physical, 1 = Logical. Ignored for ``Instruction``. arity: Instruction arity, ``None`` for variable. Ignored for - ``_Instruction``. - time: Time in ns (or ``_IntFunction``). Ignored for ``_Instruction``. + ``Instruction``. + time: Time in ns (or ``_IntFunction``). Ignored for ``Instruction``. space: Space in physical qubits (or ``_IntFunction``). Ignored for - ``_Instruction``. - length: Arity including ancillas. Ignored for ``_Instruction``. + ``Instruction``. + length: Arity including ancillas. Ignored for ``Instruction``. error_rate: Error rate (or ``_FloatFunction``). Ignored for - ``_Instruction``. + ``Instruction``. **kwargs: Additional properties (e.g. ``distance=9``). Returns: @@ -1511,21 +1511,21 @@ class InstructionFrontier: """ ... - def insert(self, point: _Instruction): + def insert(self, point: Instruction): """ Inserts an instruction to the frontier. Args: - point (_Instruction): The instruction to insert. + point (Instruction): The instruction to insert. """ ... - def extend(self, points: list[_Instruction]) -> None: + def extend(self, points: list[Instruction]) -> None: """ Extends the frontier with a list of instructions. Args: - points (list[_Instruction]): The instructions to insert. + points (list[Instruction]): The instructions to insert. """ ... @@ -1538,12 +1538,12 @@ class InstructionFrontier: """ ... - def __iter__(self) -> Iterator[_Instruction]: + def __iter__(self) -> Iterator[Instruction]: """ Returns an iterator over the instructions in the frontier. Returns: - Iterator[_Instruction]: The iterator. + Iterator[Instruction]: The iterator. """ ... diff --git a/source/pip/qsharp/qre/_results.py b/source/pip/qsharp/qre/_results.py new file mode 100644 index 0000000000..efaa3be144 --- /dev/null +++ b/source/pip/qsharp/qre/_results.py @@ -0,0 +1,374 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Optional, Callable, Any, Iterable + +import pandas as pd + +from ._architecture import ISAContext +from ._qre import ( + FactoryResult, + instruction_name, + EstimationResult, +) +from ._instruction import InstructionSource +from .property_keys import ( + PHYSICAL_COMPUTE_QUBITS, + PHYSICAL_MEMORY_QUBITS, + PHYSICAL_FACTORY_QUBITS, +) + + +class EstimationTable(list["EstimationTableEntry"]): + """A table of quantum resource estimation results. + + Extends ``list[EstimationTableEntry]`` and provides configurable columns for + displaying estimation data. By default the table includes *qubits*, + *runtime* (displayed as a ``pandas.Timedelta``), and *error* columns. + Additional columns can be added or inserted with :meth:`add_column` and + :meth:`insert_column`. + """ + + def __init__(self): + """Initialize an empty estimation table with default columns.""" + super().__init__() + + self.name: Optional[str] = None + self.stats = EstimationTableStats() + + self._columns: list[tuple[str, EstimationTableColumn]] = [ + ("qubits", EstimationTableColumn(lambda entry: entry.qubits)), + ( + "runtime", + EstimationTableColumn( + lambda entry: entry.runtime, + formatter=lambda x: pd.Timedelta(x, unit="ns"), + ), + ), + ("error", EstimationTableColumn(lambda entry: entry.error)), + ] + + def add_column( + self, + name: str, + function: Callable[[EstimationTableEntry], Any], + formatter: Optional[Callable[[Any], Any]] = None, + ) -> None: + """Adds a column to the estimation table. + + Args: + name (str): The name of the column. + function (Callable[[EstimationTableEntry], Any]): A function that + takes an EstimationTableEntry and returns the value for this + column. + formatter (Optional[Callable[[Any], Any]]): An optional function + that formats the output of `function` for display purposes. + """ + self._columns.append((name, EstimationTableColumn(function, formatter))) + + def insert_column( + self, + index: int, + name: str, + function: Callable[[EstimationTableEntry], Any], + formatter: Optional[Callable[[Any], Any]] = None, + ) -> None: + """Inserts a column at the specified index in the estimation table. + + Args: + index (int): The index at which to insert the column. + name (str): The name of the column. + function (Callable[[EstimationTableEntry], Any]): A function that + takes an EstimationTableEntry and returns the value for this + column. + formatter (Optional[Callable[[Any], Any]]): An optional function + that formats the output of `function` for display purposes. + """ + self._columns.insert(index, (name, EstimationTableColumn(function, formatter))) + + def add_qubit_partition_column(self) -> None: + self.add_column( + "physical_compute_qubits", + lambda entry: entry.properties.get(PHYSICAL_COMPUTE_QUBITS, 0), + ) + self.add_column( + "physical_factory_qubits", + lambda entry: entry.properties.get(PHYSICAL_FACTORY_QUBITS, 0), + ) + self.add_column( + "physical_memory_qubits", + lambda entry: entry.properties.get(PHYSICAL_MEMORY_QUBITS, 0), + ) + + def add_factory_summary_column(self) -> None: + """Adds a column to the estimation table that summarizes the factories used in the estimation.""" + + def summarize_factories(entry: EstimationTableEntry) -> str: + if not entry.factories: + return "None" + return ", ".join( + f"{factory_result.copies}×{instruction_name(id)}" + for id, factory_result in entry.factories.items() + ) + + self.add_column("factories", summarize_factories) + + def as_frame(self): + """Convert the estimation table to a :class:`pandas.DataFrame`. + + Each row corresponds to an :class:`EstimationTableEntry` and each + column is determined by the columns registered on this table. Column + formatters, when present, are applied to the values before they are + placed in the frame. + + Returns: + pandas.DataFrame: A DataFrame representation of the estimation + results. + """ + return pd.DataFrame( + [ + { + column_name: ( + column.formatter(column.function(entry)) + if column.formatter is not None + else column.function(entry) + ) + for column_name, column in self._columns + } + for entry in self + ] + ) + + def plot(self, **kwargs): + """Plot this table's results. + + Convenience wrapper around :func:`plot_estimates`. All keyword + arguments are forwarded. + + Returns: + matplotlib.figure.Figure: The figure containing the plot. + """ + return plot_estimates(self, **kwargs) + + +@dataclass(frozen=True, slots=True) +class EstimationTableColumn: + """Definition of a single column in an :class:`EstimationTable`. + + Attributes: + function: A callable that extracts the raw column value from an + :class:`EstimationTableEntry`. + formatter: An optional callable that transforms the raw value for + display purposes (e.g. converting nanoseconds to a + ``pandas.Timedelta``). + """ + + function: Callable[[EstimationTableEntry], Any] + formatter: Optional[Callable[[Any], Any]] = None + + +@dataclass(frozen=True, slots=True) +class EstimationTableEntry: + """A single row in an :class:`EstimationTable`. + + Each entry represents one Pareto-optimal estimation result for a + particular combination of application trace and architecture ISA. + + Attributes: + qubits: Total number of physical qubits required. + runtime: Total runtime of the algorithm in nanoseconds. + error: Total estimated error probability. + source: The instruction source derived from the architecture ISA used + for this estimation. + factories: A mapping from instruction id to the + :class:`FactoryResult` describing the magic-state factory used + and the number of copies required. + properties: Additional key-value properties attached to the + estimation result. + """ + + qubits: int + runtime: int + error: float + source: InstructionSource + factories: dict[int, FactoryResult] = field(default_factory=dict) + properties: dict[int, int | float | bool | str] = field(default_factory=dict) + + @classmethod + def from_result( + cls, result: EstimationResult, ctx: ISAContext + ) -> EstimationTableEntry: + return cls( + qubits=result.qubits, + runtime=result.runtime, + error=result.error, + source=InstructionSource.from_isa(ctx, result.isa), + factories=result.factories.copy(), + properties=result.properties.copy(), + ) + + +@dataclass(slots=True) +class EstimationTableStats: + num_traces: int = 0 + num_isas: int = 0 + total_jobs: int = 0 + successful_estimates: int = 0 + pareto_results: int = 0 + + +# Mapping from runtime unit name to its value in nanoseconds. +_TIME_UNITS: dict[str, float] = { + "ns": 1, + "µs": 1e3, + "us": 1e3, + "ms": 1e6, + "s": 1e9, + "min": 60e9, + "hours": 3600e9, + "days": 86_400e9, + "weeks": 604_800e9, + "months": 31 * 86_400e9, + "years": 365 * 86_400e9, + "decades": 10 * 365 * 86_400e9, + "centuries": 100 * 365 * 86_400e9, +} + +# Ordered subset of _TIME_UNITS used for default x-axis tick labels. +_TICK_UNITS: list[tuple[str, float]] = [ + ("1 ns", _TIME_UNITS["ns"]), + ("1 µs", _TIME_UNITS["µs"]), + ("1 ms", _TIME_UNITS["ms"]), + ("1 s", _TIME_UNITS["s"]), + ("1 min", _TIME_UNITS["min"]), + ("1 hour", _TIME_UNITS["hours"]), + ("1 day", _TIME_UNITS["days"]), + ("1 week", _TIME_UNITS["weeks"]), + ("1 month", _TIME_UNITS["months"]), + ("1 year", _TIME_UNITS["years"]), + ("1 decade", _TIME_UNITS["decades"]), + ("1 century", _TIME_UNITS["centuries"]), +] + + +def plot_estimates( + data: EstimationTable | Iterable[EstimationTable], + *, + runtime_unit: Optional[str] = None, + figsize: tuple[float, float] = (15, 8), + scatter_args: dict[str, Any] = {"marker": "x"}, +): + """Returns a plot of the estimates displaying qubits vs runtime. + + Creates a log-log scatter plot where the x-axis shows the total runtime and + the y-axis shows the total number of physical qubits. + + *data* may be a single `EstimationTable` or an iterable of tables. When + multiple tables are provided, each is plotted as a separate series. If a + table has a `EstimationTable.name` (set via the *name* parameter of + `estimate`), it is used as the legend label for that series. + + When *runtime_unit* is ``None`` (the default), the x-axis uses + human-readable time-unit tick labels spanning nanoseconds to centuries. + When a unit string is given (e.g. ``"hours"``), all runtimes are scaled to + that unit and the x-axis label includes the unit while the ticks are plain + numbers. + + Supported *runtime_unit* values: ``"ns"``, ``"µs"`` (or ``"us"``), ``"ms"``, + ``"s"``, ``"min"``, ``"hours"``, ``"days"``, ``"weeks"``, ``"months"``, + ``"years"``. + + Args: + data: A single EstimationTable or an iterable of + EstimationTable objects to plot. + runtime_unit: Optional time unit to scale the x-axis to. + figsize: Figure dimensions in inches as ``(width, height)``. + scatter_args: Additional keyword arguments to pass to + ``matplotlib.axes.Axes.scatter`` when plotting the points. + + Returns: + matplotlib.figure.Figure: The figure containing the plot. + + Raises: + ImportError: If matplotlib is not installed. + ValueError: If all tables are empty or *runtime_unit* is not + recognised. + """ + try: + import matplotlib.pyplot as plt + except ImportError: + raise ImportError( + "Missing optional 'matplotlib' dependency. To install run: " + "pip install matplotlib" + ) + + # Normalize to a list of tables + if isinstance(data, EstimationTable): + tables = [data] + else: + tables = list(data) + + if not tables or all(len(t) == 0 for t in tables): + raise ValueError("Cannot plot an empty EstimationTable.") + + if runtime_unit is not None and runtime_unit not in _TIME_UNITS: + raise ValueError( + f"Unknown runtime_unit {runtime_unit!r}. " + f"Supported units: {', '.join(_TIME_UNITS)}" + ) + + fig, ax = plt.subplots(figsize=figsize) + ax.set_ylabel("Physical qubits") + ax.set_xscale("log") + ax.set_yscale("log") + + all_xs: list[float] = [] + has_labels = False + + for table in tables: + if len(table) == 0: + continue + + ys = [entry.qubits for entry in table] + + if runtime_unit is not None: + scale = _TIME_UNITS[runtime_unit] + xs = [entry.runtime / scale for entry in table] + else: + xs = [float(entry.runtime) for entry in table] + + all_xs.extend(xs) + + label = table.name + if label is not None: + has_labels = True + + ax.scatter(x=xs, y=ys, label=label, **scatter_args) + + if runtime_unit is not None: + ax.set_xlabel(f"Runtime ({runtime_unit})") + else: + ax.set_xlabel("Runtime") + + time_labels, time_units = zip(*_TICK_UNITS) + + cutoff = ( + next( + (i for i, x in enumerate(time_units) if x > max(all_xs)), + len(time_units) - 1, + ) + + 1 + ) + + ax.set_xticks(time_units[:cutoff]) + ax.set_xticklabels(time_labels[:cutoff], rotation=90) + + if has_labels: + ax.legend() + + plt.close(fig) + + return fig diff --git a/source/pip/qsharp/qre/interop/_cirq.py b/source/pip/qsharp/qre/interop/_cirq.py index 0153c00320..b685456d84 100644 --- a/source/pip/qsharp/qre/interop/_cirq.py +++ b/source/pip/qsharp/qre/interop/_cirq.py @@ -90,7 +90,7 @@ def trace_from_cirq( # circuit is OP_TREE circuit = cirq.Circuit(circuit) - context = _Context(circuit, classical_control_probability) + context = _CirqTraceBuilder(circuit, classical_control_probability) for moment in circuit: for op in moment.operations: @@ -99,11 +99,25 @@ def trace_from_cirq( return context.trace -class _Context: - """Tracks the current trace and block nesting during trace generation. +class _CirqTraceBuilder: + """Builds a resource estimation ``Trace`` from a Cirq circuit. - Maintains a stack of blocks so that ``PushBlock`` and ``PopBlock`` - operations can create nested repeated sections in the trace. + This class walks the operations produced by ``trace_from_cirq`` and + translates each one into trace instructions. It maintains the state + needed during the conversion: + + * A ``Trace`` instance that accumulates the result. + * A stack of ``Block`` objects so that ``PushBlock`` / ``PopBlock`` + markers can create nested repeated sections. + * A qubit-id mapping (``_QidToTraceId``) that assigns each Cirq qubit + a sequential integer index. + * A Cirq ``DecompositionContext`` for gates that need recursive + decomposition. + + Args: + circuit: The Cirq circuit being converted. + classical_control_probability: Probability that a classically + controlled operation is included in the trace. """ def __init__(self, circuit: cirq.Circuit, classical_control_probability: float): @@ -116,31 +130,41 @@ def __init__(self, circuit: cirq.Circuit, classical_control_probability: float): ) def push_block(self, repetitions: int): + """Open a new repeated block with the given number of repetitions.""" block = self.block.add_block(repetitions) self._blocks.append(block) def pop_block(self): + """Close the current repeated block, returning to the parent.""" self._blocks.pop() @property def trace(self) -> Trace: + """The accumulated trace, with ``compute_qubits`` updated to reflect + all qubits seen so far (including any allocated during decomposition).""" self._trace.compute_qubits = len(self._q_to_id) return self._trace @property def block(self) -> Block: + """The innermost open block in the trace.""" return self._blocks[-1] @property def q_to_id(self) -> _QidToTraceId: + """Mapping from Cirq ``Qid`` to integer trace qubit index.""" return self._q_to_id @property def classical_control_probability(self) -> float: + """Probability used to stochastically include classically controlled + operations.""" return self._classical_control_probability @property def decomp_context(self) -> cirq.DecompositionContext: + """Cirq decomposition context shared across all recursive + decompositions.""" return self._decomp_context def handle_op( @@ -151,15 +175,18 @@ def handle_op( Supported operation forms: - - ``TraceGate``: A raw trace instruction, added directly to the current block. - - ``PushBlock`` / ``PopBlock``: Control block nesting with repetitions. - - ``GateOperation``: Dispatched via ``_to_trace`` if available on the - gate, otherwise decomposed via ``_decompose_with_context_`` or - ``_decompose_``. + - ``TraceGate``: A raw trace instruction, added directly to the + current block. + - ``PushBlock`` / ``PopBlock``: Control block nesting with + repetitions. + - ``GateOperation``: Dispatched via ``_to_trace`` if available on + the gate, otherwise decomposed via + ``_decompose_with_context_`` or ``_decompose_``. - ``ClassicallyControlledOperation``: Included with the probability - specified in the generation context. - - ``list``: Each element is handled recursively. - - Any other operation: Decomposed via ``_decompose_with_context_``. + given by ``classical_control_probability``. + - ``list`` / iterable: Each element is handled recursively. + - Any other ``cirq.Operation``: Decomposed via + ``_decompose_with_context_``. Args: op: The operation to convert. diff --git a/source/pip/qsharp/qre/models/__init__.py b/source/pip/qsharp/qre/models/__init__.py index 5b8400002c..3da76797ac 100644 --- a/source/pip/qsharp/qre/models/__init__.py +++ b/source/pip/qsharp/qre/models/__init__.py @@ -8,10 +8,10 @@ OneDimensionalYokedSurfaceCode, TwoDimensionalYokedSurfaceCode, ) -from .qubits import AQREGateBased, Majorana +from .qubits import GateBased, Majorana __all__ = [ - "AQREGateBased", + "GateBased", "Litinski19Factory", "Majorana", "MagicUpToClifford", diff --git a/source/pip/qsharp/qre/models/factories/_litinski.py b/source/pip/qsharp/qre/models/factories/_litinski.py index d4f35117e4..30d3b444c6 100644 --- a/source/pip/qsharp/qre/models/factories/_litinski.py +++ b/source/pip/qsharp/qre/models/factories/_litinski.py @@ -7,7 +7,7 @@ from math import ceil from typing import Generator -from ..._architecture import _Context +from ..._architecture import ISAContext from ..._qre import ISARequirements, ConstraintBound, ISA from ..._instruction import ISATransform, constraint, LOGICAL from ...instruction_ids import T, CNOT, H, MEAS_Z, CCZ @@ -48,7 +48,9 @@ def required_isa() -> ISARequirements: constraint(MEAS_Z, error_rate=ConstraintBound.le(1e-3)), ) - def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, None]: + def provided_isa( + self, impl_isa: ISA, ctx: ISAContext + ) -> Generator[ISA, None, None]: h = impl_isa[H] cnot = impl_isa[CNOT] meas_z = impl_isa[MEAS_Z] diff --git a/source/pip/qsharp/qre/models/factories/_round_based.py b/source/pip/qsharp/qre/models/factories/_round_based.py index aed95e1243..5f746595bd 100644 --- a/source/pip/qsharp/qre/models/factories/_round_based.py +++ b/source/pip/qsharp/qre/models/factories/_round_based.py @@ -11,7 +11,7 @@ from pathlib import Path from typing import Callable, Generator, Iterable, Optional, Sequence -from ..._qre import ISA, InstructionFrontier, ISARequirements, _Instruction, _binom_ppf +from ..._qre import ISA, InstructionFrontier, ISARequirements, Instruction, _binom_ppf from ..._instruction import ( LOGICAL, PHYSICAL, @@ -19,7 +19,7 @@ ISATransform, constraint, ) -from ..._architecture import _Context +from ..._architecture import ISAContext from ...instruction_ids import CNOT, LATTICE_SURGERY, T, MEAS_ZZ from ..qec import SurfaceCode @@ -103,7 +103,9 @@ def required_isa() -> ISARequirements: constraint(T), ) - def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, None]: + def provided_isa( + self, impl_isa: ISA, ctx: ISAContext + ) -> Generator[ISA, None, None]: cache_path = self._cache_path(impl_isa) # 1) Try to load from cache @@ -190,7 +192,7 @@ def _physical_units(self, gate_time, clifford_error) -> list[_DistillationUnit]: ] def _logical_units( - self, lattice_surgery_instruction: _Instruction + self, lattice_surgery_instruction: Instruction ) -> list[_DistillationUnit]: logical_cycle_time = lattice_surgery_instruction.expect_time(1) logical_error = lattice_surgery_instruction.expect_error_rate(1) @@ -214,8 +216,8 @@ def _logical_units( ), ] - def _state_from_pipeline(self, pipeline: _Pipeline) -> _Instruction: - return _Instruction.fixed_arity( + def _state_from_pipeline(self, pipeline: _Pipeline) -> Instruction: + return Instruction.fixed_arity( T, int(LOGICAL), 1, diff --git a/source/pip/qsharp/qre/models/factories/_utils.py b/source/pip/qsharp/qre/models/factories/_utils.py index dcd72c6afe..a0efbc4ec5 100644 --- a/source/pip/qsharp/qre/models/factories/_utils.py +++ b/source/pip/qsharp/qre/models/factories/_utils.py @@ -3,7 +3,7 @@ from typing import Generator -from ..._architecture import _Context +from ..._architecture import ISAContext from ..._qre import ISARequirements, ISA from ..._instruction import ISATransform from ...instruction_ids import ( @@ -58,7 +58,7 @@ class MagicUpToClifford(ISATransform): def required_isa() -> ISARequirements: return ISARequirements() - def provided_isa(self, impl_isa, ctx: _Context) -> Generator[ISA, None, None]: + def provided_isa(self, impl_isa, ctx: ISAContext) -> Generator[ISA, None, None]: # Families of equivalent gates under Clifford conjugation. families = [ [ diff --git a/source/pip/qsharp/qre/models/qec/_surface_code.py b/source/pip/qsharp/qre/models/qec/_surface_code.py index e402ea9c41..ee5cc8bace 100644 --- a/source/pip/qsharp/qre/models/qec/_surface_code.py +++ b/source/pip/qsharp/qre/models/qec/_surface_code.py @@ -12,7 +12,7 @@ ConstraintBound, LOGICAL, ) -from ..._isa_enumeration import _Context +from ..._isa_enumeration import ISAContext from ..._qre import linear_function from ...instruction_ids import CNOT, H, LATTICE_SURGERY, MEAS_Z from ...property_keys import ( @@ -73,7 +73,9 @@ def required_isa() -> ISARequirements: constraint(MEAS_Z, error_rate=ConstraintBound.lt(0.01)), ) - def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, None]: + def provided_isa( + self, impl_isa: ISA, ctx: ISAContext + ) -> Generator[ISA, None, None]: cnot = impl_isa[CNOT] h = impl_isa[H] meas_z = impl_isa[MEAS_Z] diff --git a/source/pip/qsharp/qre/models/qec/_three_aux.py b/source/pip/qsharp/qre/models/qec/_three_aux.py index 2af1879205..5f7cff6da3 100644 --- a/source/pip/qsharp/qre/models/qec/_three_aux.py +++ b/source/pip/qsharp/qre/models/qec/_three_aux.py @@ -6,7 +6,7 @@ from dataclasses import KW_ONLY, dataclass, field from typing import Generator -from ..._architecture import _Context +from ..._architecture import ISAContext from ..._instruction import ( LOGICAL, ISATransform, @@ -59,7 +59,9 @@ def required_isa() -> ISARequirements: constraint(MEAS_ZZ, arity=2), ) - def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, None]: + def provided_isa( + self, impl_isa: ISA, ctx: ISAContext + ) -> Generator[ISA, None, None]: meas_x = impl_isa[MEAS_X] meas_z = impl_isa[MEAS_Z] meas_xx = impl_isa[MEAS_XX] diff --git a/source/pip/qsharp/qre/models/qec/_yoked.py b/source/pip/qsharp/qre/models/qec/_yoked.py index 8bb9bf9597..9cb1b26527 100644 --- a/source/pip/qsharp/qre/models/qec/_yoked.py +++ b/source/pip/qsharp/qre/models/qec/_yoked.py @@ -7,7 +7,7 @@ from ..._instruction import ISATransform, constraint, LOGICAL from ..._qre import ISA, ISARequirements, generic_function -from ..._architecture import _Context +from ..._architecture import ISAContext from ...instruction_ids import LATTICE_SURGERY, MEMORY from ...property_keys import DISTANCE @@ -58,7 +58,9 @@ def required_isa() -> ISARequirements: constraint(LATTICE_SURGERY, LOGICAL, arity=None, distance=True), ) - def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, None]: + def provided_isa( + self, impl_isa: ISA, ctx: ISAContext + ) -> Generator[ISA, None, None]: lattice_surgery = impl_isa[LATTICE_SURGERY] distance = lattice_surgery.get_property(DISTANCE) assert distance is not None @@ -178,7 +180,9 @@ def required_isa() -> ISARequirements: constraint(LATTICE_SURGERY, LOGICAL, arity=None, distance=True), ) - def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, None]: + def provided_isa( + self, impl_isa: ISA, ctx: ISAContext + ) -> Generator[ISA, None, None]: lattice_surgery = impl_isa[LATTICE_SURGERY] distance = lattice_surgery.get_property(DISTANCE) assert distance is not None diff --git a/source/pip/qsharp/qre/models/qubits/__init__.py b/source/pip/qsharp/qre/models/qubits/__init__.py index 99c9e1c156..ab7887faf3 100644 --- a/source/pip/qsharp/qre/models/qubits/__init__.py +++ b/source/pip/qsharp/qre/models/qubits/__init__.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from ._aqre import AQREGateBased +from ._gate_based import GateBased from ._msft import Majorana -__all__ = ["AQREGateBased", "Majorana"] +__all__ = ["GateBased", "Majorana"] diff --git a/source/pip/qsharp/qre/models/qubits/_aqre.py b/source/pip/qsharp/qre/models/qubits/_gate_based.py similarity index 86% rename from source/pip/qsharp/qre/models/qubits/_aqre.py rename to source/pip/qsharp/qre/models/qubits/_gate_based.py index 6e6f09b8be..d9ee589485 100644 --- a/source/pip/qsharp/qre/models/qubits/_aqre.py +++ b/source/pip/qsharp/qre/models/qubits/_gate_based.py @@ -4,7 +4,7 @@ from dataclasses import KW_ONLY, dataclass, field from typing import Optional -from ..._architecture import Architecture, _Context +from ..._architecture import Architecture, ISAContext from ..._instruction import ISA, Encoding from ...instruction_ids import ( CNOT, @@ -36,15 +36,10 @@ @dataclass -class AQREGateBased(Architecture): +class GateBased(Architecture): """ - A generic gate-based architecture based on the qubit parameters in Azure - Quantum Resource Estimator (AQRE, - [arXiv:2211.07629](https://arxiv.org/abs/2211.07629)). The error rate can - be set arbitrarily and is either 1e-3 or 1e-4 in the reference. Typical - gate times are 50ns and measurement times are 100ns for superconducting - transmon qubits - [arXiv:cond-mat/0703002](https://arxiv.org/abs/cond-mat/0703002). + A generic gate-based architecture. The error rate can be set arbitrarily + and is either 1e-3 or 1e-4 in the reference. Args: error_rate: The error rate for all gates. Defaults to 1e-4. @@ -76,7 +71,7 @@ def __post_init__(self): if self.two_qubit_gate_time is None: self.two_qubit_gate_time = self.gate_time - def provided_isa(self, ctx: _Context) -> ISA: + def provided_isa(self, ctx: ISAContext) -> ISA: # Value is initialized in __post_init__ assert self.two_qubit_gate_time is not None diff --git a/source/pip/qsharp/qre/models/qubits/_msft.py b/source/pip/qsharp/qre/models/qubits/_msft.py index 022157c1d4..1d74300e3e 100644 --- a/source/pip/qsharp/qre/models/qubits/_msft.py +++ b/source/pip/qsharp/qre/models/qubits/_msft.py @@ -3,7 +3,7 @@ from dataclasses import KW_ONLY, dataclass, field -from ..._architecture import Architecture, _Context +from ..._architecture import Architecture, ISAContext from ...instruction_ids import ( T, PREP_X, @@ -47,7 +47,7 @@ class Majorana(Architecture): _: KW_ONLY error_rate: float = field(default=1e-5, metadata={"domain": [1e-4, 1e-5, 1e-6]}) - def provided_isa(self, ctx: _Context) -> ISA: + def provided_isa(self, ctx: ISAContext) -> ISA: if abs(self.error_rate - 1e-4) <= 1e-8: t_error_rate = 0.05 elif abs(self.error_rate - 1e-5) <= 1e-8: diff --git a/source/pip/src/qre.rs b/source/pip/src/qre.rs index 23b5f6baf7..0e9daa1686 100644 --- a/source/pip/src/qre.rs +++ b/source/pip/src/qre.rs @@ -205,7 +205,7 @@ impl ISARequirementsIterator { } #[allow(clippy::unsafe_derive_deserialize)] -#[pyclass(name = "_Instruction")] +#[pyclass(from_py_object)] #[derive(Clone, Serialize, Deserialize)] #[serde(transparent)] pub struct Instruction(qre::Instruction); @@ -566,7 +566,7 @@ impl ConstraintBound { } #[derive(Clone)] -#[pyclass(name = "_ProvenanceGraph")] +#[pyclass(name = "_ProvenanceGraph", from_py_object)] pub struct ProvenanceGraph(Arc>); impl Default for ProvenanceGraph { diff --git a/source/pip/tests/qre/__init__.py b/source/pip/tests/qre/__init__.py new file mode 100644 index 0000000000..59e481eb93 --- /dev/null +++ b/source/pip/tests/qre/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. diff --git a/source/pip/tests/qre/conftest.py b/source/pip/tests/qre/conftest.py new file mode 100644 index 0000000000..c779e6ff31 --- /dev/null +++ b/source/pip/tests/qre/conftest.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from dataclasses import KW_ONLY, dataclass, field +from typing import Generator + +from qsharp.qre import ( + ISA, + LOGICAL, + ISARequirements, + ISATransform, + constraint, +) +from qsharp.qre._architecture import ISAContext +from qsharp.qre.instruction_ids import LATTICE_SURGERY, T + + +# NOTE These classes will be generalized as part of the QRE API in the following +# pull requests and then moved out of the tests. + + +@dataclass +class ExampleFactory(ISATransform): + _: KW_ONLY + level: int = field(default=1, metadata={"domain": range(1, 4)}) + + @staticmethod + def required_isa() -> ISARequirements: + return ISARequirements( + constraint(T), + ) + + def provided_isa( + self, impl_isa: ISA, ctx: ISAContext + ) -> Generator[ISA, None, None]: + yield ctx.make_isa( + ctx.add_instruction(T, encoding=LOGICAL, time=1000, error_rate=1e-8), + ) + + +@dataclass +class ExampleLogicalFactory(ISATransform): + _: KW_ONLY + level: int = field(default=1, metadata={"domain": range(1, 4)}) + + @staticmethod + def required_isa() -> ISARequirements: + return ISARequirements( + constraint(LATTICE_SURGERY, encoding=LOGICAL), + constraint(T, encoding=LOGICAL), + ) + + def provided_isa( + self, impl_isa: ISA, ctx: ISAContext + ) -> Generator[ISA, None, None]: + yield ctx.make_isa( + ctx.add_instruction(T, encoding=LOGICAL, time=1000, error_rate=1e-10), + ) diff --git a/source/pip/tests/qre/test_application.py b/source/pip/tests/qre/test_application.py new file mode 100644 index 0000000000..6b73222e12 --- /dev/null +++ b/source/pip/tests/qre/test_application.py @@ -0,0 +1,216 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from dataclasses import dataclass, field + +import qsharp + +from qsharp.qre import ( + Application, + ISA, + LOGICAL, + PSSPC, + EstimationResult, + LatticeSurgery, + Trace, + linear_function, +) +from qsharp.qre._qre import _ProvenanceGraph +from qsharp.qre._enumeration import _enumerate_instances +from qsharp.qre.application import QSharpApplication +from qsharp.qre.instruction_ids import CCX, LATTICE_SURGERY, T, RZ +from qsharp.qre.property_keys import ( + ALGORITHM_COMPUTE_QUBITS, + ALGORITHM_MEMORY_QUBITS, + LOGICAL_COMPUTE_QUBITS, + LOGICAL_MEMORY_QUBITS, +) + + +def _assert_estimation_result(trace: Trace, result: EstimationResult, isa: ISA): + actual_qubits = ( + isa[LATTICE_SURGERY].expect_space(trace.compute_qubits) + + isa[T].expect_space() * result.factories[T].copies + ) + if CCX in trace.resource_states: + actual_qubits += isa[CCX].expect_space() * result.factories[CCX].copies + assert result.qubits == actual_qubits + + assert ( + result.runtime + == isa[LATTICE_SURGERY].expect_time(trace.compute_qubits) * trace.depth + ) + + actual_error = ( + trace.base_error + + isa[LATTICE_SURGERY].expect_error_rate(trace.compute_qubits) * trace.depth + + isa[T].expect_error_rate() * result.factories[T].states + ) + if CCX in trace.resource_states: + actual_error += isa[CCX].expect_error_rate() * result.factories[CCX].states + assert abs(result.error - actual_error) <= 1e-8 + + +def test_trace_properties(): + trace = Trace(42) + + INT = 0 + FLOAT = 1 + BOOL = 2 + STR = 3 + + trace.set_property(INT, 42) + assert trace.get_property(INT) == 42 + assert isinstance(trace.get_property(INT), int) + + trace.set_property(FLOAT, 3.14) + assert trace.get_property(FLOAT) == 3.14 + assert isinstance(trace.get_property(FLOAT), float) + + trace.set_property(BOOL, True) + assert trace.get_property(BOOL) is True + assert isinstance(trace.get_property(BOOL), bool) + + trace.set_property(STR, "hello") + assert trace.get_property(STR) == "hello" + assert isinstance(trace.get_property(STR), str) + + +def test_qsharp_application(): + code = """ + {{ + use (a, b, c) = (Qubit(), Qubit(), Qubit()); + T(a); + CCNOT(a, b, c); + Rz(1.2345, a); + }} + """ + + app = QSharpApplication(code) + trace = app.get_trace() + + assert trace.compute_qubits == 3 + assert trace.depth == 3 + assert trace.resource_states == {} + + assert {c.id for c in trace.required_isa} == {CCX, T, RZ} + + graph = _ProvenanceGraph() + isa = graph.make_isa( + [ + graph.add_instruction( + LATTICE_SURGERY, + encoding=LOGICAL, + arity=None, + time=1000, + space=linear_function(50), + error_rate=linear_function(1e-6), + ), + graph.add_instruction( + T, encoding=LOGICAL, time=1000, space=400, error_rate=1e-8 + ), + graph.add_instruction( + CCX, encoding=LOGICAL, time=2000, space=800, error_rate=1e-10 + ), + ] + ) + + # Properties from the program + counts = qsharp.logical_counts(code) + num_ts = counts["tCount"] + num_ccx = counts["cczCount"] + num_rotations = counts["rotationCount"] + rotation_depth = counts["rotationDepth"] + + lattice_surgery = LatticeSurgery() + + counter = 0 + for psspc in _enumerate_instances(PSSPC): + counter += 1 + trace2 = psspc.transform(trace) + assert trace2 is not None + trace2 = lattice_surgery.transform(trace2) + assert trace2 is not None + assert trace2.compute_qubits == 12 + assert ( + trace2.depth + == num_ts + + num_ccx * 3 + + num_rotations + + rotation_depth * psspc.num_ts_per_rotation + ) + if psspc.ccx_magic_states: + assert trace2.resource_states == { + T: num_ts + psspc.num_ts_per_rotation * num_rotations, + CCX: num_ccx, + } + assert {c.id for c in trace2.required_isa} == {CCX, T, LATTICE_SURGERY} + else: + assert trace2.resource_states == { + T: num_ts + psspc.num_ts_per_rotation * num_rotations + 4 * num_ccx + } + assert {c.id for c in trace2.required_isa} == {T, LATTICE_SURGERY} + assert trace2.get_property(ALGORITHM_COMPUTE_QUBITS) == 3 + assert trace2.get_property(ALGORITHM_MEMORY_QUBITS) == 0 + result = trace2.estimate(isa, max_error=float("inf")) + assert result is not None + assert result.properties[ALGORITHM_COMPUTE_QUBITS] == 3 + assert result.properties[ALGORITHM_MEMORY_QUBITS] == 0 + assert result.properties[LOGICAL_COMPUTE_QUBITS] == 12 + assert result.properties[LOGICAL_MEMORY_QUBITS] == 0 + _assert_estimation_result(trace2, result, isa) + assert counter == 32 + + +def test_application_enumeration(): + @dataclass(kw_only=True) + class _Params: + size: int = field(default=1, metadata={"domain": range(1, 4)}) + + class TestApp(Application[_Params]): + def get_trace(self, parameters: _Params) -> Trace: + return Trace(parameters.size) + + app = TestApp() + assert sum(1 for _ in TestApp.q().enumerate(app.context())) == 3 + assert sum(1 for _ in TestApp.q(size=1).enumerate(app.context())) == 1 + assert sum(1 for _ in TestApp.q(size=[4, 5]).enumerate(app.context())) == 2 + + +def test_trace_enumeration(): + code = """ + {{ + use (a, b, c) = (Qubit(), Qubit(), Qubit()); + T(a); + CCNOT(a, b, c); + Rz(1.2345, a); + }} + """ + + app = QSharpApplication(code) + + ctx = app.context() + assert sum(1 for _ in QSharpApplication.q().enumerate(ctx)) == 1 + + assert sum(1 for _ in PSSPC.q().enumerate(ctx)) == 32 + + assert sum(1 for _ in LatticeSurgery.q().enumerate(ctx)) == 1 + + q = PSSPC.q() * LatticeSurgery.q() + assert sum(1 for _ in q.enumerate(ctx)) == 32 + + +def test_rotation_error_psspc(): + # This test helps to bound the variables for the number of rotations in PSSPC + + # Create a trace with a single rotation gate and ensure that the base error + # after PSSPC transformation is less than 1. + trace = Trace(1) + trace.add_operation(RZ, [0]) + + for psspc in _enumerate_instances(PSSPC, ccx_magic_states=False): + transformed = psspc.transform(trace) + assert transformed is not None + assert ( + transformed.base_error < 1.0 + ), f"Base error too high: {transformed.base_error} for {psspc.num_ts_per_rotation} T states per rotation" diff --git a/source/pip/tests/qre/test_enumeration.py b/source/pip/tests/qre/test_enumeration.py new file mode 100644 index 0000000000..476e65f22b --- /dev/null +++ b/source/pip/tests/qre/test_enumeration.py @@ -0,0 +1,527 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from dataclasses import KW_ONLY, dataclass, field +from enum import Enum +from typing import cast + +import pytest + +from qsharp.qre import LOGICAL +from qsharp.qre.models import SurfaceCode, GateBased +from qsharp.qre._isa_enumeration import ( + ISARefNode, + _ComponentQuery, + _ProductNode, + _SumNode, +) + +from .conftest import ExampleFactory, ExampleLogicalFactory + + +def test_enumerate_instances(): + from qsharp.qre._enumeration import _enumerate_instances + + instances = list(_enumerate_instances(SurfaceCode)) + + # There are 12 instances with distances from 3 to 25 + assert len(instances) == 12 + expected_distances = list(range(3, 26, 2)) + for instance, expected_distance in zip(instances, expected_distances): + assert instance.distance == expected_distance + + # Test with specific distances + instances = list(_enumerate_instances(SurfaceCode, distance=[3, 5, 7])) + assert len(instances) == 3 + expected_distances = [3, 5, 7] + for instance, expected_distance in zip(instances, expected_distances): + assert instance.distance == expected_distance + + # Test with fixed distance + instances = list(_enumerate_instances(SurfaceCode, distance=9)) + assert len(instances) == 1 + assert instances[0].distance == 9 + + +def test_enumerate_instances_bool(): + from qsharp.qre._enumeration import _enumerate_instances + + @dataclass + class BoolConfig: + _: KW_ONLY + flag: bool + + instances = list(_enumerate_instances(BoolConfig)) + assert len(instances) == 2 + assert instances[0].flag is True + assert instances[1].flag is False + + +def test_enumerate_instances_enum(): + from qsharp.qre._enumeration import _enumerate_instances + + class Color(Enum): + RED = 1 + GREEN = 2 + BLUE = 3 + + @dataclass + class EnumConfig: + _: KW_ONLY + color: Color + + instances = list(_enumerate_instances(EnumConfig)) + assert len(instances) == 3 + assert instances[0].color == Color.RED + assert instances[1].color == Color.GREEN + assert instances[2].color == Color.BLUE + + +def test_enumerate_instances_failure(): + from qsharp.qre._enumeration import _enumerate_instances + + @dataclass + class InvalidConfig: + _: KW_ONLY + # This field has no domain, is not bool/enum, and has no default + value: int + + with pytest.raises(ValueError, match="Cannot enumerate field value"): + list(_enumerate_instances(InvalidConfig)) + + +def test_enumerate_instances_single(): + from qsharp.qre._enumeration import _enumerate_instances + + @dataclass + class SingleConfig: + value: int = 42 + + instances = list(_enumerate_instances(SingleConfig)) + assert len(instances) == 1 + assert instances[0].value == 42 + + +def test_enumerate_instances_literal(): + from qsharp.qre._enumeration import _enumerate_instances + + from typing import Literal + + @dataclass + class LiteralConfig: + _: KW_ONLY + mode: Literal["fast", "slow"] + + instances = list(_enumerate_instances(LiteralConfig)) + assert len(instances) == 2 + assert instances[0].mode == "fast" + assert instances[1].mode == "slow" + + +def test_enumerate_instances_nested(): + from qsharp.qre._enumeration import _enumerate_instances + + @dataclass + class InnerConfig: + _: KW_ONLY + option: bool + + @dataclass + class OuterConfig: + _: KW_ONLY + inner: InnerConfig + + instances = list(_enumerate_instances(OuterConfig)) + assert len(instances) == 2 + assert instances[0].inner.option is True + assert instances[1].inner.option is False + + +def test_enumerate_instances_union(): + from qsharp.qre._enumeration import _enumerate_instances + + @dataclass + class OptionA: + _: KW_ONLY + value: bool + + @dataclass + class OptionB: + _: KW_ONLY + number: int = field(default=1, metadata={"domain": [1, 2, 3]}) + + @dataclass + class UnionConfig: + _: KW_ONLY + option: OptionA | OptionB + + instances = list(_enumerate_instances(UnionConfig)) + assert len(instances) == 5 + assert isinstance(instances[0].option, OptionA) + assert instances[0].option.value is True + assert isinstance(instances[2].option, OptionB) + assert instances[2].option.number == 1 + + +def test_enumerate_instances_nested_with_constraints(): + from qsharp.qre._enumeration import _enumerate_instances + + @dataclass + class InnerConfig: + _: KW_ONLY + option: bool + + @dataclass + class OuterConfig: + _: KW_ONLY + inner: InnerConfig + + # Constrain nested field via dict + instances = list(_enumerate_instances(OuterConfig, inner={"option": True})) + assert len(instances) == 1 + assert instances[0].inner.option is True + + +def test_enumerate_instances_union_single_type(): + from qsharp.qre._enumeration import _enumerate_instances + + @dataclass + class OptionA: + _: KW_ONLY + value: bool + + @dataclass + class OptionB: + _: KW_ONLY + number: int = field(default=1, metadata={"domain": [1, 2, 3]}) + + @dataclass + class UnionConfig: + _: KW_ONLY + option: OptionA | OptionB + + # Restrict to OptionB only - uses its default domain + instances = list(_enumerate_instances(UnionConfig, option=OptionB)) + assert len(instances) == 3 + assert all(isinstance(i.option, OptionB) for i in instances) + assert [cast(OptionB, i.option).number for i in instances] == [1, 2, 3] + + # Restrict to OptionA only + instances = list(_enumerate_instances(UnionConfig, option=OptionA)) + assert len(instances) == 2 + assert all(isinstance(i.option, OptionA) for i in instances) + assert cast(OptionA, instances[0].option).value is True + assert cast(OptionA, instances[1].option).value is False + + +def test_enumerate_instances_union_list_of_types(): + from qsharp.qre._enumeration import _enumerate_instances + + @dataclass + class OptionA: + _: KW_ONLY + value: bool + + @dataclass + class OptionB: + _: KW_ONLY + number: int = field(default=1, metadata={"domain": [1, 2, 3]}) + + @dataclass + class OptionC: + _: KW_ONLY + flag: bool + + @dataclass + class UnionConfig: + _: KW_ONLY + option: OptionA | OptionB | OptionC + + # Select a subset: only OptionA and OptionB + instances = list(_enumerate_instances(UnionConfig, option=[OptionA, OptionB])) + assert len(instances) == 5 # 2 from OptionA + 3 from OptionB + assert all(isinstance(i.option, (OptionA, OptionB)) for i in instances) + + +def test_enumerate_instances_union_constraint_dict(): + from qsharp.qre._enumeration import _enumerate_instances + + @dataclass + class OptionA: + _: KW_ONLY + value: bool + + @dataclass + class OptionB: + _: KW_ONLY + number: int = field(default=1, metadata={"domain": [1, 2, 3]}) + + @dataclass + class UnionConfig: + _: KW_ONLY + option: OptionA | OptionB + + # Constrain OptionA, enumerate only that member + instances = list( + _enumerate_instances(UnionConfig, option={OptionA: {"value": True}}) + ) + assert len(instances) == 1 + assert isinstance(instances[0].option, OptionA) + assert instances[0].option.value is True + + # Constrain OptionB with a domain, enumerate only that member + instances = list( + _enumerate_instances(UnionConfig, option={OptionB: {"number": [2, 3]}}) + ) + assert len(instances) == 2 + assert all(isinstance(i.option, OptionB) for i in instances) + assert cast(OptionB, instances[0].option).number == 2 + assert cast(OptionB, instances[1].option).number == 3 + + # Constrain one member and keep another with defaults + instances = list( + _enumerate_instances( + UnionConfig, + option={OptionA: {"value": True}, OptionB: {}}, + ) + ) + assert len(instances) == 4 # 1 from OptionA + 3 from OptionB + assert isinstance(instances[0].option, OptionA) + assert instances[0].option.value is True + assert all(isinstance(i.option, OptionB) for i in instances[1:]) + assert [cast(OptionB, i.option).number for i in instances[1:]] == [1, 2, 3] + + +def test_enumerate_isas(): + ctx = GateBased(gate_time=50, measurement_time=100).context() + + # This will enumerate the 4 ISAs for the error correction code + count = sum(1 for _ in SurfaceCode.q().enumerate(ctx)) + assert count == 12 + + # This will enumerate the 2 ISAs for the error correction code when + # restricting the domain + count = sum(1 for _ in SurfaceCode.q(distance=[3, 4]).enumerate(ctx)) + assert count == 2 + + # This will enumerate the 3 ISAs for the factory + count = sum(1 for _ in ExampleFactory.q().enumerate(ctx)) + assert count == 3 + + # This will enumerate 36 ISAs for all products between the 12 error + # correction code ISAs and the 3 factory ISAs + count = sum(1 for _ in (SurfaceCode.q() * ExampleFactory.q()).enumerate(ctx)) + assert count == 36 + + # When providing a list, components are chained (OR operation). This + # enumerates ISAs from first factory instance OR second factory instance + count = sum( + 1 + for _ in ( + SurfaceCode.q() * (ExampleFactory.q() + ExampleFactory.q()) + ).enumerate(ctx) + ) + assert count == 72 + + # When providing separate arguments, components are combined via product + # (AND). This enumerates ISAs from first factory instance AND second + # factory instance + count = sum( + 1 + for _ in (SurfaceCode.q() * ExampleFactory.q() * ExampleFactory.q()).enumerate( + ctx + ) + ) + assert count == 108 + + # Hierarchical factory using from_components: the component receives ISAs + # from the product of other components as its source + count = sum( + 1 + for _ in ( + SurfaceCode.q() + * ExampleLogicalFactory.q(source=(SurfaceCode.q() * ExampleFactory.q())) + ).enumerate(ctx) + ) + assert count == 1296 + + +def test_binding_node(): + """Test binding nodes with ISARefNode for component bindings""" + ctx = GateBased(gate_time=50, measurement_time=100).context() + + # Test basic binding: same code used twice + # Without binding: 12 codes × 12 codes = 144 combinations + count_without = sum(1 for _ in (SurfaceCode.q() * SurfaceCode.q()).enumerate(ctx)) + assert count_without == 144 + + # With binding: 12 codes (same instance used twice) + count_with = sum( + 1 + for _ in SurfaceCode.bind("c", ISARefNode("c") * ISARefNode("c")).enumerate(ctx) + ) + assert count_with == 12 + + # Verify the binding works: with binding, both should use same params + for isa in SurfaceCode.bind("c", ISARefNode("c") * ISARefNode("c")).enumerate(ctx): + logical_gates = [g for g in isa if g.encoding == LOGICAL] + # Should have 1 logical gate (LATTICE_SURGERY) + assert len(logical_gates) == 1 + + # Test binding with factories (nested bindings) + count_without = sum( + 1 + for _ in ( + SurfaceCode.q() * ExampleFactory.q() * SurfaceCode.q() * ExampleFactory.q() + ).enumerate(ctx) + ) + assert count_without == 1296 # 12 * 3 * 12 * 3 + + count_with = sum( + 1 + for _ in SurfaceCode.bind( + "c", + ExampleFactory.bind( + "f", + ISARefNode("c") * ISARefNode("f") * ISARefNode("c") * ISARefNode("f"), + ), + ).enumerate(ctx) + ) + assert count_with == 36 # 12 * 3 + + # Test binding with from_components equivalent (hierarchical) + # Without binding: 4 outer codes × (4 inner codes × 3 factories × 3 levels) + count_without = sum( + 1 + for _ in ( + SurfaceCode.q() + * ExampleLogicalFactory.q( + source=(SurfaceCode.q() * ExampleFactory.q()), + ) + ).enumerate(ctx) + ) + assert count_without == 1296 # 12 * 12 * 3 * 3 + + # With binding: 4 codes (same used twice) × 3 factories × 3 levels + count_with = sum( + 1 + for _ in SurfaceCode.bind( + "c", + ISARefNode("c") + * ExampleLogicalFactory.q( + source=(ISARefNode("c") * ExampleFactory.q()), + ), + ).enumerate(ctx) + ) + assert count_with == 108 # 12 * 3 * 3 + + # Test binding with kwargs + count_with_kwargs = sum( + 1 + for _ in SurfaceCode.q(distance=5) + .bind("c", ISARefNode("c") * ISARefNode("c")) + .enumerate(ctx) + ) + assert count_with_kwargs == 1 # Only distance=5 + + # Verify kwargs are applied + for isa in ( + SurfaceCode.q(distance=5) + .bind("c", ISARefNode("c") * ISARefNode("c")) + .enumerate(ctx) + ): + logical_gates = [g for g in isa if g.encoding == LOGICAL] + assert all(g.space(1) == 49 for g in logical_gates) + + # Test multiple independent bindings (nested) + count = sum( + 1 + for _ in SurfaceCode.bind( + "c1", + ExampleFactory.bind( + "c2", + ISARefNode("c1") + * ISARefNode("c1") + * ISARefNode("c2") + * ISARefNode("c2"), + ), + ).enumerate(ctx) + ) + # 12 codes for c1 × 3 factories for c2 + assert count == 36 + + +def test_binding_node_errors(): + """Test error handling for binding nodes""" + ctx = GateBased(gate_time=50, measurement_time=100).context() + + # Test ISARefNode enumerate with undefined binding raises ValueError + try: + list(ISARefNode("test").enumerate(ctx)) + assert False, "Should have raised ValueError" + except ValueError as e: + assert "Undefined component reference: 'test'" in str(e) + + +def test_product_isa_enumeration_nodes(): + terminal = SurfaceCode.q() + query = terminal * terminal + + # Multiplication should create ProductNode + assert isinstance(query, _ProductNode) + assert len(query.sources) == 2 + for source in query.sources: + assert isinstance(source, _ComponentQuery) + + # Multiplying again should extend the sources + query = query * terminal + assert isinstance(query, _ProductNode) + assert len(query.sources) == 3 + for source in query.sources: + assert isinstance(source, _ComponentQuery) + + # Also from the other side + query = terminal * query + assert isinstance(query, _ProductNode) + assert len(query.sources) == 4 + for source in query.sources: + assert isinstance(source, _ComponentQuery) + + # Also for two ProductNodes + query = query * query + assert isinstance(query, _ProductNode) + assert len(query.sources) == 8 + for source in query.sources: + assert isinstance(source, _ComponentQuery) + + +def test_sum_isa_enumeration_nodes(): + terminal = SurfaceCode.q() + query = terminal + terminal + + # Multiplication should create SumNode + assert isinstance(query, _SumNode) + assert len(query.sources) == 2 + for source in query.sources: + assert isinstance(source, _ComponentQuery) + + # Multiplying again should extend the sources + query = query + terminal + assert isinstance(query, _SumNode) + assert len(query.sources) == 3 + for source in query.sources: + assert isinstance(source, _ComponentQuery) + + # Also from the other side + query = terminal + query + assert isinstance(query, _SumNode) + assert len(query.sources) == 4 + for source in query.sources: + assert isinstance(source, _ComponentQuery) + + # Also for two SumNodes + query = query + query + assert isinstance(query, _SumNode) + assert len(query.sources) == 8 + for source in query.sources: + assert isinstance(source, _ComponentQuery) diff --git a/source/pip/tests/qre/test_estimation.py b/source/pip/tests/qre/test_estimation.py new file mode 100644 index 0000000000..bb857115ed --- /dev/null +++ b/source/pip/tests/qre/test_estimation.py @@ -0,0 +1,98 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import os + +import pytest + +from qsharp.estimator import LogicalCounts +from qsharp.qre import ( + PSSPC, + LatticeSurgery, + estimate, +) +from qsharp.qre.application import QSharpApplication +from qsharp.qre.models import ( + SurfaceCode, + GateBased, + RoundBasedFactory, + TwoDimensionalYokedSurfaceCode, +) + +from .conftest import ExampleFactory + + +def test_estimation_max_error(): + app = QSharpApplication(LogicalCounts({"numQubits": 100, "measurementCount": 100})) + arch = GateBased(gate_time=50, measurement_time=100) + + for max_error in [1e-1, 1e-2, 1e-3, 1e-4]: + results = estimate( + app, + arch, + SurfaceCode.q() * ExampleFactory.q(), + PSSPC.q() * LatticeSurgery.q(), + max_error=max_error, + ) + + assert len(results) == 1 + assert next(iter(results)).error <= max_error + + +@pytest.mark.skipif( + "SLOW_TESTS" not in os.environ, + reason="turn on slow tests by setting SLOW_TESTS=1 in the environment", +) +@pytest.mark.parametrize( + "post_process, use_graph", + [ + (False, False), + (True, False), + (False, True), + (True, True), + ], +) +def test_estimation_methods(post_process, use_graph): + counts = LogicalCounts( + { + "numQubits": 1000, + "tCount": 1_500_000, + "rotationCount": 0, + "rotationDepth": 0, + "cczCount": 1_000_000_000, + "ccixCount": 0, + "measurementCount": 25_000_000, + "numComputeQubits": 200, + "readFromMemoryCount": 30_000_000, + "writeToMemoryCount": 30_000_000, + } + ) + + trace_query = PSSPC.q() * LatticeSurgery.q(slow_down_factor=[1.0, 2.0]) + isa_query = ( + SurfaceCode.q() + * RoundBasedFactory.q() + * TwoDimensionalYokedSurfaceCode.q(source=SurfaceCode.q()) + ) + + app = QSharpApplication(counts) + arch = GateBased(gate_time=50, measurement_time=100) + + results = estimate( + app, + arch, + isa_query, + trace_query, + max_error=1 / 3, + post_process=post_process, + use_graph=use_graph, + ) + results.add_factory_summary_column() + + assert [(result.qubits, result.runtime) for result in results] == [ + (238707, 23997050000000), + (240407, 11998525000000), + ] + + print() + print(results.stats) diff --git a/source/pip/tests/qre/test_estimation_table.py b/source/pip/tests/qre/test_estimation_table.py new file mode 100644 index 0000000000..d2a25ae31b --- /dev/null +++ b/source/pip/tests/qre/test_estimation_table.py @@ -0,0 +1,439 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import cast, Sized + +import pytest +import pandas as pd + +from qsharp.qre import ( + PSSPC, + LatticeSurgery, + estimate, +) +from qsharp.qre.application import QSharpApplication +from qsharp.qre.models import SurfaceCode, GateBased +from qsharp.qre._estimation import ( + EstimationTable, + EstimationTableEntry, +) +from qsharp.qre._instruction import InstructionSource +from qsharp.qre.instruction_ids import LATTICE_SURGERY +from qsharp.qre.property_keys import DISTANCE, NUM_TS_PER_ROTATION + +from .conftest import ExampleFactory + + +def _make_entry(qubits, runtime, error, properties=None): + """Helper to create an EstimationTableEntry with a dummy InstructionSource.""" + return EstimationTableEntry( + qubits=qubits, + runtime=runtime, + error=error, + source=InstructionSource(), + properties=properties or {}, + ) + + +def test_estimation_table_default_columns(): + """Test that a new EstimationTable has the three default columns.""" + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01)) + + frame = table.as_frame() + assert list(frame.columns) == ["qubits", "runtime", "error"] + assert frame["qubits"][0] == 100 + assert frame["runtime"][0] == pd.Timedelta(5000, unit="ns") + assert frame["error"][0] == 0.01 + + +def test_estimation_table_multiple_rows(): + """Test as_frame with multiple entries.""" + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01)) + table.append(_make_entry(200, 10000, 0.02)) + table.append(_make_entry(300, 15000, 0.03)) + + frame = table.as_frame() + assert len(frame) == 3 + assert list(frame["qubits"]) == [100, 200, 300] + assert list(frame["error"]) == [0.01, 0.02, 0.03] + + +def test_estimation_table_empty(): + """Test as_frame with no entries produces an empty DataFrame.""" + table = EstimationTable() + frame = table.as_frame() + assert len(frame) == 0 + + +def test_estimation_table_add_column(): + """Test adding a column to the table.""" + VAL = 0 + + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01, properties={VAL: 42})) + table.append(_make_entry(200, 10000, 0.02, properties={VAL: 84})) + + table.add_column("val", lambda e: e.properties[VAL]) + + frame = table.as_frame() + assert list(frame.columns) == ["qubits", "runtime", "error", "val"] + assert list(frame["val"]) == [42, 84] + + +def test_estimation_table_add_column_with_formatter(): + """Test adding a column with a formatter.""" + NS = 0 + + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01, properties={NS: 1000})) + + table.add_column( + "duration", + lambda e: e.properties[NS], + formatter=lambda x: pd.Timedelta(x, unit="ns"), + ) + + frame = table.as_frame() + assert frame["duration"][0] == pd.Timedelta(1000, unit="ns") + + +def test_estimation_table_add_multiple_columns(): + """Test adding multiple columns preserves order.""" + A = 0 + B = 1 + C = 2 + + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01, properties={A: 1, B: 2, C: 3})) + + table.add_column("a", lambda e: e.properties[A]) + table.add_column("b", lambda e: e.properties[B]) + table.add_column("c", lambda e: e.properties[C]) + + frame = table.as_frame() + assert list(frame.columns) == ["qubits", "runtime", "error", "a", "b", "c"] + assert frame["a"][0] == 1 + assert frame["b"][0] == 2 + assert frame["c"][0] == 3 + + +def test_estimation_table_insert_column_at_beginning(): + """Test inserting a column at index 0.""" + NAME = 0 + + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01, properties={NAME: "test"})) + + table.insert_column(0, "name", lambda e: e.properties[NAME]) + + frame = table.as_frame() + assert list(frame.columns) == ["name", "qubits", "runtime", "error"] + assert frame["name"][0] == "test" + + +def test_estimation_table_insert_column_in_middle(): + """Test inserting a column between existing default columns.""" + EXTRA = 0 + + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01, properties={EXTRA: 99})) + + # Insert between qubits and runtime (index 1) + table.insert_column(1, "extra", lambda e: e.properties[EXTRA]) + + frame = table.as_frame() + assert list(frame.columns) == ["qubits", "extra", "runtime", "error"] + assert frame["extra"][0] == 99 + + +def test_estimation_table_insert_column_at_end(): + """Test inserting a column at the end (same effect as add_column).""" + LAST = 0 + + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01, properties={LAST: True})) + + # 3 default columns, inserting at index 3 = end + table.insert_column(3, "last", lambda e: e.properties[LAST]) + + frame = table.as_frame() + assert list(frame.columns) == ["qubits", "runtime", "error", "last"] + assert frame["last"][0] + + +def test_estimation_table_insert_column_with_formatter(): + """Test inserting a column with a formatter.""" + NS = 0 + + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01, properties={NS: 2000})) + + table.insert_column( + 0, + "custom_time", + lambda e: e.properties[NS], + formatter=lambda x: pd.Timedelta(x, unit="ns"), + ) + + frame = table.as_frame() + assert frame["custom_time"][0] == pd.Timedelta(2000, unit="ns") + assert list(frame.columns)[0] == "custom_time" + + +def test_estimation_table_insert_and_add_columns(): + """Test combining insert_column and add_column.""" + A = 0 + B = 0 + + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01, properties={A: 1, B: 2})) + + table.add_column("b", lambda e: e.properties[B]) + table.insert_column(0, "a", lambda e: e.properties[A]) + + frame = table.as_frame() + assert list(frame.columns) == ["a", "qubits", "runtime", "error", "b"] + + +def test_estimation_table_factory_summary_no_factories(): + """Test factory summary column when entries have no factories.""" + table = EstimationTable() + table.append(_make_entry(100, 5000, 0.01)) + + table.add_factory_summary_column() + + frame = table.as_frame() + assert "factories" in frame.columns + assert frame["factories"][0] == "None" + + +def test_estimation_table_factory_summary_with_estimation(): + """Test factory summary column with real estimation results.""" + code = """ + { + use (a, b, c) = (Qubit(), Qubit(), Qubit()); + T(a); + CCNOT(a, b, c); + Rz(1.2345, a); + } + """ + app = QSharpApplication(code) + arch = GateBased(gate_time=50, measurement_time=100) + results = estimate( + app, + arch, + SurfaceCode.q() * ExampleFactory.q(), + PSSPC.q() * LatticeSurgery.q(), + max_error=0.5, + ) + + assert len(results) >= 1 + + results.add_factory_summary_column() + frame = results.as_frame() + + assert "factories" in frame.columns + # Each result should mention T in the factory summary + for val in frame["factories"]: + assert "T" in val + + +def test_estimation_table_add_column_from_source(): + """Test adding a column that accesses the InstructionSource (like distance).""" + code = """ + { + use (a, b, c) = (Qubit(), Qubit(), Qubit()); + T(a); + CCNOT(a, b, c); + Rz(1.2345, a); + } + """ + app = QSharpApplication(code) + arch = GateBased(gate_time=50, measurement_time=100) + results = estimate( + app, + arch, + SurfaceCode.q() * ExampleFactory.q(), + PSSPC.q() * LatticeSurgery.q(), + max_error=0.5, + ) + + assert len(results) >= 1 + + results.add_column( + "compute_distance", + lambda entry: entry.source[LATTICE_SURGERY].instruction[DISTANCE], + ) + + frame = results.as_frame() + assert "compute_distance" in frame.columns + for d in frame["compute_distance"]: + assert isinstance(d, int) + assert d >= 3 + + +def test_estimation_table_add_column_from_properties(): + """Test adding columns that access trace properties from estimation.""" + code = """ + { + use (a, b, c) = (Qubit(), Qubit(), Qubit()); + T(a); + CCNOT(a, b, c); + Rz(1.2345, a); + } + """ + app = QSharpApplication(code) + arch = GateBased(gate_time=50, measurement_time=100) + results = estimate( + app, + arch, + SurfaceCode.q() * ExampleFactory.q(), + PSSPC.q() * LatticeSurgery.q(), + max_error=0.5, + ) + + assert len(results) >= 1 + + results.add_column( + "num_ts_per_rotation", + lambda entry: entry.properties[NUM_TS_PER_ROTATION], + ) + + frame = results.as_frame() + assert "num_ts_per_rotation" in frame.columns + for val in frame["num_ts_per_rotation"]: + assert isinstance(val, int) + assert val >= 1 + + +def test_estimation_table_insert_column_before_defaults(): + """Test inserting a name column before all default columns, similar to the factoring notebook.""" + code = """ + { + use (a, b, c) = (Qubit(), Qubit(), Qubit()); + T(a); + CCNOT(a, b, c); + Rz(1.2345, a); + } + """ + app = QSharpApplication(code) + arch = GateBased(gate_time=50, measurement_time=100) + results = estimate( + app, + arch, + SurfaceCode.q() * ExampleFactory.q(), + PSSPC.q() * LatticeSurgery.q(), + max_error=0.5, + name="test_experiment", + ) + + assert len(results) >= 1 + + # Add a factory summary at the end + results.add_factory_summary_column() + + frame = results.as_frame() + assert frame.columns[0] == "name" + assert frame.columns[-1] == "factories" + # Default columns should still be in order + assert list(frame.columns[1:4]) == ["qubits", "runtime", "error"] + + +def test_estimation_table_as_frame_sortable(): + """Test that the DataFrame from as_frame can be sorted, as done in the factoring tests.""" + table = EstimationTable() + table.append(_make_entry(300, 15000, 0.03)) + table.append(_make_entry(100, 5000, 0.01)) + table.append(_make_entry(200, 10000, 0.02)) + + frame = table.as_frame() + sorted_frame = frame.sort_values(by=["qubits", "runtime"]).reset_index(drop=True) + + assert list(sorted_frame["qubits"]) == [100, 200, 300] + assert list(sorted_frame["error"]) == [0.01, 0.02, 0.03] + + +def test_estimation_table_computed_column(): + """Test adding a column that computes a derived value from the entry.""" + table = EstimationTable() + table.append(_make_entry(100, 5_000_000, 0.01)) + table.append(_make_entry(200, 10_000_000, 0.02)) + + # Compute qubits * error as a derived metric + table.add_column("qubit_error_product", lambda e: e.qubits * e.error) + + frame = table.as_frame() + assert frame["qubit_error_product"][0] == pytest.approx(1.0) + assert frame["qubit_error_product"][1] == pytest.approx(4.0) + + +def test_estimation_table_plot_returns_figure(): + """Test that plot() returns a matplotlib Figure with correct axes.""" + from matplotlib.figure import Figure + + table = EstimationTable() + table.append(_make_entry(100, 5_000_000_000, 0.01)) + table.append(_make_entry(200, 10_000_000_000, 0.02)) + table.append(_make_entry(50, 50_000_000_000, 0.005)) + + fig = table.plot() + + assert isinstance(fig, Figure) + ax = fig.axes[0] + assert ax.get_ylabel() == "Physical qubits" + assert ax.get_xlabel() == "Runtime" + assert ax.get_xscale() == "log" + assert ax.get_yscale() == "log" + + # Verify data points + offsets = ax.collections[0].get_offsets() + assert len(cast(Sized, offsets)) == 3 + + +def test_estimation_table_plot_empty_raises(): + """Test that plot() raises ValueError on an empty table.""" + table = EstimationTable() + with pytest.raises(ValueError, match="Cannot plot an empty EstimationTable"): + table.plot() + + +def test_estimation_table_plot_single_entry(): + """Test that plot() works with a single entry.""" + from matplotlib.figure import Figure + + table = EstimationTable() + table.append(_make_entry(100, 1_000_000, 0.01)) + + fig = table.plot() + assert isinstance(fig, Figure) + + offsets = fig.axes[0].collections[0].get_offsets() + assert len(cast(Sized, offsets)) == 1 + + +def test_estimation_table_plot_with_runtime_unit(): + """Test that plot(runtime_unit=...) scales x values and labels the axis.""" + table = EstimationTable() + # 1 hour = 3600e9 ns, 2 hours = 7200e9 ns + table.append(_make_entry(100, int(3600e9), 0.01)) + table.append(_make_entry(200, int(7200e9), 0.02)) + + fig = table.plot(runtime_unit="hours") + + ax = fig.axes[0] + assert ax.get_xlabel() == "Runtime (hours)" + + # Verify the x data is scaled: should be 1.0 and 2.0 hours + offsets = cast(list, ax.collections[0].get_offsets()) + assert offsets[0][0] == pytest.approx(1.0) + assert offsets[1][0] == pytest.approx(2.0) + + +def test_estimation_table_plot_invalid_runtime_unit(): + """Test that plot() raises ValueError for an unknown runtime_unit.""" + table = EstimationTable() + table.append(_make_entry(100, 1000, 0.01)) + with pytest.raises(ValueError, match="Unknown runtime_unit"): + table.plot(runtime_unit="fortnights") diff --git a/source/pip/tests/qre/test_interop.py b/source/pip/tests/qre/test_interop.py new file mode 100644 index 0000000000..a8f7900abb --- /dev/null +++ b/source/pip/tests/qre/test_interop.py @@ -0,0 +1,213 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from pathlib import Path + +import pytest + +from qsharp.qre.interop import trace_from_qir + + +def _ll_files(): + ll_dir = ( + Path(__file__).parent.parent.parent + / "tests-integration" + / "resources" + / "adaptive_ri" + / "output" + ) + return sorted(ll_dir.glob("*.ll")) + + +@pytest.mark.parametrize("ll_file", _ll_files(), ids=lambda p: p.stem) +def test_trace_from_qir(ll_file): + # NOTE: This test is primarily to ensure that the function can parse real + # QIR output without errors, rather than checking specific properties of the + # trace. + try: + trace_from_qir(ll_file.read_text()) + except ValueError as e: + # The only reason of failure is presence of control flow + assert ( + str(e) + == "simulation of programs with branching control flow is not supported" + ) + + +def test_trace_from_qir_handles_all_instruction_ids(): + """Verify that trace_from_qir handles every QirInstructionId except CorrelatedNoise. + + Generates a synthetic QIR program containing one instance of each gate + intrinsic recognised by AggregateGatesPass and asserts that trace_from_qir + processes all of them without error. + """ + import pyqir + import pyqir.qis as qis + from qsharp._native import QirInstructionId + from qsharp.qre.interop._qir import _GATE_MAP, _MEAS_MAP, _SKIP + + # -- Completeness check: every QirInstructionId must be covered -------- + handled_ids = ( + [qir_id for qir_id, _, _ in _GATE_MAP] + + [qir_id for qir_id, _ in _MEAS_MAP] + + list(_SKIP) + ) + # Exhaustive list of all QirInstructionId variants (pyo3 enums are not iterable) + all_ids = [ + QirInstructionId.I, + QirInstructionId.H, + QirInstructionId.X, + QirInstructionId.Y, + QirInstructionId.Z, + QirInstructionId.S, + QirInstructionId.SAdj, + QirInstructionId.SX, + QirInstructionId.SXAdj, + QirInstructionId.T, + QirInstructionId.TAdj, + QirInstructionId.CNOT, + QirInstructionId.CX, + QirInstructionId.CY, + QirInstructionId.CZ, + QirInstructionId.CCX, + QirInstructionId.SWAP, + QirInstructionId.RX, + QirInstructionId.RY, + QirInstructionId.RZ, + QirInstructionId.RXX, + QirInstructionId.RYY, + QirInstructionId.RZZ, + QirInstructionId.RESET, + QirInstructionId.M, + QirInstructionId.MResetZ, + QirInstructionId.MZ, + QirInstructionId.Move, + QirInstructionId.ReadResult, + QirInstructionId.ResultRecordOutput, + QirInstructionId.BoolRecordOutput, + QirInstructionId.IntRecordOutput, + QirInstructionId.DoubleRecordOutput, + QirInstructionId.TupleRecordOutput, + QirInstructionId.ArrayRecordOutput, + QirInstructionId.CorrelatedNoise, + ] + unhandled = [ + i + for i in all_ids + if i not in handled_ids and i != QirInstructionId.CorrelatedNoise + ] + assert unhandled == [], ( + f"QirInstructionId values not covered by _GATE_MAP, _MEAS_MAP, or _SKIP: " + f"{', '.join(str(i) for i in unhandled)}" + ) + + # -- Generate a QIR program with every producible gate ----------------- + simple = pyqir.SimpleModule("test_all_gates", num_qubits=4, num_results=3) + builder = simple.builder + ctx = simple.context + q = simple.qubits + r = simple.results + + void_ty = pyqir.Type.void(ctx) + qubit_ty = pyqir.qubit_type(ctx) + result_ty = pyqir.result_type(ctx) + double_ty = pyqir.Type.double(ctx) + i64_ty = pyqir.IntType(ctx, 64) + + def declare(name, param_types): + return simple.add_external_function( + name, pyqir.FunctionType(void_ty, param_types) + ) + + # Single-qubit gates (pyqir.qis builtins) + qis.h(builder, q[0]) + qis.x(builder, q[0]) + qis.y(builder, q[0]) + qis.z(builder, q[0]) + qis.s(builder, q[0]) + qis.s_adj(builder, q[0]) + qis.t(builder, q[0]) + qis.t_adj(builder, q[0]) + + # SX — not in pyqir.qis + sx_fn = declare("__quantum__qis__sx__body", [qubit_ty]) + builder.call(sx_fn, [q[0]]) + + # Two-qubit gates (qis.cx emits __quantum__qis__cnot__body which the + # pass does not handle, so use builder.call with the correct name) + cx_fn = declare("__quantum__qis__cx__body", [qubit_ty, qubit_ty]) + builder.call(cx_fn, [q[0], q[1]]) + qis.cz(builder, q[0], q[1]) + qis.swap(builder, q[0], q[1]) + + cy_fn = declare("__quantum__qis__cy__body", [qubit_ty, qubit_ty]) + builder.call(cy_fn, [q[0], q[1]]) + + # Three-qubit gate + qis.ccx(builder, q[0], q[1], q[2]) + + # Single-qubit rotations + qis.rx(builder, 1.0, q[0]) + qis.ry(builder, 1.0, q[0]) + qis.rz(builder, 1.0, q[0]) + + # Two-qubit rotations — not in pyqir.qis + rot2_ty = [double_ty, qubit_ty, qubit_ty] + angle = pyqir.const(double_ty, 1.0) + for name in ("rxx", "ryy", "rzz"): + fn = declare(f"__quantum__qis__{name}__body", rot2_ty) + builder.call(fn, [angle, q[0], q[1]]) + + # Measurements + qis.mz(builder, q[0], r[0]) + + m_fn = declare("__quantum__qis__m__body", [qubit_ty, result_ty]) + builder.call(m_fn, [q[1], r[1]]) + + mresetz_fn = declare("__quantum__qis__mresetz__body", [qubit_ty, result_ty]) + builder.call(mresetz_fn, [q[2], r[2]]) + + # Reset / Move + qis.reset(builder, q[0]) + + move_fn = declare("__quantum__qis__move__body", [qubit_ty]) + builder.call(move_fn, [q[0]]) + + # Output recording + tag = simple.add_byte_string(b"tag") + arr_fn = declare("__quantum__rt__array_record_output", [i64_ty, tag.type]) + builder.call(arr_fn, [pyqir.const(i64_ty, 1), tag]) + + rec_fn = declare("__quantum__rt__result_record_output", [result_ty, tag.type]) + builder.call(rec_fn, [r[0], tag]) + + tup_fn = declare("__quantum__rt__tuple_record_output", [i64_ty, tag.type]) + builder.call(tup_fn, [pyqir.const(i64_ty, 1), tag]) + + # -- Run trace_from_qir and verify it succeeds ------------------------- + trace = trace_from_qir(simple.ir()) + assert trace is not None + + +def test_rotation_buckets(): + from qsharp.qre.interop._qsharp import _bucketize_rotation_counts + + print() + + r_count = 15066 + r_depth = 14756 + q_count = 291 + + result = _bucketize_rotation_counts(r_count, r_depth) + + a_count = 0 + a_depth = 0 + for c, d in result: + print(c, d) + assert c <= q_count + assert c > 0 + a_count += c * d + a_depth += d + + assert a_count == r_count + assert a_depth == r_depth diff --git a/source/pip/tests/qre/test_isa.py b/source/pip/tests/qre/test_isa.py new file mode 100644 index 0000000000..6c1e8e318a --- /dev/null +++ b/source/pip/tests/qre/test_isa.py @@ -0,0 +1,181 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import pytest + +from qsharp.qre import ( + LOGICAL, + ISARequirements, + constraint, + generic_function, + property_name, + property_name_to_key, +) +from qsharp.qre._qre import _ProvenanceGraph +from qsharp.qre.models import SurfaceCode, GateBased +from qsharp.qre._architecture import _make_instruction +from qsharp.qre.instruction_ids import CCX, CCZ, LATTICE_SURGERY, T +from qsharp.qre.property_keys import DISTANCE + + +def test_isa(): + graph = _ProvenanceGraph() + isa = graph.make_isa( + [ + graph.add_instruction( + T, encoding=LOGICAL, time=1000, space=400, error_rate=1e-8 + ), + graph.add_instruction( + CCX, encoding=LOGICAL, arity=3, time=2000, space=800, error_rate=1e-10 + ), + ] + ) + + assert T in isa + assert CCX in isa + assert LATTICE_SURGERY not in isa + + t_instr = isa[T] + assert t_instr.time() == 1000 + assert t_instr.error_rate() == 1e-8 + assert t_instr.space() == 400 + + assert len(isa) == 2 + ccz_instr = isa[CCX].with_id(CCZ) + assert ccz_instr.arity == 3 + assert ccz_instr.time() == 2000 + assert ccz_instr.error_rate() == 1e-10 + assert ccz_instr.space() == 800 + + # Add another instruction to the graph and register it in the ISA + ccz_node = graph.add_instruction(ccz_instr) + isa.add_node(CCZ, ccz_node) + assert CCZ in isa + assert len(isa) == 3 + + # Adding the same instruction ID should not increase the count + isa.add_node(CCZ, ccz_node) + assert len(isa) == 3 + + +def test_instruction_properties(): + # Test instruction with no properties + instr_no_props = _make_instruction(T, 1, 1, 1000, None, None, 1e-8, {}) + assert instr_no_props.get_property(DISTANCE) is None + assert instr_no_props.has_property(DISTANCE) is False + assert instr_no_props.get_property_or(DISTANCE, 5) == 5 + + # Test instruction with valid property (distance) + instr_with_distance = _make_instruction( + T, 1, 1, 1000, None, None, 1e-8, {"distance": 9} + ) + assert instr_with_distance.get_property(DISTANCE) == 9 + assert instr_with_distance.has_property(DISTANCE) is True + assert instr_with_distance.get_property_or(DISTANCE, 5) == 9 + + # Test instruction with invalid property name + with pytest.raises(ValueError, match="Unknown property 'invalid_prop'"): + _make_instruction(T, 1, 1, 1000, None, None, 1e-8, {"invalid_prop": 42}) + + +def test_instruction_constraints(): + # Test constraint without properties + c_no_props = constraint(T, encoding=LOGICAL) + assert c_no_props.has_property(DISTANCE) is False + + # Test constraint with valid property (distance=True) + c_with_distance = constraint(T, encoding=LOGICAL, distance=True) + assert c_with_distance.has_property(DISTANCE) is True + + # Test constraint with distance=False (should not add the property) + c_distance_false = constraint(T, encoding=LOGICAL, distance=False) + assert c_distance_false.has_property(DISTANCE) is False + + # Test constraint with invalid property name + with pytest.raises(ValueError, match="Unknown property 'invalid_prop'"): + constraint(T, encoding=LOGICAL, invalid_prop=True) + + # Test ISA.satisfies with property constraints + graph = _ProvenanceGraph() + isa_no_dist = graph.make_isa( + [ + graph.add_instruction(T, encoding=LOGICAL, time=1000, error_rate=1e-8), + ] + ) + isa_with_dist = graph.make_isa( + [ + graph.add_instruction( + T, encoding=LOGICAL, time=1000, error_rate=1e-8, distance=9 + ), + ] + ) + + reqs_no_prop = ISARequirements(constraint(T, encoding=LOGICAL)) + reqs_with_prop = ISARequirements(constraint(T, encoding=LOGICAL, distance=True)) + + # ISA without distance property + assert isa_no_dist.satisfies(reqs_no_prop) is True + assert isa_no_dist.satisfies(reqs_with_prop) is False + + # ISA with distance property + assert isa_with_dist.satisfies(reqs_no_prop) is True + assert isa_with_dist.satisfies(reqs_with_prop) is True + + +def test_property_names(): + assert property_name(DISTANCE) == "DISTANCE" + + # An unregistered property + UNKNOWN = 10_000 + assert property_name(UNKNOWN) is None + + # But using an existing property key with a different variable name will + # still return something + UNKNOWN = 0 + assert property_name(UNKNOWN) == "DISTANCE" + + assert property_name_to_key("DISTANCE") == DISTANCE + + # But we also allow case-insensitive lookup + assert property_name_to_key("distance") == DISTANCE + + +def test_generic_function(): + from qsharp.qre._qre import _IntFunction, _FloatFunction + + def time(x: int) -> int: + return x * x + + time_fn = generic_function(time) + assert isinstance(time_fn, _IntFunction) + + def error_rate(x: int) -> float: + return x / 2.0 + + error_rate_fn = generic_function(error_rate) + assert isinstance(error_rate_fn, _FloatFunction) + + # Without annotations, defaults to FloatFunction + space_fn = generic_function(lambda x: 12) + assert isinstance(space_fn, _FloatFunction) + + i = _make_instruction(42, 0, None, time_fn, 12, None, error_rate_fn, {}) + assert i.space(5) == 12 + assert i.time(5) == 25 + assert i.error_rate(5) == 2.5 + + +def test_isa_from_architecture(): + arch = GateBased(gate_time=50, measurement_time=100) + code = SurfaceCode() + ctx = arch.context() + + # Verify that the architecture satisfies the code requirements + assert ctx.isa.satisfies(SurfaceCode.required_isa()) + + # Generate logical ISAs + isas = list(code.provided_isa(ctx.isa, ctx)) + + # There is one ISA with one instructions + assert len(isas) == 1 + assert len(isas[0]) == 1 diff --git a/source/pip/tests/test_qre_models.py b/source/pip/tests/qre/test_models.py similarity index 90% rename from source/pip/tests/test_qre_models.py rename to source/pip/tests/qre/test_models.py index ef03a1eb42..728d557169 100644 --- a/source/pip/tests/test_qre_models.py +++ b/source/pip/tests/qre/test_models.py @@ -27,7 +27,7 @@ SQRT_SQRT_Z_DAG, ) from qsharp.qre.models import ( - AQREGateBased, + GateBased, Majorana, RoundBasedFactory, MagicUpToClifford, @@ -40,21 +40,21 @@ # --------------------------------------------------------------------------- -# AQREGateBased architecture tests +# GateBased architecture tests # --------------------------------------------------------------------------- -class TestAQREGateBased: +class TestGateBased: def test_default_error_rate(self): - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) assert arch.error_rate == 1e-4 def test_custom_error_rate(self): - arch = AQREGateBased(error_rate=1e-3, gate_time=50, measurement_time=100) + arch = GateBased(error_rate=1e-3, gate_time=50, measurement_time=100) assert arch.error_rate == 1e-3 def test_provided_isa_contains_expected_instructions(self): - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() isa = ctx.isa @@ -62,7 +62,7 @@ def test_provided_isa_contains_expected_instructions(self): assert instr_id in isa def test_instruction_encodings_are_physical(self): - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() isa = ctx.isa @@ -71,7 +71,7 @@ def test_instruction_encodings_are_physical(self): def test_instruction_error_rates_match(self): rate = 1e-3 - arch = AQREGateBased(error_rate=rate, gate_time=50, measurement_time=100) + arch = GateBased(error_rate=rate, gate_time=50, measurement_time=100) ctx = arch.context() isa = ctx.isa @@ -79,7 +79,7 @@ def test_instruction_error_rates_match(self): assert isa[instr_id].expect_error_rate() == rate def test_gate_times(self): - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() isa = ctx.isa @@ -95,7 +95,7 @@ def test_gate_times(self): assert isa[MEAS_Z].expect_time() == 100 def test_arities(self): - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() isa = ctx.isa @@ -106,7 +106,7 @@ def test_arities(self): assert isa[MEAS_Z].arity == 1 def test_context_creation(self): - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() assert ctx is not None @@ -180,7 +180,7 @@ def test_default_distance(self): assert sc.distance == 3 def test_provides_lattice_surgery(self): - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() sc = SurfaceCode(distance=3) @@ -195,7 +195,7 @@ def test_provides_lattice_surgery(self): def test_space_scales_with_distance(self): """Space = 2*d^2 - 1 physical qubits per logical qubit.""" - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) for d in [3, 5, 7, 9]: ctx = arch.context() @@ -207,8 +207,8 @@ def test_space_scales_with_distance(self): def test_time_scales_with_distance(self): """Time = (h_time + 4*cnot_time + meas_time) * d.""" - arch = AQREGateBased(gate_time=50, measurement_time=100) - # h=50, cnot=50, meas=100 for AQREGateBased + arch = GateBased(gate_time=50, measurement_time=100) + # h=50, cnot=50, meas=100 for GateBased syndrome_time = 50 + 4 * 50 + 100 # = 350 for d in [3, 5, 7]: @@ -219,7 +219,7 @@ def test_time_scales_with_distance(self): assert ls.expect_time(1) == syndrome_time * d def test_error_rate_decreases_with_distance(self): - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) errors = [] for d in [3, 5, 7, 9, 11]: @@ -234,7 +234,7 @@ def test_error_rate_decreases_with_distance(self): def test_enumeration_via_query(self): """Enumerating SurfaceCode.q() should yield multiple distances.""" - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() count = 0 @@ -246,7 +246,7 @@ def test_enumeration_via_query(self): assert count == 12 def test_custom_crossing_prefactor(self): - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() sc_default = SurfaceCode(distance=5) @@ -265,7 +265,7 @@ def test_custom_crossing_prefactor(self): assert abs(custom_error - 2 * default_error) < 1e-20 def test_custom_error_correction_threshold(self): - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx1 = arch.context() sc_low_threshold = SurfaceCode(error_correction_threshold=0.005, distance=5) @@ -395,7 +395,7 @@ def test_enumeration_via_query(self): class TestYokedSurfaceCode: def _get_lattice_surgery_isa(self, distance=5): """Helper to get a lattice surgery ISA from SurfaceCode.""" - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() sc = SurfaceCode(distance=distance) isas = list(sc.provided_isa(ctx.isa, ctx)) @@ -479,9 +479,9 @@ def test_required_isa(self): reqs = Litinski19Factory.required_isa() assert reqs is not None - def test_table1_aqre_yields_t_and_ccz(self): - """AQREGateBased (error 1e-4) matches Table 1 scenario: T & CCZ.""" - arch = AQREGateBased(gate_time=50, measurement_time=100) + def test_table1_yields_t_and_ccz(self): + """GateBased (error 1e-4) matches Table 1 scenario: T & CCZ.""" + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() factory = Litinski19Factory() @@ -496,7 +496,7 @@ def test_table1_aqre_yields_t_and_ccz(self): assert len(isa) == 2 def test_table1_instruction_properties(self): - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() factory = Litinski19Factory() @@ -516,7 +516,7 @@ def test_table1_instruction_properties(self): def test_table1_t_error_rates_are_diverse(self): """T entries in Table 1 should span a range of error rates.""" - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() factory = Litinski19Factory() @@ -532,8 +532,8 @@ def test_table1_t_error_rates_are_diverse(self): assert 0 < err < 1e-5 def test_table1_1e3_clifford_yields_6_isas(self): - """AQREGateBased with 1e-3 error matches Table 1 at 1e-3 Clifford.""" - arch = AQREGateBased(error_rate=1e-3, gate_time=50, measurement_time=100) + """GateBased with 1e-3 error matches Table 1 at 1e-3 Clifford.""" + arch = GateBased(error_rate=1e-3, gate_time=50, measurement_time=100) ctx = arch.context() factory = Litinski19Factory() @@ -550,7 +550,7 @@ def test_table2_scenario_no_ccz(self): """Table 2 scenario: T error ~10x higher than Clifford, no CCZ.""" from qsharp.qre._qre import _ProvenanceGraph - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() # Manually create ISA with T error rate 10x Clifford @@ -578,7 +578,7 @@ def test_no_yield_when_error_too_high(self): """If T error > 10x Clifford, no entries match.""" from qsharp.qre._qre import _ProvenanceGraph - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() graph = _ProvenanceGraph() @@ -597,11 +597,11 @@ def test_no_yield_when_error_too_high(self): def test_time_based_on_syndrome_extraction(self): """Time should be based on syndrome extraction time × cycles.""" - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() factory = Litinski19Factory() - # For AQREGateBased: syndrome_extraction_time = 4*50 + 50 + 100 = 350 + # For GateBased: syndrome_extraction_time = 4*50 + 50 + 100 = 350 syndrome_time = 4 * 50 + 50 + 100 # 350 ns isas = list(factory.provided_isa(ctx.isa, ctx)) @@ -625,7 +625,7 @@ def test_required_isa_is_empty(self): def test_adds_clifford_equivalent_t_gates(self): """Given T gate, should add SQRT_SQRT_X/Y/Z and dagger variants.""" - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() factory = Litinski19Factory() modifier = MagicUpToClifford() @@ -650,7 +650,7 @@ def test_adds_clifford_equivalent_t_gates(self): def test_adds_clifford_equivalent_ccz(self): """Given CCZ, should add CCX and CCY.""" - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() factory = Litinski19Factory() modifier = MagicUpToClifford() @@ -666,7 +666,7 @@ def test_adds_clifford_equivalent_ccz(self): def test_full_count_of_instructions(self): """T gate (1) + 5 equivalents (SQRT_SQRT_*) + CCZ (1) + 2 equivalents (CCX, CCY) = 9.""" - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() factory = Litinski19Factory() modifier = MagicUpToClifford() @@ -678,7 +678,7 @@ def test_full_count_of_instructions(self): def test_equivalent_instructions_share_properties(self): """Clifford equivalents should have same time, space, error rate.""" - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() factory = Litinski19Factory() modifier = MagicUpToClifford() @@ -709,7 +709,7 @@ def test_equivalent_instructions_share_properties(self): def test_modification_count_matches_factory_output(self): """MagicUpToClifford should produce one modified ISA per input ISA.""" - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() factory = Litinski19Factory() modifier = MagicUpToClifford() @@ -725,7 +725,7 @@ def test_no_family_present_passes_through(self): """If no family member is present, ISA passes through unchanged.""" from qsharp.qre._qre import _ProvenanceGraph - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() modifier = MagicUpToClifford() @@ -758,7 +758,7 @@ def test_no_family_present_passes_through(self): def test_isa_manipulation(): - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) factory = Litinski19Factory() modifier = MagicUpToClifford() @@ -813,7 +813,7 @@ def test_required_isa(self): assert reqs is not None def test_produces_logical_t_gates(self): - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) for isa in RoundBasedFactory.q(use_cache=False).enumerate(arch.context()): t = isa[T] @@ -826,7 +826,7 @@ def test_produces_logical_t_gates(self): def test_error_rates_are_bounded(self): """Distilled T error rates should be bounded and mostly small.""" - arch = AQREGateBased(gate_time=50, measurement_time=100) # T error rate is 1e-4 + arch = GateBased(gate_time=50, measurement_time=100) # T error rate is 1e-4 errors = [] for isa in RoundBasedFactory.q(use_cache=False).enumerate(arch.context()): @@ -843,7 +843,7 @@ def test_error_rates_are_bounded(self): def test_max_produces_fewer_or_equal_results_than_sum(self): """Using max for physical_qubit_calculation may filter differently.""" - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) sum_count = sum( 1 for _ in RoundBasedFactory.q(use_cache=False).enumerate(arch.context()) @@ -859,7 +859,7 @@ def test_max_produces_fewer_or_equal_results_than_sum(self): def test_max_space_less_than_or_equal_sum_space(self): """max-aggregated space should be <= sum-aggregated space for each.""" - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) sum_spaces = sorted( isa[T].expect_space() @@ -890,8 +890,8 @@ def test_with_three_aux_code_query(self): assert count > 0 - def test_round_based_aqre_sum(self): - arch = AQREGateBased(gate_time=50, measurement_time=100) + def test_round_based_gate_based_sum(self): + arch = GateBased(gate_time=50, measurement_time=100) total_space = 0 total_time = 0 @@ -909,8 +909,8 @@ def test_round_based_aqre_sum(self): assert abs(total_error - 0.001_463_030_863_973_197_8) < 1e-8 assert count == 107 - def test_round_based_aqre_max(self): - arch = AQREGateBased(gate_time=50, measurement_time=100) + def test_round_based_gate_based_max(self): + arch = GateBased(gate_time=50, measurement_time=100) total_space = 0 total_time = 0 @@ -960,10 +960,10 @@ def test_round_based_msft_sum(self): class TestCrossModelIntegration: def test_surface_code_feeds_into_litinski(self): """SurfaceCode -> Litinski19Factory pipeline works end to end.""" - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() - # SurfaceCode takes AQRE physical ISA -> LATTICE_SURGERY + # SurfaceCode takes gate-based physical ISA -> LATTICE_SURGERY sc = SurfaceCode(distance=5) sc_isas = list(sc.provided_isa(ctx.isa, ctx)) assert len(sc_isas) == 1 @@ -989,7 +989,7 @@ def test_three_aux_feeds_into_round_based(self): def test_litinski_with_magic_up_to_clifford_query(self): """Full query chain: Litinski19Factory -> MagicUpToClifford.""" - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() count = 0 @@ -1004,7 +1004,7 @@ def test_litinski_with_magic_up_to_clifford_query(self): def test_surface_code_with_yoked_surface_code(self): """SurfaceCode -> YokedSurfaceCode pipeline provides MEMORY.""" - arch = AQREGateBased(gate_time=50, measurement_time=100) + arch = GateBased(gate_time=50, measurement_time=100) ctx = arch.context() count = 0 diff --git a/source/pip/tests/test_qre.py b/source/pip/tests/test_qre.py deleted file mode 100644 index 2dc318fa5e..0000000000 --- a/source/pip/tests/test_qre.py +++ /dev/null @@ -1,1666 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -from dataclasses import KW_ONLY, dataclass, field -from enum import Enum -from pathlib import Path -from typing import cast, Generator, Sized -import os -import pytest - -import pandas as pd -import qsharp -from qsharp.estimator import LogicalCounts -from qsharp.qre import ( - Application, - ISA, - LOGICAL, - PSSPC, - EstimationResult, - ISARequirements, - ISATransform, - LatticeSurgery, - Trace, - constraint, - estimate, - linear_function, - generic_function, - property_name, - property_name_to_key, -) -from qsharp.qre._qre import _ProvenanceGraph -from qsharp.qre.application import QSharpApplication -from qsharp.qre.models import ( - SurfaceCode, - AQREGateBased, - RoundBasedFactory, - TwoDimensionalYokedSurfaceCode, -) -from qsharp.qre.interop import trace_from_qir -from qsharp.qre._architecture import _Context, _make_instruction -from qsharp.qre._estimation import ( - EstimationTable, - EstimationTableEntry, -) -from qsharp.qre._instruction import InstructionSource -from qsharp.qre._isa_enumeration import ( - ISARefNode, -) -from qsharp.qre.instruction_ids import CCX, CCZ, LATTICE_SURGERY, T, RZ -from qsharp.qre.property_keys import ( - DISTANCE, - NUM_TS_PER_ROTATION, - ALGORITHM_COMPUTE_QUBITS, - ALGORITHM_MEMORY_QUBITS, - LOGICAL_COMPUTE_QUBITS, - LOGICAL_MEMORY_QUBITS, -) - -# NOTE These classes will be generalized as part of the QRE API in the following -# pull requests and then moved out of the tests. - - -@dataclass -class ExampleFactory(ISATransform): - _: KW_ONLY - level: int = field(default=1, metadata={"domain": range(1, 4)}) - - @staticmethod - def required_isa() -> ISARequirements: - return ISARequirements( - constraint(T), - ) - - def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, None]: - yield ctx.make_isa( - ctx.add_instruction(T, encoding=LOGICAL, time=1000, error_rate=1e-8), - ) - - -@dataclass -class ExampleLogicalFactory(ISATransform): - _: KW_ONLY - level: int = field(default=1, metadata={"domain": range(1, 4)}) - - @staticmethod - def required_isa() -> ISARequirements: - return ISARequirements( - constraint(LATTICE_SURGERY, encoding=LOGICAL), - constraint(T, encoding=LOGICAL), - ) - - def provided_isa(self, impl_isa: ISA, ctx: _Context) -> Generator[ISA, None, None]: - yield ctx.make_isa( - ctx.add_instruction(T, encoding=LOGICAL, time=1000, error_rate=1e-10), - ) - - -def test_isa(): - graph = _ProvenanceGraph() - isa = graph.make_isa( - [ - graph.add_instruction( - T, encoding=LOGICAL, time=1000, space=400, error_rate=1e-8 - ), - graph.add_instruction( - CCX, encoding=LOGICAL, arity=3, time=2000, space=800, error_rate=1e-10 - ), - ] - ) - - assert T in isa - assert CCX in isa - assert LATTICE_SURGERY not in isa - - t_instr = isa[T] - assert t_instr.time() == 1000 - assert t_instr.error_rate() == 1e-8 - assert t_instr.space() == 400 - - assert len(isa) == 2 - ccz_instr = isa[CCX].with_id(CCZ) - assert ccz_instr.arity == 3 - assert ccz_instr.time() == 2000 - assert ccz_instr.error_rate() == 1e-10 - assert ccz_instr.space() == 800 - - # Add another instruction to the graph and register it in the ISA - ccz_node = graph.add_instruction(ccz_instr) - isa.add_node(CCZ, ccz_node) - assert CCZ in isa - assert len(isa) == 3 - - # Adding the same instruction ID should not increase the count - isa.add_node(CCZ, ccz_node) - assert len(isa) == 3 - - -def test_instruction_properties(): - # Test instruction with no properties - instr_no_props = _make_instruction(T, 1, 1, 1000, None, None, 1e-8, {}) - assert instr_no_props.get_property(DISTANCE) is None - assert instr_no_props.has_property(DISTANCE) is False - assert instr_no_props.get_property_or(DISTANCE, 5) == 5 - - # Test instruction with valid property (distance) - instr_with_distance = _make_instruction( - T, 1, 1, 1000, None, None, 1e-8, {"distance": 9} - ) - assert instr_with_distance.get_property(DISTANCE) == 9 - assert instr_with_distance.has_property(DISTANCE) is True - assert instr_with_distance.get_property_or(DISTANCE, 5) == 9 - - # Test instruction with invalid property name - with pytest.raises(ValueError, match="Unknown property 'invalid_prop'"): - _make_instruction(T, 1, 1, 1000, None, None, 1e-8, {"invalid_prop": 42}) - - -def test_instruction_constraints(): - # Test constraint without properties - c_no_props = constraint(T, encoding=LOGICAL) - assert c_no_props.has_property(DISTANCE) is False - - # Test constraint with valid property (distance=True) - c_with_distance = constraint(T, encoding=LOGICAL, distance=True) - assert c_with_distance.has_property(DISTANCE) is True - - # Test constraint with distance=False (should not add the property) - c_distance_false = constraint(T, encoding=LOGICAL, distance=False) - assert c_distance_false.has_property(DISTANCE) is False - - # Test constraint with invalid property name - with pytest.raises(ValueError, match="Unknown property 'invalid_prop'"): - constraint(T, encoding=LOGICAL, invalid_prop=True) - - # Test ISA.satisfies with property constraints - graph = _ProvenanceGraph() - isa_no_dist = graph.make_isa( - [ - graph.add_instruction(T, encoding=LOGICAL, time=1000, error_rate=1e-8), - ] - ) - isa_with_dist = graph.make_isa( - [ - graph.add_instruction( - T, encoding=LOGICAL, time=1000, error_rate=1e-8, distance=9 - ), - ] - ) - - reqs_no_prop = ISARequirements(constraint(T, encoding=LOGICAL)) - reqs_with_prop = ISARequirements(constraint(T, encoding=LOGICAL, distance=True)) - - # ISA without distance property - assert isa_no_dist.satisfies(reqs_no_prop) is True - assert isa_no_dist.satisfies(reqs_with_prop) is False - - # ISA with distance property - assert isa_with_dist.satisfies(reqs_no_prop) is True - assert isa_with_dist.satisfies(reqs_with_prop) is True - - -def test_property_names(): - assert property_name(DISTANCE) == "DISTANCE" - - # An unregistered property - UNKNOWN = 10_000 - assert property_name(UNKNOWN) is None - - # But using an existing property key with a different variable name will - # still return something - UNKNOWN = 0 - assert property_name(UNKNOWN) == "DISTANCE" - - assert property_name_to_key("DISTANCE") == DISTANCE - - # But we also allow case-insensitive lookup - assert property_name_to_key("distance") == DISTANCE - - -def test_generic_function(): - from qsharp.qre._qre import _IntFunction, _FloatFunction - - def time(x: int) -> int: - return x * x - - time_fn = generic_function(time) - assert isinstance(time_fn, _IntFunction) - - def error_rate(x: int) -> float: - return x / 2.0 - - error_rate_fn = generic_function(error_rate) - assert isinstance(error_rate_fn, _FloatFunction) - - # Without annotations, defaults to FloatFunction - space_fn = generic_function(lambda x: 12) - assert isinstance(space_fn, _FloatFunction) - - i = _make_instruction(42, 0, None, time_fn, 12, None, error_rate_fn, {}) - assert i.space(5) == 12 - assert i.time(5) == 25 - assert i.error_rate(5) == 2.5 - - -def test_isa_from_architecture(): - arch = AQREGateBased(gate_time=50, measurement_time=100) - code = SurfaceCode() - ctx = arch.context() - - # Verify that the architecture satisfies the code requirements - assert ctx.isa.satisfies(SurfaceCode.required_isa()) - - # Generate logical ISAs - isas = list(code.provided_isa(ctx.isa, ctx)) - - # There is one ISA with one instructions - assert len(isas) == 1 - assert len(isas[0]) == 1 - - -def test_enumerate_instances(): - from qsharp.qre._enumeration import _enumerate_instances - - instances = list(_enumerate_instances(SurfaceCode)) - - # There are 12 instances with distances from 3 to 25 - assert len(instances) == 12 - expected_distances = list(range(3, 26, 2)) - for instance, expected_distance in zip(instances, expected_distances): - assert instance.distance == expected_distance - - # Test with specific distances - instances = list(_enumerate_instances(SurfaceCode, distance=[3, 5, 7])) - assert len(instances) == 3 - expected_distances = [3, 5, 7] - for instance, expected_distance in zip(instances, expected_distances): - assert instance.distance == expected_distance - - # Test with fixed distance - instances = list(_enumerate_instances(SurfaceCode, distance=9)) - assert len(instances) == 1 - assert instances[0].distance == 9 - - -def test_enumerate_instances_bool(): - from qsharp.qre._enumeration import _enumerate_instances - - @dataclass - class BoolConfig: - _: KW_ONLY - flag: bool - - instances = list(_enumerate_instances(BoolConfig)) - assert len(instances) == 2 - assert instances[0].flag is True - assert instances[1].flag is False - - -def test_enumerate_instances_enum(): - from qsharp.qre._enumeration import _enumerate_instances - - class Color(Enum): - RED = 1 - GREEN = 2 - BLUE = 3 - - @dataclass - class EnumConfig: - _: KW_ONLY - color: Color - - instances = list(_enumerate_instances(EnumConfig)) - assert len(instances) == 3 - assert instances[0].color == Color.RED - assert instances[1].color == Color.GREEN - assert instances[2].color == Color.BLUE - - -def test_enumerate_instances_failure(): - from qsharp.qre._enumeration import _enumerate_instances - - import pytest - - @dataclass - class InvalidConfig: - _: KW_ONLY - # This field has no domain, is not bool/enum, and has no default - value: int - - with pytest.raises(ValueError, match="Cannot enumerate field value"): - list(_enumerate_instances(InvalidConfig)) - - -def test_enumerate_instances_single(): - from qsharp.qre._enumeration import _enumerate_instances - - @dataclass - class SingleConfig: - value: int = 42 - - instances = list(_enumerate_instances(SingleConfig)) - assert len(instances) == 1 - assert instances[0].value == 42 - - -def test_enumerate_instances_literal(): - from qsharp.qre._enumeration import _enumerate_instances - - from typing import Literal - - @dataclass - class LiteralConfig: - _: KW_ONLY - mode: Literal["fast", "slow"] - - instances = list(_enumerate_instances(LiteralConfig)) - assert len(instances) == 2 - assert instances[0].mode == "fast" - assert instances[1].mode == "slow" - - -def test_enumerate_instances_nested(): - from qsharp.qre._enumeration import _enumerate_instances - - @dataclass - class InnerConfig: - _: KW_ONLY - option: bool - - @dataclass - class OuterConfig: - _: KW_ONLY - inner: InnerConfig - - instances = list(_enumerate_instances(OuterConfig)) - assert len(instances) == 2 - assert instances[0].inner.option is True - assert instances[1].inner.option is False - - -def test_enumerate_instances_union(): - from qsharp.qre._enumeration import _enumerate_instances - - @dataclass - class OptionA: - _: KW_ONLY - value: bool - - @dataclass - class OptionB: - _: KW_ONLY - number: int = field(default=1, metadata={"domain": [1, 2, 3]}) - - @dataclass - class UnionConfig: - _: KW_ONLY - option: OptionA | OptionB - - instances = list(_enumerate_instances(UnionConfig)) - assert len(instances) == 5 - assert isinstance(instances[0].option, OptionA) - assert instances[0].option.value is True - assert isinstance(instances[2].option, OptionB) - assert instances[2].option.number == 1 - - -def test_enumerate_instances_nested_with_constraints(): - from qsharp.qre._enumeration import _enumerate_instances - - @dataclass - class InnerConfig: - _: KW_ONLY - option: bool - - @dataclass - class OuterConfig: - _: KW_ONLY - inner: InnerConfig - - # Constrain nested field via dict - instances = list(_enumerate_instances(OuterConfig, inner={"option": True})) - assert len(instances) == 1 - assert instances[0].inner.option is True - - -def test_enumerate_instances_union_single_type(): - from qsharp.qre._enumeration import _enumerate_instances - - @dataclass - class OptionA: - _: KW_ONLY - value: bool - - @dataclass - class OptionB: - _: KW_ONLY - number: int = field(default=1, metadata={"domain": [1, 2, 3]}) - - @dataclass - class UnionConfig: - _: KW_ONLY - option: OptionA | OptionB - - # Restrict to OptionB only - uses its default domain - instances = list(_enumerate_instances(UnionConfig, option=OptionB)) - assert len(instances) == 3 - assert all(isinstance(i.option, OptionB) for i in instances) - assert [cast(OptionB, i.option).number for i in instances] == [1, 2, 3] - - # Restrict to OptionA only - instances = list(_enumerate_instances(UnionConfig, option=OptionA)) - assert len(instances) == 2 - assert all(isinstance(i.option, OptionA) for i in instances) - assert cast(OptionA, instances[0].option).value is True - assert cast(OptionA, instances[1].option).value is False - - -def test_enumerate_instances_union_list_of_types(): - from qsharp.qre._enumeration import _enumerate_instances - - @dataclass - class OptionA: - _: KW_ONLY - value: bool - - @dataclass - class OptionB: - _: KW_ONLY - number: int = field(default=1, metadata={"domain": [1, 2, 3]}) - - @dataclass - class OptionC: - _: KW_ONLY - flag: bool - - @dataclass - class UnionConfig: - _: KW_ONLY - option: OptionA | OptionB | OptionC - - # Select a subset: only OptionA and OptionB - instances = list(_enumerate_instances(UnionConfig, option=[OptionA, OptionB])) - assert len(instances) == 5 # 2 from OptionA + 3 from OptionB - assert all(isinstance(i.option, (OptionA, OptionB)) for i in instances) - - -def test_enumerate_instances_union_constraint_dict(): - from qsharp.qre._enumeration import _enumerate_instances - - @dataclass - class OptionA: - _: KW_ONLY - value: bool - - @dataclass - class OptionB: - _: KW_ONLY - number: int = field(default=1, metadata={"domain": [1, 2, 3]}) - - @dataclass - class UnionConfig: - _: KW_ONLY - option: OptionA | OptionB - - # Constrain OptionA, enumerate only that member - instances = list( - _enumerate_instances(UnionConfig, option={OptionA: {"value": True}}) - ) - assert len(instances) == 1 - assert isinstance(instances[0].option, OptionA) - assert instances[0].option.value is True - - # Constrain OptionB with a domain, enumerate only that member - instances = list( - _enumerate_instances(UnionConfig, option={OptionB: {"number": [2, 3]}}) - ) - assert len(instances) == 2 - assert all(isinstance(i.option, OptionB) for i in instances) - assert cast(OptionB, instances[0].option).number == 2 - assert cast(OptionB, instances[1].option).number == 3 - - # Constrain one member and keep another with defaults - instances = list( - _enumerate_instances( - UnionConfig, - option={OptionA: {"value": True}, OptionB: {}}, - ) - ) - assert len(instances) == 4 # 1 from OptionA + 3 from OptionB - assert isinstance(instances[0].option, OptionA) - assert instances[0].option.value is True - assert all(isinstance(i.option, OptionB) for i in instances[1:]) - assert [cast(OptionB, i.option).number for i in instances[1:]] == [1, 2, 3] - - -def test_enumerate_isas(): - ctx = AQREGateBased(gate_time=50, measurement_time=100).context() - - # This will enumerate the 4 ISAs for the error correction code - count = sum(1 for _ in SurfaceCode.q().enumerate(ctx)) - assert count == 12 - - # This will enumerate the 2 ISAs for the error correction code when - # restricting the domain - count = sum(1 for _ in SurfaceCode.q(distance=[3, 4]).enumerate(ctx)) - assert count == 2 - - # This will enumerate the 3 ISAs for the factory - count = sum(1 for _ in ExampleFactory.q().enumerate(ctx)) - assert count == 3 - - # This will enumerate 36 ISAs for all products between the 12 error - # correction code ISAs and the 3 factory ISAs - count = sum(1 for _ in (SurfaceCode.q() * ExampleFactory.q()).enumerate(ctx)) - assert count == 36 - - # When providing a list, components are chained (OR operation). This - # enumerates ISAs from first factory instance OR second factory instance - count = sum( - 1 - for _ in ( - SurfaceCode.q() * (ExampleFactory.q() + ExampleFactory.q()) - ).enumerate(ctx) - ) - assert count == 72 - - # When providing separate arguments, components are combined via product - # (AND). This enumerates ISAs from first factory instance AND second - # factory instance - count = sum( - 1 - for _ in (SurfaceCode.q() * ExampleFactory.q() * ExampleFactory.q()).enumerate( - ctx - ) - ) - assert count == 108 - - # Hierarchical factory using from_components: the component receives ISAs - # from the product of other components as its source - count = sum( - 1 - for _ in ( - SurfaceCode.q() - * ExampleLogicalFactory.q(source=(SurfaceCode.q() * ExampleFactory.q())) - ).enumerate(ctx) - ) - assert count == 1296 - - -def test_binding_node(): - """Test binding nodes with ISARefNode for component bindings""" - ctx = AQREGateBased(gate_time=50, measurement_time=100).context() - - # Test basic binding: same code used twice - # Without binding: 12 codes × 12 codes = 144 combinations - count_without = sum(1 for _ in (SurfaceCode.q() * SurfaceCode.q()).enumerate(ctx)) - assert count_without == 144 - - # With binding: 12 codes (same instance used twice) - count_with = sum( - 1 - for _ in SurfaceCode.bind("c", ISARefNode("c") * ISARefNode("c")).enumerate(ctx) - ) - assert count_with == 12 - - # Verify the binding works: with binding, both should use same params - for isa in SurfaceCode.bind("c", ISARefNode("c") * ISARefNode("c")).enumerate(ctx): - logical_gates = [g for g in isa if g.encoding == LOGICAL] - # Should have 1 logical gate (LATTICE_SURGERY) - assert len(logical_gates) == 1 - - # Test binding with factories (nested bindings) - count_without = sum( - 1 - for _ in ( - SurfaceCode.q() * ExampleFactory.q() * SurfaceCode.q() * ExampleFactory.q() - ).enumerate(ctx) - ) - assert count_without == 1296 # 12 * 3 * 12 * 3 - - count_with = sum( - 1 - for _ in SurfaceCode.bind( - "c", - ExampleFactory.bind( - "f", - ISARefNode("c") * ISARefNode("f") * ISARefNode("c") * ISARefNode("f"), - ), - ).enumerate(ctx) - ) - assert count_with == 36 # 12 * 3 - - # Test binding with from_components equivalent (hierarchical) - # Without binding: 4 outer codes × (4 inner codes × 3 factories × 3 levels) - count_without = sum( - 1 - for _ in ( - SurfaceCode.q() - * ExampleLogicalFactory.q( - source=(SurfaceCode.q() * ExampleFactory.q()), - ) - ).enumerate(ctx) - ) - assert count_without == 1296 # 12 * 12 * 3 * 3 - - # With binding: 4 codes (same used twice) × 3 factories × 3 levels - count_with = sum( - 1 - for _ in SurfaceCode.bind( - "c", - ISARefNode("c") - * ExampleLogicalFactory.q( - source=(ISARefNode("c") * ExampleFactory.q()), - ), - ).enumerate(ctx) - ) - assert count_with == 108 # 12 * 3 * 3 - - # Test binding with kwargs - count_with_kwargs = sum( - 1 - for _ in SurfaceCode.q(distance=5) - .bind("c", ISARefNode("c") * ISARefNode("c")) - .enumerate(ctx) - ) - assert count_with_kwargs == 1 # Only distance=5 - - # Verify kwargs are applied - for isa in ( - SurfaceCode.q(distance=5) - .bind("c", ISARefNode("c") * ISARefNode("c")) - .enumerate(ctx) - ): - logical_gates = [g for g in isa if g.encoding == LOGICAL] - assert all(g.space(1) == 49 for g in logical_gates) - - # Test multiple independent bindings (nested) - count = sum( - 1 - for _ in SurfaceCode.bind( - "c1", - ExampleFactory.bind( - "c2", - ISARefNode("c1") - * ISARefNode("c1") - * ISARefNode("c2") - * ISARefNode("c2"), - ), - ).enumerate(ctx) - ) - # 12 codes for c1 × 3 factories for c2 - assert count == 36 - - -def test_binding_node_errors(): - """Test error handling for binding nodes""" - ctx = AQREGateBased(gate_time=50, measurement_time=100).context() - - # Test ISARefNode enumerate with undefined binding raises ValueError - try: - list(ISARefNode("test").enumerate(ctx)) - assert False, "Should have raised ValueError" - except ValueError as e: - assert "Undefined component reference: 'test'" in str(e) - - -def test_product_isa_enumeration_nodes(): - from qsharp.qre._isa_enumeration import _ComponentQuery, _ProductNode - - terminal = SurfaceCode.q() - query = terminal * terminal - - # Multiplication should create ProductNode - assert isinstance(query, _ProductNode) - assert len(query.sources) == 2 - for source in query.sources: - assert isinstance(source, _ComponentQuery) - - # Multiplying again should extend the sources - query = query * terminal - assert isinstance(query, _ProductNode) - assert len(query.sources) == 3 - for source in query.sources: - assert isinstance(source, _ComponentQuery) - - # Also from the other side - query = terminal * query - assert isinstance(query, _ProductNode) - assert len(query.sources) == 4 - for source in query.sources: - assert isinstance(source, _ComponentQuery) - - # Also for two ProductNodes - query = query * query - assert isinstance(query, _ProductNode) - assert len(query.sources) == 8 - for source in query.sources: - assert isinstance(source, _ComponentQuery) - - -def test_sum_isa_enumeration_nodes(): - from qsharp.qre._isa_enumeration import _ComponentQuery, _SumNode - - terminal = SurfaceCode.q() - query = terminal + terminal - - # Multiplication should create SumNode - assert isinstance(query, _SumNode) - assert len(query.sources) == 2 - for source in query.sources: - assert isinstance(source, _ComponentQuery) - - # Multiplying again should extend the sources - query = query + terminal - assert isinstance(query, _SumNode) - assert len(query.sources) == 3 - for source in query.sources: - assert isinstance(source, _ComponentQuery) - - # Also from the other side - query = terminal + query - assert isinstance(query, _SumNode) - assert len(query.sources) == 4 - for source in query.sources: - assert isinstance(source, _ComponentQuery) - - # Also for two SumNodes - query = query + query - assert isinstance(query, _SumNode) - assert len(query.sources) == 8 - for source in query.sources: - assert isinstance(source, _ComponentQuery) - - -def test_trace_properties(): - trace = Trace(42) - - INT = 0 - FLOAT = 1 - BOOL = 2 - STR = 3 - - trace.set_property(INT, 42) - assert trace.get_property(INT) == 42 - assert isinstance(trace.get_property(INT), int) - - trace.set_property(FLOAT, 3.14) - assert trace.get_property(FLOAT) == 3.14 - assert isinstance(trace.get_property(FLOAT), float) - - trace.set_property(BOOL, True) - assert trace.get_property(BOOL) is True - assert isinstance(trace.get_property(BOOL), bool) - - trace.set_property(STR, "hello") - assert trace.get_property(STR) == "hello" - assert isinstance(trace.get_property(STR), str) - - -def test_qsharp_application(): - from qsharp.qre._enumeration import _enumerate_instances - - code = """ - {{ - use (a, b, c) = (Qubit(), Qubit(), Qubit()); - T(a); - CCNOT(a, b, c); - Rz(1.2345, a); - }} - """ - - app = QSharpApplication(code) - trace = app.get_trace() - - assert trace.compute_qubits == 3 - assert trace.depth == 3 - assert trace.resource_states == {} - - assert {c.id for c in trace.required_isa} == {CCX, T, RZ} - - graph = _ProvenanceGraph() - isa = graph.make_isa( - [ - graph.add_instruction( - LATTICE_SURGERY, - encoding=LOGICAL, - arity=None, - time=1000, - space=linear_function(50), - error_rate=linear_function(1e-6), - ), - graph.add_instruction( - T, encoding=LOGICAL, time=1000, space=400, error_rate=1e-8 - ), - graph.add_instruction( - CCX, encoding=LOGICAL, time=2000, space=800, error_rate=1e-10 - ), - ] - ) - - # Properties from the program - counts = qsharp.logical_counts(code) - num_ts = counts["tCount"] - num_ccx = counts["cczCount"] - num_rotations = counts["rotationCount"] - rotation_depth = counts["rotationDepth"] - - lattice_surgery = LatticeSurgery() - - counter = 0 - for psspc in _enumerate_instances(PSSPC): - counter += 1 - trace2 = psspc.transform(trace) - assert trace2 is not None - trace2 = lattice_surgery.transform(trace2) - assert trace2 is not None - assert trace2.compute_qubits == 12 - assert ( - trace2.depth - == num_ts - + num_ccx * 3 - + num_rotations - + rotation_depth * psspc.num_ts_per_rotation - ) - if psspc.ccx_magic_states: - assert trace2.resource_states == { - T: num_ts + psspc.num_ts_per_rotation * num_rotations, - CCX: num_ccx, - } - assert {c.id for c in trace2.required_isa} == {CCX, T, LATTICE_SURGERY} - else: - assert trace2.resource_states == { - T: num_ts + psspc.num_ts_per_rotation * num_rotations + 4 * num_ccx - } - assert {c.id for c in trace2.required_isa} == {T, LATTICE_SURGERY} - assert trace2.get_property(ALGORITHM_COMPUTE_QUBITS) == 3 - assert trace2.get_property(ALGORITHM_MEMORY_QUBITS) == 0 - result = trace2.estimate(isa, max_error=float("inf")) - assert result is not None - assert result.properties[ALGORITHM_COMPUTE_QUBITS] == 3 - assert result.properties[ALGORITHM_MEMORY_QUBITS] == 0 - assert result.properties[LOGICAL_COMPUTE_QUBITS] == 12 - assert result.properties[LOGICAL_MEMORY_QUBITS] == 0 - _assert_estimation_result(trace2, result, isa) - assert counter == 32 - - -def test_application_enumeration(): - @dataclass(kw_only=True) - class _Params: - size: int = field(default=1, metadata={"domain": range(1, 4)}) - - class TestApp(Application[_Params]): - def get_trace(self, parameters: _Params) -> Trace: - return Trace(parameters.size) - - app = TestApp() - assert sum(1 for _ in TestApp.q().enumerate(app.context())) == 3 - assert sum(1 for _ in TestApp.q(size=1).enumerate(app.context())) == 1 - assert sum(1 for _ in TestApp.q(size=[4, 5]).enumerate(app.context())) == 2 - - -def test_trace_enumeration(): - code = """ - {{ - use (a, b, c) = (Qubit(), Qubit(), Qubit()); - T(a); - CCNOT(a, b, c); - Rz(1.2345, a); - }} - """ - - app = QSharpApplication(code) - - ctx = app.context() - assert sum(1 for _ in QSharpApplication.q().enumerate(ctx)) == 1 - - assert sum(1 for _ in PSSPC.q().enumerate(ctx)) == 32 - - assert sum(1 for _ in LatticeSurgery.q().enumerate(ctx)) == 1 - - q = PSSPC.q() * LatticeSurgery.q() - assert sum(1 for _ in q.enumerate(ctx)) == 32 - - -def test_rotation_error_psspc(): - from qsharp.qre._enumeration import _enumerate_instances - - # This test helps to bound the variables for the number of rotations in PSSPC - - # Create a trace with a single rotation gate and ensure that the base error - # after PSSPC transformation is less than 1. - trace = Trace(1) - trace.add_operation(RZ, [0]) - - for psspc in _enumerate_instances(PSSPC, ccx_magic_states=False): - transformed = psspc.transform(trace) - assert transformed is not None - assert ( - transformed.base_error < 1.0 - ), f"Base error too high: {transformed.base_error} for {psspc.num_ts_per_rotation} T states per rotation" - - -def test_estimation_max_error(): - from qsharp.estimator import LogicalCounts - - app = QSharpApplication(LogicalCounts({"numQubits": 100, "measurementCount": 100})) - arch = AQREGateBased(gate_time=50, measurement_time=100) - - for max_error in [1e-1, 1e-2, 1e-3, 1e-4]: - results = estimate( - app, - arch, - SurfaceCode.q() * ExampleFactory.q(), - PSSPC.q() * LatticeSurgery.q(), - max_error=max_error, - ) - - assert len(results) == 1 - assert next(iter(results)).error <= max_error - - -def _assert_estimation_result(trace: Trace, result: EstimationResult, isa: ISA): - actual_qubits = ( - isa[LATTICE_SURGERY].expect_space(trace.compute_qubits) - + isa[T].expect_space() * result.factories[T].copies - ) - if CCX in trace.resource_states: - actual_qubits += isa[CCX].expect_space() * result.factories[CCX].copies - assert result.qubits == actual_qubits - - assert ( - result.runtime - == isa[LATTICE_SURGERY].expect_time(trace.compute_qubits) * trace.depth - ) - - actual_error = ( - trace.base_error - + isa[LATTICE_SURGERY].expect_error_rate(trace.compute_qubits) * trace.depth - + isa[T].expect_error_rate() * result.factories[T].states - ) - if CCX in trace.resource_states: - actual_error += isa[CCX].expect_error_rate() * result.factories[CCX].states - assert abs(result.error - actual_error) <= 1e-8 - - -# --- EstimationTable tests --- - - -def _make_entry(qubits, runtime, error, properties=None): - """Helper to create an EstimationTableEntry with a dummy InstructionSource.""" - return EstimationTableEntry( - qubits=qubits, - runtime=runtime, - error=error, - source=InstructionSource(), - properties=properties or {}, - ) - - -def test_estimation_table_default_columns(): - """Test that a new EstimationTable has the three default columns.""" - table = EstimationTable() - table.append(_make_entry(100, 5000, 0.01)) - - frame = table.as_frame() - assert list(frame.columns) == ["qubits", "runtime", "error"] - assert frame["qubits"][0] == 100 - assert frame["runtime"][0] == pd.Timedelta(5000, unit="ns") - assert frame["error"][0] == 0.01 - - -def test_estimation_table_multiple_rows(): - """Test as_frame with multiple entries.""" - table = EstimationTable() - table.append(_make_entry(100, 5000, 0.01)) - table.append(_make_entry(200, 10000, 0.02)) - table.append(_make_entry(300, 15000, 0.03)) - - frame = table.as_frame() - assert len(frame) == 3 - assert list(frame["qubits"]) == [100, 200, 300] - assert list(frame["error"]) == [0.01, 0.02, 0.03] - - -def test_estimation_table_empty(): - """Test as_frame with no entries produces an empty DataFrame.""" - table = EstimationTable() - frame = table.as_frame() - assert len(frame) == 0 - - -def test_estimation_table_add_column(): - """Test adding a column to the table.""" - VAL = 0 - - table = EstimationTable() - table.append(_make_entry(100, 5000, 0.01, properties={VAL: 42})) - table.append(_make_entry(200, 10000, 0.02, properties={VAL: 84})) - - table.add_column("val", lambda e: e.properties[VAL]) - - frame = table.as_frame() - assert list(frame.columns) == ["qubits", "runtime", "error", "val"] - assert list(frame["val"]) == [42, 84] - - -def test_estimation_table_add_column_with_formatter(): - """Test adding a column with a formatter.""" - NS = 0 - - table = EstimationTable() - table.append(_make_entry(100, 5000, 0.01, properties={NS: 1000})) - - table.add_column( - "duration", - lambda e: e.properties[NS], - formatter=lambda x: pd.Timedelta(x, unit="ns"), - ) - - frame = table.as_frame() - assert frame["duration"][0] == pd.Timedelta(1000, unit="ns") - - -def test_estimation_table_add_multiple_columns(): - """Test adding multiple columns preserves order.""" - A = 0 - B = 1 - C = 2 - - table = EstimationTable() - table.append(_make_entry(100, 5000, 0.01, properties={A: 1, B: 2, C: 3})) - - table.add_column("a", lambda e: e.properties[A]) - table.add_column("b", lambda e: e.properties[B]) - table.add_column("c", lambda e: e.properties[C]) - - frame = table.as_frame() - assert list(frame.columns) == ["qubits", "runtime", "error", "a", "b", "c"] - assert frame["a"][0] == 1 - assert frame["b"][0] == 2 - assert frame["c"][0] == 3 - - -def test_estimation_table_insert_column_at_beginning(): - """Test inserting a column at index 0.""" - NAME = 0 - - table = EstimationTable() - table.append(_make_entry(100, 5000, 0.01, properties={NAME: "test"})) - - table.insert_column(0, "name", lambda e: e.properties[NAME]) - - frame = table.as_frame() - assert list(frame.columns) == ["name", "qubits", "runtime", "error"] - assert frame["name"][0] == "test" - - -def test_estimation_table_insert_column_in_middle(): - """Test inserting a column between existing default columns.""" - EXTRA = 0 - - table = EstimationTable() - table.append(_make_entry(100, 5000, 0.01, properties={EXTRA: 99})) - - # Insert between qubits and runtime (index 1) - table.insert_column(1, "extra", lambda e: e.properties[EXTRA]) - - frame = table.as_frame() - assert list(frame.columns) == ["qubits", "extra", "runtime", "error"] - assert frame["extra"][0] == 99 - - -def test_estimation_table_insert_column_at_end(): - """Test inserting a column at the end (same effect as add_column).""" - LAST = 0 - - table = EstimationTable() - table.append(_make_entry(100, 5000, 0.01, properties={LAST: True})) - - # 3 default columns, inserting at index 3 = end - table.insert_column(3, "last", lambda e: e.properties[LAST]) - - frame = table.as_frame() - assert list(frame.columns) == ["qubits", "runtime", "error", "last"] - assert frame["last"][0] - - -def test_estimation_table_insert_column_with_formatter(): - """Test inserting a column with a formatter.""" - NS = 0 - - table = EstimationTable() - table.append(_make_entry(100, 5000, 0.01, properties={NS: 2000})) - - table.insert_column( - 0, - "custom_time", - lambda e: e.properties[NS], - formatter=lambda x: pd.Timedelta(x, unit="ns"), - ) - - frame = table.as_frame() - assert frame["custom_time"][0] == pd.Timedelta(2000, unit="ns") - assert list(frame.columns)[0] == "custom_time" - - -def test_estimation_table_insert_and_add_columns(): - """Test combining insert_column and add_column.""" - A = 0 - B = 0 - - table = EstimationTable() - table.append(_make_entry(100, 5000, 0.01, properties={A: 1, B: 2})) - - table.add_column("b", lambda e: e.properties[B]) - table.insert_column(0, "a", lambda e: e.properties[A]) - - frame = table.as_frame() - assert list(frame.columns) == ["a", "qubits", "runtime", "error", "b"] - - -def test_estimation_table_factory_summary_no_factories(): - """Test factory summary column when entries have no factories.""" - table = EstimationTable() - table.append(_make_entry(100, 5000, 0.01)) - - table.add_factory_summary_column() - - frame = table.as_frame() - assert "factories" in frame.columns - assert frame["factories"][0] == "None" - - -def test_estimation_table_factory_summary_with_estimation(): - """Test factory summary column with real estimation results.""" - code = """ - { - use (a, b, c) = (Qubit(), Qubit(), Qubit()); - T(a); - CCNOT(a, b, c); - Rz(1.2345, a); - } - """ - app = QSharpApplication(code) - arch = AQREGateBased(gate_time=50, measurement_time=100) - results = estimate( - app, - arch, - SurfaceCode.q() * ExampleFactory.q(), - PSSPC.q() * LatticeSurgery.q(), - max_error=0.5, - ) - - assert len(results) >= 1 - - results.add_factory_summary_column() - frame = results.as_frame() - - assert "factories" in frame.columns - # Each result should mention T in the factory summary - for val in frame["factories"]: - assert "T" in val - - -def test_estimation_table_add_column_from_source(): - """Test adding a column that accesses the InstructionSource (like distance).""" - code = """ - { - use (a, b, c) = (Qubit(), Qubit(), Qubit()); - T(a); - CCNOT(a, b, c); - Rz(1.2345, a); - } - """ - app = QSharpApplication(code) - arch = AQREGateBased(gate_time=50, measurement_time=100) - results = estimate( - app, - arch, - SurfaceCode.q() * ExampleFactory.q(), - PSSPC.q() * LatticeSurgery.q(), - max_error=0.5, - ) - - assert len(results) >= 1 - - results.add_column( - "compute_distance", - lambda entry: entry.source[LATTICE_SURGERY].instruction[DISTANCE], - ) - - frame = results.as_frame() - assert "compute_distance" in frame.columns - for d in frame["compute_distance"]: - assert isinstance(d, int) - assert d >= 3 - - -def test_estimation_table_add_column_from_properties(): - """Test adding columns that access trace properties from estimation.""" - code = """ - { - use (a, b, c) = (Qubit(), Qubit(), Qubit()); - T(a); - CCNOT(a, b, c); - Rz(1.2345, a); - } - """ - app = QSharpApplication(code) - arch = AQREGateBased(gate_time=50, measurement_time=100) - results = estimate( - app, - arch, - SurfaceCode.q() * ExampleFactory.q(), - PSSPC.q() * LatticeSurgery.q(), - max_error=0.5, - ) - - assert len(results) >= 1 - - results.add_column( - "num_ts_per_rotation", - lambda entry: entry.properties[NUM_TS_PER_ROTATION], - ) - - frame = results.as_frame() - assert "num_ts_per_rotation" in frame.columns - for val in frame["num_ts_per_rotation"]: - assert isinstance(val, int) - assert val >= 1 - - -def test_estimation_table_insert_column_before_defaults(): - """Test inserting a name column before all default columns, similar to the factoring notebook.""" - code = """ - { - use (a, b, c) = (Qubit(), Qubit(), Qubit()); - T(a); - CCNOT(a, b, c); - Rz(1.2345, a); - } - """ - app = QSharpApplication(code) - arch = AQREGateBased(gate_time=50, measurement_time=100) - results = estimate( - app, - arch, - SurfaceCode.q() * ExampleFactory.q(), - PSSPC.q() * LatticeSurgery.q(), - max_error=0.5, - name="test_experiment", - ) - - assert len(results) >= 1 - - # Add a factory summary at the end - results.add_factory_summary_column() - - frame = results.as_frame() - assert frame.columns[0] == "name" - assert frame.columns[-1] == "factories" - # Default columns should still be in order - assert list(frame.columns[1:4]) == ["qubits", "runtime", "error"] - - -def test_estimation_table_as_frame_sortable(): - """Test that the DataFrame from as_frame can be sorted, as done in the factoring tests.""" - table = EstimationTable() - table.append(_make_entry(300, 15000, 0.03)) - table.append(_make_entry(100, 5000, 0.01)) - table.append(_make_entry(200, 10000, 0.02)) - - frame = table.as_frame() - sorted_frame = frame.sort_values(by=["qubits", "runtime"]).reset_index(drop=True) - - assert list(sorted_frame["qubits"]) == [100, 200, 300] - assert list(sorted_frame["error"]) == [0.01, 0.02, 0.03] - - -def test_estimation_table_computed_column(): - """Test adding a column that computes a derived value from the entry.""" - table = EstimationTable() - table.append(_make_entry(100, 5_000_000, 0.01)) - table.append(_make_entry(200, 10_000_000, 0.02)) - - # Compute qubits * error as a derived metric - table.add_column("qubit_error_product", lambda e: e.qubits * e.error) - - frame = table.as_frame() - assert frame["qubit_error_product"][0] == pytest.approx(1.0) - assert frame["qubit_error_product"][1] == pytest.approx(4.0) - - -def test_estimation_table_plot_returns_figure(): - """Test that plot() returns a matplotlib Figure with correct axes.""" - from matplotlib.figure import Figure - - table = EstimationTable() - table.append(_make_entry(100, 5_000_000_000, 0.01)) - table.append(_make_entry(200, 10_000_000_000, 0.02)) - table.append(_make_entry(50, 50_000_000_000, 0.005)) - - fig = table.plot() - - assert isinstance(fig, Figure) - ax = fig.axes[0] - assert ax.get_ylabel() == "Physical qubits" - assert ax.get_xlabel() == "Runtime" - assert ax.get_xscale() == "log" - assert ax.get_yscale() == "log" - - # Verify data points - offsets = ax.collections[0].get_offsets() - assert len(cast(Sized, offsets)) == 3 - - -def test_estimation_table_plot_empty_raises(): - """Test that plot() raises ValueError on an empty table.""" - table = EstimationTable() - with pytest.raises(ValueError, match="Cannot plot an empty EstimationTable"): - table.plot() - - -def test_estimation_table_plot_single_entry(): - """Test that plot() works with a single entry.""" - from matplotlib.figure import Figure - - table = EstimationTable() - table.append(_make_entry(100, 1_000_000, 0.01)) - - fig = table.plot() - assert isinstance(fig, Figure) - - offsets = fig.axes[0].collections[0].get_offsets() - assert len(cast(Sized, offsets)) == 1 - - -def test_estimation_table_plot_with_runtime_unit(): - """Test that plot(runtime_unit=...) scales x values and labels the axis.""" - table = EstimationTable() - # 1 hour = 3600e9 ns, 2 hours = 7200e9 ns - table.append(_make_entry(100, int(3600e9), 0.01)) - table.append(_make_entry(200, int(7200e9), 0.02)) - - fig = table.plot(runtime_unit="hours") - - ax = fig.axes[0] - assert ax.get_xlabel() == "Runtime (hours)" - - # Verify the x data is scaled: should be 1.0 and 2.0 hours - offsets = cast(list, ax.collections[0].get_offsets()) - assert offsets[0][0] == pytest.approx(1.0) - assert offsets[1][0] == pytest.approx(2.0) - - -def test_estimation_table_plot_invalid_runtime_unit(): - """Test that plot() raises ValueError for an unknown runtime_unit.""" - table = EstimationTable() - table.append(_make_entry(100, 1000, 0.01)) - with pytest.raises(ValueError, match="Unknown runtime_unit"): - table.plot(runtime_unit="fortnights") - - -def _ll_files(): - ll_dir = ( - Path(__file__).parent.parent - / "tests-integration" - / "resources" - / "adaptive_ri" - / "output" - ) - return sorted(ll_dir.glob("*.ll")) - - -@pytest.mark.parametrize("ll_file", _ll_files(), ids=lambda p: p.stem) -def test_trace_from_qir(ll_file): - # NOTE: This test is primarily to ensure that the function can parse real - # QIR output without errors, rather than checking specific properties of the - # trace. - try: - trace_from_qir(ll_file.read_text()) - except ValueError as e: - # The only reason of failure is presence of control flow - assert ( - str(e) - == "simulation of programs with branching control flow is not supported" - ) - - -def test_trace_from_qir_handles_all_instruction_ids(): - """Verify that trace_from_qir handles every QirInstructionId except CorrelatedNoise. - - Generates a synthetic QIR program containing one instance of each gate - intrinsic recognised by AggregateGatesPass and asserts that trace_from_qir - processes all of them without error. - """ - import pyqir - import pyqir.qis as qis - from qsharp._native import QirInstructionId - from qsharp.qre.interop._qir import _GATE_MAP, _MEAS_MAP, _SKIP - - # -- Completeness check: every QirInstructionId must be covered -------- - handled_ids = ( - [qir_id for qir_id, _, _ in _GATE_MAP] - + [qir_id for qir_id, _ in _MEAS_MAP] - + list(_SKIP) - ) - # Exhaustive list of all QirInstructionId variants (pyo3 enums are not iterable) - all_ids = [ - QirInstructionId.I, - QirInstructionId.H, - QirInstructionId.X, - QirInstructionId.Y, - QirInstructionId.Z, - QirInstructionId.S, - QirInstructionId.SAdj, - QirInstructionId.SX, - QirInstructionId.SXAdj, - QirInstructionId.T, - QirInstructionId.TAdj, - QirInstructionId.CNOT, - QirInstructionId.CX, - QirInstructionId.CY, - QirInstructionId.CZ, - QirInstructionId.CCX, - QirInstructionId.SWAP, - QirInstructionId.RX, - QirInstructionId.RY, - QirInstructionId.RZ, - QirInstructionId.RXX, - QirInstructionId.RYY, - QirInstructionId.RZZ, - QirInstructionId.RESET, - QirInstructionId.M, - QirInstructionId.MResetZ, - QirInstructionId.MZ, - QirInstructionId.Move, - QirInstructionId.ReadResult, - QirInstructionId.ResultRecordOutput, - QirInstructionId.BoolRecordOutput, - QirInstructionId.IntRecordOutput, - QirInstructionId.DoubleRecordOutput, - QirInstructionId.TupleRecordOutput, - QirInstructionId.ArrayRecordOutput, - QirInstructionId.CorrelatedNoise, - ] - unhandled = [ - i - for i in all_ids - if i not in handled_ids and i != QirInstructionId.CorrelatedNoise - ] - assert unhandled == [], ( - f"QirInstructionId values not covered by _GATE_MAP, _MEAS_MAP, or _SKIP: " - f"{', '.join(str(i) for i in unhandled)}" - ) - - # -- Generate a QIR program with every producible gate ----------------- - simple = pyqir.SimpleModule("test_all_gates", num_qubits=4, num_results=3) - builder = simple.builder - ctx = simple.context - q = simple.qubits - r = simple.results - - void_ty = pyqir.Type.void(ctx) - qubit_ty = pyqir.qubit_type(ctx) - result_ty = pyqir.result_type(ctx) - double_ty = pyqir.Type.double(ctx) - i64_ty = pyqir.IntType(ctx, 64) - - def declare(name, param_types): - return simple.add_external_function( - name, pyqir.FunctionType(void_ty, param_types) - ) - - # Single-qubit gates (pyqir.qis builtins) - qis.h(builder, q[0]) - qis.x(builder, q[0]) - qis.y(builder, q[0]) - qis.z(builder, q[0]) - qis.s(builder, q[0]) - qis.s_adj(builder, q[0]) - qis.t(builder, q[0]) - qis.t_adj(builder, q[0]) - - # SX — not in pyqir.qis - sx_fn = declare("__quantum__qis__sx__body", [qubit_ty]) - builder.call(sx_fn, [q[0]]) - - # Two-qubit gates (qis.cx emits __quantum__qis__cnot__body which the - # pass does not handle, so use builder.call with the correct name) - cx_fn = declare("__quantum__qis__cx__body", [qubit_ty, qubit_ty]) - builder.call(cx_fn, [q[0], q[1]]) - qis.cz(builder, q[0], q[1]) - qis.swap(builder, q[0], q[1]) - - cy_fn = declare("__quantum__qis__cy__body", [qubit_ty, qubit_ty]) - builder.call(cy_fn, [q[0], q[1]]) - - # Three-qubit gate - qis.ccx(builder, q[0], q[1], q[2]) - - # Single-qubit rotations - qis.rx(builder, 1.0, q[0]) - qis.ry(builder, 1.0, q[0]) - qis.rz(builder, 1.0, q[0]) - - # Two-qubit rotations — not in pyqir.qis - rot2_ty = [double_ty, qubit_ty, qubit_ty] - angle = pyqir.const(double_ty, 1.0) - for name in ("rxx", "ryy", "rzz"): - fn = declare(f"__quantum__qis__{name}__body", rot2_ty) - builder.call(fn, [angle, q[0], q[1]]) - - # Measurements - qis.mz(builder, q[0], r[0]) - - m_fn = declare("__quantum__qis__m__body", [qubit_ty, result_ty]) - builder.call(m_fn, [q[1], r[1]]) - - mresetz_fn = declare("__quantum__qis__mresetz__body", [qubit_ty, result_ty]) - builder.call(mresetz_fn, [q[2], r[2]]) - - # Reset / Move - qis.reset(builder, q[0]) - - move_fn = declare("__quantum__qis__move__body", [qubit_ty]) - builder.call(move_fn, [q[0]]) - - # Output recording - tag = simple.add_byte_string(b"tag") - arr_fn = declare("__quantum__rt__array_record_output", [i64_ty, tag.type]) - builder.call(arr_fn, [pyqir.const(i64_ty, 1), tag]) - - rec_fn = declare("__quantum__rt__result_record_output", [result_ty, tag.type]) - builder.call(rec_fn, [r[0], tag]) - - tup_fn = declare("__quantum__rt__tuple_record_output", [i64_ty, tag.type]) - builder.call(tup_fn, [pyqir.const(i64_ty, 1), tag]) - - # -- Run trace_from_qir and verify it succeeds ------------------------- - trace = trace_from_qir(simple.ir()) - assert trace is not None - - -@pytest.mark.skipif( - "SLOW_TESTS" not in os.environ, - reason="turn on slow tests by setting SLOW_TESTS=1 in the environment", -) -@pytest.mark.parametrize( - "post_process, use_graph", - [ - (False, False), - (True, False), - (False, True), - (True, True), - ], -) -def test_estimation_methods(post_process, use_graph): - counts = LogicalCounts( - { - "numQubits": 1000, - "tCount": 1_500_000, - "rotationCount": 0, - "rotationDepth": 0, - "cczCount": 1_000_000_000, - "ccixCount": 0, - "measurementCount": 25_000_000, - "numComputeQubits": 200, - "readFromMemoryCount": 30_000_000, - "writeToMemoryCount": 30_000_000, - } - ) - - trace_query = PSSPC.q() * LatticeSurgery.q(slow_down_factor=[1.0, 2.0]) - isa_query = ( - SurfaceCode.q() - * RoundBasedFactory.q() - * TwoDimensionalYokedSurfaceCode.q(source=SurfaceCode.q()) - ) - - app = QSharpApplication(counts) - arch = AQREGateBased(gate_time=50, measurement_time=100) - - results = estimate( - app, - arch, - isa_query, - trace_query, - max_error=1 / 3, - post_process=post_process, - use_graph=use_graph, - ) - results.add_factory_summary_column() - - assert [(result.qubits, result.runtime) for result in results] == [ - (238707, 23997050000000), - (240407, 11998525000000), - ] - - print() - print(results.stats) - - -def test_rotation_buckets(): - from qsharp.qre.interop._qsharp import _bucketize_rotation_counts - - print() - - r_count = 15066 - r_depth = 14756 - q_count = 291 - - result = _bucketize_rotation_counts(r_count, r_depth) - - a_count = 0 - a_depth = 0 - for c, d in result: - print(c, d) - assert c <= q_count - assert c > 0 - a_count += c * d - a_depth += d - - assert a_count == r_count - assert a_depth == r_depth diff --git a/source/qre/src/isa.rs b/source/qre/src/isa.rs index 3fcde87d89..cdeac80524 100644 --- a/source/qre/src/isa.rs +++ b/source/qre/src/isa.rs @@ -12,10 +12,13 @@ use num_traits::FromPrimitive; use rustc_hash::{FxHashMap, FxHashSet}; use serde::{Deserialize, Serialize}; -use crate::{ParetoFrontier3D, trace::instruction_ids::instruction_name}; +use crate::trace::instruction_ids::instruction_name; pub mod property_keys; +mod provenance; +pub use provenance::ProvenanceGraph; + #[cfg(test)] mod tests; @@ -703,327 +706,3 @@ impl ConstraintBound { } } } - -pub struct ProvenanceGraph { - nodes: Vec, - // A consecutive list of child node indices for each node, where the - // children of node i are located at children[offset..offset+num_children] - // in the children vector. - children: Vec, - // Per-instruction-ID index of Pareto-optimal node indices. - // Built by `build_pareto_index()` after all nodes have been added. - pareto_index: FxHashMap>, -} - -impl Default for ProvenanceGraph { - fn default() -> Self { - // Initialize with a dummy node at index 0 to simplify indexing logic - // (so that 0 can be used as a "null" provenance) - let empty = ProvenanceNode::default(); - ProvenanceGraph { - nodes: vec![empty], - children: Vec::new(), - pareto_index: FxHashMap::default(), - } - } -} - -/// Thin wrapper for 3D Pareto comparison of instructions at arity 1. -struct InstructionParetoItem { - node_index: usize, - space: u64, - time: u64, - error: f64, -} - -impl crate::ParetoItem3D for InstructionParetoItem { - type Objective1 = u64; - type Objective2 = u64; - type Objective3 = f64; - - fn objective1(&self) -> u64 { - self.space - } - fn objective2(&self) -> u64 { - self.time - } - fn objective3(&self) -> f64 { - self.error - } -} - -impl ProvenanceGraph { - #[must_use] - pub fn new() -> Self { - Self::default() - } - - pub fn add_node( - &mut self, - mut instruction: Instruction, - transform_id: u64, - children: &[usize], - ) -> usize { - let node_index = self.nodes.len(); - instruction.source = node_index; - let offset = self.children.len(); - let num_children = children.len(); - self.children.extend_from_slice(children); - self.nodes.push(ProvenanceNode { - instruction, - transform_id, - offset, - num_children, - }); - node_index - } - - #[must_use] - pub fn instruction(&self, node_index: usize) -> &Instruction { - &self.nodes[node_index].instruction - } - - #[must_use] - pub fn transform_id(&self, node_index: usize) -> u64 { - self.nodes[node_index].transform_id - } - - #[must_use] - pub fn children(&self, node_index: usize) -> &[usize] { - let node = &self.nodes[node_index]; - &self.children[node.offset..node.offset + node.num_children] - } - - #[must_use] - pub fn num_nodes(&self) -> usize { - self.nodes.len() - 1 - } - - #[must_use] - pub fn num_edges(&self) -> usize { - self.children.len() - } - - /// Builds the per-instruction-ID Pareto index. - /// - /// For each instruction ID in the graph, collects all nodes and retains - /// only the Pareto-optimal subset with respect to (space, time, `error_rate`) - /// evaluated at arity 1. Instructions with different encodings or - /// properties are never in competition. - /// - /// Must be called after all nodes have been added. - pub fn build_pareto_index(&mut self) { - // Group node indices by (instruction_id, encoding, properties) - let mut groups: FxHashMap> = FxHashMap::default(); - for idx in 1..self.nodes.len() { - let instr = &self.nodes[idx].instruction; - groups.entry(instr.id).or_default().push(idx); - } - - let mut pareto_index = FxHashMap::default(); - for (id, node_indices) in groups { - // Sub-partition by encoding and property keys to avoid comparing - // incompatible instructions (Risk R2 mitigation) - #[allow(clippy::type_complexity)] - let mut sub_groups: FxHashMap<(Encoding, Vec<(u64, u64)>), Vec> = - FxHashMap::default(); - for &idx in &node_indices { - let instr = &self.nodes[idx].instruction; - let mut prop_vec: Vec<(u64, u64)> = instr - .properties - .as_ref() - .map(|p| { - let mut v: Vec<_> = p.iter().map(|(&k, &v)| (k, v)).collect(); - v.sort_unstable(); - v - }) - .unwrap_or_default(); - prop_vec.sort_unstable(); - sub_groups - .entry((instr.encoding, prop_vec)) - .or_default() - .push(idx); - } - - let mut pareto_nodes = Vec::new(); - for (_key, indices) in sub_groups { - let items: Vec = indices - .iter() - .filter_map(|&idx| { - let instr = &self.nodes[idx].instruction; - let space = instr.space(Some(1))?; - let time = instr.time(Some(1))?; - let error = instr.error_rate(Some(1))?; - Some(InstructionParetoItem { - node_index: idx, - space, - time, - error, - }) - }) - .collect(); - - let frontier: ParetoFrontier3D = items.into_iter().collect(); - pareto_nodes.extend(frontier.into_iter().map(|item| item.node_index)); - } - - pareto_index.insert(id, pareto_nodes); - } - - self.pareto_index = pareto_index; - } - - /// Returns the Pareto-optimal node indices for a given instruction ID. - #[must_use] - pub fn pareto_nodes(&self, instruction_id: u64) -> Option<&[usize]> { - self.pareto_index.get(&instruction_id).map(Vec::as_slice) - } - - /// Returns all instruction IDs that have Pareto-optimal entries. - #[must_use] - pub fn pareto_instruction_ids(&self) -> Vec { - self.pareto_index.keys().copied().collect() - } - - /// Returns the raw node count (including the sentinel at index 0). - #[must_use] - pub fn raw_node_count(&self) -> usize { - self.nodes.len() - } - - /// Returns the total number of ISAs that can be formed from Pareto-optimal - /// nodes. - /// - /// Requires [`build_pareto_index`](Self::build_pareto_index) to have - /// been called. - #[must_use] - pub fn total_isa_count(&self) -> usize { - self.pareto_index.values().map(Vec::len).product() - } - - /// Returns ISAs formed from Pareto-optimal nodes that satisfy the given - /// requirements. - /// - /// For each constraint, selects matching Pareto-optimal nodes. Produces - /// the Cartesian product of per-constraint match sets, each augmented - /// with one representative node per unconstrained instruction ID (so - /// that returned ISAs contain entries for all instruction types in the - /// graph). - /// - /// When `min_node_idx` is `Some(n)`, only Pareto nodes with index ≥ n - /// are considered for constrained groups. Unconstrained "extra" nodes - /// are not filtered since they serve only as default placeholders. - /// - /// Requires [`build_pareto_index`](Self::build_pareto_index) to have - /// been called. - #[must_use] - pub fn query_satisfying( - &self, - graph_arc: &Arc>, - requirements: &ISARequirements, - min_node_idx: Option, - ) -> Vec { - let min_idx = min_node_idx.unwrap_or(0); - - let mut constrained_groups: Vec> = Vec::new(); - let mut constrained_ids: FxHashSet = FxHashSet::default(); - - for constraint in requirements.constraints.values() { - constrained_ids.insert(constraint.id()); - - // When a node range is specified, scan ALL nodes in the range - // instead of using the global Pareto index. The global index - // may have pruned nodes from this range as duplicates of - // earlier, equivalent nodes outside the range. - let matching: Vec<(u64, usize)> = if min_idx > 0 { - (min_idx..self.nodes.len()) - .filter(|&node_idx| { - let instr = &self.nodes[node_idx].instruction; - instr.id == constraint.id() && constraint.is_satisfied_by(instr) - }) - .map(|node_idx| (constraint.id(), node_idx)) - .collect() - } else { - let Some(pareto) = self.pareto_index.get(&constraint.id()) else { - return Vec::new(); - }; - pareto - .iter() - .filter(|&&node_idx| constraint.is_satisfied_by(self.instruction(node_idx))) - .map(|&node_idx| (constraint.id(), node_idx)) - .collect() - }; - - if matching.is_empty() { - return Vec::new(); - } - constrained_groups.push(matching); - } - - // One representative node per unconstrained instruction ID. - // When a Pareto index is available, use it; otherwise scan all - // nodes (this path is used during populate() before the index - // is built). - let extra_nodes: Vec<(u64, usize)> = if self.pareto_index.is_empty() { - let mut seen: FxHashMap = FxHashMap::default(); - for idx in 1..self.nodes.len() { - let id = self.nodes[idx].instruction.id; - if !constrained_ids.contains(&id) { - seen.entry(id).or_insert(idx); - } - } - seen.into_iter().collect() - } else { - self.pareto_index - .iter() - .filter(|(id, _)| !constrained_ids.contains(id)) - .filter_map(|(&id, nodes)| nodes.first().map(|&n| (id, n))) - .collect() - }; - - // Cartesian product of constrained groups - let mut combinations: Vec> = vec![Vec::new()]; - for group in &constrained_groups { - let mut next = Vec::with_capacity(combinations.len() * group.len()); - for combo in &combinations { - for &item in group { - let mut extended = combo.clone(); - extended.push(item); - next.push(extended); - } - } - combinations = next; - } - - // Build ISAs from selections - combinations - .into_iter() - .map(|mut combo| { - combo.extend(extra_nodes.iter().copied()); - let mut isa = ISA::with_graph(Arc::clone(graph_arc)); - for (id, node_idx) in combo { - isa.add_node(id, node_idx); - } - isa - }) - .collect() - } -} - -struct ProvenanceNode { - instruction: Instruction, - transform_id: u64, - offset: usize, - num_children: usize, -} - -impl Default for ProvenanceNode { - fn default() -> Self { - ProvenanceNode { - instruction: Instruction::fixed_arity(0, Encoding::Physical, 0, 0, None, None, 0.0), - transform_id: 0, - offset: 0, - num_children: 0, - } - } -} diff --git a/source/qre/src/isa/provenance.rs b/source/qre/src/isa/provenance.rs new file mode 100644 index 0000000000..8b59660639 --- /dev/null +++ b/source/qre/src/isa/provenance.rs @@ -0,0 +1,332 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::sync::{Arc, RwLock}; + +use rustc_hash::{FxHashMap, FxHashSet}; + +use crate::{Encoding, ISA, ISARequirements, Instruction, ParetoFrontier3D}; + +pub struct ProvenanceGraph { + nodes: Vec, + // A consecutive list of child node indices for each node, where the + // children of node i are located at children[offset..offset+num_children] + // in the children vector. + children: Vec, + // Per-instruction-ID index of Pareto-optimal node indices. + // Built by `build_pareto_index()` after all nodes have been added. + pareto_index: FxHashMap>, +} + +impl Default for ProvenanceGraph { + fn default() -> Self { + // Initialize with a dummy node at index 0 to simplify indexing logic + // (so that 0 can be used as a "null" provenance) + let empty = ProvenanceNode::default(); + ProvenanceGraph { + nodes: vec![empty], + children: Vec::new(), + pareto_index: FxHashMap::default(), + } + } +} + +/// Thin wrapper for 3D Pareto comparison of instructions at arity 1. +struct InstructionParetoItem { + node_index: usize, + space: u64, + time: u64, + error: f64, +} + +impl crate::ParetoItem3D for InstructionParetoItem { + type Objective1 = u64; + type Objective2 = u64; + type Objective3 = f64; + + fn objective1(&self) -> u64 { + self.space + } + fn objective2(&self) -> u64 { + self.time + } + fn objective3(&self) -> f64 { + self.error + } +} + +impl ProvenanceGraph { + #[must_use] + pub fn new() -> Self { + Self::default() + } + + pub fn add_node( + &mut self, + mut instruction: Instruction, + transform_id: u64, + children: &[usize], + ) -> usize { + let node_index = self.nodes.len(); + instruction.source = node_index; + let offset = self.children.len(); + let num_children = children.len(); + self.children.extend_from_slice(children); + self.nodes.push(ProvenanceNode { + instruction, + transform_id, + offset, + num_children, + }); + node_index + } + + #[must_use] + pub fn instruction(&self, node_index: usize) -> &Instruction { + &self.nodes[node_index].instruction + } + + #[must_use] + pub fn transform_id(&self, node_index: usize) -> u64 { + self.nodes[node_index].transform_id + } + + #[must_use] + pub fn children(&self, node_index: usize) -> &[usize] { + let node = &self.nodes[node_index]; + &self.children[node.offset..node.offset + node.num_children] + } + + #[must_use] + pub fn num_nodes(&self) -> usize { + self.nodes.len() - 1 + } + + #[must_use] + pub fn num_edges(&self) -> usize { + self.children.len() + } + + /// Builds the per-instruction-ID Pareto index. + /// + /// For each instruction ID in the graph, collects all nodes and retains + /// only the Pareto-optimal subset with respect to (space, time, `error_rate`) + /// evaluated at arity 1. Instructions with different encodings or + /// properties are never in competition. + /// + /// Must be called after all nodes have been added. + pub fn build_pareto_index(&mut self) { + // Group node indices by (instruction_id, encoding, properties) + let mut groups: FxHashMap> = FxHashMap::default(); + for idx in 1..self.nodes.len() { + let instr = &self.nodes[idx].instruction; + groups.entry(instr.id).or_default().push(idx); + } + + let mut pareto_index = FxHashMap::default(); + for (id, node_indices) in groups { + // Sub-partition by encoding and property keys to avoid comparing + // incompatible instructions (Risk R2 mitigation) + #[allow(clippy::type_complexity)] + let mut sub_groups: FxHashMap<(Encoding, Vec<(u64, u64)>), Vec> = + FxHashMap::default(); + for &idx in &node_indices { + let instr = &self.nodes[idx].instruction; + let mut prop_vec: Vec<(u64, u64)> = instr + .properties + .as_ref() + .map(|p| { + let mut v: Vec<_> = p.iter().map(|(&k, &v)| (k, v)).collect(); + v.sort_unstable(); + v + }) + .unwrap_or_default(); + prop_vec.sort_unstable(); + sub_groups + .entry((instr.encoding, prop_vec)) + .or_default() + .push(idx); + } + + let mut pareto_nodes = Vec::new(); + for (_key, indices) in sub_groups { + let items: Vec = indices + .iter() + .filter_map(|&idx| { + let instr = &self.nodes[idx].instruction; + let space = instr.space(Some(1))?; + let time = instr.time(Some(1))?; + let error = instr.error_rate(Some(1))?; + Some(InstructionParetoItem { + node_index: idx, + space, + time, + error, + }) + }) + .collect(); + + let frontier: ParetoFrontier3D = items.into_iter().collect(); + pareto_nodes.extend(frontier.into_iter().map(|item| item.node_index)); + } + + pareto_index.insert(id, pareto_nodes); + } + + self.pareto_index = pareto_index; + } + + /// Returns the Pareto-optimal node indices for a given instruction ID. + #[must_use] + pub fn pareto_nodes(&self, instruction_id: u64) -> Option<&[usize]> { + self.pareto_index.get(&instruction_id).map(Vec::as_slice) + } + + /// Returns all instruction IDs that have Pareto-optimal entries. + #[must_use] + pub fn pareto_instruction_ids(&self) -> Vec { + self.pareto_index.keys().copied().collect() + } + + /// Returns the raw node count (including the sentinel at index 0). + #[must_use] + pub fn raw_node_count(&self) -> usize { + self.nodes.len() + } + + /// Returns the total number of ISAs that can be formed from Pareto-optimal + /// nodes. + /// + /// Requires [`build_pareto_index`](Self::build_pareto_index) to have + /// been called. + #[must_use] + pub fn total_isa_count(&self) -> usize { + self.pareto_index.values().map(Vec::len).product() + } + + /// Returns ISAs formed from Pareto-optimal nodes that satisfy the given + /// requirements. + /// + /// For each constraint, selects matching Pareto-optimal nodes. Produces + /// the Cartesian product of per-constraint match sets, each augmented + /// with one representative node per unconstrained instruction ID (so + /// that returned ISAs contain entries for all instruction types in the + /// graph). + /// + /// When `min_node_idx` is `Some(n)`, only Pareto nodes with index ≥ n + /// are considered for constrained groups. Unconstrained "extra" nodes + /// are not filtered since they serve only as default placeholders. + /// + /// Requires [`build_pareto_index`](Self::build_pareto_index) to have + /// been called. + #[must_use] + pub fn query_satisfying( + &self, + graph_arc: &Arc>, + requirements: &ISARequirements, + min_node_idx: Option, + ) -> Vec { + let min_idx = min_node_idx.unwrap_or(0); + + let mut constrained_groups: Vec> = Vec::new(); + let mut constrained_ids: FxHashSet = FxHashSet::default(); + + for constraint in requirements.constraints.values() { + constrained_ids.insert(constraint.id()); + + // When a node range is specified, scan ALL nodes in the range + // instead of using the global Pareto index. The global index + // may have pruned nodes from this range as duplicates of + // earlier, equivalent nodes outside the range. + let matching: Vec<(u64, usize)> = if min_idx > 0 { + (min_idx..self.nodes.len()) + .filter(|&node_idx| { + let instr = &self.nodes[node_idx].instruction; + instr.id == constraint.id() && constraint.is_satisfied_by(instr) + }) + .map(|node_idx| (constraint.id(), node_idx)) + .collect() + } else { + let Some(pareto) = self.pareto_index.get(&constraint.id()) else { + return Vec::new(); + }; + pareto + .iter() + .filter(|&&node_idx| constraint.is_satisfied_by(self.instruction(node_idx))) + .map(|&node_idx| (constraint.id(), node_idx)) + .collect() + }; + + if matching.is_empty() { + return Vec::new(); + } + constrained_groups.push(matching); + } + + // One representative node per unconstrained instruction ID. + // When a Pareto index is available, use it; otherwise scan all + // nodes (this path is used during populate() before the index + // is built). + let extra_nodes: Vec<(u64, usize)> = if self.pareto_index.is_empty() { + let mut seen: FxHashMap = FxHashMap::default(); + for idx in 1..self.nodes.len() { + let id = self.nodes[idx].instruction.id; + if !constrained_ids.contains(&id) { + seen.entry(id).or_insert(idx); + } + } + seen.into_iter().collect() + } else { + self.pareto_index + .iter() + .filter(|(id, _)| !constrained_ids.contains(id)) + .filter_map(|(&id, nodes)| nodes.first().map(|&n| (id, n))) + .collect() + }; + + // Cartesian product of constrained groups + let mut combinations: Vec> = vec![Vec::new()]; + for group in &constrained_groups { + let mut next = Vec::with_capacity(combinations.len() * group.len()); + for combo in &combinations { + for &item in group { + let mut extended = combo.clone(); + extended.push(item); + next.push(extended); + } + } + combinations = next; + } + + // Build ISAs from selections + combinations + .into_iter() + .map(|mut combo| { + combo.extend(extra_nodes.iter().copied()); + let mut isa = ISA::with_graph(Arc::clone(graph_arc)); + for (id, node_idx) in combo { + isa.add_node(id, node_idx); + } + isa + }) + .collect() + } +} + +struct ProvenanceNode { + instruction: Instruction, + transform_id: u64, + offset: usize, + num_children: usize, +} + +impl Default for ProvenanceNode { + fn default() -> Self { + ProvenanceNode { + instruction: Instruction::fixed_arity(0, Encoding::Physical, 0, 0, None, None, 0.0), + transform_id: 0, + offset: 0, + num_children: 0, + } + } +} diff --git a/source/qre/src/trace.rs b/source/qre/src/trace.rs index 08858a4551..0e2d1ef106 100644 --- a/source/qre/src/trace.rs +++ b/source/qre/src/trace.rs @@ -2,11 +2,7 @@ // Licensed under the MIT License. use std::{ - collections::hash_map::DefaultHasher, fmt::{Display, Formatter}, - hash::{Hash, Hasher}, - iter::repeat_with, - sync::{Arc, RwLock, atomic::AtomicUsize}, vec, }; @@ -14,14 +10,17 @@ use rustc_hash::{FxHashMap, FxHashSet}; use serde::{Deserialize, Serialize}; use crate::{ - ConstraintBound, Encoding, Error, EstimationCollection, EstimationResult, FactoryResult, ISA, - ISARequirements, Instruction, InstructionConstraint, LockedISA, ProvenanceGraph, ResultSummary, + ConstraintBound, Encoding, Error, EstimationResult, FactoryResult, ISA, ISARequirements, + Instruction, InstructionConstraint, LockedISA, property_keys::{ LOGICAL_COMPUTE_QUBITS, LOGICAL_MEMORY_QUBITS, PHYSICAL_COMPUTE_QUBITS, PHYSICAL_FACTORY_QUBITS, PHYSICAL_MEMORY_QUBITS, }, }; +mod estimation; +pub use estimation::{estimate_parallel, estimate_with_graph}; + pub mod instruction_ids; use instruction_ids::instruction_name; #[cfg(test)] @@ -752,452 +751,3 @@ fn get_error_rate_by_id(isa: &LockedISA<'_>, id: u64) -> Result { .error_rate(None) .ok_or(Error::CannotExtractErrorRate(id)) } - -/// Estimates all (trace, ISA) combinations in parallel, returning only the -/// successful results collected into an [`EstimationCollection`]. -/// -/// This uses a shared atomic counter as a lock-free work queue. Each worker -/// thread atomically claims the next job index, maps it to a `(trace, isa)` -/// pair, and runs the estimation. This keeps all available cores busy until -/// the last job completes. -/// -/// # Work distribution -/// -/// Jobs are numbered `0 .. traces.len() * isas.len()`. For job index `j`: -/// - `trace_idx = j / isas.len()` -/// - `isa_idx = j % isas.len()` -/// -/// Each worker accumulates results locally and sends them back over a bounded -/// channel once it runs out of work, avoiding contention on the shared -/// collection. -#[must_use] -pub fn estimate_parallel<'a>( - traces: &[&'a Trace], - isas: &[&'a ISA], - max_error: Option, - post_process: bool, -) -> EstimationCollection { - let total_jobs = traces.len() * isas.len(); - let num_isas = isas.len(); - - // Shared atomic counter acts as a lock-free work queue. Workers call - // fetch_add to claim the next job index. - let next_job = AtomicUsize::new(0); - - let mut collection = EstimationCollection::new(); - collection.set_total_jobs(total_jobs); - - std::thread::scope(|scope| { - let num_threads = std::thread::available_parallelism() - .map(std::num::NonZero::get) - .unwrap_or(1); - - // Bounded channel so each worker can send its batch of results back - // to the main thread without unbounded buffering. - let (tx, rx) = std::sync::mpsc::sync_channel(num_threads); - - for _ in 0..num_threads { - let tx = tx.clone(); - let next_job = &next_job; - scope.spawn(move || { - let mut local_results = Vec::new(); - loop { - // Atomically claim the next job. Relaxed ordering is - // sufficient because there is no dependent data between - // jobs — each (trace, isa) pair is independent. - let job = next_job.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - if job >= total_jobs { - break; - } - - // Map the flat job index to a (trace, ISA) pair. - let trace_idx = job / num_isas; - let isa_idx = job % num_isas; - - if let Ok(mut estimation) = traces[trace_idx].estimate(isas[isa_idx], max_error) - { - estimation.set_isa_index(isa_idx); - estimation.set_trace_index(trace_idx); - - local_results.push(estimation); - } - } - // Send all results from this worker in one batch. - let _ = tx.send(local_results); - }); - } - // Drop the cloned sender so the receiver iterator terminates once all - // workers have finished. - drop(tx); - - // Collect results from all workers into the shared collection. - let mut successful = 0; - for local_results in rx { - if post_process { - for result in &local_results { - collection.push_summary(ResultSummary { - trace_index: result.trace_index().unwrap_or(0), - isa_index: result.isa_index().unwrap_or(0), - qubits: result.qubits(), - runtime: result.runtime(), - }); - } - } - successful += local_results.len(); - collection.extend(local_results.into_iter()); - } - collection.set_successful_estimates(successful); - }); - - // Attach ISAs only to Pareto-surviving results, avoiding O(M) HashMap - // clones for discarded results. - for result in collection.iter_mut() { - if let Some(idx) = result.isa_index() { - result.set_isa(isas[idx].clone()); - } - } - - collection -} - -/// A node in the provenance graph along with pre-computed (space, time) values -/// for pruning. -#[derive(Clone, Copy, Hash, PartialEq, Eq)] -struct NodeProfile { - node_index: usize, - space: u64, - time: u64, -} - -/// A single entry in a combination of instruction choices for estimation. -#[derive(Clone, Copy, Hash, Eq, PartialEq)] -struct CombinationEntry { - instruction_id: u64, - node: NodeProfile, -} - -/// Per-slot pruning witnesses: maps a context hash to the `(space, time)` -/// pairs observed in successful estimations. -type SlotWitnesses = RwLock>>; - -/// Computes a hash of the combination context (all slots except the excluded -/// one). Two combinations that agree on every slot except `exclude_idx` -/// produce the same context hash. -fn combination_context_hash(combination: &[CombinationEntry], exclude_idx: usize) -> u64 { - let mut hasher = DefaultHasher::new(); - for (i, entry) in combination.iter().enumerate() { - if i != exclude_idx { - entry.instruction_id.hash(&mut hasher); - entry.node.node_index.hash(&mut hasher); - } - } - hasher.finish() -} - -/// Checks whether a combination is dominated by a previously successful one. -/// -/// A combination is prunable if, for any instruction slot, there exists a -/// successful combination with the same instructions in all other slots and -/// an instruction at that slot with `space <=` and `time <=`. -fn is_dominated(combination: &[CombinationEntry], trace_pruning: &[SlotWitnesses]) -> bool { - for (slot_idx, entry) in combination.iter().enumerate() { - let ctx_hash = combination_context_hash(combination, slot_idx); - let map = trace_pruning[slot_idx] - .read() - .expect("Pruning lock poisoned"); - if map.get(&ctx_hash).is_some_and(|w| { - w.iter() - .any(|&(ws, wt)| ws <= entry.node.space && wt <= entry.node.time) - }) { - return true; - } - } - false -} - -/// Records a successful estimation as a pruning witness for future -/// combinations. -fn record_success(combination: &[CombinationEntry], trace_pruning: &[SlotWitnesses]) { - for (slot_idx, entry) in combination.iter().enumerate() { - let ctx_hash = combination_context_hash(combination, slot_idx); - let mut map = trace_pruning[slot_idx] - .write() - .expect("Pruning lock poisoned"); - map.entry(ctx_hash) - .or_default() - .push((entry.node.space, entry.node.time)); - } -} - -#[derive(Default)] -struct ISAIndex { - index: FxHashMap, usize>, - isas: Vec, -} - -impl From for Vec { - fn from(value: ISAIndex) -> Self { - value.isas - } -} - -impl ISAIndex { - pub fn push(&mut self, combination: &Vec, isa: &ISA) -> usize { - if let Some(&idx) = self.index.get(combination) { - idx - } else { - let idx = self.isas.len(); - self.isas.push(isa.clone()); - self.index.insert(combination.clone(), idx); - idx - } - } -} - -/// Generates the cartesian product of `id_and_nodes` and pushes each -/// combination directly into `jobs`, avoiding intermediate allocations. -/// -/// The cartesian product is enumerated using mixed-radix indexing. Given -/// dimensions with sizes `[n0, n1, n2, …]`, the total number of combinations -/// is `n0 * n1 * n2 * …`. Each combination index `i` in `0..total` uniquely -/// identifies one element from every dimension: the index into dimension `d` is -/// `(i / (n0 * n1 * … * n(d-1))) % nd`, which we compute incrementally by -/// repeatedly taking `i % nd` and then dividing `i` by `nd`. This is -/// analogous to extracting digits from a number in a mixed-radix system. -fn push_cartesian_product( - id_and_nodes: &[(u64, Vec)], - trace_idx: usize, - jobs: &mut Vec<(usize, Vec)>, - max_slots: &mut usize, -) { - // The product of all dimension sizes gives the total number of - // combinations. If any dimension is empty the product is zero and there - // are no valid combinations to generate. - let total: usize = id_and_nodes.iter().map(|(_, nodes)| nodes.len()).product(); - if total == 0 { - return; - } - - *max_slots = (*max_slots).max(id_and_nodes.len()); - jobs.reserve(total); - - // Enumerate every combination by treating the combination index `i` as a - // mixed-radix number. The inner loop "peels off" one digit per dimension: - // node_idx = i % nodes.len() — selects this dimension's element - // i /= nodes.len() — shifts to the next dimension's digit - // After processing all dimensions, `i` is exhausted (becomes 0), and - // `combo` contains exactly one entry per instruction id. - for mut i in 0..total { - let mut combo = Vec::with_capacity(id_and_nodes.len()); - for (id, nodes) in id_and_nodes { - let node_idx = i % nodes.len(); - i /= nodes.len(); - let profile = nodes[node_idx]; - combo.push(CombinationEntry { - instruction_id: *id, - node: profile, - }); - } - jobs.push((trace_idx, combo)); - } -} - -#[must_use] -#[allow(clippy::cast_precision_loss, clippy::too_many_lines)] -pub fn estimate_with_graph( - traces: &[&Trace], - graph: &Arc>, - max_error: Option, - post_process: bool, -) -> EstimationCollection { - let max_error = max_error.unwrap_or(1.0); - - // Phase 1: Pre-compute all (trace_index, combination) jobs sequentially. - // This reads the provenance graph once per trace and generates the - // cartesian product of Pareto-filtered nodes. Each node carries - // pre-computed (space, time) values for dominance pruning in Phase 2. - let mut jobs: Vec<(usize, Vec)> = Vec::new(); - - // Use the maximum number of instruction slots across all combinations to - // size the pruning witness structure. This will updated while we generate - // jobs. - let mut max_slots = 0; - - for (trace_idx, trace) in traces.iter().enumerate() { - if trace.base_error() > max_error { - continue; - } - - let required = trace.required_instruction_ids(Some(max_error)); - - let graph_lock = graph.read().expect("Graph lock poisoned"); - let id_and_nodes: Vec<_> = required - .constraints() - .iter() - .filter_map(|constraint| { - graph_lock.pareto_nodes(constraint.id()).map(|nodes| { - ( - constraint.id(), - nodes - .iter() - .filter(|&&node| { - // Filter out nodes that don't meet the constraint bounds. - let instruction = graph_lock.instruction(node); - constraint.error_rate().is_none_or(|c| { - c.evaluate(&instruction.error_rate(Some(1)).unwrap_or(0.0)) - }) - }) - .map(|&node| { - let instruction = graph_lock.instruction(node); - let space = instruction.space(Some(1)).unwrap_or(0); - let time = instruction.time(Some(1)).unwrap_or(0); - NodeProfile { - node_index: node, - space, - time, - } - }) - .collect::>(), - ) - }) - }) - .collect(); - drop(graph_lock); - - if id_and_nodes.len() != required.len() { - // If any required instruction is missing from the graph, we can't - // run any estimation for this trace. - continue; - } - - push_cartesian_product(&id_and_nodes, trace_idx, &mut jobs, &mut max_slots); - } - - // Sort jobs so that combinations with smaller total (space + time) are - // processed first. This maximises the effectiveness of dominance pruning - // because successful "cheap" combinations establish witnesses that let us - // skip more expensive ones. - jobs.sort_by_key(|(_, combo)| { - combo - .iter() - .map(|entry| entry.node.space + entry.node.time) - .sum::() - }); - - let total_jobs = jobs.len(); - - // Phase 2: Run estimations in parallel with dominance-based pruning. - // - // For each instruction slot in a combination, we track (space, time) - // witnesses from successful estimations keyed by the "context", which is a - // hash of the node indices in all *other* slots. Before running an - // estimation, we check every slot: if a witness with space ≤ and time ≤ - // exists for that context, the combination is dominated and skipped. - let next_job = AtomicUsize::new(0); - - let pruning_witnesses: Vec> = repeat_with(|| { - repeat_with(|| RwLock::new(FxHashMap::default())) - .take(max_slots) - .collect() - }) - .take(traces.len()) - .collect(); - - // There are no explicit ISAs in this estimation function, as we create them - // on the fly from the graph nodes. For successful jobs, we will attach the - // ISAs to the results collection in a vector with the ISA index addressing - // that vector. In order to avoid storing duplicate ISAs we hash the ISA - // index. - let isa_index = Arc::new(RwLock::new(ISAIndex::default())); - - let mut collection = EstimationCollection::new(); - collection.set_total_jobs(total_jobs); - - std::thread::scope(|scope| { - let num_threads = std::thread::available_parallelism() - .map(std::num::NonZero::get) - .unwrap_or(1); - - let (tx, rx) = std::sync::mpsc::sync_channel(num_threads); - - for _ in 0..num_threads { - let tx = tx.clone(); - let next_job = &next_job; - let jobs = &jobs; - let pruning_witnesses = &pruning_witnesses; - let isa_index = Arc::clone(&isa_index); - scope.spawn(move || { - let mut local_results = Vec::new(); - loop { - let job_idx = next_job.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - if job_idx >= total_jobs { - break; - } - - let (trace_idx, combination) = &jobs[job_idx]; - - // Dominance pruning: skip if a cheaper instruction at any - // slot already succeeded with the same surrounding context. - if is_dominated(combination, &pruning_witnesses[*trace_idx]) { - continue; - } - - let mut isa = ISA::with_graph(graph.clone()); - for entry in combination { - isa.add_node(entry.instruction_id, entry.node.node_index); - } - - if let Ok(mut result) = traces[*trace_idx].estimate(&isa, Some(max_error)) { - let isa_idx = isa_index - .write() - .expect("RwLock should not be poisoned") - .push(combination, &isa); - result.set_isa_index(isa_idx); - - result.set_trace_index(*trace_idx); - - local_results.push(result); - record_success(combination, &pruning_witnesses[*trace_idx]); - } - } - let _ = tx.send(local_results); - }); - } - drop(tx); - - let mut successful = 0; - for local_results in rx { - if post_process { - for result in &local_results { - collection.push_summary(ResultSummary { - trace_index: result.trace_index().unwrap_or(0), - isa_index: result.isa_index().unwrap_or(0), - qubits: result.qubits(), - runtime: result.runtime(), - }); - } - } - successful += local_results.len(); - collection.extend(local_results.into_iter()); - } - collection.set_successful_estimates(successful); - }); - - let isa_index = Arc::try_unwrap(isa_index) - .ok() - .expect("all threads joined; Arc refcount should be 1") - .into_inner() - .expect("RwLock should not be poisoned"); - - // Attach ISAs only to Pareto-surviving results, avoiding O(M) HashMap - // clones for discarded results. - for result in collection.iter_mut() { - if let Some(idx) = result.isa_index() { - result.set_isa(isa_index.isas[idx].clone()); - } - } - - collection.set_isas(isa_index.into()); - - collection -} diff --git a/source/qre/src/trace/estimation.rs b/source/qre/src/trace/estimation.rs new file mode 100644 index 0000000000..b75ab35fde --- /dev/null +++ b/source/qre/src/trace/estimation.rs @@ -0,0 +1,462 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::{ + collections::hash_map::DefaultHasher, + hash::{Hash, Hasher}, + iter::repeat_with, + sync::{Arc, RwLock, atomic::AtomicUsize}, +}; + +use rustc_hash::FxHashMap; + +use crate::{EstimationCollection, ISA, ProvenanceGraph, ResultSummary, Trace}; + +/// Estimates all (trace, ISA) combinations in parallel, returning only the +/// successful results collected into an [`EstimationCollection`]. +/// +/// This uses a shared atomic counter as a lock-free work queue. Each worker +/// thread atomically claims the next job index, maps it to a `(trace, isa)` +/// pair, and runs the estimation. This keeps all available cores busy until +/// the last job completes. +/// +/// # Work distribution +/// +/// Jobs are numbered `0 .. traces.len() * isas.len()`. For job index `j`: +/// - `trace_idx = j / isas.len()` +/// - `isa_idx = j % isas.len()` +/// +/// Each worker accumulates results locally and sends them back over a bounded +/// channel once it runs out of work, avoiding contention on the shared +/// collection. +#[must_use] +pub fn estimate_parallel<'a>( + traces: &[&'a Trace], + isas: &[&'a ISA], + max_error: Option, + post_process: bool, +) -> EstimationCollection { + let total_jobs = traces.len() * isas.len(); + let num_isas = isas.len(); + + // Shared atomic counter acts as a lock-free work queue. Workers call + // fetch_add to claim the next job index. + let next_job = AtomicUsize::new(0); + + let mut collection = EstimationCollection::new(); + collection.set_total_jobs(total_jobs); + + std::thread::scope(|scope| { + let num_threads = std::thread::available_parallelism() + .map(std::num::NonZero::get) + .unwrap_or(1); + + // Bounded channel so each worker can send its batch of results back + // to the main thread without unbounded buffering. + let (tx, rx) = std::sync::mpsc::sync_channel(num_threads); + + for _ in 0..num_threads { + let tx = tx.clone(); + let next_job = &next_job; + scope.spawn(move || { + let mut local_results = Vec::new(); + loop { + // Atomically claim the next job. Relaxed ordering is + // sufficient because there is no dependent data between + // jobs — each (trace, isa) pair is independent. + let job = next_job.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + if job >= total_jobs { + break; + } + + // Map the flat job index to a (trace, ISA) pair. + let trace_idx = job / num_isas; + let isa_idx = job % num_isas; + + if let Ok(mut estimation) = traces[trace_idx].estimate(isas[isa_idx], max_error) + { + estimation.set_isa_index(isa_idx); + estimation.set_trace_index(trace_idx); + + local_results.push(estimation); + } + } + // Send all results from this worker in one batch. + let _ = tx.send(local_results); + }); + } + // Drop the cloned sender so the receiver iterator terminates once all + // workers have finished. + drop(tx); + + // Collect results from all workers into the shared collection. + let mut successful = 0; + for local_results in rx { + if post_process { + for result in &local_results { + collection.push_summary(ResultSummary { + trace_index: result.trace_index().unwrap_or(0), + isa_index: result.isa_index().unwrap_or(0), + qubits: result.qubits(), + runtime: result.runtime(), + }); + } + } + successful += local_results.len(); + collection.extend(local_results.into_iter()); + } + collection.set_successful_estimates(successful); + }); + + // Attach ISAs only to Pareto-surviving results, avoiding O(M) HashMap + // clones for discarded results. + for result in collection.iter_mut() { + if let Some(idx) = result.isa_index() { + result.set_isa(isas[idx].clone()); + } + } + + collection +} + +/// A node in the provenance graph along with pre-computed (space, time) values +/// for pruning. +#[derive(Clone, Copy, Hash, PartialEq, Eq)] +struct NodeProfile { + node_index: usize, + space: u64, + time: u64, +} + +/// A single entry in a combination of instruction choices for estimation. +#[derive(Clone, Copy, Hash, Eq, PartialEq)] +struct CombinationEntry { + instruction_id: u64, + node: NodeProfile, +} + +/// Per-slot pruning witnesses: maps a context hash to the `(space, time)` +/// pairs observed in successful estimations. +type SlotWitnesses = RwLock>>; + +/// Computes a hash of the combination context (all slots except the excluded +/// one). Two combinations that agree on every slot except `exclude_idx` +/// produce the same context hash. +fn combination_context_hash(combination: &[CombinationEntry], exclude_idx: usize) -> u64 { + let mut hasher = DefaultHasher::new(); + for (i, entry) in combination.iter().enumerate() { + if i != exclude_idx { + entry.instruction_id.hash(&mut hasher); + entry.node.node_index.hash(&mut hasher); + } + } + hasher.finish() +} + +/// Checks whether a combination is dominated by a previously successful one. +/// +/// A combination is prunable if, for any instruction slot, there exists a +/// successful combination with the same instructions in all other slots and +/// an instruction at that slot with `space <=` and `time <=`. +fn is_dominated(combination: &[CombinationEntry], trace_pruning: &[SlotWitnesses]) -> bool { + for (slot_idx, entry) in combination.iter().enumerate() { + let ctx_hash = combination_context_hash(combination, slot_idx); + let map = trace_pruning[slot_idx] + .read() + .expect("Pruning lock poisoned"); + if map.get(&ctx_hash).is_some_and(|w| { + w.iter() + .any(|&(ws, wt)| ws <= entry.node.space && wt <= entry.node.time) + }) { + return true; + } + } + false +} + +/// Records a successful estimation as a pruning witness for future +/// combinations. +fn record_success(combination: &[CombinationEntry], trace_pruning: &[SlotWitnesses]) { + for (slot_idx, entry) in combination.iter().enumerate() { + let ctx_hash = combination_context_hash(combination, slot_idx); + let mut map = trace_pruning[slot_idx] + .write() + .expect("Pruning lock poisoned"); + map.entry(ctx_hash) + .or_default() + .push((entry.node.space, entry.node.time)); + } +} + +#[derive(Default)] +struct ISAIndex { + index: FxHashMap, usize>, + isas: Vec, +} + +impl From for Vec { + fn from(value: ISAIndex) -> Self { + value.isas + } +} + +impl ISAIndex { + pub fn push(&mut self, combination: &Vec, isa: &ISA) -> usize { + if let Some(&idx) = self.index.get(combination) { + idx + } else { + let idx = self.isas.len(); + self.isas.push(isa.clone()); + self.index.insert(combination.clone(), idx); + idx + } + } +} + +/// Generates the cartesian product of `id_and_nodes` and pushes each +/// combination directly into `jobs`, avoiding intermediate allocations. +/// +/// The cartesian product is enumerated using mixed-radix indexing. Given +/// dimensions with sizes `[n0, n1, n2, …]`, the total number of combinations +/// is `n0 * n1 * n2 * …`. Each combination index `i` in `0..total` uniquely +/// identifies one element from every dimension: the index into dimension `d` is +/// `(i / (n0 * n1 * … * n(d-1))) % nd`, which we compute incrementally by +/// repeatedly taking `i % nd` and then dividing `i` by `nd`. This is +/// analogous to extracting digits from a number in a mixed-radix system. +fn push_cartesian_product( + id_and_nodes: &[(u64, Vec)], + trace_idx: usize, + jobs: &mut Vec<(usize, Vec)>, + max_slots: &mut usize, +) { + // The product of all dimension sizes gives the total number of + // combinations. If any dimension is empty the product is zero and there + // are no valid combinations to generate. + let total: usize = id_and_nodes.iter().map(|(_, nodes)| nodes.len()).product(); + if total == 0 { + return; + } + + *max_slots = (*max_slots).max(id_and_nodes.len()); + jobs.reserve(total); + + // Enumerate every combination by treating the combination index `i` as a + // mixed-radix number. The inner loop "peels off" one digit per dimension: + // node_idx = i % nodes.len() — selects this dimension's element + // i /= nodes.len() — shifts to the next dimension's digit + // After processing all dimensions, `i` is exhausted (becomes 0), and + // `combo` contains exactly one entry per instruction id. + for mut i in 0..total { + let mut combo = Vec::with_capacity(id_and_nodes.len()); + for (id, nodes) in id_and_nodes { + let node_idx = i % nodes.len(); + i /= nodes.len(); + let profile = nodes[node_idx]; + combo.push(CombinationEntry { + instruction_id: *id, + node: profile, + }); + } + jobs.push((trace_idx, combo)); + } +} + +#[must_use] +#[allow(clippy::cast_precision_loss, clippy::too_many_lines)] +pub fn estimate_with_graph( + traces: &[&Trace], + graph: &Arc>, + max_error: Option, + post_process: bool, +) -> EstimationCollection { + let max_error = max_error.unwrap_or(1.0); + + // Phase 1: Pre-compute all (trace_index, combination) jobs sequentially. + // This reads the provenance graph once per trace and generates the + // cartesian product of Pareto-filtered nodes. Each node carries + // pre-computed (space, time) values for dominance pruning in Phase 2. + let mut jobs: Vec<(usize, Vec)> = Vec::new(); + + // Use the maximum number of instruction slots across all combinations to + // size the pruning witness structure. This will updated while we generate + // jobs. + let mut max_slots = 0; + + for (trace_idx, trace) in traces.iter().enumerate() { + if trace.base_error() > max_error { + continue; + } + + let required = trace.required_instruction_ids(Some(max_error)); + + let graph_lock = graph.read().expect("Graph lock poisoned"); + let id_and_nodes: Vec<_> = required + .constraints() + .iter() + .filter_map(|constraint| { + graph_lock.pareto_nodes(constraint.id()).map(|nodes| { + ( + constraint.id(), + nodes + .iter() + .filter(|&&node| { + // Filter out nodes that don't meet the constraint bounds. + let instruction = graph_lock.instruction(node); + constraint.error_rate().is_none_or(|c| { + c.evaluate(&instruction.error_rate(Some(1)).unwrap_or(0.0)) + }) + }) + .map(|&node| { + let instruction = graph_lock.instruction(node); + let space = instruction.space(Some(1)).unwrap_or(0); + let time = instruction.time(Some(1)).unwrap_or(0); + NodeProfile { + node_index: node, + space, + time, + } + }) + .collect::>(), + ) + }) + }) + .collect(); + drop(graph_lock); + + if id_and_nodes.len() != required.len() { + // If any required instruction is missing from the graph, we can't + // run any estimation for this trace. + continue; + } + + push_cartesian_product(&id_and_nodes, trace_idx, &mut jobs, &mut max_slots); + } + + // Sort jobs so that combinations with smaller total (space + time) are + // processed first. This maximises the effectiveness of dominance pruning + // because successful "cheap" combinations establish witnesses that let us + // skip more expensive ones. + jobs.sort_by_key(|(_, combo)| { + combo + .iter() + .map(|entry| entry.node.space + entry.node.time) + .sum::() + }); + + let total_jobs = jobs.len(); + + // Phase 2: Run estimations in parallel with dominance-based pruning. + // + // For each instruction slot in a combination, we track (space, time) + // witnesses from successful estimations keyed by the "context", which is a + // hash of the node indices in all *other* slots. Before running an + // estimation, we check every slot: if a witness with space ≤ and time ≤ + // exists for that context, the combination is dominated and skipped. + let next_job = AtomicUsize::new(0); + + let pruning_witnesses: Vec> = repeat_with(|| { + repeat_with(|| RwLock::new(FxHashMap::default())) + .take(max_slots) + .collect() + }) + .take(traces.len()) + .collect(); + + // There are no explicit ISAs in this estimation function, as we create them + // on the fly from the graph nodes. For successful jobs, we will attach the + // ISAs to the results collection in a vector with the ISA index addressing + // that vector. In order to avoid storing duplicate ISAs we hash the ISA + // index. + let isa_index = Arc::new(RwLock::new(ISAIndex::default())); + + let mut collection = EstimationCollection::new(); + collection.set_total_jobs(total_jobs); + + std::thread::scope(|scope| { + let num_threads = std::thread::available_parallelism() + .map(std::num::NonZero::get) + .unwrap_or(1); + + let (tx, rx) = std::sync::mpsc::sync_channel(num_threads); + + for _ in 0..num_threads { + let tx = tx.clone(); + let next_job = &next_job; + let jobs = &jobs; + let pruning_witnesses = &pruning_witnesses; + let isa_index = Arc::clone(&isa_index); + scope.spawn(move || { + let mut local_results = Vec::new(); + loop { + let job_idx = next_job.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + if job_idx >= total_jobs { + break; + } + + let (trace_idx, combination) = &jobs[job_idx]; + + // Dominance pruning: skip if a cheaper instruction at any + // slot already succeeded with the same surrounding context. + if is_dominated(combination, &pruning_witnesses[*trace_idx]) { + continue; + } + + let mut isa = ISA::with_graph(graph.clone()); + for entry in combination { + isa.add_node(entry.instruction_id, entry.node.node_index); + } + + if let Ok(mut result) = traces[*trace_idx].estimate(&isa, Some(max_error)) { + let isa_idx = isa_index + .write() + .expect("RwLock should not be poisoned") + .push(combination, &isa); + result.set_isa_index(isa_idx); + + result.set_trace_index(*trace_idx); + + local_results.push(result); + record_success(combination, &pruning_witnesses[*trace_idx]); + } + } + let _ = tx.send(local_results); + }); + } + drop(tx); + + let mut successful = 0; + for local_results in rx { + if post_process { + for result in &local_results { + collection.push_summary(ResultSummary { + trace_index: result.trace_index().unwrap_or(0), + isa_index: result.isa_index().unwrap_or(0), + qubits: result.qubits(), + runtime: result.runtime(), + }); + } + } + successful += local_results.len(); + collection.extend(local_results.into_iter()); + } + collection.set_successful_estimates(successful); + }); + + let isa_index = Arc::try_unwrap(isa_index) + .ok() + .expect("all threads joined; Arc refcount should be 1") + .into_inner() + .expect("RwLock should not be poisoned"); + + // Attach ISAs only to Pareto-surviving results, avoiding O(M) HashMap + // clones for discarded results. + for result in collection.iter_mut() { + if let Some(idx) = result.isa_index() { + result.set_isa(isa_index.isas[idx].clone()); + } + } + + collection.set_isas(isa_index.into()); + + collection +}