Skip to content

Commit

Permalink
First pass of pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Alon committed Sep 29, 2023
1 parent 82f704f commit ebcd433
Show file tree
Hide file tree
Showing 11 changed files with 106 additions and 111 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,4 @@ for more information.
Quantum Fast Circuit Optimizer (QFactor) JAX implementation Copyright (c) 2023,
The Regents of the University of California, through Lawrence
Berkeley National Laboratory (subject to receipt of any required
approvals from the U.S. Dept. of Energy) and Alon Kukliansky. All rights reserved.
approvals from the U.S. Dept. of Energy) and Alon Kukliansky. All rights reserved.
4 changes: 2 additions & 2 deletions bqskitqfactorjax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from __future__ import annotations
__all__ = [
'qfactor_jax','unitary_acc', 'unitarybuilderjax', 'unitarymatrixjax'
'qfactor_jax','unitary_acc', 'unitarybuilderjax', 'unitarymatrixjax',
]

73 changes: 37 additions & 36 deletions bqskitqfactorjax/qfactor_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
import jax.numpy as jnp
import numpy as np
import numpy.typing as npt
from scipy.stats import unitary_group

from bqskit.ir.gates.constantgate import ConstantGate
from bqskit.ir.gates.parameterized.u3 import U3Gate
from bqskit.ir.gates.parameterized.unitary import VariableUnitaryGate
from bqskit.qis.unitary.unitarymatrix import UnitaryMatrix
from bqskit.ir.opt.instantiater import Instantiater
from bqskit.qis.state.state import StateVector
from bqskit.qis.state.system import StateSystem
from bqskit.qis.unitary.unitarymatrix import UnitaryMatrix
from scipy.stats import unitary_group

from bqskitqfactorjax.unitary_acc import VariableUnitaryGateAcc
from bqskitqfactorjax.unitarybuilderjax import UnitaryBuilderJax
from bqskitqfactorjax.unitarymatrixjax import UnitaryMatrixJax
Expand Down Expand Up @@ -85,7 +85,7 @@ def instantiate(
) -> npt.NDArray[np.float64]:

return self.multi_start_instantiate(circuit, target, 1)


