Skip to content

Commit

Permalink
adding cupy/cusignal import tests and how enables are set
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-rakowski committed Dec 12, 2023
1 parent 7b309ad commit 7639f25
Showing 1 changed file with 136 additions and 55 deletions.
191 changes: 136 additions & 55 deletions pylops/utils/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,61 +12,52 @@
]

import os
from importlib import import_module
from typing import Optional


def check_module_enabled(
module: str,
envrionment_str: Optional[str] = None,
envrionment_val: Optional[int] = 1,
) -> bool:
"""
Check whether a specific module can be imported in the current Python environment.
Parameters
----------
module : str
The name of the module to check import state for.
environment_str : str, optional
An optional environment variable name to check for. If provided, the function will return True
only if the environment variable is set to the specified value. Defaults to None.
environment_val : str, optional
The value to compare the environment variable against. Defaults to "1".
Returns
-------
bool
True if the module is available, False otherwise.
"""
# try to import the module
try:
_ = import_module(module) # noqa: F401
# run envrionment check if needed
if envrionment_str is not None:
# return True if the value matches expected value
return int(os.getenv(envrionment_str, envrionment_val)) == envrionment_val
# if no environment check return True as import_module worked
else:
return True
# if cannot import and provides expected Exceptions, return False
except (ImportError, ModuleNotFoundError):
return False
# raise warning if anyother exception raised in import
except Exception as e:
raise UserWarning(f"Unexpceted Exception when importing {module}") from e


cupy_enabled = check_module_enabled("cupy", "CUPY_PYLOPS")
cusignal_enabled = check_module_enabled("cusignal", "CUSIGNAL_PYLOPS")
devito_enabled = check_module_enabled("devito")
numba_enabled = check_module_enabled("numba")
pyfftw_enabled = check_module_enabled("pyfftw")
pywt_enabled = check_module_enabled("pywt")
skfmm_enabled = check_module_enabled("skfmm")
spgl1_enabled = check_module_enabled("spgl1")
sympy_enabled = check_module_enabled("sympy")
torch_enabled = check_module_enabled("torch")

# from importlib import import_module
from importlib import util

# from typing import Optional


# def check_module_enabled(
# module: str,
# envrionment_str: Optional[str] = None,
# envrionment_val: Optional[int] = 1,
# ) -> bool:
# """
# Check whether a specific module can be imported in the current Python environment.

# Parameters
# ----------
# module : str
# The name of the module to check import state for.
# environment_str : str, optional
# An optional environment variable name to check for. If provided, the function will return True
# only if the environment variable is set to the specified value. Defaults to None.
# environment_val : str, optional
# The value to compare the environment variable against. Defaults to "1".

# Returns
# -------
# bool
# True if the module is available, False otherwise.
# """
# # try to import the module
# try:
# _ = import_module(module) # noqa: F401
# # run envrionment check if needed
# if envrionment_str is not None:
# # return True if the value matches expected value
# return int(os.getenv(envrionment_str, envrionment_val)) == envrionment_val
# # if no environment check return True as import_module worked
# else:
# return True
# # if cannot import and provides expected Exceptions, return False
# except (ImportError, ModuleNotFoundError):
# return False
# # raise warning if anyother exception raised in import
# except Exception as e:
# raise UserWarning(f"Unexpceted Exception when importing {module}") from e


# error message at import of available package
Expand Down Expand Up @@ -194,3 +185,93 @@ def sympy_import(message):
f'"pip install sympy".'
)
return sympy_message


def cupy_import(message):
# detect if cupy should be importable
cupy_test = (
util.find_spec("cupy") is not None and int(os.getenv("CUPY_PYLOPS", 1)) == 1
)
# if cupy should be importable
if cupy_test:
# try importing it
try:
import cupy # noqa: F401

# if successful set the message to None.
cupy_message = None
# if unable to import but it is installed
except (ImportError, ModuleNotFoundError) as e:
cupy_message = (
f"Failed to import cupy. Falling back to CPU (error: {e}) ."
f"{message} run"
"Please ensure your CUDA envrionment is set up correctly"
"for more details visit 'https://docs.cupy.dev/en/stable/install.html'"
)
# if cupy_test is False it means not installed or envrionment variable set to 0
else:
cupy_message = (
f"cupy package not installed or os.getenv('CUPY_PYLOPS') == 0. In order to be able to use "
f"{message} "
"os.getenv('CUPY_PYLOPS') == 1 and run"
"'pip install cupy'."
"for more details visit 'https://docs.cupy.dev/en/stable/install.html'"
)

return cupy_message


def cusignal_import(message):
# detect if cupy should be importable
cusignal_test = (
util.find_spec("cusignal") is not None
and int(os.getenv("CUSIGNAL_PYLOPS", 1)) == 1
)
# if cupy should be importable
if cusignal_test:
# try importing it
try:
import cusignal # noqa: F401

# if successful set the message to None.
cusignal_message = None
# if unable to import but it is installed
except (ImportError, ModuleNotFoundError) as e:
cusignal_message = (
f"Failed to import cusignal. Falling back to CPU (error: {e}) ."
f"{message} run"
"Please ensure your CUDA envrionment is set up correctly"
"for more details visit 'https://github.com/rapidsai/cusignal#installation'"
)
# if cupy_test is False it means not installed or envrionment variable set to 0
else:
cusignal_message = (
f"cusignal package not installed or os.getenv('CUSIGNAL_PYLOPS') == 0. In order to be able to use "
f"{message} "
"os.getenv('CUSIGNAL_PYLOPS') == 1 and run"
"'pip install cupy'."
"for more details visit ''https://github.com/rapidsai/cusignal#installation''"
)

return cusignal_message


cupy_enabled = (
True
if (cupy_import() is not None and int(os.getenv("CUPY_PYLOPS", 1)) == 1)
else False # noqa:F821,E501
)
cusignal_enabled = (
True
if (cusignal_import() is not None and int(os.getenv("CUSIGNAL_PYLOPS", 1)) == 1)
else False # noqa:F821,E501
)
# cusignal_enabled = check_module_enabled("cusignal", "CUSIGNAL_PYLOPS")
devito_enabled = util.find_spec("devito") is not None
numba_enabled = util.find_spec("numba") is not None
pyfftw_enabled = util.find_spec("pyfftw") is not None
pywt_enabled = util.find_spec("pywt") is not None
skfmm_enabled = util.find_spec("skfmm") is not None
spgl1_enabled = util.find_spec("spgl1") is not None
sympy_enabled = util.find_spec("sympy") is not None
torch_enabled = util.find_spec("torch") is not None

0 comments on commit 7639f25

Please sign in to comment.