Skip to content

Commit

Permalink
Using pip to install JAX+GPU and also adding a basic usage
Browse files Browse the repository at this point in the history
  • Loading branch information
alonkukl committed Sep 13, 2024
1 parent 2a0ae99 commit 3cfe480
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 17 deletions.
56 changes: 43 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,27 +1,57 @@
# QFactor and QFactor-Sample implementations on GPUs using JAX
# QFactor and QFactor-Sample Implementations on GPUs Using JAX
`bqskit-qfactor-jax` is a Python package that implements circuit instantiation using the [QFactor](https://ieeexplore.ieee.org/abstract/document/10313638) and [QFactor-Sample](https://arxiv.org/abs/2405.12866) algorithms on GPUs to accelerate [BQSKit](https://github.com/bqskit/bqskit). It uses [JAX](https://jax.readthedocs.io/en/latest/index.html) as an abstraction layer of the GPUs, seamlessly utilizing JIT compilation and GPU parallelism.

## Installation
`bqskit-qfactor-jax` is available for Python 3.8+ on Linux.
`bqskit-qfactor-jax` is available for Python 3.9+ on Linux. It can be installed using pip

First, install JAX with GPU support, you may refer to JAX's [installation instructions](https://github.com/google/jax#installation).
For users working on Perlmutter please use the following modules before installing JAX in your environment:
```sh
module load cudnn/8.9.3_cuda12
module load nccl/2.18.3-cu12
pip install bqskit-qfactor-jax
```

Next, install this package with pip:
If you are experiencing issues with JAX please refer to JAX's [installation instructions](https://github.com/google/jax#installation).

```sh
pip install bqskit-qfactor-jax

## Basic Usage
QFactor and QFactor-Sample are instantiation algorithms that, given a unitary matrix and a parameterized circuit, optimize the circuit parameters to best approximate the target unitary matrix.

```python
import numpy as np
from bqskit import Circuit
from bqskit.ir.gates import VariableUnitaryGate
from bqskit.qis.unitary import UnitaryMatrix

from qfactorjax.qfactor_sample_jax import QFactorSampleJax



# Load a circuit from QASM
circuit = Circuit.from_file("template.qasm")

# Load the target unitary
unitary_target = UnitaryMatrix.from_file("target.mat")

# Create the instantiator object
qfactor_sample_gpu_instantiator = QFactorSampleJax()

# Perform the instantiation
circuit.instantiate(
unitary_target,
multistarts=16,
method=qfactor_sample_gpu_instantiator,
)

# Calculate and print final distance
dist = circuit.get_unitary().get_distance_from(unitary_target, 1)

print('Final Distance: ', dist)
```

Please look at the [examples](https://github.com/BQSKit/bqskit-qfactor-jax/tree/main/examples) for a more detailed usage, especially at performance comparison between QFactor and QFactor-Sample.


# Running bqskit-qfactor-jax
## GPU Configuration and Memory Management
Please set the environment variable XLA_PYTHON_CLIENT_PREALLOCATE=False when using this package. Also, if you encounter OOM issues consider setting XLA_PYTHON_CLIENT_ALLOCATOR=platform.

Please look at the [examples](https://github.com/BQSKit/bqskit-qfactor-jax/tree/main/examples) for basic usage, especially at performance comparison between QFactor and QFactor-Sample.

When using several workers on the same GPU, we recommend using [Nvidia's MPS](https://docs.nvidia.com/deploy/mps/index.html). You may initiate it using the command line
```sh
Expand All @@ -33,8 +63,8 @@ You can disable it by running this command line:
echo quit | nvidia-cuda-mps-control
```

# References
If you are using QFactor-JAX please cite:\
## References
If you are using QFactor please cite:\
Kukliansky, Alon, et al. "QFactor: A Domain-Specific Optimizer for Quantum Circuit Instantiation." 2023 IEEE International Conference on Quantum Computing and Engineering (QCE). Vol. 1. IEEE, 2023. [Link](https://ieeexplore.ieee.org/abstract/document/10313638).

If you are using QFactor-Sample please cite:\
Expand Down
7 changes: 3 additions & 4 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ classifiers =
License :: OSI Approved :: BSD License
Operating System :: POSIX :: Linux
Programming Language :: Python :: 3 :: Only
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
Programming Language :: Python :: 3.11
Programming Language :: Python :: 3.12
Topic :: Scientific/Engineering
Topic :: Scientific/Engineering :: Mathematics
Topic :: Scientific/Engineering :: Physics
Expand All @@ -39,11 +39,10 @@ install_requires =
numpy
bqskit>=1.1.0
typing-extensions>=4.0.0
; For the jax+gpu installation see JAX's wiki
jax
jax[cuda12]
jaxlib
jaxtyping
python_requires = >=3.8, <4
python_requires = >=3.9, <4

[options.packages.find]
exclude =
Expand Down

0 comments on commit 3cfe480

Please sign in to comment.