def multi_start_instantiate_inplace(
self,
Expand All @@ -107,15 +107,15 @@ def multi_start_instantiate_inplace(
circuit.set_params(params)




async def multi_start_instantiate_async(
self,
circuit: Circuit,
target: UnitaryLike | StateLike | StateSystemLike,
num_starts: int,
) -> npt.NDArray[np.float64]:

return self.multi_start_instantiate(circuit, target, num_starts)

def multi_start_instantiate(
Expand Down Expand Up @@ -167,9 +167,9 @@ def multi_start_instantiate(
res_var = _sweep2_jited(
target, locations, gates, untrys, self.reset_iter, self.dist_tol,
self.diff_tol_a, self.diff_tol_r, self.plateau_windows_size,
self.max_iters, self.min_iters, num_starts, self.diff_tol_step_r, self.diff_tol_step, self.beta
self.max_iters, self.min_iters, num_starts, self.diff_tol_step_r, self.diff_tol_step, self.beta,
)

it = res_var['iteration_counts'][0]
c1s = res_var['c1s']
untrys = res_var['untrys']
Expand Down Expand Up @@ -202,7 +202,7 @@ def multi_start_instantiate(

elif it >= self.max_iters:
_logger.debug('Terminated: iteration limit reached.')

else:
_logger.error(
f'Terminated with no good reason after {it} iterstion '
Expand All @@ -212,15 +212,15 @@ def multi_start_instantiate(
for untry, gate in zip(untrys[best_start], gates):
if isinstance(gate, ConstantGate):
params.extend([])
else:
else:
params.extend(
gate.get_params(
_remove_padding_and_create_matrix(untry, gate),
),
)
)

return np.array(params)

@staticmethod
def get_method_name() -> str:
"""Return the name of this method."""
Expand Down Expand Up @@ -302,7 +302,7 @@ def _initilize_circuit_tensor(

def _single_sweep(
locations, gates, amount_of_gates, target_untry_builder:UnitaryBuilderJax,
untrys, beta=0
untrys, beta=0,
):
# from right to left
for k in reversed(range(amount_of_gates)):
Expand Down Expand Up @@ -354,9 +354,9 @@ def _single_sweep(

def _single_sweep_sim(
locations, gates, amount_of_gates, target_untry_builder,
untrys, beta=0
untrys, beta=0,
):

new_untrys = []
# from right to left
for k in reversed(range(amount_of_gates)):
Expand All @@ -376,7 +376,7 @@ def _single_sweep_sim(
else:
new_untrys.append(untry)


target_untry_builder.apply_left(
untry, location, check_arguments=False,
)
Expand All @@ -403,7 +403,7 @@ def _remove_padding_and_create_matrix(untry, gate):
def Loop_vars(
untrys, c1s, plateau_windows, curr_plateau_calc_l,
curr_reached_required_tol_l, iteration_counts,
target_untry_builders, prev_step_c1s, curr_step_calc_l
target_untry_builders, prev_step_c1s, curr_step_calc_l,
):
d = {}
d['untrys'] = untrys
Expand All @@ -422,7 +422,7 @@ def Loop_vars(
def _sweep2(
target, locations, gates, untrys, n, dist_tol, diff_tol_a,
diff_tol_r, plateau_windows_size, max_iters, min_iters,
amount_of_starts, diff_tol_step_r, diff_tol_step, beta
amount_of_starts, diff_tol_step_r, diff_tol_step, beta,
):
c1s = jnp.array([1.0] * amount_of_starts)
plateau_windows = jnp.array(
Expand All @@ -442,8 +442,8 @@ def should_continue(var):
var['iteration_counts'][0] > min_iters,
jnp.logical_or(
jnp.all(var['curr_plateau_calc_l']),
jnp.all(var['curr_step_calc_l'])
)
jnp.all(var['curr_step_calc_l']),
),
),
),
),
Expand All @@ -452,7 +452,7 @@ def should_continue(var):
def _while_body_to_be_vmaped(
untrys, c1, plateau_window, curr_plateau_calc,
curr_reached_required_tol, iteration_count,
target_untry_builder_tensor, prev_step_c1, curr_step_calc
target_untry_builder_tensor, prev_step_c1, curr_step_calc,
):
amount_of_gates = len(gates)
amount_of_qudits = target.num_qudits
Expand All @@ -473,19 +473,19 @@ def _while_body_to_be_vmaped(

target_untry_builder_tensor = _initilize_circuit_tensor(
amount_of_qudits, target_radixes, locations, target.numpy, untrys,
).tensor
).tensor

target_untry_builder = UnitaryBuilderJax(
amount_of_qudits, target_radixes,
tensor=target_untry_builder_tensor,
)


iteration_count = iteration_count + 1


untrys = _single_sweep_sim(
locations, gates, amount_of_gates, target_untry_builder, untrys, beta
locations, gates, amount_of_gates, target_untry_builder, untrys, beta,
)

target_untry_builder_tensor = _initilize_circuit_tensor(
Expand Down Expand Up @@ -519,9 +519,9 @@ def _while_body_to_be_vmaped(
iteration_count = iteration_count + 1

target_untry_builder, untrys = _single_sweep(
locations, gates, amount_of_gates, target_untry_builder, untrys, beta
locations, gates, amount_of_gates, target_untry_builder, untrys, beta,
)

c2 = c1
dim = target_untry_builder.dim
untry_res = target_untry_builder.tensor.reshape((dim, dim))
Expand All @@ -546,7 +546,8 @@ def _while_body_to_be_vmaped(

prev_step_c1, curr_step_calc = jax.lax.cond(
(iteration_count+1) % diff_tol_step == 0,
reached_step_body, not_reached_step_body, operand_for_if)
reached_step_body, not_reached_step_body, operand_for_if,
)

biggest_gate_size = max(gate.num_qudits for gate in gates)
final_untrys_padded = jnp.array([
Expand All @@ -559,7 +560,7 @@ def _while_body_to_be_vmaped(
return (
final_untrys_padded, c1, plateau_window, curr_plateau_calc,
curr_reached_required_tol, iteration_count,
target_untry_builder.tensor, prev_step_c1, curr_step_calc
target_untry_builder.tensor, prev_step_c1, curr_step_calc,
)

while_body_vmaped = jax.vmap(_while_body_to_be_vmaped)
Expand All @@ -574,7 +575,7 @@ def while_body(var):
var['iteration_counts'],
var['target_untry_builders'],
var['prev_step_c1s'],
var['curr_step_calc_l']
var['curr_step_calc_l'],
),
)

Expand All @@ -590,17 +591,17 @@ def while_body(var):
jnp.array([False] * amount_of_starts),
jnp.array([0] * amount_of_starts),
initial_untray_builders_values, prev_step_c1s,
jnp.array([False] * amount_of_starts)
jnp.array([False] * amount_of_starts),
)


if 'PRINT_LOSS_QFACTOR' in os.environ:
loop_var = initial_loop_var
i = 1
while(should_continue(loop_var)):
loop_var = while_body(loop_var)

print("LOSS:",i , loop_var['c1s'])
print('LOSS:',i , loop_var['c1s'])
i +=1
res_var = loop_var
else:
Expand All @@ -614,6 +615,6 @@ def while_body(var):
else:
_sweep2_jited = jax.jit(
_sweep2, static_argnums=(
1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14
1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
),
)
8 changes: 4 additions & 4 deletions bqskitqfactorjax/unitary_acc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

import jax.numpy as jnp
import jax.scipy.linalg as jla
from jax import Array

from bqskit.ir.gates.parameterized.unitary import VariableUnitaryGate
from bqskit.qis.unitary.unitary import RealVector
from bqskit.qis.unitary.unitarymatrix import UnitaryLike
from bqskit.qis.unitary.unitarymatrix import UnitaryMatrix
from jax import Array

from bqskitqfactorjax.unitarymatrixjax import UnitaryMatrixJax


Expand Down Expand Up @@ -36,11 +36,11 @@ def optimize(self, env_matrix, get_untry: bool = False, prev_utry=None, beta:flo
See :class:`LocallyOptimizableUnitary` for more info.
"""

U, _, Vh = jla.svd((1-beta) * env_matrix + beta*prev_utry._utry.conj().T)
utry = Vh.conj().T @ U.conj().T

if get_untry:
if get_untry:
return UnitaryMatrixJax(utry, radixes=self.radixes)

x = jnp.reshape(utry, (self.num_params // 2,))
Expand Down
9 changes: 5 additions & 4 deletions bqskitqfactorjax/unitarybuilderjax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,21 @@
from __future__ import annotations

import logging
from typing import Sequence, cast
from typing import cast
from typing import Sequence

import jax
import jax.numpy as jnp
import numpy as np
import numpy.typing as npt

from bqskit.ir.location import CircuitLocation
from bqskit.ir.location import CircuitLocationLike
from bqskit.qis.unitary.unitary import RealVector
from bqskit.qis.unitary.unitarybuilder import UnitaryBuilder
from bqskit.qis.unitary.unitarymatrix import UnitaryMatrix
from bqskit.utils.typing import is_integer
from bqskit.utils.typing import is_valid_radixes
from bqskit.ir.location import CircuitLocationLike
from bqskit.ir.location import CircuitLocation

from bqskitqfactorjax.unitarymatrixjax import UnitaryMatrixJax

logger = logging.getLogger(__name__)
Expand Down
Loading

0 comments on commit ebcd433

Please sign in to comment.