diff --git a/.github/workflows/run_periodic_tests.yml b/.github/workflows/run_periodic_tests.yml index 6e86de054e..49d960310d 100644 --- a/.github/workflows/run_periodic_tests.yml +++ b/.github/workflows/run_periodic_tests.yml @@ -15,7 +15,6 @@ on: env: FORCE_COLOR: 3 PYBAMM_IDAKLU_EXPR_CASADI: ON - PYBAMM_IDAKLU_EXPR_IREE: ON concurrency: # github.workflow: name of the workflow, so that we don't cancel other workflows diff --git a/.github/workflows/test_on_push.yml b/.github/workflows/test_on_push.yml index 32e3017446..8e88271b15 100644 --- a/.github/workflows/test_on_push.yml +++ b/.github/workflows/test_on_push.yml @@ -7,7 +7,6 @@ on: env: FORCE_COLOR: 3 PYBAMM_IDAKLU_EXPR_CASADI: ON - PYBAMM_IDAKLU_EXPR_IREE: ON concurrency: # github.workflow: name of the workflow, so that we don't cancel other workflows diff --git a/CMakeLists.txt b/CMakeLists.txt index ec594e5ca5..dd334ad5e1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -53,27 +53,6 @@ if(${PYBAMM_IDAKLU_EXPR_CASADI} STREQUAL "ON" ) ) endif() -# Check IREE build flag -if(NOT DEFINED PYBAMM_IDAKLU_EXPR_IREE) - set(PYBAMM_IDAKLU_EXPR_IREE OFF) -endif() -message("PYBAMM_IDAKLU_EXPR_IREE: ${PYBAMM_IDAKLU_EXPR_IREE}") - -# IREE (MLIR expression evaluation) PyBaMM source files -set(IDAKLU_EXPR_IREE_SOURCE_FILES "") -if(${PYBAMM_IDAKLU_EXPR_IREE} STREQUAL "ON" ) - add_compile_definitions(IREE_ENABLE) - # Source file list - set(IDAKLU_EXPR_IREE_SOURCE_FILES - src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/iree_jit.cpp - src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/iree_jit.hpp - src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEFunctions.cpp - src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEFunctions.hpp - src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/ModuleParser.cpp - src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/ModuleParser.hpp - ) -endif() - # The complete (all dependencies) sources list should be mirrored in setup.py pybind11_add_module(idaklu # pybind11 interface @@ -109,7 +88,6 @@ pybind11_add_module(idaklu src/pybamm/solvers/c_solvers/idaklu/observe.cpp # IDAKLU expressions - concrete implementations ${IDAKLU_EXPR_CASADI_SOURCE_FILES} - ${IDAKLU_EXPR_IREE_SOURCE_FILES} ) if (NOT DEFINED USE_PYTHON_CASADI) @@ -179,16 +157,3 @@ else() endif() include_directories(${SuiteSparse_INCLUDE_DIRS}) target_link_libraries(idaklu PRIVATE ${SuiteSparse_LIBRARIES}) - -# IREE (MLIR compiler and runtime library) build settings -if(${PYBAMM_IDAKLU_EXPR_IREE} STREQUAL "ON" ) - set(IREE_BUILD_COMPILER ON) - set(IREE_BUILD_TESTS OFF) - set(IREE_BUILD_SAMPLES OFF) - add_subdirectory(iree EXCLUDE_FROM_ALL) - set(IREE_COMPILER_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/iree/compiler") - target_include_directories(idaklu SYSTEM PRIVATE "${IREE_COMPILER_ROOT}/bindings/c/iree/compiler") - target_compile_options(idaklu PRIVATE ${IREE_DEFAULT_COPTS}) - target_link_libraries(idaklu PRIVATE iree_compiler_bindings_c_loader) - target_link_libraries(idaklu PRIVATE iree_runtime_runtime) -endif() diff --git a/docs/source/user_guide/installation/gnu-linux-mac.rst b/docs/source/user_guide/installation/gnu-linux-mac.rst index 97171b53b7..116ffe48c8 100644 --- a/docs/source/user_guide/installation/gnu-linux-mac.rst +++ b/docs/source/user_guide/installation/gnu-linux-mac.rst @@ -101,18 +101,6 @@ Users can install ``jax`` and ``jaxlib`` to use the Jax solver. The ``pip install "pybamm[jax]"`` command automatically downloads and installs ``pybamm`` and the compatible versions of ``jax`` and ``jaxlib`` on your system. -.. _optional-iree-mlir-support: - -Optional - IREE / MLIR support -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Users can install ``iree`` (for MLIR just-in-time compilation) to use for main expression evaluation in the IDAKLU solver. Requires ``jax``. - -.. code:: bash - - pip install "pybamm[iree,jax]" - -The ``pip install "pybamm[iree,jax]"`` command automatically downloads and installs ``pybamm`` and the compatible versions of ``jax`` and ``iree`` onto your system. Uninstall PyBaMM ---------------- diff --git a/docs/source/user_guide/installation/index.rst b/docs/source/user_guide/installation/index.rst index 9225f1ee98..7610aa1dba 100644 --- a/docs/source/user_guide/installation/index.rst +++ b/docs/source/user_guide/installation/index.rst @@ -48,7 +48,6 @@ Optional solvers The following solvers are optionally available: * `jax `_ -based solver, see `Optional - JaxSolver `_. -* `IREE `_ (`MLIR `_) support, see `Optional - IREE / MLIR Support `_. Dependencies ------------ @@ -207,17 +206,6 @@ Dependency Minimu `jaxlib `__ 0.4.20 jax Support library for JAX ========================================================================= ================== ================== ======================= -IREE dependencies -^^^^^^^^^^^^^^^^^^ - -Installable with ``pip install "pybamm[iree]"`` (requires ``jax`` dependencies to be installed). - -========================================================================= ================== ================== ======================= -Dependency Minimum Version pip extra Notes -========================================================================= ================== ================== ======================= -`iree-compiler `__ 20240507.886 iree IREE compiler -========================================================================= ================== ================== ======================= - Full installation guide ----------------------- diff --git a/noxfile.py b/noxfile.py index d65812b8ed..bcb32469b4 100644 --- a/noxfile.py +++ b/noxfile.py @@ -1,7 +1,6 @@ import nox import os import sys -import warnings from pathlib import Path @@ -13,31 +12,6 @@ else: nox.options.sessions = ["pre-commit", "unit"] - -def set_iree_state(): - """ - Check if IREE is enabled and set the environment variable accordingly. - - Returns - ------- - str - "ON" if IREE is enabled, "OFF" otherwise. - - """ - state = "ON" if os.getenv("PYBAMM_IDAKLU_EXPR_IREE", "OFF") == "ON" else "OFF" - if state == "ON": - if sys.platform == "win32" or sys.platform == "darwin": - warnings.warn( - ( - "IREE is not enabled on Windows and MacOS. " - "Setting PYBAMM_IDAKLU_EXPR_IREE=OFF." - ), - stacklevel=2, - ) - return "OFF" - return state - - homedir = os.getenv("HOME") PYBAMM_ENV = { "LD_LIBRARY_PATH": f"{homedir}/.local/lib", @@ -45,10 +19,6 @@ def set_iree_state(): "MPLBACKEND": "Agg", # Expression evaluators (...EXPR_CASADI cannot be fully disabled at this time) "PYBAMM_IDAKLU_EXPR_CASADI": os.getenv("PYBAMM_IDAKLU_EXPR_CASADI", "ON"), - "PYBAMM_IDAKLU_EXPR_IREE": set_iree_state(), - "IREE_INDEX_URL": os.getenv( - "IREE_INDEX_URL", "https://iree.dev/pip-release-links.html" - ), "PYBAMM_DISABLE_TELEMETRY": "true", } VENV_DIR = Path("./venv").resolve() @@ -91,29 +61,6 @@ def run_pybamm_requires(session): "advice.detachedHead=false", external=True, ) - if PYBAMM_ENV.get("PYBAMM_IDAKLU_EXPR_IREE") == "ON" and not os.path.exists( - "./iree" - ): - session.run( - "git", - "clone", - "--depth=1", - "--recurse-submodules", - "--shallow-submodules", - "--branch=candidate-20240507.886", - "https://github.com/openxla/iree", - "iree/", - external=True, - ) - with session.chdir("iree"): - session.run( - "git", - "submodule", - "update", - "--init", - "--recursive", - external=True, - ) else: session.error("nox -s pybamm-requires is only available on Linux & macOS.") @@ -128,15 +75,6 @@ def run_coverage(session): if "CI" in os.environ: session.install("pytest-github-actions-annotate-failures") session.install("-e", ".[all,dev,jax]", silent=False) - if PYBAMM_ENV.get("PYBAMM_IDAKLU_EXPR_IREE") == "ON": - # See comments in 'dev' session - session.install( - "-e", - ".[iree]", - "--find-links", - PYBAMM_ENV.get("IREE_INDEX_URL"), - silent=False, - ) session.run("pytest", "--cov=pybamm", "--cov-report=xml", "tests/unit") @@ -177,15 +115,6 @@ def run_unit(session): set_environment_variables(PYBAMM_ENV, session=session) session.install("setuptools", silent=False) session.install("-e", ".[all,dev,jax]", silent=False) - if PYBAMM_ENV.get("PYBAMM_IDAKLU_EXPR_IREE") == "ON": - # See comments in 'dev' session - session.install( - "-e", - ".[iree]", - "--find-links", - PYBAMM_ENV.get("IREE_INDEX_URL"), - silent=False, - ) session.run("python", "-m", "pytest", "-m", "unit") @@ -220,17 +149,6 @@ def set_dev(session): session.install("virtualenv", "cmake") session.run("virtualenv", os.fsdecode(VENV_DIR), silent=True) python = os.fsdecode(VENV_DIR.joinpath("bin/python")) - components = ["all", "dev", "jax"] - args = [] - if PYBAMM_ENV.get("PYBAMM_IDAKLU_EXPR_IREE") == "ON": - # Install IREE libraries for Jax-MLIR expression evaluation in the IDAKLU solver - # (optional). IREE is currently pre-release and relies on nightly jaxlib builds. - # When upgrading Jax/IREE ensure that the following are compatible with each other: - # - Jax and Jaxlib version [pyproject.toml] - # - IREE repository clone (use the matching nightly candidate) [noxfile.py] - # - IREE compiler matches Jaxlib (use the matching nightly build) [pyproject.toml] - components.append("iree") - args = ["--find-links", PYBAMM_ENV.get("IREE_INDEX_URL")] # Temporary fix for Python 3.12 CI. TODO: remove after # https://bitbucket.org/pybtex-devs/pybtex/issues/169/replace-pkg_resources-with # is fixed @@ -241,8 +159,7 @@ def set_dev(session): "pip", "install", "-e", - ".[{}]".format(",".join(components)), - *args, + ".[all,dev,jax]", external=True, ) diff --git a/pyproject.toml b/pyproject.toml index b83c3704fe..e40c68f8e3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -118,18 +118,12 @@ dev = [ "importlib-metadata; python_version < '3.10'", ] # For the Jax solver. -# Note: These must be kept in sync with the versions defined in pybamm/util.py, and -# must remain compatible with IREE (see noxfile.py for IREE compatibility). +# Note: These must be kept in sync with the versions defined in pybamm/util.py jax = [ "jax==0.4.27", "jaxlib==0.4.27", ] -# For MLIR expression evaluation (IDAKLU Solver) -iree = [ - # must be pip installed with --find-links=https://iree.dev/pip-release-links.html - "iree-compiler==20240507.886", # see IREE compatibility notes in noxfile.py -] -# Contains all optional dependencies, except for jax, iree, and dev dependencies +# Contains all optional dependencies, except for jax and dev dependencies all = [ "scikit-fem>=8.1.0", "pybamm[examples,plot,cite,bpx,tqdm]", diff --git a/setup.py b/setup.py index 8a49bfd715..f090e8e5e5 100644 --- a/setup.py +++ b/setup.py @@ -93,13 +93,11 @@ def run(self): build_type = os.getenv("PYBAMM_CPP_BUILD_TYPE", "RELEASE") idaklu_expr_casadi = os.getenv("PYBAMM_IDAKLU_EXPR_CASADI", "ON") - idaklu_expr_iree = os.getenv("PYBAMM_IDAKLU_EXPR_IREE", "OFF") cmake_args = [ f"-DCMAKE_BUILD_TYPE={build_type}", f"-DPYTHON_EXECUTABLE={sys.executable}", "-DUSE_PYTHON_CASADI={}".format("TRUE" if use_python_casadi else "FALSE"), f"-DPYBAMM_IDAKLU_EXPR_CASADI={idaklu_expr_casadi}", - f"-DPYBAMM_IDAKLU_EXPR_IREE={idaklu_expr_iree}", ] if self.suitesparse_root: cmake_args.append( @@ -302,14 +300,6 @@ def compile_KLU(): "src/pybamm/solvers/c_solvers/idaklu/Expressions/Base/ExpressionSparsity.hpp", "src/pybamm/solvers/c_solvers/idaklu/Expressions/Casadi/CasadiFunctions.cpp", "src/pybamm/solvers/c_solvers/idaklu/Expressions/Casadi/CasadiFunctions.hpp", - "src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEBaseFunction.hpp", - "src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEFunction.hpp", - "src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEFunctions.cpp", - "src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEFunctions.hpp", - "src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/iree_jit.cpp", - "src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/iree_jit.hpp", - "src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/ModuleParser.cpp", - "src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/ModuleParser.hpp", "src/pybamm/solvers/c_solvers/idaklu/idaklu_solver.hpp", "src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.cpp", "src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.hpp", diff --git a/src/pybamm/__init__.py b/src/pybamm/__init__.py index b466c3896b..3201191f5e 100644 --- a/src/pybamm/__init__.py +++ b/src/pybamm/__init__.py @@ -1,8 +1,5 @@ from pybamm.version import __version__ -# Demote expressions to 32-bit floats/ints - option used for IDAKLU-MLIR compilation -demote_expressions_to_32bit = False - # Utility classes and methods from .util import root_dir from .util import Timer, TimerTime, FuzzyDict @@ -173,7 +170,7 @@ from .solvers.jax_bdf_solver import jax_bdf_integrate from .solvers.idaklu_jax import IDAKLUJax -from .solvers.idaklu_solver import IDAKLUSolver, has_idaklu, has_iree +from .solvers.idaklu_solver import IDAKLUSolver, has_idaklu # Experiments from .experiment.experiment import Experiment diff --git a/src/pybamm/expression_tree/operations/evaluate_python.py b/src/pybamm/expression_tree/operations/evaluate_python.py index a8a37ea7b2..eb4a0f39b9 100644 --- a/src/pybamm/expression_tree/operations/evaluate_python.py +++ b/src/pybamm/expression_tree/operations/evaluate_python.py @@ -596,54 +596,9 @@ def __init__(self, symbol: pybamm.Symbol): static_argnums=self._static_argnums, ) - def _demote_constants(self): - """Demote 64-bit constants (f64, i64) to 32-bit (f32, i32)""" - if not pybamm.demote_expressions_to_32bit: - return # pragma: no cover - self._constants = EvaluatorJax._demote_64_to_32(self._constants) - - @classmethod - def _demote_64_to_32(cls, c): - """Demote 64-bit operations (f64, i64) to 32-bit (f32, i32)""" - - if not pybamm.demote_expressions_to_32bit: - return c - if isinstance(c, float): - c = jax.numpy.float32(c) - if isinstance(c, int): - c = jax.numpy.int32(c) - if isinstance(c, np.int64): - c = c.astype(jax.numpy.int32) - if isinstance(c, np.ndarray): - if c.dtype == np.float64: - c = c.astype(jax.numpy.float32) - if c.dtype == np.int64: - c = c.astype(jax.numpy.int32) - if isinstance(c, jax.numpy.ndarray): - if c.dtype == jax.numpy.float64: - c = c.astype(jax.numpy.float32) - if c.dtype == jax.numpy.int64: - c = c.astype(jax.numpy.int32) - if isinstance( - c, pybamm.expression_tree.operations.evaluate_python.JaxCooMatrix - ): - if c.data.dtype == np.float64: - c.data = c.data.astype(jax.numpy.float32) - if c.row.dtype == np.int64: - c.row = c.row.astype(jax.numpy.int32) - if c.col.dtype == np.int64: - c.col = c.col.astype(jax.numpy.int32) - if isinstance(c, dict): - c = {key: EvaluatorJax._demote_64_to_32(value) for key, value in c.items()} - if isinstance(c, tuple): - c = tuple(EvaluatorJax._demote_64_to_32(value) for value in c) - if isinstance(c, list): - c = [EvaluatorJax._demote_64_to_32(value) for value in c] - return c - @property def _constants(self): - return tuple(map(EvaluatorJax._demote_64_to_32, self.__constants)) + return self.__constants @_constants.setter def _constants(self, value): diff --git a/src/pybamm/solvers/c_solvers/idaklu.cpp b/src/pybamm/solvers/c_solvers/idaklu.cpp index 82a3cbe91c..180161ea81 100644 --- a/src/pybamm/solvers/c_solvers/idaklu.cpp +++ b/src/pybamm/solvers/c_solvers/idaklu.cpp @@ -15,11 +15,6 @@ #include "idaklu/common.hpp" #include "idaklu/Expressions/Casadi/CasadiFunctions.hpp" -#ifdef IREE_ENABLE -#include "idaklu/Expressions/IREE/IREEFunctions.hpp" -#endif - - casadi::Function generate_casadi_function(const std::string &data) { return casadi::Function::deserialize(data); @@ -96,34 +91,6 @@ PYBIND11_MODULE(idaklu, m) py::arg("shape"), py::return_value_policy::take_ownership); -#ifdef IREE_ENABLE - m.def("create_iree_solver_group", &create_idaklu_solver_group, - "Create a group of iree idaklu solver objects", - py::arg("number_of_states"), - py::arg("number_of_parameters"), - py::arg("rhs_alg"), - py::arg("jac_times_cjmass"), - py::arg("jac_times_cjmass_colptrs"), - py::arg("jac_times_cjmass_rowvals"), - py::arg("jac_times_cjmass_nnz"), - py::arg("jac_bandwidth_lower"), - py::arg("jac_bandwidth_upper"), - py::arg("jac_action"), - py::arg("mass_action"), - py::arg("sens"), - py::arg("events"), - py::arg("number_of_events"), - py::arg("rhs_alg_id"), - py::arg("atol"), - py::arg("rtol"), - py::arg("inputs"), - py::arg("var_fcns"), - py::arg("dvar_dy_fcns"), - py::arg("dvar_dp_fcns"), - py::arg("options"), - py::return_value_policy::take_ownership); -#endif - m.def("generate_function", &generate_casadi_function, "Generate a casadi function", py::arg("string"), @@ -174,20 +141,6 @@ PYBIND11_MODULE(idaklu, m) py::class_(m, "Function"); -#ifdef IREE_ENABLE - py::class_(m, "IREEBaseFunctionType") - .def(py::init<>()) - .def_readwrite("mlir", &IREEBaseFunctionType::mlir) - .def_readwrite("kept_var_idx", &IREEBaseFunctionType::kept_var_idx) - .def_readwrite("nnz", &IREEBaseFunctionType::nnz) - .def_readwrite("numel", &IREEBaseFunctionType::numel) - .def_readwrite("col", &IREEBaseFunctionType::col) - .def_readwrite("row", &IREEBaseFunctionType::row) - .def_readwrite("pytree_shape", &IREEBaseFunctionType::pytree_shape) - .def_readwrite("pytree_sizes", &IREEBaseFunctionType::pytree_sizes) - .def_readwrite("n_args", &IREEBaseFunctionType::n_args); -#endif - py::class_(m, "solution") .def_readwrite("t", &Solution::t) .def_readwrite("y", &Solution::y) diff --git a/src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEBaseFunction.hpp b/src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEBaseFunction.hpp deleted file mode 100644 index d2ba7e4de0..0000000000 --- a/src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEBaseFunction.hpp +++ /dev/null @@ -1,27 +0,0 @@ -#ifndef PYBAMM_IDAKLU_IREE_BASE_FUNCTION_HPP -#define PYBAMM_IDAKLU_IREE_BASE_FUNCTION_HPP - -#include -#include - -/* - * @brief Function definition passed from PyBaMM - */ -class IREEBaseFunctionType -{ -public: // methods - const std::string& get_mlir() const { return mlir; } - -public: // data members - std::string mlir; // cppcheck-suppress unusedStructMember - std::vector kept_var_idx; // cppcheck-suppress unusedStructMember - expr_int nnz; // cppcheck-suppress unusedStructMember - expr_int numel; // cppcheck-suppress unusedStructMember - std::vector col; // cppcheck-suppress unusedStructMember - std::vector row; // cppcheck-suppress unusedStructMember - std::vector pytree_shape; // cppcheck-suppress unusedStructMember - std::vector pytree_sizes; // cppcheck-suppress unusedStructMember - expr_int n_args; // cppcheck-suppress unusedStructMember -}; - -#endif // PYBAMM_IDAKLU_IREE_BASE_FUNCTION_HPP diff --git a/src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEFunction.hpp b/src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEFunction.hpp deleted file mode 100644 index bcdae5eabf..0000000000 --- a/src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEFunction.hpp +++ /dev/null @@ -1,59 +0,0 @@ -#ifndef PYBAMM_IDAKLU_IREE_FUNCTION_HPP -#define PYBAMM_IDAKLU_IREE_FUNCTION_HPP - -#include "../../Options.hpp" -#include "../Expressions.hpp" -#include -#include "iree_jit.hpp" -#include "IREEBaseFunction.hpp" - -/** - * @brief Class for handling individual iree functions - */ -class IREEFunction : public Expression -{ -public: - typedef IREEBaseFunctionType BaseFunctionType; - - /* - * @brief Constructor - */ - explicit IREEFunction(const BaseFunctionType &f); - - // Method overrides - void operator()() override; - void operator()(const std::vector& inputs, - const std::vector& results) override; - expr_int out_shape(int k) override; - expr_int nnz() override; - expr_int nnz_out() override; - const std::vector& get_col() override; - const std::vector& get_row() override; - - /* - * @brief Evaluate the MLIR function - */ - void evaluate(); - - /* - * @brief Evaluate the MLIR function - * @param n_outputs The number of outputs to return - */ - void evaluate(int n_outputs); - -public: - std::unique_ptr session; - std::vector> result; // cppcheck-suppress unusedStructMember - std::vector> input_shape; // cppcheck-suppress unusedStructMember - std::vector> output_shape; // cppcheck-suppress unusedStructMember - std::vector> input_data; // cppcheck-suppress unusedStructMember - - BaseFunctionType m_func; // cppcheck-suppress unusedStructMember - std::string module_name; // cppcheck-suppress unusedStructMember - std::string function_name; // cppcheck-suppress unusedStructMember - std::vector m_arg_argno; // cppcheck-suppress unusedStructMember - std::vector m_arg_argix; // cppcheck-suppress unusedStructMember - std::vector numel; // cppcheck-suppress unusedStructMember -}; - -#endif // PYBAMM_IDAKLU_IREE_FUNCTION_HPP diff --git a/src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEFunctions.cpp b/src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEFunctions.cpp deleted file mode 100644 index 3bde647113..0000000000 --- a/src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEFunctions.cpp +++ /dev/null @@ -1,230 +0,0 @@ -#include -#include -#include -#include -#include - -#include "IREEFunctions.hpp" -#include "iree_jit.hpp" -#include "ModuleParser.hpp" - -IREEFunction::IREEFunction(const BaseFunctionType &f) : Expression(), m_func(f) -{ - DEBUG("IreeFunction constructor"); - const std::string& mlir = f.get_mlir(); - - // Parse IREE (MLIR) function string - if (mlir.size() == 0) { - DEBUG("Empty function --- skipping..."); - return; - } - - // Parse MLIR for module name, input and output shapes - ModuleParser parser(mlir); - module_name = parser.getModuleName(); - function_name = parser.getFunctionName(); - input_shape = parser.getInputShape(); - output_shape = parser.getOutputShape(); - - DEBUG("Compiling module: '" << module_name << "'"); - const char* device_uri = "local-sync"; - session = std::make_unique(device_uri, mlir); - DEBUG("compile complete."); - // Create index vectors into m_arg - // This is required since Jax expands input arguments through PyTrees, which need to - // be remapped to the corresponding expression call. For example: - // fcn(t, y, inputs, cj) with inputs = [[in1], [in2], [in3]] - // will produce a function with six inputs; we therefore need to be able to map - // arguments to their 1) corresponding input argument, and 2) the correct position - // within that argument. - m_arg_argno.clear(); - m_arg_argix.clear(); - int current_element = 0; - for (int i=0; i 2) || - ((input_shape[j].size() == 2) && (input_shape[j][1] > 1)) - ) { - std::cerr << "Unsupported input shape: " << input_shape[j].size() << " ["; - for (int k=0; k {res0} signature (i.e. x and z are reduced out) - // with kept_var_idx = [1] - // - // *********************************************************************************** - - DEBUG("Copying inputs, shape " << input_shape.size() << " - " << m_func.kept_var_idx.size()); - for (int j=0; j 1) { - // Index into argument using appropriate shape - for(int k=0; k(m_arg[m_arg_from][m_arg_argix[mlir_arg]+k]); - } - } else { - // Copy the entire vector - for(int k=0; k(m_arg[m_arg_from][k]); - } - } - } - - // Call the 'main' function of the module - const std::string mlir = m_func.get_mlir(); - DEBUG("Calling function '" << function_name << "'"); - auto status = session->iree_runtime_exec(function_name, input_shape, input_data, result); - if (!iree_status_is_ok(status)) { - iree_status_fprint(stderr, status); - std::cerr << "MLIR: " << mlir.substr(0,1000) << std::endl; - throw std::runtime_error("Execution failed"); - } - - // Copy results to output array - for(size_t k=0; k(result[k][j]); - } - } - - DEBUG("IreeFunction operator() complete"); -} - -expr_int IREEFunction::out_shape(int k) { - DEBUG("IreeFunction nnz(" << k << "): " << m_func.nnz); - auto elements = 1; - for (auto i : output_shape[k]) { - elements *= i; - } - return elements; -} - -expr_int IREEFunction::nnz() { - DEBUG("IreeFunction nnz: " << m_func.nnz); - return nnz_out(); -} - -expr_int IREEFunction::nnz_out() { - DEBUG("IreeFunction nnz_out" << m_func.nnz); - return m_func.nnz; -} - -const std::vector& IREEFunction::get_row() { - DEBUG("IreeFunction get_row" << m_func.row.size()); - return m_func.row; -} - -const std::vector& IREEFunction::get_col() { - DEBUG("IreeFunction get_col" << m_func.col.size()); - return m_func.col; -} - -void IREEFunction::operator()(const std::vector& inputs, - const std::vector& results) -{ - DEBUG("IreeFunction operator() with inputs and results"); - // Set-up input arguments, provide result vector, then execute function - // Example call: fcn({in1, in2, in3}, {out1}) - ASSERT(inputs.size() == m_func.n_args); - for(size_t k=0; k -#include "iree_jit.hpp" -#include "IREEFunction.hpp" - -/** - * @brief Class for handling iree functions - */ -class IREEFunctions : public ExpressionSet -{ -public: - std::unique_ptr iree_compiler; - - typedef IREEFunction::BaseFunctionType BaseFunctionType; // expose typedef in class - - int iree_init_status; - - int iree_init(const std::string& device_uri, const std::string& target_backends) { - // Initialise IREE - DEBUG("IREEFunctions: Initialising IREECompiler"); - iree_compiler = std::make_unique(device_uri.c_str()); - - int iree_argc = 2; - std::string target_backends_str = "--iree-hal-target-backends=" + target_backends; - const char* iree_argv[2] = {"iree", target_backends_str.c_str()}; - iree_compiler->init(iree_argc, iree_argv); - DEBUG("IREEFunctions: Initialised IREECompiler"); - return 0; - } - - int iree_init() { - return iree_init("local-sync", "llvm-cpu"); - } - - - /** - * @brief Create a new IREEFunctions object - */ - IREEFunctions( - const BaseFunctionType &rhs_alg, - const BaseFunctionType &jac_times_cjmass, - const int jac_times_cjmass_nnz, - const int jac_bandwidth_lower, - const int jac_bandwidth_upper, - const np_array_int &jac_times_cjmass_rowvals_arg, - const np_array_int &jac_times_cjmass_colptrs_arg, - const int inputs_length, - const BaseFunctionType &jac_action, - const BaseFunctionType &mass_action, - const BaseFunctionType &sens, - const BaseFunctionType &events, - const int n_s, - const int n_e, - const int n_p, - const std::vector& var_fcns, - const std::vector& dvar_dy_fcns, - const std::vector& dvar_dp_fcns, - const SetupOptions& setup_opts - ) : - iree_init_status(iree_init()), - rhs_alg_iree(rhs_alg), - jac_times_cjmass_iree(jac_times_cjmass), - jac_action_iree(jac_action), - mass_action_iree(mass_action), - sens_iree(sens), - events_iree(events), - ExpressionSet( - static_cast(&rhs_alg_iree), - static_cast(&jac_times_cjmass_iree), - jac_times_cjmass_nnz, - jac_bandwidth_lower, - jac_bandwidth_upper, - jac_times_cjmass_rowvals_arg, - jac_times_cjmass_colptrs_arg, - inputs_length, - static_cast(&jac_action_iree), - static_cast(&mass_action_iree), - static_cast(&sens_iree), - static_cast(&events_iree), - n_s, n_e, n_p, - setup_opts) - { - // convert BaseFunctionType list to IREEFunction list - // NOTE: You must allocate ALL std::vector elements before taking references - for (auto& var : var_fcns) - var_fcns_iree.push_back(IREEFunction(*var)); - for (int k = 0; k < var_fcns_iree.size(); k++) - ExpressionSet::var_fcns.push_back(&this->var_fcns_iree[k]); - - for (auto& var : dvar_dy_fcns) - dvar_dy_fcns_iree.push_back(IREEFunction(*var)); - for (int k = 0; k < dvar_dy_fcns_iree.size(); k++) - this->dvar_dy_fcns.push_back(&this->dvar_dy_fcns_iree[k]); - - for (auto& var : dvar_dp_fcns) - dvar_dp_fcns_iree.push_back(IREEFunction(*var)); - for (int k = 0; k < dvar_dp_fcns_iree.size(); k++) - this->dvar_dp_fcns.push_back(&this->dvar_dp_fcns_iree[k]); - - // copy across numpy array values - const int n_row_vals = jac_times_cjmass_rowvals_arg.request().size; - auto p_jac_times_cjmass_rowvals = jac_times_cjmass_rowvals_arg.unchecked<1>(); - jac_times_cjmass_rowvals.resize(n_row_vals); - for (int i = 0; i < n_row_vals; i++) { - jac_times_cjmass_rowvals[i] = p_jac_times_cjmass_rowvals[i]; - } - - const int n_col_ptrs = jac_times_cjmass_colptrs_arg.request().size; - auto p_jac_times_cjmass_colptrs = jac_times_cjmass_colptrs_arg.unchecked<1>(); - jac_times_cjmass_colptrs.resize(n_col_ptrs); - for (int i = 0; i < n_col_ptrs; i++) { - jac_times_cjmass_colptrs[i] = p_jac_times_cjmass_colptrs[i]; - } - - inputs.resize(inputs_length); - } - - IREEFunction rhs_alg_iree; - IREEFunction jac_times_cjmass_iree; - IREEFunction jac_action_iree; - IREEFunction mass_action_iree; - IREEFunction sens_iree; - IREEFunction events_iree; - - std::vector var_fcns_iree; - std::vector dvar_dy_fcns_iree; - std::vector dvar_dp_fcns_iree; - - realtype* get_tmp_state_vector() override { - return tmp_state_vector.data(); - } - realtype* get_tmp_sparse_jacobian_data() override { - return tmp_sparse_jacobian_data.data(); - } - - ~IREEFunctions() { - // cleanup IREE - iree_compiler->cleanup(); - } -}; - -#endif // PYBAMM_IDAKLU_IREE_FUNCTIONS_HPP diff --git a/src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/ModuleParser.cpp b/src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/ModuleParser.cpp deleted file mode 100644 index d1c5575ee2..0000000000 --- a/src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/ModuleParser.cpp +++ /dev/null @@ -1,91 +0,0 @@ -#include "ModuleParser.hpp" - -ModuleParser::ModuleParser(const std::string& mlir) : mlir(mlir) -{ - parse(); -} - -void ModuleParser::parse() -{ - // Parse module name - std::regex module_name_regex("module @([^\\s]+)"); // Match until first whitespace - std::smatch module_name_match; - std::regex_search(this->mlir, module_name_match, module_name_regex); - if (module_name_match.size() == 0) { - std::cerr << "Could not find module name in module" << std::endl; - std::cerr << "Module snippet: " << this->mlir.substr(0, 1000) << std::endl; - throw std::runtime_error("Could not find module name in module"); - } - module_name = module_name_match[1].str(); - DEBUG("Module name: " << module_name); - - // Assign function name - function_name = module_name + ".main"; - - // Isolate 'main' function call signature - std::regex main_func("public @main\\((.*?)\\) -> \\((.*?)\\)"); - std::smatch match; - std::regex_search(this->mlir, match, main_func); - if (match.size() == 0) { - std::cerr << "Could not find 'main' function in module" << std::endl; - std::cerr << "Module snippet: " << this->mlir.substr(0, 1000) << std::endl; - throw std::runtime_error("Could not find 'main' function in module"); - } - std::string main_sig_inputs = match[1].str(); - std::string main_sig_outputs = match[2].str(); - DEBUG( - "Main function signature: " << main_sig_inputs << " -> " << main_sig_outputs << '\n' - ); - - // Parse input sizes - input_shape.clear(); - std::regex input_size("tensor<(.*?)>"); - for(std::sregex_iterator i = std::sregex_iterator(main_sig_inputs.begin(), main_sig_inputs.end(), input_size); - i != std::sregex_iterator(); - ++i) - { - std::smatch matchi = *i; - std::string match_str = matchi.str(); - std::string shape_str = match_str.substr(7, match_str.size() - 8); // Remove 'tensor<>' from string - std::vector shape; - std::string dim_str; - for (char c : shape_str) { - if (c == 'x') { - shape.push_back(std::stoi(dim_str)); - dim_str = ""; - } else { - dim_str += c; - } - } - input_shape.push_back(shape); - } - - // Parse output sizes - output_shape.clear(); - std::regex output_size("tensor<(.*?)>"); - for( - std::sregex_iterator i = std::sregex_iterator(main_sig_outputs.begin(), main_sig_outputs.end(), output_size); - i != std::sregex_iterator(); - ++i - ) { - std::smatch matchi = *i; - std::string match_str = matchi.str(); - std::string shape_str = match_str.substr(7, match_str.size() - 8); // Remove 'tensor<>' from string - std::vector shape; - std::string dim_str; - for (char c : shape_str) { - if (c == 'x') { - shape.push_back(std::stoi(dim_str)); - dim_str = ""; - } else { - dim_str += c; - } - } - // If shape is empty, assume scalar (i.e. "tensor" or some singleton variant) - if (shape.size() == 0) { - shape.push_back(1); - } - // Add output to list - output_shape.push_back(shape); - } -} diff --git a/src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/ModuleParser.hpp b/src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/ModuleParser.hpp deleted file mode 100644 index 2fbfdc086c..0000000000 --- a/src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/ModuleParser.hpp +++ /dev/null @@ -1,55 +0,0 @@ -#ifndef PYBAMM_IDAKLU_IREE_MODULE_PARSER_HPP -#define PYBAMM_IDAKLU_IREE_MODULE_PARSER_HPP - -#include -#include -#include -#include -#include - -#include "../../common.hpp" - -class ModuleParser { -private: - std::string mlir; // cppcheck-suppress unusedStructMember - // codacy fix: member is referenced as this->mlir in parse() - std::string module_name; - std::string function_name; - std::vector> input_shape; - std::vector> output_shape; -public: - /** - * @brief Constructor - * @param mlir: string representation of MLIR code for the module - */ - explicit ModuleParser(const std::string& mlir); - - /** - * @brief Get the module name - * @return module name - */ - const std::string& getModuleName() const { return module_name; } - - /** - * @brief Get the function name - * @return function name - */ - const std::string& getFunctionName() const { return function_name; } - - /** - * @brief Get the input shape - * @return input shape - */ - const std::vector>& getInputShape() const { return input_shape; } - - /** - * @brief Get the output shape - * @return output shape - */ - const std::vector>& getOutputShape() const { return output_shape; } - -private: - void parse(); -}; - -#endif // PYBAMM_IDAKLU_IREE_MODULE_PARSER_HPP diff --git a/src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/iree_jit.cpp b/src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/iree_jit.cpp deleted file mode 100644 index c84c3928bd..0000000000 --- a/src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/iree_jit.cpp +++ /dev/null @@ -1,408 +0,0 @@ -#include "iree_jit.hpp" -#include "iree/hal/buffer_view.h" -#include "iree/hal/buffer_view_util.h" -#include "../../common.hpp" - -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -// Used to suppress stderr output (see initIREE below) -#ifdef _WIN32 -#include -#define close _close -#define dup _dup -#define fileno _fileno -#define open _open -#define dup2 _dup2 -#define NULL_DEVICE "NUL" -#else -#define NULL_DEVICE "/dev/null" -#endif - -void IREESession::handle_compiler_error(iree_compiler_error_t *error) { - const char *msg = ireeCompilerErrorGetMessage(error); - fprintf(stderr, "Error from compiler API:\n%s\n", msg); - ireeCompilerErrorDestroy(error); -} - -void IREESession::cleanup_compiler_state(compiler_state_t s) { - if (s.inv) - ireeCompilerInvocationDestroy(s.inv); - if (s.output) - ireeCompilerOutputDestroy(s.output); - if (s.source) - ireeCompilerSourceDestroy(s.source); - if (s.session) - ireeCompilerSessionDestroy(s.session); -} - -IREECompiler::IREECompiler() { - this->device_uri = "local-sync"; -}; - -IREECompiler::~IREECompiler() { - ireeCompilerGlobalShutdown(); -}; - -int IREECompiler::init(int argc, const char **argv) { - return initIREE(argc, argv); // Initialisation and version checking -}; - -int IREECompiler::cleanup() { - return 0; -}; - -IREESession::IREESession() { - s.session = NULL; - s.source = NULL; - s.output = NULL; - s.inv = NULL; -}; - -IREESession::IREESession(const char *device_uri, const std::string& mlir_code) : IREESession() { - this->device_uri=device_uri; - this->mlir_code=mlir_code; - init(); -} - -int IREESession::init() { - if (initCompiler() != 0) // Prepare compiler inputs and outputs - return 1; - if (initCompileToByteCode() != 0) // Compile to bytecode - return 1; - if (initRuntime() != 0) // Initialise runtime environment - return 1; - return 0; -}; - -int IREECompiler::initIREE(int argc, const char **argv) { - - if (device_uri == NULL) { - DEBUG("No device URI provided, using local-sync\n"); - this->device_uri = "local-sync"; - } - - int cl_argc = argc; - const char *iree_compiler_lib = std::getenv("IREE_COMPILER_LIB"); - - // Load the compiler library and initialize it - // NOTE: On second and subsequent calls, the function will return false and display - // a message on stderr, but it is safe to ignore this message. For an improved user - // experience we actively suppress stderr during the call to this function but since - // this also suppresses any other error message, we actively check for the presence - // of the library file prior to the call. - - // Check if the library file exists - if (iree_compiler_lib == NULL) { - fprintf(stderr, "Error: IREE_COMPILER_LIB environment variable not set\n"); - return 1; - } - if (access(iree_compiler_lib, F_OK) == -1) { - fprintf(stderr, "Error: IREE_COMPILER_LIB file not found\n"); - return 1; - } - // Suppress stderr - int saved_stderr = dup(fileno(stderr)); - if (!freopen(NULL_DEVICE, "w", stderr)) - DEBUG("Error: failed redirecting stderr"); - // Load library - bool result = ireeCompilerLoadLibrary(iree_compiler_lib); - // Restore stderr - fflush(stderr); - dup2(saved_stderr, fileno(stderr)); - close(saved_stderr); - // Process result - if (!result) { - // Library may have already been loaded (can be safely ignored), - // or may not be found (critical error), we cannot tell which from the return value. - return 1; - } - // Must be balanced with a call to ireeCompilerGlobalShutdown() - ireeCompilerGlobalInitialize(); - - // To set global options (see `iree-compile --help` for possibilities), use - // |ireeCompilerGetProcessCLArgs| and |ireeCompilerSetupGlobalCL| - ireeCompilerGetProcessCLArgs(&cl_argc, &argv); - ireeCompilerSetupGlobalCL(cl_argc, argv, "iree-jit", false); - - // Check the API version before proceeding any further - uint32_t api_version = (uint32_t)ireeCompilerGetAPIVersion(); - uint16_t api_version_major = (uint16_t)((api_version >> 16) & 0xFFFFUL); - uint16_t api_version_minor = (uint16_t)(api_version & 0xFFFFUL); - DEBUG("Compiler API version: " << api_version_major << "." << api_version_minor); - if (api_version_major > IREE_COMPILER_EXPECTED_API_MAJOR || - api_version_minor < IREE_COMPILER_EXPECTED_API_MINOR) { - fprintf(stderr, - "Error: incompatible API version; built for version %" PRIu16 - ".%" PRIu16 " but loaded version %" PRIu16 ".%" PRIu16 "\n", - IREE_COMPILER_EXPECTED_API_MAJOR, IREE_COMPILER_EXPECTED_API_MINOR, - api_version_major, api_version_minor); - ireeCompilerGlobalShutdown(); - return 1; - } - - // Check for a build tag with release version information - const char *revision = ireeCompilerGetRevision(); // cppcheck-suppress unreadVariable - DEBUG("Compiler revision: '" << revision << "'"); - return 0; -}; - -int IREESession::initCompiler() { - - // A session provides a scope where one or more invocations can be executed - s.session = ireeCompilerSessionCreate(); - - // Read the MLIR from memory - error = ireeCompilerSourceWrapBuffer( - s.session, - "expr_buffer", // name of the buffer (does not need to match MLIR) - mlir_code.c_str(), - mlir_code.length() + 1, - true, - &s.source - ); - if (error) { - fprintf(stderr, "Error wrapping source buffer\n"); - handle_compiler_error(error); - cleanup_compiler_state(s); - return 1; - } - DEBUG("Wrapped buffer as a compiler source"); - - return 0; -}; - -int IREESession::initCompileToByteCode() { - // Use an invocation to compile from the input source to the output stream - iree_compiler_invocation_t *inv = ireeCompilerInvocationCreate(s.session); - ireeCompilerInvocationEnableConsoleDiagnostics(inv); - - if (!ireeCompilerInvocationParseSource(inv, s.source)) { - fprintf(stderr, "Error parsing input source into invocation\n"); - cleanup_compiler_state(s); - return 1; - } - - // Compile, specifying the target dialect phase - ireeCompilerInvocationSetCompileToPhase(inv, "end"); - - // Run the compiler invocation pipeline - if (!ireeCompilerInvocationPipeline(inv, IREE_COMPILER_PIPELINE_STD)) { - fprintf(stderr, "Error running compiler invocation\n"); - cleanup_compiler_state(s); - return 1; - } - DEBUG("Compilation successful"); - - // Create compiler 'output' to a memory buffer - error = ireeCompilerOutputOpenMembuffer(&s.output); - if (error) { - fprintf(stderr, "Error opening output membuffer\n"); - handle_compiler_error(error); - cleanup_compiler_state(s); - return 1; - } - - // Create bytecode in memory - error = ireeCompilerInvocationOutputVMBytecode(inv, s.output); - if (error) { - fprintf(stderr, "Error creating VM bytecode\n"); - handle_compiler_error(error); - cleanup_compiler_state(s); - return 1; - } - - // Once the bytecode has been written, retrieve the memory map - ireeCompilerOutputMapMemory(s.output, &contents, &size); - - return 0; -}; - -int IREESession::initRuntime() { - // Setup the shared runtime instance - iree_runtime_instance_options_t instance_options; - iree_runtime_instance_options_initialize(&instance_options); - iree_runtime_instance_options_use_all_available_drivers(&instance_options); - status = iree_runtime_instance_create( - &instance_options, iree_allocator_system(), &instance); - - // Create the HAL device used to run the workloads - if (iree_status_is_ok(status)) { - status = iree_hal_create_device( - iree_runtime_instance_driver_registry(instance), - iree_make_cstring_view(device_uri), - iree_runtime_instance_host_allocator(instance), &device); - } - - // Set up the session to run the module - if (iree_status_is_ok(status)) { - iree_runtime_session_options_t session_options; - iree_runtime_session_options_initialize(&session_options); - status = iree_runtime_session_create_with_device( - instance, &session_options, device, - iree_runtime_instance_host_allocator(instance), &session); - } - - // Load the compiled user module from a file - if (iree_status_is_ok(status)) { - /*status = iree_runtime_session_append_bytecode_module_from_file(session, module_path);*/ - status = iree_runtime_session_append_bytecode_module_from_memory( - session, - iree_make_const_byte_span(contents, size), - iree_allocator_null()); - } - - if (!iree_status_is_ok(status)) - return 1; - - return 0; -}; - -// Release the session and free all cached resources. -int IREESession::cleanup() { - iree_runtime_session_release(session); - iree_hal_device_release(device); - iree_runtime_instance_release(instance); - - int ret = (int)iree_status_code(status); - if (!iree_status_is_ok(status)) { - iree_status_fprint(stderr, status); - iree_status_ignore(status); - } - cleanup_compiler_state(s); - return ret; -} - -iree_status_t IREESession::iree_runtime_exec( - const std::string& function_name, - const std::vector>& inputs, - const std::vector>& data, - std::vector>& result -) { - - // Initialize the call to the function. - status = iree_runtime_call_initialize_by_name( - session, iree_make_cstring_view(function_name.c_str()), &call); - if (!iree_status_is_ok(status)) { - std::cerr << "Error: iree_runtime_call_initialize_by_name failed" << std::endl; - iree_status_fprint(stderr, status); - return status; - } - - // Append the function inputs with the HAL device allocator in use by the - // session. The buffers will be usable within the session and _may_ be usable - // in other sessions depending on whether they share a compatible device. - iree_hal_allocator_t* device_allocator = - iree_runtime_session_device_allocator(session); - host_allocator = iree_runtime_session_host_allocator(session); - status = iree_ok_status(); - if (iree_status_is_ok(status)) { - - for(int k=0; k arg_shape(input_shape.size()); - for (int i = 0; i < input_shape.size(); i++) { - arg_shape[i] = input_shape[i]; - } - int numel = 1; - for(int i = 0; i < input_shape.size(); i++) { - numel *= input_shape[i]; - } - std::vector arg_data(numel); - for(int i = 0; i < numel; i++) { - arg_data[i] = input_data[i]; - } - - status = iree_hal_buffer_view_allocate_buffer_copy( - device, device_allocator, - // Shape rank and dimensions: - arg_shape.size(), arg_shape.data(), - // Element type: - IREE_HAL_ELEMENT_TYPE_FLOAT_32, - // Encoding type: - IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, - (iree_hal_buffer_params_t){ - // Intended usage of the buffer (transfers, dispatches, etc): - .usage = IREE_HAL_BUFFER_USAGE_DEFAULT, - // Access to allow to this memory: - .access = IREE_HAL_MEMORY_ACCESS_ALL, - // Where to allocate (host or device): - .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, - }, - // The actual heap buffer to wrap or clone and its allocator: - iree_make_const_byte_span(&arg_data[0], sizeof(float) * arg_data.size()), - // Buffer view + storage are returned and owned by the caller: - &arg); - } - if (iree_status_is_ok(status)) { - // Add to the call inputs list (which retains the buffer view). - status = iree_runtime_call_inputs_push_back_buffer_view(&call, arg); - if (!iree_status_is_ok(status)) { - std::cerr << "Error: iree_runtime_call_inputs_push_back_buffer_view failed" << std::endl; - iree_status_fprint(stderr, status); - } - } - // Since the call retains the buffer view we can release it here. - iree_hal_buffer_view_release(arg); - } - } - - // Synchronously perform the call. - if (iree_status_is_ok(status)) { - status = iree_runtime_call_invoke(&call, /*flags=*/0); - } - if (!iree_status_is_ok(status)) { - std::cerr << "Error: iree_runtime_call_invoke failed" << std::endl; - iree_status_fprint(stderr, status); - } - - for(int k=0; k -#include -#include -#include - -#include -#include -#include - -#define IREE_COMPILER_EXPECTED_API_MAJOR 1 // At most this major version -#define IREE_COMPILER_EXPECTED_API_MINOR 2 // At least this minor version - -// Forward declaration -class IREESession; - -/* - * @brief IREECompiler class - * @details This class is used to compile MLIR code to IREE bytecode and - * create IREE sessions. - */ -class IREECompiler { -private: - /* - * @brief Device Uniform Resource Identifier (URI) - * @details The device URI is used to specify the device to be used by the - * IREE runtime. E.g. "local-sync" for CPU, "vulkan" for GPU, etc. - */ - const char *device_uri = NULL; - -private: - /* - * @brief Initialize the IREE runtime - */ - int initIREE(int argc, const char **argv); - -public: - /* - * @brief Default constructor - */ - IREECompiler(); - - /* - * @brief Destructor - */ - ~IREECompiler(); - - /* - * @brief Constructor with device URI - * @param device_uri Device URI - */ - explicit IREECompiler(const char *device_uri) - : IREECompiler() { this->device_uri=device_uri; } - - /* - * @brief Initialize the compiler - */ - int init(int argc, const char **argv); - - /* - * @brief Cleanup the compiler - * @details This method cleans up the compiler and all the IREE sessions - * created by the compiler. Returns 0 on success. - */ - int cleanup(); -}; - -/* - * @brief Compiler state - */ -typedef struct compiler_state_t { - iree_compiler_session_t *session; // cppcheck-suppress unusedStructMember - iree_compiler_source_t *source; // cppcheck-suppress unusedStructMember - iree_compiler_output_t *output; // cppcheck-suppress unusedStructMember - iree_compiler_invocation_t *inv; // cppcheck-suppress unusedStructMember -} compiler_state_t; - -/* - * @brief IREE session class - */ -class IREESession { -private: // data members - const char *device_uri = NULL; - compiler_state_t s; - iree_compiler_error_t *error = NULL; - void *contents = NULL; - uint64_t size = 0; - iree_runtime_session_t* session = NULL; - iree_status_t status; - iree_hal_device_t* device = NULL; - iree_runtime_instance_t* instance = NULL; - std::string mlir_code; // cppcheck-suppress unusedStructMember - iree_runtime_call_t call; - iree_allocator_t host_allocator; - -private: // private methods - void handle_compiler_error(iree_compiler_error_t *error); - void cleanup_compiler_state(compiler_state_t s); - int init(); - int initCompiler(); - int initCompileToByteCode(); - int initRuntime(); - -public: // public methods - - /* - * @brief Default constructor - */ - IREESession(); - - /* - * @brief Constructor with device URI and MLIR code - * @param device_uri Device URI - * @param mlir_code MLIR code - */ - explicit IREESession(const char *device_uri, const std::string& mlir_code); - - /* - * @brief Cleanup the IREE session - */ - int cleanup(); - - /* - * @brief Execute the pre-compiled byte-code with the given inputs - * @param function_name Function name to execute - * @param inputs List of input shapes - * @param data List of input data - * @param result List of output data - */ - iree_status_t iree_runtime_exec( - const std::string& function_name, - const std::vector>& inputs, - const std::vector>& data, - std::vector>& result - ); -}; - -#endif // IREE_JIT_HPP diff --git a/src/pybamm/solvers/idaklu_solver.py b/src/pybamm/solvers/idaklu_solver.py index 484c1ed9b4..3e4c6c1e8d 100644 --- a/src/pybamm/solvers/idaklu_solver.py +++ b/src/pybamm/solvers/idaklu_solver.py @@ -1,28 +1,13 @@ -# -# Solver class using sundials with the KLU sparse linear solver -# # mypy: ignore-errors -import os import casadi import pybamm import numpy as np import numbers -import scipy.sparse as sparse -from scipy.linalg import bandwidth import importlib import warnings -if pybamm.has_jax(): - import jax - from jax import numpy as jnp - - try: - import iree.compiler - except ImportError: # pragma: no cover - pass - idaklu_spec = importlib.util.find_spec("pybamm.solvers.idaklu") if idaklu_spec is not None: try: @@ -39,15 +24,6 @@ def has_idaklu(): return idaklu_spec is not None -def has_iree(): - try: - import iree.compiler # noqa: F401 - - return True - except ImportError: # pragma: no cover - return False - - class IDAKLUSolver(pybamm.BaseSolver): """ Solve a discretised model, using sundials with the KLU sparse linear solver. @@ -84,8 +60,6 @@ class IDAKLUSolver(pybamm.BaseSolver): "num_threads": 1, # Number of solvers to use in parallel (for solving multiple sets of input parameters in parallel) "num_solvers": num_threads, - # Evaluation engine to use for jax, can be 'jax'(native) or 'iree' - "jax_evaluator": "jax", ## Linear solver interface # name of sundials linear solver to use options are: "SUNLinSol_KLU", # "SUNLinSol_Dense", "SUNLinSol_Band", "SUNLinSol_SPBCGS", @@ -191,7 +165,6 @@ def __init__( "precon_half_bandwidth_keep": 5, "num_threads": 1, "num_solvers": 1, - "jax_evaluator": "jax", "linear_solver": "SUNLinSol_KLU", "linsol_max_iterations": 5, "epsilon_linear_tolerance": 0.05, @@ -224,10 +197,6 @@ def __init__( for key, value in default_options.items(): if key not in options: options[key] = value - if options["jax_evaluator"] not in ["jax", "iree"]: - raise pybamm.SolverError( - "Evaluation engine must be 'jax' or 'iree' for IDAKLU solver" - ) self._options = options self.output_variables = [] if output_variables is None else output_variables @@ -280,19 +249,10 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False): # stack inputs if inputs_dict: arrays_to_stack = [np.array(x).reshape(-1, 1) for x in inputs_dict.values()] - inputs_sizes = [len(array) for array in arrays_to_stack] inputs = np.vstack(arrays_to_stack) else: - inputs_sizes = [] inputs = np.array([[]]) - def inputs_to_dict(inputs): - index = 0 - for n, key in zip(inputs_sizes, inputs_dict.keys()): - inputs_dict[key] = inputs[index : (index + n)] - index += n - return inputs_dict - y0 = model.y0 if isinstance(y0, casadi.DM): y0 = y0.full() @@ -303,21 +263,12 @@ def inputs_to_dict(inputs): if model.convert_to_format not in ["casadi", "jax"]: msg = ( - "The python-idaklu solver has been deprecated. " - "To use the IDAKLU solver set `convert_to_format = 'casadi'`, or `jax`" - " if using IREE." + "The python-idaklu and IREE solvers have been deprecated. " + "To use the IDAKLU solver set `convert_to_format = 'casadi'` or `jax`" ) warnings.warn(msg, DeprecationWarning, stacklevel=2) - if model.convert_to_format == "jax": - if self._options["jax_evaluator"] != "iree": - raise pybamm.SolverError( - "Unsupported evaluation engine for convert_to_format=" - f"{model.convert_to_format} " - f"(jax_evaluator={self._options['jax_evaluator']})" - ) - mass_matrix = model.mass_matrix.entries.toarray() - elif model.convert_to_format == "casadi": + if model.convert_to_format == "casadi": if self._options["jacobian"] == "dense": mass_matrix = casadi.DM(model.mass_matrix.entries.toarray()) else: @@ -464,171 +415,6 @@ def inputs_to_dict(inputs): rootfn = idaklu.generate_function(rootfn.serialize()) mass_action = idaklu.generate_function(mass_action.serialize()) sensfn = idaklu.generate_function(sensfn.serialize()) - elif ( - model.convert_to_format == "jax" - and self._options["jax_evaluator"] == "iree" - ): - # Convert Jax functions to MLIR (also, demote to single precision) - idaklu_solver_fcn = idaklu.create_iree_solver_group - pybamm.demote_expressions_to_32bit = True - if pybamm.demote_expressions_to_32bit: - warnings.warn( - "Demoting expressions to 32-bit for MLIR conversion", - stacklevel=2, - ) - jnpfloat = jnp.float32 - else: # pragma: no cover - jnpfloat = jnp.float64 - raise pybamm.SolverError( - "Demoting expressions to 32-bit is required for MLIR conversion" - " at this time" - ) - - # input arguments (used for lowering) - t_eval = self._demote_64_to_32(jnp.array([0.0], dtype=jnpfloat)) - y0 = self._demote_64_to_32(model.y0) - inputs0 = self._demote_64_to_32(inputs_to_dict(inputs)) - cj = self._demote_64_to_32(jnp.array([1.0], dtype=jnpfloat)) # array - v0 = jnp.zeros(model.len_rhs_and_alg, jnpfloat) - mass_matrix = model.mass_matrix.entries.toarray() - mass_matrix_demoted = self._demote_64_to_32(mass_matrix) - - # rhs_algebraic - rhs_algebraic_demoted = model.rhs_algebraic_eval - rhs_algebraic_demoted._demote_constants() - - def fcn_rhs_algebraic(t, y, inputs): - # function wraps an expression tree (and names MLIR module) - return rhs_algebraic_demoted(t, y, inputs) - - rhs_algebraic = self._make_iree_function( - fcn_rhs_algebraic, t_eval, y0, inputs0 - ) - - # jac_times_cjmass - jac_rhs_algebraic_demoted = rhs_algebraic_demoted.get_jacobian() - - def fcn_jac_times_cjmass(t, y, p, cj): - return jac_rhs_algebraic_demoted(t, y, p) - cj * mass_matrix_demoted - - sparse_eval = sparse.csc_matrix( - fcn_jac_times_cjmass(t_eval, y0, inputs0, cj) - ) - jac_times_cjmass_nnz = sparse_eval.nnz - jac_times_cjmass_colptrs = sparse_eval.indptr - jac_times_cjmass_rowvals = sparse_eval.indices - jac_bw_lower, jac_bw_upper = bandwidth( - sparse_eval.todense() - ) # potentially slow - if jac_bw_upper <= 1: - jac_bw_upper = jac_bw_lower - 1 - if jac_bw_lower <= 1: - jac_bw_lower = jac_bw_upper + 1 - coo = sparse_eval.tocoo() # convert to COOrdinate format for indexing - - def fcn_jac_times_cjmass_sparse(t, y, p, cj): - return fcn_jac_times_cjmass(t, y, p, cj)[coo.row, coo.col] - - jac_times_cjmass = self._make_iree_function( - fcn_jac_times_cjmass_sparse, t_eval, y0, inputs0, cj - ) - - # Mass action - def fcn_mass_action(v): - return mass_matrix_demoted @ v - - mass_action_demoted = self._demote_64_to_32(fcn_mass_action) - mass_action = self._make_iree_function(mass_action_demoted, v0) - - # rootfn - for ix, _ in enumerate(model.terminate_events_eval): - model.terminate_events_eval[ix]._demote_constants() - - def fcn_rootfn(t, y, inputs): - return jnp.array( - [event(t, y, inputs) for event in model.terminate_events_eval], - dtype=jnpfloat, - ).reshape(-1) - - def fcn_rootfn_demoted(t, y, inputs): - return self._demote_64_to_32(fcn_rootfn)(t, y, inputs) - - rootfn = self._make_iree_function(fcn_rootfn_demoted, t_eval, y0, inputs0) - - # jac_rhs_algebraic_action - jac_rhs_algebraic_action_demoted = ( - rhs_algebraic_demoted.get_jacobian_action() - ) - - def fcn_jac_rhs_algebraic_action( - t, y, p, v - ): # sundials calls (t, y, inputs, v) - return jac_rhs_algebraic_action_demoted( - t, y, v, p - ) # jvp calls (t, y, v, inputs) - - jac_rhs_algebraic_action = self._make_iree_function( - fcn_jac_rhs_algebraic_action, t_eval, y0, inputs0, v0 - ) - - # sensfn - if model.jacp_rhs_algebraic_eval is None: - sensfn = idaklu.IREEBaseFunctionType() # empty equation - else: - sensfn_demoted = rhs_algebraic_demoted.get_sensitivities() - - def fcn_sensfn(t, y, p): - return sensfn_demoted(t, y, p) - - sensfn = self._make_iree_function( - fcn_sensfn, t_eval, jnp.zeros_like(y0), inputs0 - ) - - # output_variables - self.var_idaklu_fcns = [] - self.dvar_dy_idaklu_fcns = [] - self.dvar_dp_idaklu_fcns = [] - for key in self.output_variables: - fcn = self.computed_var_fcns[key] - fcn._demote_constants() - self.var_idaklu_fcns.append( - self._make_iree_function( - lambda t, y, p: fcn(t, y, p), # noqa: B023 - t_eval, - y0, - inputs0, - ) - ) - # Convert derivative functions for sensitivities - if (len(inputs) > 0) and (model.calculate_sensitivities): - dvar_dy = fcn.get_jacobian() - self.dvar_dy_idaklu_fcns.append( - self._make_iree_function( - lambda t, y, p: dvar_dy(t, y, p), # noqa: B023 - t_eval, - y0, - inputs0, - sparse_index=True, - ) - ) - dvar_dp = fcn.get_sensitivities() - self.dvar_dp_idaklu_fcns.append( - self._make_iree_function( - lambda t, y, p: dvar_dp(t, y, p), # noqa: B023 - t_eval, - y0, - inputs0, - ) - ) - - # Identify IREE library - iree_lib_path = os.path.join(iree.compiler.__path__[0], "_mlir_libs") - os.environ["IREE_COMPILER_LIB"] = os.path.join( - iree_lib_path, - next(f for f in os.listdir(iree_lib_path) if "IREECompiler" in f), - ) - - pybamm.demote_expressions_to_32bit = False else: # pragma: no cover raise pybamm.SolverError( "Unsupported evaluation engine for convert_to_format='jax'" @@ -687,57 +473,6 @@ def fcn_sensfn(t, y, p): return base_set_up_return - def _make_iree_function(self, fcn, *args, sparse_index=False): - # Initialise IREE function object - iree_fcn = idaklu.IREEBaseFunctionType() - # Get sparsity pattern index outputs as needed - try: - fcn_eval = fcn(*args) - if not isinstance(fcn_eval, np.ndarray): - fcn_eval = jax.flatten_util.ravel_pytree(fcn_eval)[0] - coo = sparse.coo_matrix(fcn_eval) - iree_fcn.nnz = coo.nnz - iree_fcn.numel = np.prod(coo.shape) - iree_fcn.col = coo.col - iree_fcn.row = coo.row - if sparse_index: - # Isolate NNZ elements while recording original sparsity structure - fcn_inner = fcn - - def fcn(*args): - return fcn_inner(*args)[coo.row, coo.col] - - elif coo.nnz != iree_fcn.numel: - iree_fcn.nnz = iree_fcn.numel - iree_fcn.col = list(range(iree_fcn.numel)) - iree_fcn.row = [0] * iree_fcn.numel - except (TypeError, AttributeError) as error: # pragma: no cover - raise pybamm.SolverError( - "Could not get sparsity pattern for function {fcn.__name__}" - ) from error - # Lower to MLIR - lowered = jax.jit(fcn).lower(*args) - iree_fcn.mlir = lowered.as_text() - self._check_mlir_conversion(fcn.__name__, iree_fcn.mlir) - iree_fcn.kept_var_idx = list(lowered._lowering.compile_args["kept_var_idx"]) - # Record number of variables in each argument (these will flatten in the mlir) - iree_fcn.pytree_shape = [ - len(jax.tree_util.tree_flatten(arg)[0]) for arg in args - ] - # Record array length of each mlir variable - iree_fcn.pytree_sizes = [ - len(arg) for arg in jax.tree_util.tree_flatten(args)[0] - ] - iree_fcn.n_args = len(args) - return iree_fcn - - def _check_mlir_conversion(self, name, mlir: str): - if mlir.count("f64") > 0: # pragma: no cover - warnings.warn(f"f64 found in {name} (x{mlir.count('f64')})", stacklevel=2) - - def _demote_64_to_32(self, x: pybamm.EvaluatorJax): - return pybamm.EvaluatorJax._demote_64_to_32(x) - @property def supports_parallel_solve(self): return True @@ -762,13 +497,7 @@ def _integrate(self, model, t_eval, inputs_list=None, t_interp=None): The times (in seconds) at which to interpolate the solution. Defaults to `None`, which returns the adaptive time-stepping times. """ - if not ( - model.convert_to_format == "casadi" - or ( - model.convert_to_format == "jax" - and self._options["jax_evaluator"] == "iree" - ) - ): # pragma: no cover + if not (model.convert_to_format == "casadi"): # pragma: no cover # Shouldn't ever reach this point raise pybamm.SolverError("Unsupported IDAKLU solver configuration.") @@ -884,18 +613,10 @@ def _post_process_solution(self, sol, model, integration_time, inputs_dict): self._setup["var_fcns"][var](0.0, 0.0, 0.0).sparsity().nnz() ) base_variables = [self._setup["var_fcns"][var]] - elif ( - model.convert_to_format == "jax" - and self._options["jax_evaluator"] == "iree" - ): - idx = self.output_variables.index(var) - len_of_var = self._setup["var_idaklu_fcns"][idx].nnz - base_variables = [self._setup["var_idaklu_fcns"][idx]] else: # pragma: no cover raise pybamm.SolverError( "Unsupported evaluation engine for convert_to_format=" - + f"{model.convert_to_format} " - + f"(jax_evaluator={self._options['jax_evaluator']})" + + f"{model.convert_to_format}" ) newsol._variables[var] = pybamm.ProcessedVariableComputed( [model.variables_and_events[var]], @@ -933,10 +654,6 @@ def _set_consistent_initialization(self, model, time, inputs_dict): super()._set_consistent_initialization(model, time, inputs_dict) casadi_format = model.convert_to_format == "casadi" - jax_iree_format = ( - model.convert_to_format == "jax" - and self._options["jax_evaluator"] == "iree" - ) y0 = model.y0 if isinstance(y0, casadi.DM): @@ -952,7 +669,7 @@ def _set_consistent_initialization(self, model, time, inputs_dict): else: ydot0 = np.zeros_like(y0) - sensitivity = (model.y0S is not None) and (jax_iree_format or casadi_format) + sensitivity = (model.y0S is not None) and casadi_format if sensitivity: y0full, ydot0full = self._sensitivity_consistent_initialization( y0, ydot0, model, time, inputs_dict @@ -961,12 +678,6 @@ def _set_consistent_initialization(self, model, time, inputs_dict): y0full = y0 ydot0full = ydot0 - if jax_iree_format: - pybamm.demote_expressions_to_32bit = True - y0full = self._demote_64_to_32(y0full) - ydot0full = self._demote_64_to_32(ydot0full) - pybamm.demote_expressions_to_32bit = False - model.y0full = y0full model.ydot0full = ydot0full @@ -1034,19 +745,9 @@ def _sensitivity_consistent_initialization( Any input parameters to pass to the model when solving. """ - - jax_iree_format = ( - model.convert_to_format == "jax" - and self._options["jax_evaluator"] == "iree" - ) - y0S = model.y0S - if jax_iree_format: - inputs_dict = inputs_dict or {} - inputs_dict_keys = list(inputs_dict.keys()) - y0S = np.concatenate([y0S[k] for k in inputs_dict_keys]) - elif isinstance(y0S, casadi.DM): + if isinstance(y0S, casadi.DM): y0S = (y0S,) if isinstance(y0S[0], casadi.DM): diff --git a/src/pybamm/solvers/processed_variable_computed.py b/src/pybamm/solvers/processed_variable_computed.py index befe6314b6..4602de4017 100644 --- a/src/pybamm/solvers/processed_variable_computed.py +++ b/src/pybamm/solvers/processed_variable_computed.py @@ -126,11 +126,6 @@ def _unroll_nnz(self, realdata=None): nnz = sp.nnz() numel = sp.numel() row = sp.row() - elif "nnz" in dir(self.base_variables_casadi[0]): # IREE fcn - sp = self.base_variables_casadi[0] - nnz = sp.nnz - numel = sp.numel - row = sp.row if nnz != numel: data = [None] * len(realdata) for datak in range(len(realdata)): diff --git a/tests/shared.py b/tests/shared.py index 48e54e19d8..0fdc67069d 100644 --- a/tests/shared.py +++ b/tests/shared.py @@ -336,7 +336,7 @@ def no_internet_connection(): conn = socket.create_connection((host, 80), 2) conn.close() return False - except socket.gaierror: + except (socket.gaierror, TimeoutError): return True diff --git a/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py b/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py index 14b980b358..2aa7bcaf30 100644 --- a/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py +++ b/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py @@ -671,68 +671,6 @@ def test_evaluator_jax_inputs(self): result = evaluator(inputs={"a": 2}) assert result == 4 - @pytest.mark.skipif(not pybamm.has_jax(), reason="jax or jaxlib is not installed") - def test_evaluator_jax_demotion(self): - for demote in [True, False]: - pybamm.demote_expressions_to_32bit = demote # global flag - target_dtype = "32" if demote else "64" - if demote: - # Test only works after conversion to jax.numpy - for c in [ - 1.0, - 1, - ]: - assert ( - str(pybamm.EvaluatorJax._demote_64_to_32(c).dtype)[-2:] - == target_dtype - ) - for c in [ - np.float64(1.0), - np.int64(1), - np.array([1.0], dtype=np.float64), - np.array([1], dtype=np.int64), - jax.numpy.array([1.0], dtype=np.float64), - jax.numpy.array([1], dtype=np.int64), - ]: - assert ( - str(pybamm.EvaluatorJax._demote_64_to_32(c).dtype)[-2:] - == target_dtype - ) - for c in [ - {key: np.float64(1.0) for key in ["a", "b"]}, - ]: - expr_demoted = pybamm.EvaluatorJax._demote_64_to_32(c) - assert all( - str(c_v.dtype)[-2:] == target_dtype - for c_k, c_v in expr_demoted.items() - ) - for c in [ - (np.float64(1.0), np.float64(2.0)), - [np.float64(1.0), np.float64(2.0)], - ]: - expr_demoted = pybamm.EvaluatorJax._demote_64_to_32(c) - assert all(str(c_i.dtype)[-2:] == target_dtype for c_i in expr_demoted) - for dtype in [ - np.float64, - jax.numpy.float64, - ]: - c = pybamm.JaxCooMatrix([0, 1], [0, 1], dtype([1.0, 2.0]), (2, 2)) - c_demoted = pybamm.EvaluatorJax._demote_64_to_32(c) - assert all( - str(c_i.dtype)[-2:] == target_dtype for c_i in c_demoted.data - ) - for dtype in [ - np.int64, - jax.numpy.int64, - ]: - c = pybamm.JaxCooMatrix( - dtype([0, 1]), dtype([0, 1]), [1.0, 2.0], (2, 2) - ) - c_demoted = pybamm.EvaluatorJax._demote_64_to_32(c) - assert all(str(c_i.dtype)[-2:] == target_dtype for c_i in c_demoted.row) - assert all(str(c_i.dtype)[-2:] == target_dtype for c_i in c_demoted.col) - pybamm.demote_expressions_to_32bit = False - @pytest.mark.skipif(not pybamm.has_jax(), reason="jax or jaxlib is not installed") def test_jax_coo_matrix(self): A = pybamm.JaxCooMatrix([0, 1], [0, 1], [1.0, 2.0], (2, 2)) diff --git a/tests/unit/test_solvers/test_idaklu_solver.py b/tests/unit/test_solvers/test_idaklu_solver.py index e4d6559e71..6121b4b9af 100644 --- a/tests/unit/test_solvers/test_idaklu_solver.py +++ b/tests/unit/test_solvers/test_idaklu_solver.py @@ -18,49 +18,43 @@ def test_ida_roberts_klu(self): # this test implements a python version of the ida Roberts # example provided in sundials # see sundials ida examples pdf - for form in ["casadi", "iree"]: - if (form == "iree") and (not pybamm.has_jax() or not pybamm.has_iree()): - continue - if form == "casadi": - root_method = "casadi" - else: - root_method = "lm" - model = pybamm.BaseModel() - model.convert_to_format = "jax" if form == "iree" else form - u = pybamm.Variable("u") - v = pybamm.Variable("v") - model.rhs = {u: 0.1 * v} - model.algebraic = {v: 1 - v} - model.initial_conditions = {u: 0, v: 1} - model.events = [pybamm.Event("1", 0.2 - u), pybamm.Event("2", v)] + form = "casadi" + root_method = "casadi" + model = pybamm.BaseModel() + model.convert_to_format = form + u = pybamm.Variable("u") + v = pybamm.Variable("v") + model.rhs = {u: 0.1 * v} + model.algebraic = {v: 1 - v} + model.initial_conditions = {u: 0, v: 1} + model.events = [pybamm.Event("1", 0.2 - u), pybamm.Event("2", v)] - disc = pybamm.Discretisation() - disc.process_model(model) + disc = pybamm.Discretisation() + disc.process_model(model) - solver = pybamm.IDAKLUSolver( - root_method=root_method, - options={"jax_evaluator": "iree"} if form == "iree" else {}, - ) + solver = pybamm.IDAKLUSolver( + root_method=root_method, + ) - # Test - t_eval = [0, 3] - solution = solver.solve(model, t_eval) + # Test + t_eval = [0, 3] + solution = solver.solve(model, t_eval) - # test that final time is time of event - # y = 0.1 t + y0 so y=0.2 when t=2 - np.testing.assert_array_almost_equal(solution.t[-1], 2.0) + # test that final time is time of event + # y = 0.1 t + y0 so y=0.2 when t=2 + np.testing.assert_array_almost_equal(solution.t[-1], 2.0) - # test that final value is the event value - np.testing.assert_array_almost_equal(solution.y[0, -1], 0.2) + # test that final value is the event value + np.testing.assert_array_almost_equal(solution.y[0, -1], 0.2) - # test that y[1] remains constant - np.testing.assert_array_almost_equal( - solution.y[1, :], np.ones(solution.t.shape) - ) + # test that y[1] remains constant + np.testing.assert_array_almost_equal( + solution.y[1, :], np.ones(solution.t.shape) + ) - # test that y[0] = to true solution - true_solution = 0.1 * solution.t - np.testing.assert_array_almost_equal(solution.y[0, :], true_solution) + # test that y[0] = to true solution + true_solution = 0.1 * solution.t + np.testing.assert_array_almost_equal(solution.y[0, :], true_solution) def test_multiple_inputs(self): model = pybamm.BaseModel() @@ -104,488 +98,445 @@ def test_multiple_inputs(self): ) def test_model_events(self): - for form in ["casadi", "iree"]: - if (form == "iree") and (not pybamm.has_jax() or not pybamm.has_iree()): - continue - if form == "casadi": - root_method = "casadi" - else: - root_method = "lm" - # Create model - model = pybamm.BaseModel() - model.convert_to_format = "jax" if form == "iree" else form - var = pybamm.Variable("var") - model.rhs = {var: 0.1 * var} - model.initial_conditions = {var: 1} + form = "casadi" + root_method = "casadi" + # Create model + model = pybamm.BaseModel() + model.convert_to_format = form + var = pybamm.Variable("var") + model.rhs = {var: 0.1 * var} + model.initial_conditions = {var: 1} - # create discretisation - disc = pybamm.Discretisation() - model_disc = disc.process_model(model, inplace=False) - # Solve - solver = pybamm.IDAKLUSolver( - rtol=1e-8, - atol=1e-8, - root_method=root_method, - options={"jax_evaluator": "iree"} if form == "iree" else {}, - ) + # create discretisation + disc = pybamm.Discretisation() + model_disc = disc.process_model(model, inplace=False) + # Solve + solver = pybamm.IDAKLUSolver( + rtol=1e-8, + atol=1e-8, + root_method=root_method, + ) - t_interp = np.linspace(0, 1, 100) - t_eval = [t_interp[0], t_interp[-1]] + t_interp = np.linspace(0, 1, 100) + t_eval = [t_interp[0], t_interp[-1]] - solution = solver.solve(model_disc, t_eval, t_interp=t_interp) - np.testing.assert_array_equal( - solution.t, t_interp, err_msg=f"Failed for form {form}" - ) - np.testing.assert_array_almost_equal( - solution.y[0], - np.exp(0.1 * solution.t), - decimal=5, - err_msg=f"Failed for form {form}", - ) + solution = solver.solve(model_disc, t_eval, t_interp=t_interp) + np.testing.assert_array_equal( + solution.t, t_interp, err_msg=f"Failed for form {form}" + ) + np.testing.assert_array_almost_equal( + solution.y[0], + np.exp(0.1 * solution.t), + decimal=5, + err_msg=f"Failed for form {form}", + ) - # Check invalid atol type raises an error - with pytest.raises(pybamm.SolverError): - solver._check_atol_type({"key": "value"}, []) + # Check invalid atol type raises an error + with pytest.raises(pybamm.SolverError): + solver._check_atol_type({"key": "value"}, []) - # enforce events that won't be triggered - model.events = [pybamm.Event("an event", var + 1)] - model_disc = disc.process_model(model, inplace=False) - solver = pybamm.IDAKLUSolver( - rtol=1e-8, - atol=1e-8, - root_method=root_method, - options={"jax_evaluator": "iree"} if form == "iree" else {}, - ) - solution = solver.solve(model_disc, t_eval, t_interp=t_interp) - np.testing.assert_array_equal(solution.t, t_interp) - np.testing.assert_array_almost_equal( - solution.y[0], - np.exp(0.1 * solution.t), - decimal=5, - err_msg=f"Failed for form {form}", - ) + # enforce events that won't be triggered + model.events = [pybamm.Event("an event", var + 1)] + model_disc = disc.process_model(model, inplace=False) + solver = pybamm.IDAKLUSolver( + rtol=1e-8, + atol=1e-8, + root_method=root_method, + ) + solution = solver.solve(model_disc, t_eval, t_interp=t_interp) + np.testing.assert_array_equal(solution.t, t_interp) + np.testing.assert_array_almost_equal( + solution.y[0], + np.exp(0.1 * solution.t), + decimal=5, + err_msg=f"Failed for form {form}", + ) - # enforce events that will be triggered - model.events = [pybamm.Event("an event", 1.01 - var)] - model_disc = disc.process_model(model, inplace=False) - solver = pybamm.IDAKLUSolver( - rtol=1e-8, - atol=1e-8, - root_method=root_method, - options={"jax_evaluator": "iree"} if form == "iree" else {}, - ) - solution = solver.solve(model_disc, t_eval, t_interp=t_interp) - assert len(solution.t) < len(t_interp) - np.testing.assert_array_almost_equal( - solution.y[0], - np.exp(0.1 * solution.t), - decimal=5, - err_msg=f"Failed for form {form}", - ) + # enforce events that will be triggered + model.events = [pybamm.Event("an event", 1.01 - var)] + model_disc = disc.process_model(model, inplace=False) + solver = pybamm.IDAKLUSolver( + rtol=1e-8, + atol=1e-8, + root_method=root_method, + ) + solution = solver.solve(model_disc, t_eval, t_interp=t_interp) + assert len(solution.t) < len(t_interp) + np.testing.assert_array_almost_equal( + solution.y[0], + np.exp(0.1 * solution.t), + decimal=5, + err_msg=f"Failed for form {form}", + ) - # bigger dae model with multiple events - model = pybamm.BaseModel() - whole_cell = ["negative electrode", "separator", "positive electrode"] - var1 = pybamm.Variable("var1", domain=whole_cell) - var2 = pybamm.Variable("var2", domain=whole_cell) - model.rhs = {var1: 0.1 * var1} - model.algebraic = {var2: 2 * var1 - var2} - model.initial_conditions = {var1: 1, var2: 2} - model.events = [ - pybamm.Event("var1 = 1.5", pybamm.min(1.5 - var1)), - pybamm.Event("var2 = 2.5", pybamm.min(2.5 - var2)), - ] - disc = get_discretisation_for_testing() - disc.process_model(model) + # bigger dae model with multiple events + model = pybamm.BaseModel() + whole_cell = ["negative electrode", "separator", "positive electrode"] + var1 = pybamm.Variable("var1", domain=whole_cell) + var2 = pybamm.Variable("var2", domain=whole_cell) + model.rhs = {var1: 0.1 * var1} + model.algebraic = {var2: 2 * var1 - var2} + model.initial_conditions = {var1: 1, var2: 2} + model.events = [ + pybamm.Event("var1 = 1.5", pybamm.min(1.5 - var1)), + pybamm.Event("var2 = 2.5", pybamm.min(2.5 - var2)), + ] + disc = get_discretisation_for_testing() + disc.process_model(model) - solver = pybamm.IDAKLUSolver( - rtol=1e-8, - atol=1e-8, - root_method=root_method, - options={"jax_evaluator": "iree"} if form == "iree" else {}, - ) - t_eval = np.array([0, 5]) - solution = solver.solve(model, t_eval) - np.testing.assert_array_less(solution.y[0, :-1], 1.5) - np.testing.assert_array_less(solution.y[-1, :-1], 2.5) - np.testing.assert_equal(solution.t_event[0], solution.t[-1]) - np.testing.assert_array_equal(solution.y_event[:, 0], solution.y[:, -1]) - np.testing.assert_array_almost_equal( - solution.y[0], - np.exp(0.1 * solution.t), - decimal=5, - err_msg=f"Failed for form {form}", - ) - np.testing.assert_array_almost_equal( - solution.y[-1], - 2 * np.exp(0.1 * solution.t), - decimal=5, - err_msg=f"Failed for form {form}", - ) + solver = pybamm.IDAKLUSolver( + rtol=1e-8, + atol=1e-8, + root_method=root_method, + ) + t_eval = np.array([0, 5]) + solution = solver.solve(model, t_eval) + np.testing.assert_array_less(solution.y[0, :-1], 1.5) + np.testing.assert_array_less(solution.y[-1, :-1], 2.5) + np.testing.assert_equal(solution.t_event[0], solution.t[-1]) + np.testing.assert_array_equal(solution.y_event[:, 0], solution.y[:, -1]) + np.testing.assert_array_almost_equal( + solution.y[0], + np.exp(0.1 * solution.t), + decimal=5, + err_msg=f"Failed for form {form}", + ) + np.testing.assert_array_almost_equal( + solution.y[-1], + 2 * np.exp(0.1 * solution.t), + decimal=5, + err_msg=f"Failed for form {form}", + ) def test_input_params(self): # test a mix of scalar and vector input params - for form in ["casadi", "iree"]: - if (form == "iree") and (not pybamm.has_jax() or not pybamm.has_iree()): - continue - if form == "casadi": - root_method = "casadi" - else: - root_method = "lm" - model = pybamm.BaseModel() - model.convert_to_format = "jax" if form == "iree" else form - u1 = pybamm.Variable("u1") - u2 = pybamm.Variable("u2") - u3 = pybamm.Variable("u3") - v = pybamm.Variable("v") - a = pybamm.InputParameter("a") - b = pybamm.InputParameter("b", expected_size=2) - model.rhs = {u1: a * v, u2: pybamm.Index(b, 0), u3: pybamm.Index(b, 1)} - model.algebraic = {v: 1 - v} - model.initial_conditions = {u1: 0, u2: 0, u3: 0, v: 1} + form = "casadi" + root_method = "casadi" + model = pybamm.BaseModel() + model.convert_to_format = form + u1 = pybamm.Variable("u1") + u2 = pybamm.Variable("u2") + u3 = pybamm.Variable("u3") + v = pybamm.Variable("v") + a = pybamm.InputParameter("a") + b = pybamm.InputParameter("b", expected_size=2) + model.rhs = {u1: a * v, u2: pybamm.Index(b, 0), u3: pybamm.Index(b, 1)} + model.algebraic = {v: 1 - v} + model.initial_conditions = {u1: 0, u2: 0, u3: 0, v: 1} - disc = pybamm.Discretisation() - disc.process_model(model) + disc = pybamm.Discretisation() + disc.process_model(model) - solver = pybamm.IDAKLUSolver( - root_method=root_method, - options={"jax_evaluator": "iree"} if form == "iree" else {}, - ) + solver = pybamm.IDAKLUSolver( + root_method=root_method, + ) - t_interp = np.linspace(0, 3, 100) - t_eval = [t_interp[0], t_interp[-1]] - a_value = 0.1 - b_value = np.array([[0.2], [0.3]]) + t_interp = np.linspace(0, 3, 100) + t_eval = [t_interp[0], t_interp[-1]] + a_value = 0.1 + b_value = np.array([[0.2], [0.3]]) - sol = solver.solve( - model, - t_eval, - inputs={"a": a_value, "b": b_value}, - t_interp=t_interp, - ) + sol = solver.solve( + model, + t_eval, + inputs={"a": a_value, "b": b_value}, + t_interp=t_interp, + ) - # test that y[3] remains constant - np.testing.assert_array_almost_equal( - sol.y[3], np.ones(sol.t.shape), err_msg=f"Failed for form {form}" - ) + # test that y[3] remains constant + np.testing.assert_array_almost_equal( + sol.y[3], np.ones(sol.t.shape), err_msg=f"Failed for form {form}" + ) - # test that y[0] = to true solution - true_solution = a_value * sol.t - np.testing.assert_array_almost_equal( - sol.y[0], true_solution, err_msg=f"Failed for form {form}" - ) + # test that y[0] = to true solution + true_solution = a_value * sol.t + np.testing.assert_array_almost_equal( + sol.y[0], true_solution, err_msg=f"Failed for form {form}" + ) - # test that y[1:3] = to true solution - true_solution = b_value * sol.t - np.testing.assert_array_almost_equal( - sol.y[1:3], true_solution, err_msg=f"Failed for form {form}" - ) + # test that y[1:3] = to true solution + true_solution = b_value * sol.t + np.testing.assert_array_almost_equal( + sol.y[1:3], true_solution, err_msg=f"Failed for form {form}" + ) def test_sensitivities_initial_condition(self): - for form in ["casadi", "iree"]: - for output_variables in [[], ["2v"]]: - if (form == "iree") and (not pybamm.has_jax() or not pybamm.has_iree()): - continue - if form == "casadi": - root_method = "casadi" - else: - root_method = "lm" - model = pybamm.BaseModel() - model.convert_to_format = "jax" if form == "iree" else form - u = pybamm.Variable("u") - v = pybamm.Variable("v") - a = pybamm.InputParameter("a") - model.rhs = {u: -u} - model.algebraic = {v: a * u - v} - model.initial_conditions = {u: 1, v: 1} - model.variables = {"2v": 2 * v} - - disc = pybamm.Discretisation() - disc.process_model(model) - solver = pybamm.IDAKLUSolver( - rtol=1e-6, - atol=1e-6, - root_method=root_method, - output_variables=output_variables, - options={"jax_evaluator": "iree"} if form == "iree" else {}, - ) - - t_eval = [0, 3] - - a_value = 0.1 - - sol = solver.solve( - model, - t_eval, - inputs={"a": a_value}, - calculate_sensitivities=True, - ) - - np.testing.assert_array_almost_equal( - sol["2v"].sensitivities["a"].full().flatten(), - np.exp(-sol.t) * 2, - decimal=4, - err_msg=f"Failed for form {form}", - ) - - def test_ida_roberts_klu_sensitivities(self): - # this test implements a python version of the ida Roberts - # example provided in sundials - # see sundials ida examples pdf - for form in ["casadi", "iree"]: - if (form == "iree") and (not pybamm.has_jax() or not pybamm.has_iree()): - continue - if form == "casadi": - root_method = "casadi" - else: - root_method = "lm" + form = "casadi" + root_method = "casadi" + for output_variables in [[], ["2v"]]: model = pybamm.BaseModel() - model.convert_to_format = "jax" if form == "iree" else form + model.convert_to_format = form u = pybamm.Variable("u") v = pybamm.Variable("v") a = pybamm.InputParameter("a") - model.rhs = {u: a * v} - model.algebraic = {v: 1 - v} - model.initial_conditions = {u: 0, v: 1} - model.variables = {"2u": 2 * u} + model.rhs = {u: -u} + model.algebraic = {v: a * u - v} + model.initial_conditions = {u: 1, v: 1} + model.variables = {"2v": 2 * v} disc = pybamm.Discretisation() disc.process_model(model) - solver = pybamm.IDAKLUSolver( + rtol=1e-6, + atol=1e-6, root_method=root_method, - options={"jax_evaluator": "iree"} if form == "iree" else {}, + output_variables=output_variables, ) - t_interp = np.linspace(0, 3, 100) - t_eval = [t_interp[0], t_interp[-1]] + t_eval = [0, 3] + a_value = 0.1 - # solve first without sensitivities sol = solver.solve( model, t_eval, inputs={"a": a_value}, - t_interp=t_interp, + calculate_sensitivities=True, ) - # test that y[1] remains constant np.testing.assert_array_almost_equal( - sol.y[1, :], np.ones(sol.t.shape), err_msg=f"Failed for form {form}" + sol["2v"].sensitivities["a"].full().flatten(), + np.exp(-sol.t) * 2, + decimal=4, + err_msg=f"Failed for form {form}", ) - # test that y[0] = to true solution - true_solution = a_value * sol.t - np.testing.assert_array_almost_equal( - sol.y[0, :], true_solution, err_msg=f"Failed for form {form}" - ) + def test_ida_roberts_klu_sensitivities(self): + # this test implements a python version of the ida Roberts + # example provided in sundials + # see sundials ida examples pdf + form = "casadi" + root_method = "casadi" + model = pybamm.BaseModel() + model.convert_to_format = form + u = pybamm.Variable("u") + v = pybamm.Variable("v") + a = pybamm.InputParameter("a") + model.rhs = {u: a * v} + model.algebraic = {v: 1 - v} + model.initial_conditions = {u: 0, v: 1} + model.variables = {"2u": 2 * u} - # should be no sensitivities calculated - with pytest.raises(KeyError): - print(sol.sensitivities["a"]) + disc = pybamm.Discretisation() + disc.process_model(model) - # now solve with sensitivities (this should cause set_up to be run again) - sol = solver.solve( - model, - t_eval, - inputs={"a": a_value}, - calculate_sensitivities=True, - t_interp=t_interp, - ) + solver = pybamm.IDAKLUSolver( + root_method=root_method, + ) - # test that y[1] remains constant - np.testing.assert_array_almost_equal( - sol.y[1, :], np.ones(sol.t.shape), err_msg=f"Failed for form {form}" - ) + t_interp = np.linspace(0, 3, 100) + t_eval = [t_interp[0], t_interp[-1]] + a_value = 0.1 - # test that y[0] = to true solution - true_solution = a_value * sol.t - np.testing.assert_array_almost_equal( - sol.y[0, :], true_solution, err_msg=f"Failed for form {form}" - ) + # solve first without sensitivities + sol = solver.solve( + model, + t_eval, + inputs={"a": a_value}, + t_interp=t_interp, + ) - # evaluate the sensitivities using idas - dyda_ida = sol.sensitivities["a"] + # test that y[1] remains constant + np.testing.assert_array_almost_equal( + sol.y[1, :], np.ones(sol.t.shape), err_msg=f"Failed for form {form}" + ) - # evaluate the sensitivities using finite difference - h = 1e-6 - sol_plus = solver.solve( - model, t_eval, inputs={"a": a_value + 0.5 * h}, t_interp=t_interp - ) - sol_neg = solver.solve( - model, t_eval, inputs={"a": a_value - 0.5 * h}, t_interp=t_interp - ) - dyda_fd = (sol_plus.y - sol_neg.y) / h - dyda_fd = dyda_fd.transpose().reshape(-1, 1) + # test that y[0] = to true solution + true_solution = a_value * sol.t + np.testing.assert_array_almost_equal( + sol.y[0, :], true_solution, err_msg=f"Failed for form {form}" + ) - decimal = ( - 2 if form == "iree" else 6 - ) # iree currently operates with single precision - np.testing.assert_array_almost_equal( - dyda_ida, dyda_fd, decimal=decimal, err_msg=f"Failed for form {form}" - ) + # should be no sensitivities calculated + with pytest.raises(KeyError): + print(sol.sensitivities["a"]) - # get the sensitivities for the variable - d2uda = sol["2u"].sensitivities["a"] - np.testing.assert_array_almost_equal( - 2 * dyda_ida[0:200:2], - d2uda, - decimal=decimal, - err_msg=f"Failed for form {form}", - ) + # now solve with sensitivities (this should cause set_up to be run again) + sol = solver.solve( + model, + t_eval, + inputs={"a": a_value}, + calculate_sensitivities=True, + t_interp=t_interp, + ) + + # test that y[1] remains constant + np.testing.assert_array_almost_equal( + sol.y[1, :], np.ones(sol.t.shape), err_msg=f"Failed for form {form}" + ) + + # test that y[0] = to true solution + true_solution = a_value * sol.t + np.testing.assert_array_almost_equal( + sol.y[0, :], true_solution, err_msg=f"Failed for form {form}" + ) + + # evaluate the sensitivities using idas + dyda_ida = sol.sensitivities["a"] + + # evaluate the sensitivities using finite difference + h = 1e-6 + sol_plus = solver.solve( + model, t_eval, inputs={"a": a_value + 0.5 * h}, t_interp=t_interp + ) + sol_neg = solver.solve( + model, t_eval, inputs={"a": a_value - 0.5 * h}, t_interp=t_interp + ) + dyda_fd = (sol_plus.y - sol_neg.y) / h + dyda_fd = dyda_fd.transpose().reshape(-1, 1) + + decimal = 6 + np.testing.assert_array_almost_equal( + dyda_ida, dyda_fd, decimal=decimal, err_msg=f"Failed for form {form}" + ) + + # get the sensitivities for the variable + d2uda = sol["2u"].sensitivities["a"] + np.testing.assert_array_almost_equal( + 2 * dyda_ida[0:200:2], + d2uda, + decimal=decimal, + err_msg=f"Failed for form {form}", + ) def test_ida_roberts_consistent_initialization(self): # this test implements a python version of the ida Roberts # example provided in sundials # see sundials ida examples pdf - for form in ["casadi", "iree"]: - if (form == "iree") and (not pybamm.has_jax() or not pybamm.has_iree()): - continue - if form == "casadi": - root_method = "casadi" - else: - root_method = "lm" - model = pybamm.BaseModel() - model.convert_to_format = "jax" if form == "iree" else form - u = pybamm.Variable("u") - v = pybamm.Variable("v") - model.rhs = {u: 0.1 * v} - model.algebraic = {v: 1 - v} - model.initial_conditions = {u: 0, v: 2} + form = "casadi" + root_method = "casadi" + model = pybamm.BaseModel() + model.convert_to_format = form + u = pybamm.Variable("u") + v = pybamm.Variable("v") + model.rhs = {u: 0.1 * v} + model.algebraic = {v: 1 - v} + model.initial_conditions = {u: 0, v: 2} - disc = pybamm.Discretisation() - disc.process_model(model) + disc = pybamm.Discretisation() + disc.process_model(model) - solver = pybamm.IDAKLUSolver( - root_method=root_method, - options={"jax_evaluator": "iree"} if form == "iree" else {}, - ) + solver = pybamm.IDAKLUSolver( + root_method=root_method, + ) - # Set up and model consistently initializate the model - solver.set_up(model) - t0 = 0.0 - solver._set_consistent_initialization(model, t0, inputs_dict={}) + # Set up and model consistently initializate the model + solver.set_up(model) + t0 = 0.0 + solver._set_consistent_initialization(model, t0, inputs_dict={}) - # u(t0) = 0, v(t0) = 1 - np.testing.assert_array_almost_equal( - model.y0full, [0, 1], err_msg=f"Failed for form {form}" - ) - # u'(t0) = 0.1 * v(t0) = 0.1 - # Since v is algebraic, the initial derivative is set to 0 - np.testing.assert_array_almost_equal( - model.ydot0full, [0.1, 0], err_msg=f"Failed for form {form}" - ) + # u(t0) = 0, v(t0) = 1 + np.testing.assert_array_almost_equal( + model.y0full, [0, 1], err_msg=f"Failed for form {form}" + ) + # u'(t0) = 0.1 * v(t0) = 0.1 + # Since v is algebraic, the initial derivative is set to 0 + np.testing.assert_array_almost_equal( + model.ydot0full, [0.1, 0], err_msg=f"Failed for form {form}" + ) def test_sensitivities_with_events(self): # this test implements a python version of the ida Roberts # example provided in sundials # see sundials ida examples pdf - for form in ["casadi", "iree"]: - if (form == "iree") and (not pybamm.has_jax() or not pybamm.has_iree()): - continue - if form == "casadi": - root_method = "casadi" - else: - root_method = "lm" - model = pybamm.BaseModel() - model.convert_to_format = "jax" if form == "iree" else form - u = pybamm.Variable("u") - v = pybamm.Variable("v") - a = pybamm.InputParameter("a") - b = pybamm.InputParameter("b") - model.rhs = {u: a * v + b} - model.algebraic = {v: 1 - v} - model.initial_conditions = {u: 0, v: 1} - model.events = [pybamm.Event("1", 0.2 - u)] - - disc = pybamm.Discretisation() - disc.process_model(model) - - solver = pybamm.IDAKLUSolver( - root_method=root_method, - options={"jax_evaluator": "iree"} if form == "iree" else {}, - ) + form = "casadi" + root_method = "casadi" + model = pybamm.BaseModel() + model.convert_to_format = form + u = pybamm.Variable("u") + v = pybamm.Variable("v") + a = pybamm.InputParameter("a") + b = pybamm.InputParameter("b") + model.rhs = {u: a * v + b} + model.algebraic = {v: 1 - v} + model.initial_conditions = {u: 0, v: 1} + model.events = [pybamm.Event("1", 0.2 - u)] - t_interp = np.linspace(0, 3, 100) - t_eval = [t_interp[0], t_interp[-1]] + disc = pybamm.Discretisation() + disc.process_model(model) - a_value = 0.1 - b_value = 0.0 + solver = pybamm.IDAKLUSolver( + root_method=root_method, + ) - # solve first without sensitivities - sol = solver.solve( - model, - t_eval, - inputs={"a": a_value, "b": b_value}, - calculate_sensitivities=True, - t_interp=t_interp, - ) + t_interp = np.linspace(0, 3, 100) + t_eval = [t_interp[0], t_interp[-1]] - # test that y[1] remains constant - np.testing.assert_array_almost_equal( - sol.y[1, :], np.ones(sol.t.shape), err_msg=f"Failed for form {form}" - ) + a_value = 0.1 + b_value = 0.0 - # test that y[0] = to true solution - true_solution = a_value * sol.t - np.testing.assert_array_almost_equal( - sol.y[0, :], true_solution, err_msg=f"Failed for form {form}" - ) + # solve first without sensitivities + sol = solver.solve( + model, + t_eval, + inputs={"a": a_value, "b": b_value}, + calculate_sensitivities=True, + t_interp=t_interp, + ) - # evaluate the sensitivities using idas - dyda_ida = sol.sensitivities["a"] - dydb_ida = sol.sensitivities["b"] + # test that y[1] remains constant + np.testing.assert_array_almost_equal( + sol.y[1, :], np.ones(sol.t.shape), err_msg=f"Failed for form {form}" + ) - # evaluate the sensitivities using finite difference - h = 1e-6 - sol_plus = solver.solve( - model, - t_eval, - inputs={"a": a_value + 0.5 * h, "b": b_value}, - t_interp=t_interp, - ) - sol_neg = solver.solve( - model, - t_eval, - inputs={"a": a_value - 0.5 * h, "b": b_value}, - t_interp=t_interp, - ) - max_index = min(sol_plus.y.shape[1], sol_neg.y.shape[1]) - 1 - dyda_fd = (sol_plus.y[:, :max_index] - sol_neg.y[:, :max_index]) / h - dyda_fd = dyda_fd.transpose().reshape(-1, 1) + # test that y[0] = to true solution + true_solution = a_value * sol.t + np.testing.assert_array_almost_equal( + sol.y[0, :], true_solution, err_msg=f"Failed for form {form}" + ) - decimal = ( - 2 if form == "iree" else 6 - ) # iree currently operates with single precision - np.testing.assert_array_almost_equal( - dyda_ida[: (2 * max_index), :], - dyda_fd, - decimal=decimal, - err_msg=f"Failed for form {form}", - ) + # evaluate the sensitivities using idas + dyda_ida = sol.sensitivities["a"] + dydb_ida = sol.sensitivities["b"] - sol_plus = solver.solve( - model, - t_eval, - inputs={"a": a_value, "b": b_value + 0.5 * h}, - t_interp=t_interp, - ) - sol_neg = solver.solve( - model, - t_eval, - inputs={"a": a_value, "b": b_value - 0.5 * h}, - t_interp=t_interp, - ) - max_index = min(sol_plus.y.shape[1], sol_neg.y.shape[1]) - 1 - dydb_fd = (sol_plus.y[:, :max_index] - sol_neg.y[:, :max_index]) / h - dydb_fd = dydb_fd.transpose().reshape(-1, 1) + # evaluate the sensitivities using finite difference + h = 1e-6 + sol_plus = solver.solve( + model, + t_eval, + inputs={"a": a_value + 0.5 * h, "b": b_value}, + t_interp=t_interp, + ) + sol_neg = solver.solve( + model, + t_eval, + inputs={"a": a_value - 0.5 * h, "b": b_value}, + t_interp=t_interp, + ) + max_index = min(sol_plus.y.shape[1], sol_neg.y.shape[1]) - 1 + dyda_fd = (sol_plus.y[:, :max_index] - sol_neg.y[:, :max_index]) / h + dyda_fd = dyda_fd.transpose().reshape(-1, 1) + + decimal = 6 + np.testing.assert_array_almost_equal( + dyda_ida[: (2 * max_index), :], + dyda_fd, + decimal=decimal, + err_msg=f"Failed for form {form}", + ) - np.testing.assert_array_almost_equal( - dydb_ida[: (2 * max_index), :], - dydb_fd, - decimal=decimal, - err_msg=f"Failed for form {form}", - ) + sol_plus = solver.solve( + model, + t_eval, + inputs={"a": a_value, "b": b_value + 0.5 * h}, + t_interp=t_interp, + ) + sol_neg = solver.solve( + model, + t_eval, + inputs={"a": a_value, "b": b_value - 0.5 * h}, + t_interp=t_interp, + ) + max_index = min(sol_plus.y.shape[1], sol_neg.y.shape[1]) - 1 + dydb_fd = (sol_plus.y[:, :max_index] - sol_neg.y[:, :max_index]) / h + dydb_fd = dydb_fd.transpose().reshape(-1, 1) + + np.testing.assert_array_almost_equal( + dydb_ida[: (2 * max_index), :], + dydb_fd, + decimal=decimal, + err_msg=f"Failed for form {form}", + ) def test_failures(self): # this test implements a python version of the ida Roberts @@ -639,34 +590,28 @@ def test_failures(self): solver.solve(model, t_eval) def test_dae_solver_algebraic_model(self): - for form in ["casadi", "iree"]: - if (form == "iree") and (not pybamm.has_jax() or not pybamm.has_iree()): - continue - if form == "casadi": - root_method = "casadi" - else: - root_method = "lm" - model = pybamm.BaseModel() - model.convert_to_format = "jax" if form == "iree" else form - var = pybamm.Variable("var") - model.algebraic = {var: var + 1} - model.initial_conditions = {var: 0} + form = "casadi" + root_method = "casadi" + model = pybamm.BaseModel() + model.convert_to_format = form + var = pybamm.Variable("var") + model.algebraic = {var: var + 1} + model.initial_conditions = {var: 0} - disc = pybamm.Discretisation() - disc.process_model(model) + disc = pybamm.Discretisation() + disc.process_model(model) - solver = pybamm.IDAKLUSolver( - root_method=root_method, - options={"jax_evaluator": "iree"} if form == "iree" else {}, - ) - t_eval = [0, 1] - solution = solver.solve(model, t_eval) - np.testing.assert_array_equal(solution.y, -1) + solver = pybamm.IDAKLUSolver( + root_method=root_method, + ) + t_eval = [0, 1] + solution = solver.solve(model, t_eval) + np.testing.assert_array_equal(solution.y, -1) - # change initial_conditions and re-solve (to test if ics_only works) - model.concatenated_initial_conditions = pybamm.Vector(np.array([[1]])) - solution = solver.solve(model, t_eval) - np.testing.assert_array_equal(solution.y, -1) + # change initial_conditions and re-solve (to test if ics_only works) + model.concatenated_initial_conditions = pybamm.Vector(np.array([[1]])) + solution = solver.solve(model, t_eval) + np.testing.assert_array_equal(solution.y, -1) def test_banded(self): model = pybamm.lithium_ion.SPM() @@ -950,113 +895,90 @@ def test_with_output_variables_and_sensitivities(self): # Construct a model and solve for all variables, then test # the 'output_variables' option for each variable in turn, confirming # equivalence + form = "casadi" + root_method = "casadi" + input_parameters = { # Sensitivities dictionary + "Current function [A]": 0.222, + "Separator porosity": 0.3, + } - for form in ["casadi", "iree"]: - if (form == "iree") and (not pybamm.has_jax() or not pybamm.has_iree()): - continue - if form == "casadi": - root_method = "casadi" - else: - root_method = "lm" - input_parameters = { # Sensitivities dictionary - "Current function [A]": 0.222, - "Separator porosity": 0.3, - } - - # construct model - model = pybamm.lithium_ion.DFN() - model.convert_to_format = "jax" if form == "iree" else form - geometry = model.default_geometry - param = model.default_parameter_values - param.update({key: "[input]" for key in input_parameters}) - param.process_model(model) - param.process_geometry(geometry) - var_pts = {"x_n": 50, "x_s": 50, "x_p": 50, "r_n": 5, "r_p": 5} - mesh = pybamm.Mesh(geometry, model.default_submesh_types, var_pts) - disc = pybamm.Discretisation(mesh, model.default_spatial_methods) - disc.process_model(model) - - t_interp = np.linspace(0, 100, 5) - t_eval = [t_interp[0], t_interp[-1]] + # construct model + model = pybamm.lithium_ion.DFN() + model.convert_to_format = form + geometry = model.default_geometry + param = model.default_parameter_values + param.update({key: "[input]" for key in input_parameters}) + param.process_model(model) + param.process_geometry(geometry) + var_pts = {"x_n": 50, "x_s": 50, "x_p": 50, "r_n": 5, "r_p": 5} + mesh = pybamm.Mesh(geometry, model.default_submesh_types, var_pts) + disc = pybamm.Discretisation(mesh, model.default_spatial_methods) + disc.process_model(model) - options = { - "linear_solver": "SUNLinSol_KLU", - "jacobian": "sparse", - "num_threads": 4, - "max_num_steps": 1000, - } - if form == "iree": - options["jax_evaluator"] = "iree" - - # Use a selection of variables of different types - output_variables = [ - "Voltage [V]", - "Time [min]", - "x [m]", - "Negative particle flux [mol.m-2.s-1]", - "Throughput capacity [A.h]", # ExplicitTimeIntegral - ] - - # Use the full model as comparison (tested separately) - solver_all = pybamm.IDAKLUSolver( - root_method=root_method, - atol=1e-8 if form != "iree" else 1e-0, # iree has reduced precision - rtol=1e-8 if form != "iree" else 1e-0, # iree has reduced precision - options=options, - ) - sol_all = solver_all.solve( - model, - t_eval, - inputs=input_parameters, - calculate_sensitivities=True, - t_interp=t_interp, - ) + t_interp = np.linspace(0, 100, 5) + t_eval = [t_interp[0], t_interp[-1]] - # Solve for a subset of variables and compare results - solver = pybamm.IDAKLUSolver( - root_method=root_method, - atol=1e-8 if form != "iree" else 1e-0, # iree has reduced precision - rtol=1e-8 if form != "iree" else 1e-0, # iree has reduced precision - options=options, - output_variables=output_variables, - ) - sol = solver.solve( - model, - t_eval, - inputs=input_parameters, - calculate_sensitivities=True, - t_interp=t_interp, - ) + options = { + "linear_solver": "SUNLinSol_KLU", + "jacobian": "sparse", + "num_threads": 4, + "max_num_steps": 1000, + } - # Compare output to sol_all - tol = 1e-5 if form != "iree" else 1e-2 # iree has reduced precision - for varname in output_variables: - np.testing.assert_array_almost_equal( - sol[varname](t_interp), - sol_all[varname](t_interp), - tol, - err_msg=f"Failed for {varname} with form {form}", - ) + # Use a selection of variables of different types + output_variables = [ + "Voltage [V]", + "Time [min]", + "x [m]", + "Negative particle flux [mol.m-2.s-1]", + "Throughput capacity [A.h]", # ExplicitTimeIntegral + ] - # Mock a 1D current collector and initialise (none in the model) - sol["x_s [m]"].domain = ["current collector"] - sol["x_s [m]"].entries + # Use the full model as comparison (tested separately) + solver_all = pybamm.IDAKLUSolver( + root_method=root_method, + atol=1e-8, + rtol=1e-8, + options=options, + ) + sol_all = solver_all.solve( + model, + t_eval, + inputs=input_parameters, + calculate_sensitivities=True, + t_interp=t_interp, + ) - def test_bad_jax_evaluator(self): - model = pybamm.lithium_ion.DFN() - model.convert_to_format = "jax" - with pytest.raises(pybamm.SolverError): - pybamm.IDAKLUSolver(options={"jax_evaluator": "bad_evaluator"}) + # Solve for a subset of variables and compare results + solver = pybamm.IDAKLUSolver( + root_method=root_method, + atol=1e-8, + rtol=1e-8, + options=options, + output_variables=output_variables, + ) + sol = solver.solve( + model, + t_eval, + inputs=input_parameters, + calculate_sensitivities=True, + t_interp=t_interp, + ) - def test_bad_jax_evaluator_output_variables(self): - model = pybamm.lithium_ion.DFN() - model.convert_to_format = "jax" - with pytest.raises(pybamm.SolverError): - pybamm.IDAKLUSolver( - options={"jax_evaluator": "bad_evaluator"}, - output_variables=["Terminal voltage [V]"], + # Compare output to sol_all + tol = 1e-5 + for varname in output_variables: + np.testing.assert_array_almost_equal( + sol[varname](t_interp), + sol_all[varname](t_interp), + tol, + err_msg=f"Failed for {varname} with form {form}", ) + # Mock a 1D current collector and initialise (none in the model) + sol["x_s [m]"].domain = ["current collector"] + sol["x_s [m]"].entries + def test_with_output_variables_and_event_termination(self): model = pybamm.lithium_ion.DFN() parameter_values = pybamm.ParameterValues("Chen2020") @@ -1145,7 +1067,7 @@ def experiment_setup(period=None): ) def test_python_idaklu_deprecation_errors(self): - for form in ["python", "", "jax"]: + for form in ["python", "jax"]: if form == "jax" and not pybamm.has_jax(): continue @@ -1174,13 +1096,13 @@ def test_python_idaklu_deprecation_errors(self): ): with pytest.raises( DeprecationWarning, - match="The python-idaklu solver has been deprecated.", + match="The python-idaklu and IREE solvers have been deprecated.", ): _ = solver.solve(model, t_eval) elif form == "jax": with pytest.raises( pybamm.SolverError, - match="Unsupported evaluation engine for convert_to_format=jax", + match="Unsupported option for convert_to_format=jax", ): _ = solver.solve(model, t_eval)