Skip to content

Commit

Permalink
add examples for fitting with Adam optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Sep 6, 2024
1 parent f4bdf83 commit d2fdb1a
Show file tree
Hide file tree
Showing 7 changed files with 275 additions and 422 deletions.
21 changes: 13 additions & 8 deletions dendritex/_integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from __future__ import annotations

from typing import Optional, Tuple, Callable, Dict
from typing import Optional, Tuple, Callable, Dict, Union

import brainstate as bst
import brainunit as u
Expand Down Expand Up @@ -72,12 +72,17 @@ def _diffrax_solve(
atol: Optional[float] = None,
max_steps: int = None,
):
if adjoint == 'adjoint':
adjoint = dfx.BacksolveAdjoint()
elif adjoint == 'checkpoint':
adjoint = dfx.RecursiveCheckpointAdjoint()
elif adjoint == 'direct':
adjoint = dfx.DirectAdjoint()
if isinstance(adjoint, str):
if adjoint == 'adjoint':
adjoint = dfx.BacksolveAdjoint()
elif adjoint == 'checkpoint':
adjoint = dfx.RecursiveCheckpointAdjoint()
elif adjoint == 'direct':
adjoint = dfx.DirectAdjoint()
else:
raise ValueError(f"Unknown adjoint method: {adjoint}. Only support 'checkpoint', 'direct', and 'adjoint'.")
elif isinstance(adjoint, dfx.AbstractAdjoint):
adjoint = adjoint
else:
raise ValueError(f"Unknown adjoint method: {adjoint}. Only support 'checkpoint', 'direct', and 'adjoint'.")

Expand Down Expand Up @@ -280,7 +285,7 @@ def diffrax_solve(
rtol: Optional[float] = None,
atol: Optional[float] = None,
max_steps: Optional[int] = None,
adjoint: str = 'checkpoint',
adjoint: Union[str, dfx.AbstractAdjoint] = 'checkpoint',
) -> Tuple[u.Quantity, bst.typing.PyTree[u.Quantity], Dict]:
"""
Solve the differential equations using `diffrax <https://docs.kidger.site/diffrax>`_.
Expand Down
2 changes: 1 addition & 1 deletion dendritex/channels/potassium_calcium.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def p_vdep(self, V):
return bu.math.exp((V + 70.) / 27.)

def p_concdep(self, Ca):
# concdep_1 = 500 * (0.015 - Ca.C / bu.mM) / (bu.math.exp((0.015 - Ca.C / bu.mM) / 0.0013) - 1)
# concdep_1 = 500 * (0.015 - Ca.C / u.mM) / (u.math.exp((0.015 - Ca.C / u.mM) / 0.0013) - 1)
concdep_1 = 500 * 0.0013 / bu.math.exprel((0.015 - Ca.C / bu.mM) / 0.0013)
with jax.ensure_compile_time_eval():
concdep_2 = 500 * 0.005 / (bu.math.exp(0.005 / 0.0013) - 1)
Expand Down
214 changes: 0 additions & 214 deletions examples/fitting_simple_dendrite_model.py

This file was deleted.

Loading

0 comments on commit d2fdb1a

Please sign in to comment.