diff --git a/README.md b/README.md index f5ad790..1b308fa 100644 --- a/README.md +++ b/README.md @@ -42,4 +42,4 @@ for more information. Quantum Fast Circuit Optimizer (QFactor) JAX implementation Copyright (c) 2023, The Regents of the University of California, through Lawrence Berkeley National Laboratory (subject to receipt of any required -approvals from the U.S. Dept. of Energy) and Alon Kukliansky. All rights reserved. \ No newline at end of file +approvals from the U.S. Dept. of Energy) and Alon Kukliansky. All rights reserved. diff --git a/bqskitqfactorjax/__init__.py b/bqskitqfactorjax/__init__.py index 01947e6..d80f942 100644 --- a/bqskitqfactorjax/__init__.py +++ b/bqskitqfactorjax/__init__.py @@ -1,4 +1,4 @@ +from __future__ import annotations __all__ = [ - 'qfactor_jax','unitary_acc', 'unitarybuilderjax', 'unitarymatrixjax' + 'qfactor_jax','unitary_acc', 'unitarybuilderjax', 'unitarymatrixjax', ] - diff --git a/bqskitqfactorjax/qfactor_jax.py b/bqskitqfactorjax/qfactor_jax.py index 8b5e3a7..ef94271 100644 --- a/bqskitqfactorjax/qfactor_jax.py +++ b/bqskitqfactorjax/qfactor_jax.py @@ -9,15 +9,15 @@ import jax.numpy as jnp import numpy as np import numpy.typing as npt -from scipy.stats import unitary_group - from bqskit.ir.gates.constantgate import ConstantGate from bqskit.ir.gates.parameterized.u3 import U3Gate from bqskit.ir.gates.parameterized.unitary import VariableUnitaryGate -from bqskit.qis.unitary.unitarymatrix import UnitaryMatrix from bqskit.ir.opt.instantiater import Instantiater from bqskit.qis.state.state import StateVector from bqskit.qis.state.system import StateSystem +from bqskit.qis.unitary.unitarymatrix import UnitaryMatrix +from scipy.stats import unitary_group + from bqskitqfactorjax.unitary_acc import VariableUnitaryGateAcc from bqskitqfactorjax.unitarybuilderjax import UnitaryBuilderJax from bqskitqfactorjax.unitarymatrixjax import UnitaryMatrixJax @@ -85,7 +85,7 @@ def instantiate( ) -> npt.NDArray[np.float64]: return self.multi_start_instantiate(circuit, target, 1) - + def multi_start_instantiate_inplace( self, @@ -107,15 +107,15 @@ def multi_start_instantiate_inplace( circuit.set_params(params) - - + + async def multi_start_instantiate_async( self, circuit: Circuit, target: UnitaryLike | StateLike | StateSystemLike, num_starts: int, ) -> npt.NDArray[np.float64]: - + return self.multi_start_instantiate(circuit, target, num_starts) def multi_start_instantiate( @@ -167,9 +167,9 @@ def multi_start_instantiate( res_var = _sweep2_jited( target, locations, gates, untrys, self.reset_iter, self.dist_tol, self.diff_tol_a, self.diff_tol_r, self.plateau_windows_size, - self.max_iters, self.min_iters, num_starts, self.diff_tol_step_r, self.diff_tol_step, self.beta + self.max_iters, self.min_iters, num_starts, self.diff_tol_step_r, self.diff_tol_step, self.beta, ) - + it = res_var['iteration_counts'][0] c1s = res_var['c1s'] untrys = res_var['untrys'] @@ -202,7 +202,7 @@ def multi_start_instantiate( elif it >= self.max_iters: _logger.debug('Terminated: iteration limit reached.') - + else: _logger.error( f'Terminated with no good reason after {it} iterstion ' @@ -212,15 +212,15 @@ def multi_start_instantiate( for untry, gate in zip(untrys[best_start], gates): if isinstance(gate, ConstantGate): params.extend([]) - else: + else: params.extend( gate.get_params( _remove_padding_and_create_matrix(untry, gate), ), - ) - + ) + return np.array(params) - + @staticmethod def get_method_name() -> str: """Return the name of this method.""" @@ -302,7 +302,7 @@ def _initilize_circuit_tensor( def _single_sweep( locations, gates, amount_of_gates, target_untry_builder:UnitaryBuilderJax, - untrys, beta=0 + untrys, beta=0, ): # from right to left for k in reversed(range(amount_of_gates)): @@ -354,9 +354,9 @@ def _single_sweep( def _single_sweep_sim( locations, gates, amount_of_gates, target_untry_builder, - untrys, beta=0 + untrys, beta=0, ): - + new_untrys = [] # from right to left for k in reversed(range(amount_of_gates)): @@ -376,7 +376,7 @@ def _single_sweep_sim( else: new_untrys.append(untry) - + target_untry_builder.apply_left( untry, location, check_arguments=False, ) @@ -403,7 +403,7 @@ def _remove_padding_and_create_matrix(untry, gate): def Loop_vars( untrys, c1s, plateau_windows, curr_plateau_calc_l, curr_reached_required_tol_l, iteration_counts, - target_untry_builders, prev_step_c1s, curr_step_calc_l + target_untry_builders, prev_step_c1s, curr_step_calc_l, ): d = {} d['untrys'] = untrys @@ -422,7 +422,7 @@ def Loop_vars( def _sweep2( target, locations, gates, untrys, n, dist_tol, diff_tol_a, diff_tol_r, plateau_windows_size, max_iters, min_iters, - amount_of_starts, diff_tol_step_r, diff_tol_step, beta + amount_of_starts, diff_tol_step_r, diff_tol_step, beta, ): c1s = jnp.array([1.0] * amount_of_starts) plateau_windows = jnp.array( @@ -442,8 +442,8 @@ def should_continue(var): var['iteration_counts'][0] > min_iters, jnp.logical_or( jnp.all(var['curr_plateau_calc_l']), - jnp.all(var['curr_step_calc_l']) - ) + jnp.all(var['curr_step_calc_l']), + ), ), ), ), @@ -452,7 +452,7 @@ def should_continue(var): def _while_body_to_be_vmaped( untrys, c1, plateau_window, curr_plateau_calc, curr_reached_required_tol, iteration_count, - target_untry_builder_tensor, prev_step_c1, curr_step_calc + target_untry_builder_tensor, prev_step_c1, curr_step_calc, ): amount_of_gates = len(gates) amount_of_qudits = target.num_qudits @@ -473,19 +473,19 @@ def _while_body_to_be_vmaped( target_untry_builder_tensor = _initilize_circuit_tensor( amount_of_qudits, target_radixes, locations, target.numpy, untrys, - ).tensor + ).tensor target_untry_builder = UnitaryBuilderJax( amount_of_qudits, target_radixes, tensor=target_untry_builder_tensor, ) - + iteration_count = iteration_count + 1 - + untrys = _single_sweep_sim( - locations, gates, amount_of_gates, target_untry_builder, untrys, beta + locations, gates, amount_of_gates, target_untry_builder, untrys, beta, ) target_untry_builder_tensor = _initilize_circuit_tensor( @@ -519,9 +519,9 @@ def _while_body_to_be_vmaped( iteration_count = iteration_count + 1 target_untry_builder, untrys = _single_sweep( - locations, gates, amount_of_gates, target_untry_builder, untrys, beta + locations, gates, amount_of_gates, target_untry_builder, untrys, beta, ) - + c2 = c1 dim = target_untry_builder.dim untry_res = target_untry_builder.tensor.reshape((dim, dim)) @@ -546,7 +546,8 @@ def _while_body_to_be_vmaped( prev_step_c1, curr_step_calc = jax.lax.cond( (iteration_count+1) % diff_tol_step == 0, - reached_step_body, not_reached_step_body, operand_for_if) + reached_step_body, not_reached_step_body, operand_for_if, + ) biggest_gate_size = max(gate.num_qudits for gate in gates) final_untrys_padded = jnp.array([ @@ -559,7 +560,7 @@ def _while_body_to_be_vmaped( return ( final_untrys_padded, c1, plateau_window, curr_plateau_calc, curr_reached_required_tol, iteration_count, - target_untry_builder.tensor, prev_step_c1, curr_step_calc + target_untry_builder.tensor, prev_step_c1, curr_step_calc, ) while_body_vmaped = jax.vmap(_while_body_to_be_vmaped) @@ -574,7 +575,7 @@ def while_body(var): var['iteration_counts'], var['target_untry_builders'], var['prev_step_c1s'], - var['curr_step_calc_l'] + var['curr_step_calc_l'], ), ) @@ -590,17 +591,17 @@ def while_body(var): jnp.array([False] * amount_of_starts), jnp.array([0] * amount_of_starts), initial_untray_builders_values, prev_step_c1s, - jnp.array([False] * amount_of_starts) + jnp.array([False] * amount_of_starts), ) - + if 'PRINT_LOSS_QFACTOR' in os.environ: loop_var = initial_loop_var i = 1 while(should_continue(loop_var)): loop_var = while_body(loop_var) - print("LOSS:",i , loop_var['c1s']) + print('LOSS:',i , loop_var['c1s']) i +=1 res_var = loop_var else: @@ -614,6 +615,6 @@ def while_body(var): else: _sweep2_jited = jax.jit( _sweep2, static_argnums=( - 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 + 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, ), ) diff --git a/bqskitqfactorjax/unitary_acc.py b/bqskitqfactorjax/unitary_acc.py index a42f207..d7adaf3 100644 --- a/bqskitqfactorjax/unitary_acc.py +++ b/bqskitqfactorjax/unitary_acc.py @@ -2,12 +2,12 @@ import jax.numpy as jnp import jax.scipy.linalg as jla -from jax import Array - from bqskit.ir.gates.parameterized.unitary import VariableUnitaryGate from bqskit.qis.unitary.unitary import RealVector from bqskit.qis.unitary.unitarymatrix import UnitaryLike from bqskit.qis.unitary.unitarymatrix import UnitaryMatrix +from jax import Array + from bqskitqfactorjax.unitarymatrixjax import UnitaryMatrixJax @@ -36,11 +36,11 @@ def optimize(self, env_matrix, get_untry: bool = False, prev_utry=None, beta:flo See :class:`LocallyOptimizableUnitary` for more info. """ - + U, _, Vh = jla.svd((1-beta) * env_matrix + beta*prev_utry._utry.conj().T) utry = Vh.conj().T @ U.conj().T - if get_untry: + if get_untry: return UnitaryMatrixJax(utry, radixes=self.radixes) x = jnp.reshape(utry, (self.num_params // 2,)) diff --git a/bqskitqfactorjax/unitarybuilderjax.py b/bqskitqfactorjax/unitarybuilderjax.py index d3bb581..80bea1e 100644 --- a/bqskitqfactorjax/unitarybuilderjax.py +++ b/bqskitqfactorjax/unitarybuilderjax.py @@ -2,20 +2,21 @@ from __future__ import annotations import logging -from typing import Sequence, cast +from typing import cast +from typing import Sequence import jax import jax.numpy as jnp import numpy as np import numpy.typing as npt - +from bqskit.ir.location import CircuitLocation +from bqskit.ir.location import CircuitLocationLike from bqskit.qis.unitary.unitary import RealVector from bqskit.qis.unitary.unitarybuilder import UnitaryBuilder from bqskit.qis.unitary.unitarymatrix import UnitaryMatrix from bqskit.utils.typing import is_integer from bqskit.utils.typing import is_valid_radixes -from bqskit.ir.location import CircuitLocationLike -from bqskit.ir.location import CircuitLocation + from bqskitqfactorjax.unitarymatrixjax import UnitaryMatrixJax logger = logging.getLogger(__name__) diff --git a/bqskitqfactorjax/unitarymatrixjax.py b/bqskitqfactorjax/unitarymatrixjax.py index 092bc86..a2cd63d 100644 --- a/bqskitqfactorjax/unitarymatrixjax.py +++ b/bqskitqfactorjax/unitarymatrixjax.py @@ -6,15 +6,12 @@ import jax.numpy as jnp import jax.scipy.linalg as jla import numpy as np -from jax import Array - +from bqskit.qis.unitary.unitary import Unitary from bqskit.qis.unitary.unitarymatrix import UnitaryLike from bqskit.qis.unitary.unitarymatrix import UnitaryMatrix from bqskit.utils.docs import building_docs from bqskit.utils.typing import is_square_matrix - -from bqskit.qis.unitary.unitary import Unitary -from bqskit.qis.unitary.unitarymatrix import UnitaryMatrix +from jax import Array if not building_docs(): from numpy.lib.mixins import NDArrayOperatorsMixin @@ -83,7 +80,7 @@ def dim(self) -> int: return self._dim return int(np.prod(self.radixes)) - + @property def num_params(self) -> int: @@ -154,7 +151,7 @@ def closest_to( def numpy(self) -> Array: """The JaxNumPy array holding the unitary.""" return self._utry - + @property def jaxnumpy(self) -> Array: """The JaxNumPy array holding the unitary.""" @@ -173,15 +170,15 @@ def from_file(filename: str): def T(self) -> UnitaryMatrixJax: """The transpose of the unitary.""" return UnitaryMatrixJax(self._utry.T, self.radixes) - + def conj(self) -> UnitaryMatrixJax: """Return the complex conjugate unitary matrix.""" return UnitaryMatrixJax(self._utry.conj(), self.radixes) - + def get_unitary(self, params) -> UnitaryMatrixJax: """Return the same object, satisfies the :class:`Unitary` API.""" return self - + @property def dagger(self) -> UnitaryMatrixJax: """The conjugate transpose of the unitary.""" @@ -198,28 +195,21 @@ def __array__( ) return self._utry - + def get_tensor_format(self) -> Array: """ Converts the unitary matrix operation into a tensor network format. - Indices are counted top to bottom, right to left: - .-----. - n -| |- 0 - n+1 -| |- 1 - . . - . . - . . - 2n-1 -| |- n-1 - '-----' - + Indices are counted top to bottom, right to left: .-----. n -| + |- 0 n+1 -| |- 1 . . . . . . 2n-1 -| + |- n-1 '-----' - Returns - Union[DeviceArray, np.ndarray]: A tensor representing this matrix. + Returns Union[DeviceArray, np.ndarray]: A tensor representing this + matrix. """ return self._utry.reshape(self.radixes + self.radixes) - + def __eq__(self, other: object) -> bool: """Check if `self` is approximately equal to `other`.""" @@ -233,7 +223,7 @@ def __eq__(self, other: object) -> bool: return np.allclose(self, other) return NotImplemented - + def __array__( self, diff --git a/examples/gate_deletion_syth.py b/examples/gate_deletion_syth.py index ac59cea..74648ae 100644 --- a/examples/gate_deletion_syth.py +++ b/examples/gate_deletion_syth.py @@ -1,20 +1,20 @@ -""" -This example shows how to resynthesize a circuit using a gate deletion flow, -that utilizses Qfacto's GPU implementation - -""" +"""This example shows how to resynthesize a circuit using a gate deletion flow, +that utilizses Qfacto's GPU implementation.""" +from __future__ import annotations import logging import os from timeit import default_timer as timer + from bqskit import Circuit from bqskit.compiler import Compiler -from bqskit.passes import QuickPartitioner from bqskit.passes import ForEachBlockPass +from bqskit.passes import QuickPartitioner from bqskit.passes import ScanningGateRemovalPass -from bqskit.passes import UnfoldPass from bqskit.passes import ToU3Pass from bqskit.passes import ToVariablePass +from bqskit.passes import UnfoldPass + from bqskitqfactorjax.qfactor_jax import QFactor_jax @@ -29,7 +29,7 @@ def run_gate_del_flow_example(): # Set the size of paritions partition_size = 4 - # QFactor hyperparameters - + # QFactor hyperparameters - # see intantiation example for more detiles on the parameters num_multistarts = 32 max_iters = 100000 @@ -45,7 +45,7 @@ def run_gate_del_flow_example(): - print(f"Will compile {file_path}") + print(f'Will compile {file_path}') # Read the QASM circuit in_circuit = Circuit.from_file(file_path) @@ -59,11 +59,12 @@ def run_gate_del_flow_example(): dist_tol=dist_tol, diff_tol_step_r=diff_tol_step_r, diff_tol_step = diff_tol_step, - beta=beta) + beta=beta, + ) instantiate_options={ 'method': batched_instantiation, 'multistarts': num_multistarts, - } + } # Prepare the comiplation passes @@ -76,26 +77,26 @@ def run_gate_del_flow_example(): # For each partition perform scanning gate removal using QFactor jax ForEachBlockPass([ - ScanningGateRemovalPass(instantiate_options=instantiate_options), - ]), + ScanningGateRemovalPass(instantiate_options=instantiate_options), + ]), # Combine the partitions back into a circuit UnfoldPass(), - + # Convert back the VariablueUnitaires into U3s - ToU3Pass() - ] + ToU3Pass(), + ] # Create the compilation task - + with Compiler( num_workers=amount_of_workers, runtime_log_level=logging.INFO, - ) as compiler: + ) as compiler: - print(f"Starting gate deletion flow using Qfactor JAX") + print(f'Starting gate deletion flow using Qfactor JAX') start = timer() out_circuit = compiler.compile(in_circuit, passes) end = timer() @@ -110,11 +111,11 @@ def run_gate_del_flow_example(): in_circuit, out_circuit, run_time = run_gate_del_flow_example() print( - f"Partitioning + Synthesis took {run_time}" - f"seconds using Qfactor JAX instantiation method." + f'Partitioning + Synthesis took {run_time}' + f'seconds using Qfactor JAX instantiation method.', ) print( - f"Circuit finished with gates: {out_circuit.gate_counts}, " - f"while started with {in_circuit.gate_counts}" + f'Circuit finished with gates: {out_circuit.gate_counts}, ' + f'while started with {in_circuit.gate_counts}', ) diff --git a/examples/toffoli_instantiation.py b/examples/toffoli_instantiation.py index 8587050..bbdb258 100644 --- a/examples/toffoli_instantiation.py +++ b/examples/toffoli_instantiation.py @@ -1,13 +1,12 @@ """ Numerical Instantiation is the foundation of many of BQSKit's algorithms. - This is the same instantiation example as in BQSKit using the GPU - implementation of QFactor +This is the same instantiation example as in BQSKit using the GPU implementation +of QFactor """ from __future__ import annotations import numpy as np - from bqskit.ir.circuit import Circuit from bqskit.ir.gates import VariableUnitaryGate from bqskit.qis.unitary import UnitaryMatrix @@ -17,26 +16,26 @@ def run_toffoli_instantiation(dist_tol_requested = 1e-10): qfactr_gpu_instantiator = QFactor_jax( - + dist_tol = dist_tol_requested, # Stopping criteria for distance max_iters = 100000, # Maximum number of iterations min_iters = 10, # Minimum number of iterations - + #One step plateau detection - #diff_tol_a + diff_tol_r ∗ |c(i)| <= |c(i)|-|c(i-1)| diff_tol_a = 0.0, # Stopping criteria for distance change diff_tol_r = 1e-10, # Relative criteria for distance change - - #Long plateau detection - + + #Long plateau detection - # diff_tol_step_r*|c(i-diff_tol_step)| <= |c(i)|-|c(i-diff_tol_step)| diff_tol_step_r = 0.1, #The relative improvment expected diff_tol_step = 200, #The interval in which to check the improvment - + #Regularization parameter - [0.0 - 1.0] # Increase to overcome local minimumas at the price of longer compute - beta = 0.0 - ) + beta = 0.0, + ) @@ -75,4 +74,4 @@ def run_toffoli_instantiation(dist_tol_requested = 1e-10): if __name__ == '__main__': dist = run_toffoli_instantiation() - print('Final Distance: ', dist) \ No newline at end of file + print('Final Distance: ', dist) diff --git a/pyproject.toml b/pyproject.toml index 61ee1ba..a033506 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,4 +3,4 @@ requires = [ "wheel", "setuptools>=40.1.0", "oldest-supported-numpy", -] \ No newline at end of file +] diff --git a/setup.py b/setup.py index e0a76d4..abbaa42 100644 --- a/setup.py +++ b/setup.py @@ -3,4 +3,4 @@ from setuptools import setup # Everything is in setup.cfg, this is for compatibility # pip install works the same. -setup() \ No newline at end of file +setup() diff --git a/tests/test_examples.py b/tests/test_examples.py index 9a738ed..2d9ae58 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -1,6 +1,10 @@ -from bqskit.ir.gates import CNOTGate, U3Gate -from examples.toffoli_instantiation import run_toffoli_instantiation +from __future__ import annotations + +from bqskit.ir.gates import CNOTGate +from bqskit.ir.gates import U3Gate + from examples.gate_deletion_syth import run_gate_del_flow_example +from examples.toffoli_instantiation import run_toffoli_instantiation def test_toffoli_instantiation(): @@ -14,4 +18,3 @@ def test_gate_del_synth(): out_circuit_gates_count = out_circuit.gate_counts assert out_circuit_gates_count[CNOTGate()] == 44 assert out_circuit_gates_count[U3Gate()] == 56 -