diff --git a/dendritex/_base.py b/dendritex/_base.py index 5f9a5cd..7a2802a 100644 --- a/dendritex/_base.py +++ b/dendritex/_base.py @@ -27,575 +27,559 @@ from ._misc import set_module_as __all__ = [ - 'DendriticDynamics', - 'State4Integral', - 'HHTypedNeuron', - 'IonChannel', - 'Ion', - 'MixIons', - 'Channel', - - 'IonInfo', + 'State4Integral', + 'HHTypedNeuron', + 'IonChannel', + 'Ion', + 'MixIons', + 'Channel', + 'IonInfo', ] - -# -# - DendriticDynamics -# - HHTypedNeuron -# - SingleCompartment -# - IonChannel -# - Ion -# - Calcium -# - Potassium -# - Sodium -# - MixIons -# - Channel -# +''' +- HHTypedNeuron + - SingleCompartment +- IonChannel + - Ion + - Calcium + - Potassium + - Sodium + - MixIons + - Channel +''' class State4Integral(bst.ShortTermState): - """ - A state that integrates the state of the system to the integral of the state. - - Attributes - ---------- - derivative: The derivative of the state. - - """ - - __module__ = 'dendritex' - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.derivative = None - - -class DendriticDynamics(bst.Dynamics): - """ - Base class for dendritic dynamics. - - Attributes: - size: The size of the simulation target. - pop_size: The size of the population, storing the number of neurons in each population. - n_compartment: The number of compartments in each neuron. - varshape: The shape of the state variables. - """ - __module__ = 'dendritex' - - def __init__( - self, - size: bst.typing.Size, - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): - # size - if isinstance(size, (list, tuple)): - if len(size) <= 0: - raise ValueError(f'size must be int, or a tuple/list of int. ' - f'But we got {type(size)}') - if not isinstance(size[0], (int, np.integer)): - raise ValueError('size must be int, or a tuple/list of int.' - f'But we got {type(size)}') - size = tuple(size) - elif isinstance(size, (int, np.integer)): - size = (size,) - else: - raise ValueError('size must be int, or a tuple/list of int.' - f'But we got {type(size)}') - self.size = size - assert len(size) >= 1, 'The size of the dendritic dynamics should be at least 1D: (..., n_neuron, n_compartment).' - self.pop_size: Tuple[int, ...] = size[:-1] - self.n_compartment: int = size[-1] - - # -- Attribute for "InputProjMixIn" -- # - # each instance of "SupportInputProj" should have - # "_current_inputs" and "_delta_inputs" attributes - self._current_inputs: Optional[Dict[str, Callable]] = None - self._delta_inputs: Optional[Dict[str, Callable]] = None - - # initialize - super().__init__(size, name=name, mode=mode, keep_size=True) - - def current(self, *args, **kwargs): - raise NotImplementedError('Must be implemented by the subclass.') - - def before_integral(self, *args, **kwargs): - raise NotImplementedError - - def compute_derivative(self, *args, **kwargs): - raise NotImplementedError('Must be implemented by the subclass.') - - def after_integral(self, *args, **kwargs): - raise NotImplementedError - - def init_state(self, *args, **kwargs): - raise NotImplementedError - - def reset_state(self, *args, **kwargs): - raise NotImplementedError + """ + A state that integrates the state of the system to the integral of the state. + Attributes + ---------- + derivative: The derivative of the state. -class Container(bst.mixin.Mixin): - __module__ = 'dendritex' - - _container_name: str - - @staticmethod - def _get_elem_name(elem): - if isinstance(elem, bst.Module): - return elem.name - else: - return bst.util.get_unique_name('ContainerElem') - - @staticmethod - def _format_elements(child_type: type, *children_as_tuple, **children_as_dict): - res = dict() - - # add tuple-typed components - for module in children_as_tuple: - if isinstance(module, child_type): - res[Container._get_elem_name(module)] = module - elif isinstance(module, (list, tuple)): - for m in module: - if not isinstance(m, child_type): - raise TypeError(f'Should be instance of {child_type.__name__}. ' - f'But we got {type(m)}') - res[Container._get_elem_name(m)] = m - elif isinstance(module, dict): - for k, v in module.items(): - if not isinstance(v, child_type): - raise TypeError(f'Should be instance of {child_type.__name__}. ' - f'But we got {type(v)}') - res[k] = v - else: - raise TypeError(f'Cannot parse sub-systems. They should be {child_type.__name__} ' - f'or a list/tuple/dict of {child_type.__name__}.') - # add dict-typed components - for k, v in children_as_dict.items(): - if not isinstance(v, child_type): - raise TypeError(f'Should be instance of {child_type.__name__}. ' - f'But we got {type(v)}') - res[k] = v - return res - - def __getitem__(self, item): - """Overwrite the slice access (`self['']`). """ - children = self.__getattr__(self._container_name) - if item in children: - return children[item] - else: - raise ValueError(f'Unknown item {item}, we only found {list(children.keys())}') - - def __getattr__(self, item): - """Overwrite the dot access (`self.`). """ - name = super().__getattribute__('_container_name') - if item == '_container_name': - return name - children = super().__getattribute__(name) - if item == name: - return children - if item in children: - return children[item] - else: - return super().__getattribute__(item) - - def add_elem(self, *elems, **elements): """ - Add new elements. - Args: - elements: children objects. - """ - raise NotImplementedError('Must be implemented by the subclass.') + __module__ = 'dendritex' + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.derivative = None -class TreeNode(bst.mixin.Mixin): - __module__ = 'dendritex' - - root_type: type - - @staticmethod - def _root_leaf_pair_check(root: type, leaf: 'TreeNode'): - if hasattr(leaf, 'root_type'): - root_type = leaf.root_type - else: - raise ValueError('Child class should define "root_type" to ' - 'specify the type of the root node. ' - f'But we did not found it in {leaf}') - if not issubclass(root, root_type): - raise TypeError(f'Type does not match. {leaf} requires a root with type ' - f'of {leaf.root_type}, but the root now is {root}.') - - @staticmethod - def check_hierarchies(root: type, *leaves, check_fun: Callable = None, **named_leaves): - if check_fun is None: - check_fun = TreeNode._root_leaf_pair_check - - for leaf in leaves: - if isinstance(leaf, bst.Module): - check_fun(root, leaf) - elif isinstance(leaf, (list, tuple)): - TreeNode.check_hierarchies(root, *leaf, check_fun=check_fun) - elif isinstance(leaf, dict): - TreeNode.check_hierarchies(root, **leaf, check_fun=check_fun) - else: - raise ValueError(f'Do not support {type(leaf)}.') - for leaf in named_leaves.values(): - if not isinstance(leaf, bst.Module): - raise ValueError(f'Do not support {type(leaf)}. Must be instance of {bst.Module}') - check_fun(root, leaf) - - -class HHTypedNeuron(DendriticDynamics, Container): - """ - The base class for the Hodgkin-Huxley typed neuronal membrane dynamics. - """ - __module__ = 'dendritex' - _container_name = 'ion_channels' - - def __init__( - self, - size: bst.typing.Size, - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - **ion_channels - ): - super().__init__(size, mode=mode, name=name) - - # attribute for ``Container`` - self.ion_channels = bst.visible_module_dict(self._format_elements(IonChannel, **ion_channels)) - - def init_state(self, batch_size=None): - nodes = self.nodes(level=1, include_self=False).subset(IonChannel).values() - TreeNode.check_hierarchies(self.__class__, *nodes) - for channel in nodes: - channel.init_state(self.V.value, batch_size=batch_size) - - def reset_state(self, batch_size=None): - nodes = self.nodes(level=1, include_self=False).subset(IonChannel).values() - for channel in nodes: - channel.reset_state(self.V.value, batch_size=batch_size) - - def add_elem(self, *elems, **elements): - """ - Add new elements. - Args: - elements: children objects. - """ - TreeNode.check_hierarchies(type(self), *elems, **elements) - self.ion_channels.update(self._format_elements(object, *elems, **elements)) +class DendriteDynamics(bst.mixin.Mixin): + def current(self, *args, **kwargs): + raise NotImplementedError -class IonChannel(DendriticDynamics, TreeNode): - """ - The base class for ion channel modeling. + def before_integral(self, *args, **kwargs): + pass - :py:class:`IonChannel` can be used to model the dynamics of an ion (instance of :py:class:`Ion`), or - a mixture of ions (instance of :py:class:`MixIons`), or a channel (instance of :py:class:`Channel`). + def compute_derivative(self, *args, **kwargs): + raise NotImplementedError - Particularly, an implementation of a :py:class:`IonChannel` should implement the following methods: + def post_derivative(self, *args, **kwargs): + pass - - :py:meth:`current`: Calculate the current of the ion channel. - - :py:meth:`before_integral`: Calculate the state variables before the integral. - - :py:meth:`compute_derivative`: Calculate the derivative of the state variables. - - :py:meth:`after_integral`: Calculate the state variables after the integral. - - :py:meth:`init_state`: Initialize the state variables. - - :py:meth:`reset_state`: Reset the state variables. + def reset_state(self, *args, **kwargs): + pass - """ - __module__ = 'dendritex' + def init_state(self, *args, **kwargs): + pass - def current(self, *args, **kwargs): - raise NotImplementedError - def before_integral(self, *args, **kwargs): - pass +class Container(bst.mixin.Mixin): + __module__ = 'dendritex' + + _container_name: str + + @staticmethod + def _format_elements(child_type: type, **children_as_dict): + res = dict() + + # add dict-typed components + for k, v in children_as_dict.items(): + if not isinstance(v, child_type): + raise TypeError(f'Should be instance of {child_type.__name__}. ' + f'But we got {type(v)}') + res[k] = v + return res + + def __getitem__(self, item): + """ + Overwrite the slice access (`self['']`). + """ + children = self.__getattr__(self._container_name) + if item in children: + return children[item] + else: + raise ValueError(f'Unknown item {item}, we only found {list(children.keys())}') + + def __getattr__(self, item): + """ + Overwrite the dot access (`self.`). + """ + name = super().__getattribute__('_container_name') + if item == '_container_name': + return name + children = super().__getattribute__(name) + if item == name: + return children + if item in children: + return children[item] + else: + return super().__getattribute__(item) + + def add_elem(self, *elems, **elements): + """ + Add new elements. + + Args: + elements: children objects. + """ + raise NotImplementedError('Must be implemented by the subclass.') - def compute_derivative(self, *args, **kwargs): - raise NotImplementedError - def after_integral(self, *args, **kwargs): - pass +class TreeNode(bst.mixin.Mixin): + __module__ = 'dendritex' + + root_type: type + + @staticmethod + def _root_leaf_pair_check(root: type, leaf: 'TreeNode'): + if hasattr(leaf, 'root_type'): + root_type = leaf.root_type + else: + raise ValueError('Child class should define "root_type" to ' + 'specify the type of the root node. ' + f'But we did not found it in {leaf}') + if not issubclass(root, root_type): + raise TypeError(f'Type does not match. {leaf} requires a root with type ' + f'of {leaf.root_type}, but the root now is {root}.') + + @staticmethod + def check_hierarchies(root: type, *leaves, check_fun: Callable = None, **named_leaves): + if check_fun is None: + check_fun = TreeNode._root_leaf_pair_check + + for leaf in leaves: + if isinstance(leaf, bst.graph.Node): + check_fun(root, leaf) + elif isinstance(leaf, (list, tuple)): + TreeNode.check_hierarchies(root, *leaf, check_fun=check_fun) + elif isinstance(leaf, dict): + TreeNode.check_hierarchies(root, **leaf, check_fun=check_fun) + else: + raise ValueError(f'Do not support {type(leaf)}.') + for leaf in named_leaves.values(): + if not isinstance(leaf, bst.graph.Node): + raise ValueError(f'Do not support {type(leaf)}. Must be instance of {bst.graph.Node}') + check_fun(root, leaf) + + +class HHTypedNeuron(bst.nn.Dynamics, Container, DendriteDynamics): + """ + The base class for the Hodgkin-Huxley typed neuronal membrane dynamics. + """ + __module__ = 'dendritex' + _container_name = 'ion_channels' + + def __init__( + self, + size: bst.typing.Size, + name: Optional[str] = None, + **ion_channels + ): + # size + if isinstance(size, (list, tuple)): + if len(size) <= 0: + raise ValueError(f'size must be int, or a tuple/list of int. ' + f'But we got {type(size)}') + if not isinstance(size[0], (int, np.integer)): + raise ValueError('size must be int, or a tuple/list of int.' + f'But we got {type(size)}') + size = tuple(size) + elif isinstance(size, (int, np.integer)): + size = (size,) + else: + raise ValueError('size must be int, or a tuple/list of int.' + f'But we got {type(size)}') + self.size = size + assert len(size) >= 1, ('The size of the dendritic dynamics should be at ' + 'least 1D: (..., n_neuron, n_compartment).') + self.pop_size: Tuple[int, ...] = size[:-1] + self.n_compartment: int = size[-1] + + # initialize + super().__init__(size, name=name) + + # attribute for ``Container`` + self.ion_channels = self._format_elements(IonChannel, **ion_channels) + + def current(self, *args, **kwargs): + raise NotImplementedError('Must be implemented by the subclass.') + + def before_integral(self, *args, **kwargs): + raise NotImplementedError + + def compute_derivative(self, *args, **kwargs): + raise NotImplementedError('Must be implemented by the subclass.') + + def post_derivative(self, *args, **kwargs): + raise NotImplementedError('Must be implemented by the subclass.') + + def init_state(self, batch_size=None): + nodes = self.nodes(allowed_hierarchy=(1, 1)).filter(IonChannel).values() + TreeNode.check_hierarchies(self.__class__, *nodes) + for channel in nodes: + channel.init_state(self.V.value, batch_size=batch_size) + + def reset_state(self, batch_size=None): + nodes = self.nodes(allowed_hierarchy=(1, 1)).filter(IonChannel).values() + for channel in nodes: + channel.reset_state(self.V.value, batch_size=batch_size) + + def add_elem(self, **elements): + """ + Add new elements. + + Args: + elements: children objects. + """ + TreeNode.check_hierarchies(type(self), **elements) + self.ion_channels.update(self._format_elements(IonChannel, **elements)) + + +class IonChannel(bst.graph.Node, TreeNode, DendriteDynamics): + """ + The base class for ion channel modeling. - def reset_state(self, *args, **kwargs): - pass + :py:class:`IonChannel` can be used to model the dynamics of an ion (instance of :py:class:`Ion`), or + a mixture of ions (instance of :py:class:`MixIons`), or a channel (instance of :py:class:`Channel`). - def init_state(self, *args, **kwargs): - pass + Particularly, an implementation of a :py:class:`IonChannel` should implement the following methods: + - :py:meth:`current`: Calculate the current of the ion channel. + - :py:meth:`before_integral`: Calculate the state variables before the integral. + - :py:meth:`compute_derivative`: Calculate the derivative of the state variables. + - :py:meth:`after_integral`: Calculate the state variables after the integral. + - :py:meth:`init_state`: Initialize the state variables. + - :py:meth:`reset_state`: Reset the state variables. -class IonInfo(NamedTuple): - C: bst.typing.ArrayLike - E: bst.typing.ArrayLike + """ + __module__ = 'dendritex' + + def __init__( + self, + size: bst.typing.Size, + name: Optional[str] = None, + ): + # size + if isinstance(size, (list, tuple)): + if len(size) <= 0: + raise ValueError(f'size must be int, or a tuple/list of int. ' + f'But we got {type(size)}') + if not isinstance(size[0], (int, np.integer)): + raise ValueError('size must be int, or a tuple/list of int.' + f'But we got {type(size)}') + size = tuple(size) + elif isinstance(size, (int, np.integer)): + size = (size,) + else: + raise ValueError('size must be int, or a tuple/list of int.' + f'But we got {type(size)}') + self.size = size + assert len(size) >= 1, ('The size of the dendritic dynamics should be at ' + 'least 1D: (..., n_neuron, n_compartment).') + self.name = name + + @property + def varshape(self): + """The shape of variables in the neuron group.""" + return self.size + + def current(self, *args, **kwargs): + raise NotImplementedError + + def before_integral(self, *args, **kwargs): + pass + + def compute_derivative(self, *args, **kwargs): + raise NotImplementedError + + def post_derivative(self, *args, **kwargs): + raise NotImplementedError('Must be implemented by the subclass.') + + def reset_state(self, *args, **kwargs): + pass + + def init_state(self, *args, **kwargs): + pass -class Ion(IonChannel, Container): - """ - The base class for modeling the Ion dynamics. - - Args: - size: The size of the simulation target. - name: The name of the object. - """ - __module__ = 'dendritex' - _container_name = 'channels' - - # The type of the master object. - root_type = HHTypedNeuron - - # Reversal potential. - E: bst.typing.ArrayLike | bst.State - - # Ion concentration. - C: bst.typing.ArrayLike | bst.State - - def __init__( - self, - size: bst.typing.Size, - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - **channels - ): - super().__init__(size, mode=mode, name=name, **channels) - self.channels: Dict[str, Channel] = bst.visible_module_dict() - self.channels.update(self._format_elements(Channel, **channels)) - - self._external_currents: Dict[str, Callable] = dict() - - def before_integral(self, V): - nodes = self.nodes(level=1, include_self=False).subset(Channel) - for node in nodes.values(): - node.before_integral(V, self.pack_info()) - - def compute_derivative(self, V): - nodes = self.nodes(level=1, include_self=False).subset(Channel) - for node in nodes.values(): - node.compute_derivative(V, self.pack_info()) - - def after_integral(self, V): - nodes = self.nodes(level=1, include_self=False).subset(Channel) - for node in nodes.values(): - node.after_integral(V, self.pack_info()) - - def current(self, V, include_external: bool = False): - """ - Generate ion channel current. +class IonInfo(NamedTuple): + C: bst.typing.ArrayLike + E: bst.typing.ArrayLike - Args: - V: The membrane potential. - include_external: Include the external current. - Returns: - Current. - """ - nodes = tuple(self.nodes(level=1, include_self=False).subset(Channel).values()) - - ion_info = self.pack_info() - current = None - if len(nodes) > 0: - for node in nodes: - node: Channel - new_current = node.current(V, ion_info) - current = new_current if current is None else (current + new_current) - if include_external: - for key, node in self._external_currents.items(): - node: Callable - current = current + node(V, ion_info) - return current - - def init_state(self, V, batch_size: int = None): - nodes = self.nodes(level=1, include_self=False).subset(Channel).values() - self.check_hierarchies(type(self), *tuple(nodes)) - ion_info = self.pack_info() - for node in nodes: - node: Channel - node.init_state(V, ion_info, batch_size) - - def reset_state(self, V, batch_size: int = None): - nodes = self.nodes(level=1, include_self=False).subset(Channel).values() - ion_info = self.pack_info() - for node in nodes: - node: Channel - node.reset_state(V, ion_info, batch_size) - - def register_external_current(self, key: str, fun: Callable): - if key in self._external_currents: - raise ValueError - self._external_currents[key] = fun - - def pack_info(self): - E = self.E - E = E.value if isinstance(E, bst.State) else E - C = self.C.value if isinstance(self.C, bst.State) else self.C - return IonInfo(E=E, C=C) - - def add_elem(self, *elems, **elements): +class Ion(IonChannel, Container): """ - Add new elements. + The base class for modeling the Ion dynamics. Args: - elements: children objects. + size: The size of the simulation target. + name: The name of the object. """ - self.check_hierarchies(type(self), *elems, **elements) - self.channels.update(self._format_elements(object, *elems, **elements)) + __module__ = 'dendritex' + _container_name = 'channels' + + # The type of the master object. + root_type = HHTypedNeuron + + # Reversal potential. + E: bst.typing.ArrayLike | bst.State + + # Ion concentration. + C: bst.typing.ArrayLike | bst.State + + def __init__( + self, + size: bst.typing.Size, + name: Optional[str] = None, + **channels + ): + super().__init__(size, name=name, **channels) + self.channels: Dict[str, Channel] = dict() + self.channels.update(self._format_elements(Channel, **channels)) + + self._external_currents: Dict[str, Callable] = dict() + + def before_integral(self, V): + nodes = bst.graph.nodes(self, Channel, allowed_hierarchy=(1, 1)) + for node in nodes.values(): + node.before_integral(V, self.pack_info()) + + def compute_derivative(self, V): + nodes = bst.graph.nodes(self, Channel, allowed_hierarchy=(1, 1)) + for node in nodes.values(): + node.compute_derivative(V, self.pack_info()) + + def post_derivative(self, V): + nodes = bst.graph.nodes(self, Channel, allowed_hierarchy=(1, 1)) + for node in nodes.values(): + node.post_derivative(V, self.pack_info()) + + def current(self, V, include_external: bool = False): + """ + Generate ion channel current. + + Args: + V: The membrane potential. + include_external: Include the external current. + + Returns: + Current. + """ + nodes = tuple(bst.graph.nodes(self, Channel, allowed_hierarchy=(1, 1)).values()) + + ion_info = self.pack_info() + current = None + if len(nodes) > 0: + for node in nodes: + node: Channel + new_current = node.current(V, ion_info) + current = new_current if current is None else (current + new_current) + if include_external: + for key, node in self._external_currents.items(): + node: Callable + current = current + node(V, ion_info) + return current + + def init_state(self, V, batch_size: int = None): + nodes = bst.graph.nodes(self, Channel, allowed_hierarchy=(1, 1)).values() + self.check_hierarchies(type(self), *tuple(nodes)) + ion_info = self.pack_info() + for node in nodes: + node: Channel + node.init_state(V, ion_info, batch_size) + + def reset_state(self, V, batch_size: int = None): + nodes = bst.graph.nodes(self, Channel, allowed_hierarchy=(1, 1)).values() + ion_info = self.pack_info() + for node in nodes: + node: Channel + node.reset_state(V, ion_info, batch_size) + + def register_external_current(self, key: str, fun: Callable): + if key in self._external_currents: + raise ValueError + self._external_currents[key] = fun + + def pack_info(self): + E = self.E + E = E.value if isinstance(E, bst.State) else E + C = self.C.value if isinstance(self.C, bst.State) else self.C + return IonInfo(E=E, C=C) + + def add_elem(self, **elements): + """ + Add new elements. + + Args: + elements: children objects. + """ + self.check_hierarchies(type(self), **elements) + self.channels.update(self._format_elements(object, **elements)) class MixIons(IonChannel, Container): - """ - Mixing Ions. - - Args: - ions: Instances of ions. This option defines the master types of all children objects. - """ - __module__ = 'dendritex' - - root_type = HHTypedNeuron - _container_name = 'channels' - - def __init__( - self, - *ions, - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - **channels - ): - # TODO: check "ions" should be independent from each other - assert len(ions) >= 2, f'{self.__class__.__name__} requires at least two ions. ' - assert all([isinstance(cls, Ion) for cls in ions]), f'Must be a sequence of Ion. But got {ions}.' - size = ions[0].size - for ion in ions: - assert ion.size == size, f'The size of all ions should be the same. But we got {ions}.' - super().__init__(size=size, name=name, mode=mode) - - # Store the ion instances - self.ions: Sequence['Ion'] = tuple(ions) - self._ion_types = tuple([type(ion) for ion in self.ions]) - - # Store the ion channel channels - self.channels: Dict[str, Channel] = bst.visible_module_dict() - self.channels.update(self._format_elements(Channel, **channels)) - - def before_integral(self, V): - nodes = tuple(self.nodes(level=1, include_self=False).subset(Channel).values()) - for node in nodes: - ion_infos = tuple([self._get_ion(ion).pack_info() for ion in node.root_type.__args__]) - node.before_integral(V, *ion_infos) - - def compute_derivative(self, V): - nodes = tuple(self.nodes(level=1, include_self=False).subset(Channel).values()) - for node in nodes: - ion_infos = tuple([self._get_ion(ion).pack_info() for ion in node.root_type.__args__]) - node.compute_derivative(V, *ion_infos) - - def after_integral(self, V): - nodes = tuple(self.nodes(level=1, include_self=False).subset(Channel).values()) - for node in nodes: - ion_infos = tuple([self._get_ion(ion).pack_info() for ion in node.root_type.__args__]) - node.after_integral(V, *ion_infos) - - def current(self, V): - """Generate ion channel current. - - Args: - V: The membrane potential. - - Returns: - Current. - """ - nodes = tuple(self.nodes(level=1, include_self=False).subset(Channel).values()) - - if len(nodes) == 0: - return 0. - else: - current = None - for node in nodes: - infos = tuple([self._get_ion(root).pack_info() for root in node.root_type.__args__]) - current = node.current(V, *infos) if current is None else (current + node.current(V, *infos)) - return current - - def init_state(self, V, batch_size: int = None): - nodes = self.nodes(level=1, include_self=False).subset(Channel).values() - self.check_hierarchies(self._ion_types, *tuple(nodes), check_fun=self._check_hierarchy) - for node in nodes: - node: Channel - infos = tuple([self._get_ion(root).pack_info() for root in node.root_type.__args__]) - node.init_state(V, *infos, batch_size) - - def reset_state(self, V, batch_size=None): - nodes = tuple(self.nodes(level=1, include_self=False).subset(Channel).values()) - for node in nodes: - infos = tuple([self._get_ion(root).pack_info() for root in node.root_type.__args__]) - node.reset_state(V, *infos, batch_size) - - def _check_hierarchy(self, ions, leaf): - # 'root_type' should be a brainpy.mixin.JointType - self._check_root(leaf) - for cls in leaf.root_type.__args__: - if not any([issubclass(root, cls) for root in ions]): - raise TypeError( - f'Type does not match. {leaf} requires a master with type ' - f'of {leaf.root_type}, but the master type now is {ions}.' - ) - - def add_elem(self, *elems, **elements): """ - Add new elements. + Mixing Ions. Args: - elements: children objects. + ions: Instances of ions. This option defines the master types of all children objects. """ - self.check_hierarchies(self._ion_types, *elems, check_fun=self._check_hierarchy, **elements) - self.channels.update(self._format_elements(Channel, *elems, **elements)) - for elem in tuple(elems) + tuple(elements.values()): - elem: Channel - for ion_root in elem.root_type.__args__: - ion = self._get_ion(ion_root) - ion.register_external_current(elem.name, self._get_ion_fun(ion, elem)) - - def _get_ion_fun(self, ion: 'Ion', node: 'Channel'): - def fun(V, ion_info): - infos = tuple( - [(ion_info if isinstance(ion, root) else self._get_ion(root).pack_info()) - for root in node.root_type.__args__] - ) - return node.current(V, *infos) - - return fun - - def _get_ion(self, cls): - for ion in self.ions: - if isinstance(ion, cls): - return ion - else: - raise ValueError(f'No instance of {cls} is found.') - - def _check_root(self, leaf): - if not isinstance(leaf.root_type, _JointGenericAlias): - raise TypeError( - f'{self.__class__.__name__} requires leaf nodes that have the root_type of ' - f'"brainpy.mixin.JointType". However, we got {leaf.root_type}' - ) + __module__ = 'dendritex' + + root_type = HHTypedNeuron + _container_name = 'channels' + + def __init__(self, *ions, name: Optional[str] = None, **channels): + # TODO: check "ions" should be independent from each other + assert len(ions) >= 2, f'{self.__class__.__name__} requires at least two ions. ' + assert all([isinstance(cls, Ion) for cls in ions]), f'Must be a sequence of Ion. But got {ions}.' + size = ions[0].size + for ion in ions: + assert ion.size == size, f'The size of all ions should be the same. But we got {ions}.' + super().__init__(size=size, name=name) + + # Store the ion instances + self.ions: Sequence['Ion'] = tuple(ions) + self._ion_types = tuple([type(ion) for ion in self.ions]) + + # Store the ion channel channels + self.channels: Dict[str, Channel] = dict() + self.channels.update(self._format_elements(Channel, **channels)) + + def before_integral(self, V): + nodes = tuple(bst.graph.nodes(self, Channel, allowed_hierarchy=(1, 1)).values()) + for node in nodes: + ion_infos = tuple([self._get_ion(ion).pack_info() for ion in node.root_type.__args__]) + node.before_integral(V, *ion_infos) + + def compute_derivative(self, V): + nodes = tuple(bst.graph.nodes(self, Channel, allowed_hierarchy=(1, 1)).values()) + for node in nodes: + ion_infos = tuple([self._get_ion(ion).pack_info() for ion in node.root_type.__args__]) + node.compute_derivative(V, *ion_infos) + + def post_derivative(self, V): + nodes = tuple(bst.graph.nodes(self, Channel, allowed_hierarchy=(1, 1)).values()) + for node in nodes: + ion_infos = tuple([self._get_ion(ion).pack_info() for ion in node.root_type.__args__]) + node.post_derivative(V, *ion_infos) + + def current(self, V): + """Generate ion channel current. + + Args: + V: The membrane potential. + + Returns: + Current. + """ + nodes = tuple(bst.graph.nodes(self, Channel, allowed_hierarchy=(1, 1)).values()) + + if len(nodes) == 0: + return 0. + else: + current = None + for node in nodes: + infos = tuple([self._get_ion(root).pack_info() for root in node.root_type.__args__]) + current = node.current(V, *infos) if current is None else (current + node.current(V, *infos)) + return current + + def init_state(self, V, batch_size: int = None): + nodes = bst.graph.nodes(self, Channel, allowed_hierarchy=(1, 1)).values() + self.check_hierarchies(self._ion_types, *tuple(nodes), check_fun=self._check_hierarchy) + for node in nodes: + node: Channel + infos = tuple([self._get_ion(root).pack_info() for root in node.root_type.__args__]) + node.init_state(V, *infos, batch_size) + + def reset_state(self, V, batch_size=None): + nodes = tuple(bst.graph.nodes(self, Channel, allowed_hierarchy=(1, 1)).values()) + for node in nodes: + infos = tuple([self._get_ion(root).pack_info() for root in node.root_type.__args__]) + node.reset_state(V, *infos, batch_size) + + def _check_hierarchy(self, ions, leaf): + # 'root_type' should be a brainpy.mixin.JointType + self._check_root(leaf) + for cls in leaf.root_type.__args__: + if not any([issubclass(root, cls) for root in ions]): + raise TypeError( + f'Type does not match. {leaf} requires a master with type ' + f'of {leaf.root_type}, but the master type now is {ions}.' + ) + + def add_elem(self, **elements): + """ + Add new elements. + + Args: + elements: children objects. + """ + self.check_hierarchies(self._ion_types, check_fun=self._check_hierarchy, **elements) + self.channels.update(self._format_elements(Channel, **elements)) + for elem in tuple(elements.values()): + elem: Channel + for ion_root in elem.root_type.__args__: + ion = self._get_ion(ion_root) + ion.register_external_current(elem.name, self._get_ion_fun(ion, elem)) + + def _get_ion_fun(self, ion: 'Ion', node: 'Channel'): + def fun(V, ion_info): + infos = tuple( + [(ion_info if isinstance(ion, root) else self._get_ion(root).pack_info()) + for root in node.root_type.__args__] + ) + return node.current(V, *infos) + + return fun + + def _get_ion(self, cls): + for ion in self.ions: + if isinstance(ion, cls): + return ion + else: + raise ValueError(f'No instance of {cls} is found.') + + def _check_root(self, leaf): + if not isinstance(leaf.root_type, _JointGenericAlias): + raise TypeError( + f'{self.__class__.__name__} requires leaf nodes that have the root_type of ' + f'"brainpy.mixin.JointType". However, we got {leaf.root_type}' + ) @set_module_as('dendritex') def mix_ions(*ions) -> MixIons: - """Create mixed ions. + """Create mixed ions. - Args: - ions: Ion instances. + Args: + ions: Ion instances. - Returns: - Instance of MixIons. - """ - for ion in ions: - assert isinstance(ion, Ion), f'Must be instance of {Ion.__name__}. But got {type(ion)}' - assert len(ions) > 0, '' - return MixIons(*ions) + Returns: + Instance of MixIons. + """ + for ion in ions: + assert isinstance(ion, Ion), f'Must be instance of {Ion.__name__}. But got {type(ion)}' + assert len(ions) > 0, '' + return MixIons(*ions) class Channel(IonChannel): - """ - The base class for modeling channel dynamics. - """ - __module__ = 'dendritex' + """ + The base class for modeling channel dynamics. + """ + __module__ = 'dendritex' diff --git a/dendritex/_base_test.py b/dendritex/_base_test.py index fa6d048..bef5b7d 100644 --- a/dendritex/_base_test.py +++ b/dendritex/_base_test.py @@ -17,5 +17,5 @@ class TestIon(unittest.TestCase): - def test1(self): - pass + def test1(self): + pass diff --git a/dendritex/_integrators.py b/dendritex/_integrators.py index 25ac895..2c82e53 100644 --- a/dendritex/_integrators.py +++ b/dendritex/_integrators.py @@ -15,47 +15,55 @@ from __future__ import annotations -from typing import Optional, Tuple, Callable, Dict, Union +from dataclasses import dataclass +from typing import Optional, Tuple, Callable, Dict, Union, Sequence import brainstate as bst import brainunit as u import diffrax as dfx import jax -from ._base import State4Integral, DendriticDynamics +from ._base import State4Integral, DendriteDynamics from ._misc import set_module_as __all__ = [ - 'diffrax_solve_adjoint', - 'diffrax_solve', - 'euler_step', - 'rk2_step', - 'rk3_step', - 'rk4_step', + 'diffrax_solve_adjoint', + 'diffrax_solve', + 'euler_step', + 'midpoint_step', + 'rk2_step', + 'heun2_step', + 'ralston2_step', + 'rk3_step', + 'heun3_step', + 'ssprk3_step', + 'ralston3_step', + 'rk4_step', + 'ralston4_step', ] diffrax_solvers = { - # explicit RK - 'euler': dfx.Euler, - 'revheun': dfx.ReversibleHeun, - 'heun': dfx.Heun, - 'midpoint': dfx.Midpoint, - 'ralston': dfx.Ralston, - 'bosh3': dfx.Bosh3, - 'tsit5': dfx.Tsit5, - 'dopri5': dfx.Dopri5, - 'dopri8': dfx.Dopri8, - - # implicit RK - 'ieuler': dfx.ImplicitEuler, - 'kvaerno3': dfx.Kvaerno3, - 'kvaerno4': dfx.Kvaerno4, - 'kvaerno5': dfx.Kvaerno5, + # explicit RK + 'euler': dfx.Euler, + 'revheun': dfx.ReversibleHeun, + 'heun': dfx.Heun, + 'midpoint': dfx.Midpoint, + 'ralston': dfx.Ralston, + 'bosh3': dfx.Bosh3, + 'tsit5': dfx.Tsit5, + 'dopri5': dfx.Dopri5, + 'dopri8': dfx.Dopri8, + + # implicit RK + 'ieuler': dfx.ImplicitEuler, + 'kvaerno3': dfx.Kvaerno3, + 'kvaerno4': dfx.Kvaerno4, + 'kvaerno5': dfx.Kvaerno5, } def _is_quantity(x): - return isinstance(x, u.Quantity) + return isinstance(x, u.Quantity) def _diffrax_solve( @@ -72,140 +80,141 @@ def _diffrax_solve( atol: Optional[float] = None, max_steps: int = None, ): - if isinstance(adjoint, str): - if adjoint == 'adjoint': - adjoint = dfx.BacksolveAdjoint() - elif adjoint == 'checkpoint': - adjoint = dfx.RecursiveCheckpointAdjoint() - elif adjoint == 'direct': - adjoint = dfx.DirectAdjoint() + if isinstance(adjoint, str): + if adjoint == 'adjoint': + adjoint = dfx.BacksolveAdjoint() + elif adjoint == 'checkpoint': + adjoint = dfx.RecursiveCheckpointAdjoint() + elif adjoint == 'direct': + adjoint = dfx.DirectAdjoint() + else: + raise ValueError(f"Unknown adjoint method: {adjoint}. Only support 'checkpoint', 'direct', and 'adjoint'.") + elif isinstance(adjoint, dfx.AbstractAdjoint): + adjoint = adjoint else: - raise ValueError(f"Unknown adjoint method: {adjoint}. Only support 'checkpoint', 'direct', and 'adjoint'.") - elif isinstance(adjoint, dfx.AbstractAdjoint): - adjoint = adjoint - else: - raise ValueError(f"Unknown adjoint method: {adjoint}. Only support 'checkpoint', 'direct', and 'adjoint'.") - - # processing times - dt0 = dt0.in_unit(u.ms) - t0 = t0.in_unit(u.ms) - t1 = t1.in_unit(u.ms) - if saveat is not None: - saveat = saveat.in_unit(u.ms) - - # stepsize controller - if rtol is None and atol is None: - stepsize_controller = dfx.ConstantStepSize() - else: - if rtol is None: - rtol = atol - if atol is None: - atol = rtol - stepsize_controller = dfx.PIDController(rtol=rtol, atol=atol) - - # numerical solver - if solver not in diffrax_solvers: - raise ValueError(f"Unknown solver: {solver}") - solver = diffrax_solvers[solver]() - - def model_to_derivative(t, *args): - with bst.environ.context(t=t * u.ms): - with bst.StateTrace() as trace: - model(t * u.ms, *args) - derivatives = [] - for st in trace.states: - if isinstance(st, State4Integral): - a = u.get_unit(st.derivative) * u.ms - b = u.get_unit(st.value) - assert a.has_same_dim(b), f'Unit mismatch. Got {a} != {b}' - if isinstance(st.derivative, u.Quantity): - st.derivative = st.derivative.in_unit(u.get_unit(st.value) / u.ms) - derivatives.append(st.derivative) - else: - raise ValueError(f"State {st} is not for integral.") + raise ValueError(f"Unknown adjoint method: {adjoint}. Only support 'checkpoint', 'direct', and 'adjoint'.") + + # processing times + dt0 = dt0.in_unit(u.ms) + t0 = t0.in_unit(u.ms) + t1 = t1.in_unit(u.ms) + if saveat is not None: + saveat = saveat.in_unit(u.ms) + + # stepsize controller + if rtol is None and atol is None: + stepsize_controller = dfx.ConstantStepSize() + else: + if rtol is None: + rtol = atol + if atol is None: + atol = rtol + stepsize_controller = dfx.PIDController(rtol=rtol, atol=atol) + + # numerical solver + if solver not in diffrax_solvers: + raise ValueError(f"Unknown solver: {solver}") + solver = diffrax_solvers[solver]() + + def model_to_derivative(t, *args): + with bst.environ.context(t=t * u.ms): + with bst.StateTraceStack() as trace: + model(t * u.ms, *args) + derivatives = [] + for st in trace.states: + if isinstance(st, State4Integral): + a = u.get_unit(st.derivative) * u.ms + b = u.get_unit(st.value) + assert a.has_same_dim(b), f'Unit mismatch. Got {a} != {b}' + if isinstance(st.derivative, u.Quantity): + st.derivative = st.derivative.in_unit(u.get_unit(st.value) / u.ms) + derivatives.append(st.derivative) + else: + raise ValueError(f"State {st} is not for integral.") + return derivatives + + # stateful function and make jaxpr + stateful_fn = bst.compile.StatefulFunction(model_to_derivative).make_jaxpr(0., *args) + + # states + states = stateful_fn.get_states() + + def vector_filed(t, state_vals, args): + new_state_vals, derivatives = stateful_fn.jaxpr_call(state_vals, t, *args) + derivatives = tuple(d.mantissa if isinstance(d, u.Quantity) else d + for d in derivatives) return derivatives - # stateful function and make jaxpr - stateful_fn = bst.transform.StatefulFunction(model_to_derivative).make_jaxpr(0., *args) - - # states - states = stateful_fn.get_states() - - def vector_filed(t, state_vals, args): - new_state_vals, derivatives = stateful_fn.jaxpr_call(state_vals, t, *args) - derivatives = tuple(d.mantissa if isinstance(d, u.Quantity) else d - for d in derivatives) - return derivatives - - def save_y(t, state_vals, args): - for st, st_val in zip(states, state_vals): - st.value = u.Quantity(st_val, unit=st.value.unit) if isinstance(st.value, u.Quantity) else st_val - assert callable(savefn), 'savefn must be callable.' - rets = savefn(t * u.ms, *args) - nonlocal return_units - if return_units is None: - return_units = jax.tree.map(lambda x: x.unit if isinstance(x, u.Quantity) else None, rets, is_leaf=_is_quantity) - return jax.tree.map(lambda x: x.mantissa if isinstance(x, u.Quantity) else x, rets, is_leaf=_is_quantity) - - return_units = None - if savefn is None: - return_units = tuple(st.value.unit if isinstance(st.value, u.Quantity) else None for st in states) - if saveat is None: - if isinstance(adjoint, dfx.BacksolveAdjoint): - raise ValueError('saveat must be specified when using backsolve adjoint.') - saveat = dfx.SaveAt(steps=True) - else: - saveat = dfx.SaveAt(ts=saveat.mantissa, t1=True) - else: - subsaveat_a = dfx.SubSaveAt(t1=True) - if saveat is None: - subsaveat_b = dfx.SubSaveAt(steps=True, fn=save_y) + def save_y(t, state_vals, args): + for st, st_val in zip(states, state_vals): + st.value = u.Quantity(st_val, unit=st.value.unit) if isinstance(st.value, u.Quantity) else st_val + assert callable(savefn), 'savefn must be callable.' + rets = savefn(t * u.ms, *args) + nonlocal return_units + if return_units is None: + return_units = jax.tree.map(lambda x: x.unit if isinstance(x, u.Quantity) else None, rets, + is_leaf=_is_quantity) + return jax.tree.map(lambda x: x.mantissa if isinstance(x, u.Quantity) else x, rets, is_leaf=_is_quantity) + + return_units = None + if savefn is None: + return_units = tuple(st.value.unit if isinstance(st.value, u.Quantity) else None for st in states) + if saveat is None: + if isinstance(adjoint, dfx.BacksolveAdjoint): + raise ValueError('saveat must be specified when using backsolve adjoint.') + saveat = dfx.SaveAt(steps=True) + else: + saveat = dfx.SaveAt(ts=saveat.mantissa, t1=True) else: - subsaveat_b = dfx.SubSaveAt(ts=saveat.mantissa, fn=save_y) - saveat = dfx.SaveAt(subs=[subsaveat_a, subsaveat_b]) - - # solving differential equations - sol = dfx.diffeqsolve( - dfx.ODETerm(vector_filed), - solver, - t0=t0.mantissa, - t1=t1.mantissa, - dt0=dt0.mantissa, - y0=tuple((v.value.mantissa if isinstance(v.value, u.Quantity) else v.value) for v in states), - args=args, - saveat=saveat, - adjoint=adjoint, - stepsize_controller=stepsize_controller, - max_steps=max_steps, - ) - if savefn is None: - # assign values back to states - for st, st_value in zip(states, sol.ys): - st.value = u.Quantity(st_value[-1], unit=st.unit) if isinstance(st, u.Quantity) else st_value[-1] - # record solver state - return ( - sol.ts * u.ms, - jax.tree.map( - lambda y, unit: (u.Quantity(y, unit=unit) if unit is not None else y), - sol.ys, - return_units - ), - sol.stats - ) - else: - # assign values back to states - for st, st_value in zip(states, sol.ys[0]): - st.value = u.Quantity(st_value[0], unit=st.unit) if isinstance(st, u.Quantity) else st_value[0] - # record solver state - return ( - sol.ts[1] * u.ms, - jax.tree.map( - lambda y, unit: (u.Quantity(y, unit=unit) if unit is not None else y), - sol.ys[1], - return_units - ), - sol.stats + subsaveat_a = dfx.SubSaveAt(t1=True) + if saveat is None: + subsaveat_b = dfx.SubSaveAt(steps=True, fn=save_y) + else: + subsaveat_b = dfx.SubSaveAt(ts=saveat.mantissa, fn=save_y) + saveat = dfx.SaveAt(subs=[subsaveat_a, subsaveat_b]) + + # solving differential equations + sol = dfx.diffeqsolve( + dfx.ODETerm(vector_filed), + solver, + t0=t0.mantissa, + t1=t1.mantissa, + dt0=dt0.mantissa, + y0=tuple((v.value.mantissa if isinstance(v.value, u.Quantity) else v.value) for v in states), + args=args, + saveat=saveat, + adjoint=adjoint, + stepsize_controller=stepsize_controller, + max_steps=max_steps, ) + if savefn is None: + # assign values back to states + for st, st_value in zip(states, sol.ys): + st.value = u.Quantity(st_value[-1], unit=st.unit) if isinstance(st, u.Quantity) else st_value[-1] + # record solver state + return ( + sol.ts * u.ms, + jax.tree.map( + lambda y, unit: (u.Quantity(y, unit=unit) if unit is not None else y), + sol.ys, + return_units + ), + sol.stats + ) + else: + # assign values back to states + for st, st_value in zip(states, sol.ys[0]): + st.value = u.Quantity(st_value[0], unit=st.unit) if isinstance(st, u.Quantity) else st_value[0] + # record solver state + return ( + sol.ts[1] * u.ms, + jax.tree.map( + lambda y, unit: (u.Quantity(y, unit=unit) if unit is not None else y), + sol.ys[1], + return_units + ), + sol.stats + ) def diffrax_solve_adjoint( @@ -221,56 +230,56 @@ def diffrax_solve_adjoint( atol: Optional[float] = None, max_steps: Optional[int] = None, ): - """ - Solve the differential equations using `diffrax `_ which - is compatible with the adjoint backpropagation. - - Args: - model: The model function to solve. - solver: The solver to use. Available solvers are: - - 'euler' - - 'revheun' - - 'heun' - - 'midpoint' - - 'ralston' - - 'bosh3' - - 'tsit5' - - 'dopri5' - - 'dopri8' - - 'ieuler' - - 'kvaerno3' - - 'kvaerno4' - - 'kvaerno5' - t0: The initial time. - t1: The final time. - dt0: The initial step size. - saveat: The time points to save the solution. If None, save the solution at every step. - savefn: The function to save the solution. If None, save the solution at every step. - args: The arguments to pass to the model function. - rtol: The relative tolerance. - atol: The absolute tolerance. - max_steps: The maximum number of steps. - - Returns: - The solution of the differential equations, including the following items: - - The time points. - - The solution. - - The running step statistics. - """ - return _diffrax_solve( - model=model, - solver=solver, - t0=t0, - t1=t1, - dt0=dt0, - saveat=saveat, - savefn=savefn, - adjoint='adjoint', - max_steps=max_steps, - args=args, - rtol=rtol, - atol=atol, - ) + """ + Solve the differential equations using `diffrax `_ which + is compatible with the adjoint backpropagation. + + Args: + model: The model function to solve. + solver: The solver to use. Available solvers are: + - 'euler' + - 'revheun' + - 'heun' + - 'midpoint' + - 'ralston' + - 'bosh3' + - 'tsit5' + - 'dopri5' + - 'dopri8' + - 'ieuler' + - 'kvaerno3' + - 'kvaerno4' + - 'kvaerno5' + t0: The initial time. + t1: The final time. + dt0: The initial step size. + saveat: The time points to save the solution. If None, save the solution at every step. + savefn: The function to save the solution. If None, save the solution at every step. + args: The arguments to pass to the model function. + rtol: The relative tolerance. + atol: The absolute tolerance. + max_steps: The maximum number of steps. + + Returns: + The solution of the differential equations, including the following items: + - The time points. + - The solution. + - The running step statistics. + """ + return _diffrax_solve( + model=model, + solver=solver, + t0=t0, + t1=t1, + dt0=dt0, + saveat=saveat, + savefn=savefn, + adjoint='adjoint', + max_steps=max_steps, + args=args, + rtol=rtol, + atol=atol, + ) def diffrax_solve( @@ -287,241 +296,243 @@ def diffrax_solve( max_steps: Optional[int] = None, adjoint: Union[str, dfx.AbstractAdjoint] = 'checkpoint', ) -> Tuple[u.Quantity, bst.typing.PyTree[u.Quantity], Dict]: - """ - Solve the differential equations using `diffrax `_. - - Args: - model: The model function to solve. - solver: The solver to use. Available solvers are: - - 'euler' - - 'revheun' - - 'heun' - - 'midpoint' - - 'ralston' - - 'bosh3' - - 'tsit5' - - 'dopri5' - - 'dopri8' - - 'ieuler' - - 'kvaerno3' - - 'kvaerno4' - - 'kvaerno5' - t0: The initial time. - t1: The final time. - dt0: The initial step size. - saveat: The time points to save the solution. If None, save the solution at every step. - savefn: The function to save the solution. If None, save the solution at every step. - args: The arguments to pass to the model function. - rtol: The relative tolerance. - atol: The absolute tolerance. - max_steps: The maximum number of steps. - adjoint: The adjoint method. Available methods are: - - 'adjoint' - - 'checkpoint' - - 'direct' - - Returns: - The solution of the differential equations, including the following items: - - The time points. - - The solution. - - The running step statistics. - """ - return _diffrax_solve( - model=model, - solver=solver, - t0=t0, - t1=t1, - dt0=dt0, - saveat=saveat, - savefn=savefn, - adjoint=adjoint, - args=args, - rtol=rtol, - atol=atol, - max_steps=max_steps, - ) - - -def tree_map(f, tree, *rest): - return jax.tree.map(f, tree, *rest, is_leaf=lambda a: isinstance(a, u.Quantity)) + """ + Solve the differential equations using `diffrax `_. + + Args: + model: The model function to solve. + solver: The solver to use. Available solvers are: + - 'euler' + - 'revheun' + - 'heun' + - 'midpoint' + - 'ralston' + - 'bosh3' + - 'tsit5' + - 'dopri5' + - 'dopri8' + - 'ieuler' + - 'kvaerno3' + - 'kvaerno4' + - 'kvaerno5' + t0: The initial time. + t1: The final time. + dt0: The initial step size. + saveat: The time points to save the solution. If None, save the solution at every step. + savefn: The function to save the solution. If None, save the solution at every step. + args: The arguments to pass to the model function. + rtol: The relative tolerance. + atol: The absolute tolerance. + max_steps: The maximum number of steps. + adjoint: The adjoint method. Available methods are: + - 'adjoint' + - 'checkpoint' + - 'direct' + + Returns: + The solution of the differential equations, including the following items: + - The time points. + - The solution. + - The running step statistics. + """ + return _diffrax_solve( + model=model, + solver=solver, + t0=t0, + t1=t1, + dt0=dt0, + saveat=saveat, + savefn=savefn, + adjoint=adjoint, + args=args, + rtol=rtol, + atol=atol, + max_steps=max_steps, + ) -@set_module_as('dendritex') -def euler_step(target: DendriticDynamics, t: jax.typing.ArrayLike, *args): - dt = bst.environ.get_dt() - - with bst.environ.context(t=t): - with bst.StateTrace() as trace: - target.before_integral(*args) - target.compute_derivative(*args) - - # state collection - states = tuple([st for st in trace.states if isinstance(st, State4Integral)]) - # initial values - ys = list([val for st, val in zip(trace.states, trace._org_values) if isinstance(st, State4Integral)]) - # derivatives - k1hs = [st.derivative for st in states] - - # y + k1 - with bst.check_state_value_tree(): - # update states with derivatives - for st, y, k1h in zip(states, ys, k1hs): - st.value = tree_map(lambda y_, k1_: y_ + k1_ * dt, y, k1h) - # update other states - target.after_integral(*args) +@dataclass(frozen=True) +class ButcherTableau: + """The Butcher tableau for an explicit or diagonal Runge--Kutta method.""" + A: Sequence[Sequence] # The A matrix in the Butcher tableau. + B: Sequence # The B vector in the Butcher tableau. + C: Sequence # The C vector in the Butcher tableau. -@set_module_as('dendritex') -def rk2_step(target: DendriticDynamics, t: jax.typing.ArrayLike, *args): - dt = bst.environ.get_dt() - - # k1 - with bst.environ.context(t=t): - with bst.StateTrace() as trace: - target.before_integral(*args) - target.compute_derivative(*args) - - # state collection - states = tuple([st for st in trace.states if isinstance(st, State4Integral)]) - # initial values - ys = list([val for st, val in zip(trace.states, trace._org_values) if isinstance(st, State4Integral)]) - # derivatives - k1hs = [st.derivative for st in states] - - # k2 - with bst.environ.context(t=t + dt): - with bst.check_state_value_tree(): - for st, y, k1h in zip(states, ys, k1hs): - st.value = tree_map(lambda y_, k1_: y_ + k1_ * dt, y, k1h) - # update other states - target.after_integral(*args) - target.before_integral(*args) - target.compute_derivative(*args) - k2s = [st.derivative for st in states] - # y + (k1 + k2) / 2 - with bst.check_state_value_tree(): - # update states with derivatives - for st, y, k1h, k2h in zip(states, ys, k1hs, k2s): - st.value = tree_map(lambda y_, k1_, k2_: y_ + 0.5 * (k1_ + k2_) * dt, y, k1h, k2h) - # update other states - target.after_integral(*args) +def _update(dt, coeff, st, y0, *ks): + assert len(coeff) == len(ks), 'The number of coefficients must be equal to the number of ks.' + + def _step(y0_, *k_): + kds = [c_ * k_ for c_, k_ in zip(coeff, k_)] + update = kds[0] + for kd in kds[1:]: + update += kd + return y0_ + update * dt + + st.value = jax.tree.map(_step, y0, *ks, is_leaf=u.math.is_quantity) @set_module_as('dendritex') -def rk3_step(target: DendriticDynamics, t: jax.typing.ArrayLike, *args): - dt = bst.environ.get_dt() - - # k1 - with bst.environ.context(t=t): - with bst.StateTrace() as trace: - target.before_integral(*args) - target.compute_derivative(*args) - - # state collection - states = tuple([st for st in trace.states if isinstance(st, State4Integral)]) - # initial values - ys = list([val for st, val in zip(trace.states, trace._org_values) if isinstance(st, State4Integral)]) - # derivatives - k1hs = [st.derivative for st in states] - - # k2 - with bst.environ.context(t=t + dt * 0.5): - with bst.check_state_value_tree(): - for st, y, k1 in zip(states, ys, k1hs): - st.value = tree_map( - lambda y_, k1_: y_ + k1_ * 0.5 * dt, - y, k1 - ) - # update other states - target.after_integral(*args) +def _general_rk_step( + tableau: ButcherTableau, + target: DendriteDynamics, + t: jax.typing.ArrayLike, + *args +): + dt = bst.environ.get_dt() + time_dtype = u.math.get_dtype(dt) + ks = [] + + # before one-step integration target.before_integral(*args) - target.compute_derivative(*args) - k2hs = [st.derivative for st in states] - # k3 - with bst.environ.context(t=t + dt): + # k1 + with bst.environ.context(t=t + u.math.asarray(tableau.C[0], dtype=time_dtype) * dt): + with bst.StateTraceStack() as trace: + # compute derivative + target.compute_derivative(*args) + + # state collection + states = tuple([st for st in trace.states if isinstance(st, State4Integral)]) + + # initial values + y0 = list([ + val + for st, val in zip(trace.states, trace.original_state_values) + if isinstance(st, State4Integral) + ]) + + # derivatives + k1hs = [st.derivative for st in states] + ks.append(k1hs) + + for i in range(1, len(tableau.C)): + with bst.environ.context(t=t + u.math.asarray(tableau.C[i], dtype=time_dtype) * dt): + with bst.check_state_value_tree(): + for st, y0_, *ks_ in zip(states, y0, *ks): + _update(dt, tableau.A[i], st, y0_, *ks_) + # after one-step derivative + target.post_derivative(*args) + target.compute_derivative(*args) + ks.append([st.derivative for st in states]) + + # final step with bst.check_state_value_tree(): - for st, y, k2h, k1 in zip(states, ys, k2hs, k1hs): - st.value = tree_map( - lambda y_, k2_, k1_: y_ + (2.0 * k2_ - k1_) * dt, - y, k2h, k1 - ) - # update other states - target.after_integral(*args) - target.before_integral(*args) - target.compute_derivative(*args) - k3hs = [st.derivative for st in states] + # update states with derivatives + for st, y0_, *ks_ in zip(states, y0, *ks): + _update(dt, tableau.B, st, y0_, *ks_) + # update other states + target.post_derivative(*args) + + +euler_tableau = ButcherTableau( + A=((),), + B=(1.0,), + C=(0.0,), +) +midpoint_tableau = ButcherTableau( + A=[(), (0.5,)], + B=(0.0, 1.0), + C=(0.0, 0.5), +) +rk2_tableau = ButcherTableau( + A=[(), (2 / 3,)], + B=(1 / 4, 3 / 4), + C=(0.0, 2 / 3), +) +heun2_tableau = ButcherTableau( + A=[(), (1.,)], + B=[0.5, 0.5], + C=[0, 1], +) +ralston2_tableau = ButcherTableau( + A=[(), (2 / 3,)], + B=[0.25, 0.75], + C=[0, 2 / 3], +) +rk3_tableau = ButcherTableau( + A=[(), (0.5,), (-1, 2)], + B=[1 / 6, 2 / 3, 1 / 6], + C=[0, 0.5, 1], +) +heun3_tableau = ButcherTableau( + A=[(), (1 / 3,), (0, 2 / 3)], + B=[0.25, 0, 0.75], + C=[0, 1 / 3, 2 / 3], +) +ralston3_tableau = ButcherTableau( + A=[(), (0.5,), (0, 0.75)], + B=[2 / 9, 1 / 3, 4 / 9], + C=[0, 0.5, 0.75], +) +ssprk3_tableau = ButcherTableau( + A=[(), (1,), (0.25, 0.25)], + B=[1 / 6, 1 / 6, 2 / 3], + C=[0, 1, 0.5], +) +rk4_tableau = ButcherTableau( + A=[(), (0.5,), (0., 0.5), (0., 0., 1)], + B=[1 / 6, 1 / 3, 1 / 3, 1 / 6], + C=[0, 0.5, 0.5, 1], +) +ralston4_tableau = ButcherTableau( + A=[(), (.4,), (.29697761, .15875964), (.21810040, -3.05096516, 3.83286476)], + B=[.17476028, -.55148066, 1.20553560, .17118478], + C=[0, .4, .45573725, 1], +) - # y + (k1 + 4 * k2 + k3) / 2 - with bst.check_state_value_tree(): - # update states with derivatives - for st, y, k1, k2h, k3h in zip(states, ys, k1hs, k2hs, k3hs): - st.value = tree_map( - lambda y_, k1_, k2_, k3_: y_ + (k1_ + 4 * k2_ + k3_) / 6 * dt, - y, k1, k2h, k3h - ) - # update other states - target.after_integral(*args) + +@set_module_as('dendritex') +def euler_step(target: DendriteDynamics, t: bst.typing.ArrayLike, *args): + _general_rk_step(euler_tableau, target, t, *args) @set_module_as('dendritex') -def rk4_step(target: DendriticDynamics, t: jax.typing.ArrayLike, *args): - dt = bst.environ.get_dt() - - # k1 - with bst.environ.context(t=t): - with bst.StateTrace() as trace: - target.before_integral(*args) - target.compute_derivative(*args) - - # state collection - states = tuple([st for st in trace.states if isinstance(st, State4Integral)]) - # initial values - ys = list([val for st, val in zip(trace.states, trace._org_values) if isinstance(st, State4Integral)]) - # derivatives - k1hs = [st.derivative for st in states] - - # k2 - with bst.environ.context(t=t + 0.5 * dt): - with bst.check_state_value_tree(): - for st, y, k1h in zip(states, ys, k1hs): - st.value = tree_map(lambda y_, k1_: y_ + 0.5 * k1_ * dt, y, k1h) - # update other states - target.after_integral(*args) - target.before_integral(*args) - target.compute_derivative(*args) - k2hs = [st.derivative for st in states] +def midpoint_step(target: DendriteDynamics, t: bst.typing.ArrayLike, *args): + _general_rk_step(midpoint_tableau, target, t, *args) - # k3 - with bst.environ.context(t=t + 0.5 * dt): - with bst.check_state_value_tree(): - for st, y, k2h in zip(states, ys, k2hs): - st.value = tree_map(lambda y_, k2_: y_ + 0.5 * k2_ * dt, y, k2h) - # update other states - target.after_integral(*args) - target.before_integral(*args) - target.compute_derivative(*args) - k3hs = [st.derivative for st in states] - # k4 - with bst.environ.context(t=t + dt): - with bst.check_state_value_tree(): - for st, y, k3h in zip(states, ys, k3hs): - st.value = tree_map(lambda y_, k3_: y_ + k3_ * dt, y, k3h) - # update other states - target.after_integral(*args) - target.before_integral(*args) - target.compute_derivative(*args) - k4hs = [st.derivative for st in states] - - # y + (k1 + 2 * k2 + 2 * k3 + k4) / 6 - with bst.check_state_value_tree(): - # update states with derivatives - for st, y, k1h, k2h, k3h, k4h in zip(states, ys, k1hs, k2hs, k3hs, k4hs): - st.value = tree_map( - lambda y_, k1_, k2_, k3_, k4_: y_ + (k1_ + 2 * k2_ + 2 * k3_ + k4_) / 6 * dt, - y, k1h, k2h, k3h, k4h - ) - # update other states - target.after_integral(*args) +@set_module_as('dendritex') +def rk2_step(target: DendriteDynamics, t: bst.typing.ArrayLike, *args): + _general_rk_step(rk2_tableau, target, t, *args) + + +@set_module_as('dendritex') +def heun2_step(target: DendriteDynamics, t: bst.typing.ArrayLike, *args): + _general_rk_step(heun2_tableau, target, t, *args) + + +@set_module_as('dendritex') +def ralston2_step(target: DendriteDynamics, t: bst.typing.ArrayLike, *args): + _general_rk_step(ralston2_tableau, target, t, *args) + + +@set_module_as('dendritex') +def rk3_step(target: DendriteDynamics, t: bst.typing.ArrayLike, *args): + _general_rk_step(rk3_tableau, target, t, *args) + + +@set_module_as('dendritex') +def heun3_step(target: DendriteDynamics, t: bst.typing.ArrayLike, *args): + _general_rk_step(heun3_tableau, target, t, *args) + + +@set_module_as('dendritex') +def ssprk3_step(target: DendriteDynamics, t: bst.typing.ArrayLike, *args): + _general_rk_step(ssprk3_tableau, target, t, *args) + + +@set_module_as('dendritex') +def ralston3_step(target: DendriteDynamics, t: bst.typing.ArrayLike, *args): + _general_rk_step(ralston3_tableau, target, t, *args) + + +@set_module_as('dendritex') +def rk4_step(target: DendriteDynamics, t: bst.typing.ArrayLike, *args): + _general_rk_step(rk4_tableau, target, t, *args) + + +@set_module_as('dendritex') +def ralston4_step(target: DendriteDynamics, t: bst.typing.ArrayLike, *args): + _general_rk_step(ralston4_tableau, target, t, *args) diff --git a/dendritex/_misc.py b/dendritex/_misc.py index 80f2e32..32b800d 100644 --- a/dendritex/_misc.py +++ b/dendritex/_misc.py @@ -14,13 +14,9 @@ # ============================================================================== -import brainunit as bu - - def set_module_as(name: str): - def decorator(module): - module.__name__ = name - return module - - return decorator + def decorator(module): + module.__name__ = name + return module + return decorator diff --git a/dendritex/channels/calcium.py b/dendritex/channels/calcium.py index 4d2b347..3d9e751 100644 --- a/dendritex/channels/calcium.py +++ b/dendritex/channels/calcium.py @@ -12,1195 +12,1148 @@ import brainstate as bst import brainunit as u -from .._base import Channel, IonInfo, State4Integral -from ..ions import Calcium +from dendritex._base import Channel, IonInfo, State4Integral +from dendritex.ions import Calcium __all__ = [ - 'CalciumChannel', - - 'ICaN_IS2008', - 'ICaT_HM1992', - 'ICaT_HP1992', - 'ICaHT_HM1992', - 'ICaHT_Re1993', - 'ICaL_IS2008', - "ICav12_Ma2020", - "ICav13_Ma2020", - "ICav23_Ma2020", - "ICav31_Ma2020", - 'ICaGrc_Ma2020', + 'CalciumChannel', + + 'ICaN_IS2008', + 'ICaT_HM1992', + 'ICaT_HP1992', + 'ICaHT_HM1992', + 'ICaHT_Re1993', + 'ICaL_IS2008', + "ICav12_Ma2020", + "ICav13_Ma2020", + "ICav23_Ma2020", + "ICav31_Ma2020", + 'ICaGrc_Ma2020', ] class CalciumChannel(Channel): - """Base class for Calcium ion channels.""" + """Base class for Calcium ion channels.""" - __module__ = 'dendritex.channels' + __module__ = 'dendritex.channels' - root_type = Calcium + root_type = Calcium - def before_integral(self, V, Ca: IonInfo): - pass + def before_integral(self, V, Ca: IonInfo): + pass - def after_integral(self, V, Ca: IonInfo): - pass + def post_derivative(self, V, Ca: IonInfo): + pass - def compute_derivative(self, V, Ca: IonInfo): - pass + def compute_derivative(self, V, Ca: IonInfo): + pass - def current(self, V, Ca: IonInfo): - raise NotImplementedError + def current(self, V, Ca: IonInfo): + raise NotImplementedError - def init_state(self, V, Ca: IonInfo, batch_size: int = None): - pass + def init_state(self, V, Ca: IonInfo, batch_size: int = None): + pass - def reset_state(self, V, Ca: IonInfo, batch_size: int = None): - pass + def reset_state(self, V, Ca: IonInfo, batch_size: int = None): + pass class ICaN_IS2008(CalciumChannel): - r"""The calcium-activated non-selective cation channel model - proposed by (Inoue & Strowbridge, 2008) [2]_. - - The dynamics of the calcium-activated non-selective cation channel model [1]_ [2]_ is given by: - - .. math:: - - \begin{aligned} - I_{CAN} &=g_{\mathrm{max}} M\left([Ca^{2+}]_{i}\right) p \left(V-E\right)\\ - &M\left([Ca^{2+}]_{i}\right) ={[Ca^{2+}]_{i} \over 0.2+[Ca^{2+}]_{i}} \\ - &{dp \over dt} = {\phi \cdot (p_{\infty}-p)\over \tau_p} \\ - &p_{\infty} = {1.0 \over 1 + \exp(-(V + 43) / 5.2)} \\ - &\tau_{p} = {2.7 \over \exp(-(V + 55) / 15) + \exp((V + 55) / 15)} + 1.6 - \end{aligned} - - where :math:`\phi` is the temperature factor. - - Parameters - ---------- - g_max : float - The maximal conductance density (:math:`mS/cm^2`). - E : float - The reversal potential (mV). - phi : float - The temperature factor. - - References - ---------- - - .. [1] Destexhe, Alain, et al. "A model of spindle rhythmicity in the isolated - thalamic reticular nucleus." Journal of neurophysiology 72.2 (1994): 803-818. - .. [2] Inoue T, Strowbridge BW (2008) Transient activity induces a long-lasting - increase in the excitability of olfactory bulb interneurons. - J Neurophysiol 99: 187–199. - """ - __module__ = 'dendritex.channels' - - root_type = Calcium - - def __init__( - self, - size: bst.typing.Size, - E: Union[bst.typing.ArrayLike, Callable] = 10. * u.mV, - g_max: Union[bst.typing.ArrayLike, Callable] = 1. * (u.mS / u.cm ** 2), - phi: Union[bst.typing.ArrayLike, Callable] = 1., - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): - super().__init__( - size, - name=name, - mode=mode - ) - - # parameters - self.E = bst.init.param(E, self.varshape, allow_none=False) - self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) - self.phi = bst.init.param(phi, self.varshape, allow_none=False) - - def init_state(self, V, Ca: IonInfo, batch_size: int = None): - self.p = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) - - def reset_state(self, V, Ca, batch_size=None): - V = V.to_decimal(u.mV) - self.p.value = 1.0 / (1 + u.math.exp(-(V + 43.) / 5.2)) - - def compute_derivative(self, V, Ca): - V = V.to_decimal(u.mV) - phi_p = 1.0 / (1 + u.math.exp(-(V + 43.) / 5.2)) - p_inf = 2.7 / (u.math.exp(-(V + 55.) / 15.) + u.math.exp((V + 55.) / 15.)) + 1.6 - self.p.derivative = self.phi * (phi_p - self.p.value) / p_inf / u.ms - - def current(self, V, Ca): - M = Ca.C / (Ca.C + 0.2 * u.mM) - g = self.g_max * M * self.p.value - return g * (self.E - V) + r"""The calcium-activated non-selective cation channel model + proposed by (Inoue & Strowbridge, 2008) [2]_. + + The dynamics of the calcium-activated non-selective cation channel model [1]_ [2]_ is given by: + + .. math:: + + \begin{aligned} + I_{CAN} &=g_{\mathrm{max}} M\left([Ca^{2+}]_{i}\right) p \left(V-E\right)\\ + &M\left([Ca^{2+}]_{i}\right) ={[Ca^{2+}]_{i} \over 0.2+[Ca^{2+}]_{i}} \\ + &{dp \over dt} = {\phi \cdot (p_{\infty}-p)\over \tau_p} \\ + &p_{\infty} = {1.0 \over 1 + \exp(-(V + 43) / 5.2)} \\ + &\tau_{p} = {2.7 \over \exp(-(V + 55) / 15) + \exp((V + 55) / 15)} + 1.6 + \end{aligned} + + where :math:`\phi` is the temperature factor. + + Parameters + ---------- + g_max : float + The maximal conductance density (:math:`mS/cm^2`). + E : float + The reversal potential (mV). + phi : float + The temperature factor. + + References + ---------- + + .. [1] Destexhe, Alain, et al. "A model of spindle rhythmicity in the isolated + thalamic reticular nucleus." Journal of neurophysiology 72.2 (1994): 803-818. + .. [2] Inoue T, Strowbridge BW (2008) Transient activity induces a long-lasting + increase in the excitability of olfactory bulb interneurons. + J Neurophysiol 99: 187–199. + """ + __module__ = 'dendritex.channels' + + root_type = Calcium + + def __init__( + self, + size: bst.typing.Size, + E: Union[bst.typing.ArrayLike, Callable] = 10. * u.mV, + g_max: Union[bst.typing.ArrayLike, Callable] = 1. * (u.mS / u.cm ** 2), + phi: Union[bst.typing.ArrayLike, Callable] = 1., + name: Optional[str] = None, + ): + super().__init__(size=size, name=name, ) + # parameters + self.E = bst.init.param(E, self.varshape, allow_none=False) + self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) + self.phi = bst.init.param(phi, self.varshape, allow_none=False) + + def init_state(self, V, Ca: IonInfo, batch_size: int = None): + self.p = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) + + def reset_state(self, V, Ca, batch_size=None): + V = V.to_decimal(u.mV) + self.p.value = 1.0 / (1 + u.math.exp(-(V + 43.) / 5.2)) + + def compute_derivative(self, V, Ca): + V = V.to_decimal(u.mV) + phi_p = 1.0 / (1 + u.math.exp(-(V + 43.) / 5.2)) + p_inf = 2.7 / (u.math.exp(-(V + 55.) / 15.) + u.math.exp((V + 55.) / 15.)) + 1.6 + self.p.derivative = self.phi * (phi_p - self.p.value) / p_inf / u.ms + + def current(self, V, Ca): + M = Ca.C / (Ca.C + 0.2 * u.mM) + g = self.g_max * M * self.p.value + return g * (self.E - V) class _ICa_p2q_ss(CalciumChannel): - r"""The calcium current model of :math:`p^2q` current which described with steady-state format. - - The dynamics of this generalized calcium current model is given by: - - .. math:: - - I_{CaT} &= g_{max} p^2 q(V-E_{Ca}) \\ - {dp \over dt} &= {\phi_p \cdot (p_{\infty}-p)\over \tau_p} \\ - {dq \over dt} &= {\phi_q \cdot (q_{\infty}-q) \over \tau_q} \\ - - where :math:`\phi_p` and :math:`\phi_q` are temperature-dependent factors, - :math:`E_{Ca}` is the reversal potential of Calcium channel. - - Parameters - ---------- - size: int, tuple of int - The size of the simulation target. - name: str - The name of the object. - g_max : float, ArrayType, Callable, Initializer - The maximum conductance. - phi_p : float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`p`. - phi_q : float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`q`. - - """ - - def __init__( - self, - size: bst.typing.Size, - phi_p: Union[bst.typing.ArrayLike, Callable] = 3., - phi_q: Union[bst.typing.ArrayLike, Callable] = 3., - g_max: Union[bst.typing.ArrayLike, Callable] = 2. * (u.mS / u.cm ** 2), - mode: Optional[bst.mixin.Mode] = None, - name: Optional[str] = None - ): - super().__init__( - size, - name=name, - mode=mode, - ) - - # parameters - self.phi_p = bst.init.param(phi_p, self.varshape, allow_none=False) - self.phi_q = bst.init.param(phi_q, self.varshape, allow_none=False) - self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) - - def init_state(self, V, Ca: IonInfo, batch_size: int = None): - self.p = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) - self.q = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) - - def reset_state(self, V, Ca, batch_size=None): - self.p.value = self.f_p_inf(V) - self.q.value = self.f_q_inf(V) - if batch_size is not None: - assert self.p.value.shape[0] == batch_size - assert self.q.value.shape[0] == batch_size - - def compute_derivative(self, V, Ca): - self.p.derivative = self.phi_p * (self.f_p_inf(V) - self.p.value) / self.f_p_tau(V) / u.ms - self.q.derivative = self.phi_q * (self.f_q_inf(V) - self.q.value) / self.f_q_tau(V) / u.ms - - def current(self, V, Ca): - return self.g_max * self.p.value * self.p.value * self.q.value * (Ca.E - V) - - def f_p_inf(self, V): - raise NotImplementedError - - def f_p_tau(self, V): - raise NotImplementedError - - def f_q_inf(self, V): - raise NotImplementedError - - def f_q_tau(self, V): - raise NotImplementedError + r"""The calcium current model of :math:`p^2q` current which described with steady-state format. + + The dynamics of this generalized calcium current model is given by: + + .. math:: + + I_{CaT} &= g_{max} p^2 q(V-E_{Ca}) \\ + {dp \over dt} &= {\phi_p \cdot (p_{\infty}-p)\over \tau_p} \\ + {dq \over dt} &= {\phi_q \cdot (q_{\infty}-q) \over \tau_q} \\ + + where :math:`\phi_p` and :math:`\phi_q` are temperature-dependent factors, + :math:`E_{Ca}` is the reversal potential of Calcium channel. + + Parameters + ---------- + size: int, tuple of int + The size of the simulation target. + name: str + The name of the object. + g_max : float, ArrayType, Callable, Initializer + The maximum conductance. + phi_p : float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`p`. + phi_q : float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`q`. + + """ + + def __init__( + self, + size: bst.typing.Size, + phi_p: Union[bst.typing.ArrayLike, Callable] = 3., + phi_q: Union[bst.typing.ArrayLike, Callable] = 3., + g_max: Union[bst.typing.ArrayLike, Callable] = 2. * (u.mS / u.cm ** 2), + name: Optional[str] = None + ): + super().__init__(size=size, name=name, ) + + # parameters + self.phi_p = bst.init.param(phi_p, self.varshape, allow_none=False) + self.phi_q = bst.init.param(phi_q, self.varshape, allow_none=False) + self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) + + def init_state(self, V, Ca: IonInfo, batch_size: int = None): + self.p = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) + self.q = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) + + def reset_state(self, V, Ca, batch_size=None): + self.p.value = self.f_p_inf(V) + self.q.value = self.f_q_inf(V) + if batch_size is not None: + assert self.p.value.shape[0] == batch_size + assert self.q.value.shape[0] == batch_size + + def compute_derivative(self, V, Ca): + self.p.derivative = self.phi_p * (self.f_p_inf(V) - self.p.value) / self.f_p_tau(V) / u.ms + self.q.derivative = self.phi_q * (self.f_q_inf(V) - self.q.value) / self.f_q_tau(V) / u.ms + + def current(self, V, Ca): + return self.g_max * self.p.value * self.p.value * self.q.value * (Ca.E - V) + + def f_p_inf(self, V): + raise NotImplementedError + + def f_p_tau(self, V): + raise NotImplementedError + + def f_q_inf(self, V): + raise NotImplementedError + + def f_q_tau(self, V): + raise NotImplementedError class _ICa_p2q_markov(CalciumChannel): - r"""The calcium current model of :math:`p^2q` current which described with first-order Markov chain. - - The dynamics of this generalized calcium current model is given by: - - .. math:: - - I_{CaT} &= g_{max} p^2 q(V-E_{Ca}) \\ - {dp \over dt} &= \phi_p (\alpha_p(V)(1-p) - \beta_p(V)p) \\ - {dq \over dt} &= \phi_q (\alpha_q(V)(1-q) - \beta_q(V)q) \\ - - where :math:`\phi_p` and :math:`\phi_q` are temperature-dependent factors, - :math:`E_{Ca}` is the reversal potential of Calcium channel. - - Parameters - ---------- - size: int, tuple of int - The size of the simulation target. - name: str - The name of the object. - g_max : float, ArrayType, Callable, Initializer - The maximum conductance. - phi_p : float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`p`. - phi_q : float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`q`. - - """ - - def __init__( - self, - size: bst.typing.Size, - phi_p: Union[bst.typing.ArrayLike, Callable] = 3., - phi_q: Union[bst.typing.ArrayLike, Callable] = 3., - g_max: Union[bst.typing.ArrayLike, Callable] = 2. * (u.mS / u.cm ** 2), - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): - super().__init__( - size, - name=name, - mode=mode - ) - - # parameters - self.phi_p = bst.init.param(phi_p, self.varshape, allow_none=False) - self.phi_q = bst.init.param(phi_q, self.varshape, allow_none=False) - self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) - - def init_state(self, V, Ca: IonInfo, batch_size: int = None): - self.p = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) - self.q = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) - - def reset_state(self, V, Ca, batch_size=None): - alpha, beta = self.f_p_alpha(V), self.f_p_beta(V) - self.p.value = alpha / (alpha + beta) - alpha, beta = self.f_q_alpha(V), self.f_q_beta(V) - self.q.value = alpha / (alpha + beta) - - def compute_derivative(self, V, Ca): - p = self.p.value - q = self.q.value - self.p.derivative = self.phi_p * (self.f_p_alpha(V) * (1 - p) - self.f_p_beta(V) * p) / u.ms - self.q.derivative = self.phi_q * (self.f_q_alpha(V) * (1 - q) - self.f_q_beta(V) * q) / u.ms - - def current(self, V, Ca): - return self.g_max * self.p.value * self.p.value * self.q.value * (Ca.E - V) - - def f_p_alpha(self, V): - raise NotImplementedError - - def f_p_beta(self, V): - raise NotImplementedError - - def f_q_alpha(self, V): - raise NotImplementedError - - def f_q_beta(self, V): - raise NotImplementedError + r"""The calcium current model of :math:`p^2q` current which described with first-order Markov chain. + + The dynamics of this generalized calcium current model is given by: + + .. math:: + + I_{CaT} &= g_{max} p^2 q(V-E_{Ca}) \\ + {dp \over dt} &= \phi_p (\alpha_p(V)(1-p) - \beta_p(V)p) \\ + {dq \over dt} &= \phi_q (\alpha_q(V)(1-q) - \beta_q(V)q) \\ + + where :math:`\phi_p` and :math:`\phi_q` are temperature-dependent factors, + :math:`E_{Ca}` is the reversal potential of Calcium channel. + + Parameters + ---------- + size: int, tuple of int + The size of the simulation target. + name: str + The name of the object. + g_max : float, ArrayType, Callable, Initializer + The maximum conductance. + phi_p : float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`p`. + phi_q : float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`q`. + + """ + + def __init__( + self, + size: bst.typing.Size, + phi_p: Union[bst.typing.ArrayLike, Callable] = 3., + phi_q: Union[bst.typing.ArrayLike, Callable] = 3., + g_max: Union[bst.typing.ArrayLike, Callable] = 2. * (u.mS / u.cm ** 2), + name: Optional[str] = None, + ): + super().__init__(size=size, name=name, ) + + # parameters + self.phi_p = bst.init.param(phi_p, self.varshape, allow_none=False) + self.phi_q = bst.init.param(phi_q, self.varshape, allow_none=False) + self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) + + def init_state(self, V, Ca: IonInfo, batch_size: int = None): + self.p = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) + self.q = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) + + def reset_state(self, V, Ca, batch_size=None): + alpha, beta = self.f_p_alpha(V), self.f_p_beta(V) + self.p.value = alpha / (alpha + beta) + alpha, beta = self.f_q_alpha(V), self.f_q_beta(V) + self.q.value = alpha / (alpha + beta) + + def compute_derivative(self, V, Ca): + p = self.p.value + q = self.q.value + self.p.derivative = self.phi_p * (self.f_p_alpha(V) * (1 - p) - self.f_p_beta(V) * p) / u.ms + self.q.derivative = self.phi_q * (self.f_q_alpha(V) * (1 - q) - self.f_q_beta(V) * q) / u.ms + + def current(self, V, Ca): + return self.g_max * self.p.value * self.p.value * self.q.value * (Ca.E - V) + + def f_p_alpha(self, V): + raise NotImplementedError + + def f_p_beta(self, V): + raise NotImplementedError + + def f_q_alpha(self, V): + raise NotImplementedError + + def f_q_beta(self, V): + raise NotImplementedError class ICaT_HM1992(_ICa_p2q_ss): - r""" - The low-threshold T-type calcium current model proposed by (Huguenard & McCormick, 1992) [1]_. - - The dynamics of the low-threshold T-type calcium current model [1]_ is given by: - - .. math:: - - I_{CaT} &= g_{max} p^2 q(V-E_{Ca}) \\ - {dp \over dt} &= {\phi_p \cdot (p_{\infty}-p)\over \tau_p} \\ - &p_{\infty} = {1 \over 1+\exp [-(V+59-V_{sh}) / 6.2]} \\ - &\tau_{p} = 0.612 + {1 \over \exp [-(V+132.-V_{sh}) / 16.7]+\exp [(V+16.8-V_{sh}) / 18.2]} \\ - {dq \over dt} &= {\phi_q \cdot (q_{\infty}-q) \over \tau_q} \\ - &q_{\infty} = {1 \over 1+\exp [(V+83-V_{sh}) / 4]} \\ - & \begin{array}{l} \tau_{q} = \exp \left(\frac{V+467-V_{sh}}{66.6}\right) \quad V< (-80 +V_{sh})\, mV \\ - \tau_{q} = \exp \left(\frac{V+22-V_{sh}}{-10.5}\right)+28 \quad V \geq (-80 + V_{sh})\, mV \end{array} - - where :math:`\phi_p = 3.55^{\frac{T-24}{10}}` and :math:`\phi_q = 3^{\frac{T-24}{10}}` - are temperature-dependent factors (:math:`T` is the temperature in Celsius), - :math:`E_{Ca}` is the reversal potential of Calcium channel. - - Parameters - ---------- - T : float, ArrayType - The temperature. - T_base_p : float, ArrayType - The brainpy_object temperature factor of :math:`p` channel. - T_base_q : float, ArrayType - The brainpy_object temperature factor of :math:`q` channel. - g_max : float, ArrayType, Callable, Initializer - The maximum conductance. - V_sh : float, ArrayType, Callable, Initializer - The membrane potential shift. - phi_p : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`p`. - phi_q : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`q`. - - References - ---------- - - .. [1] Huguenard JR, McCormick DA (1992) Simulation of the currents involved in - rhythmic oscillations in thalamic relay neurons. J Neurophysiol 68:1373–1383. - - See Also - -------- - ICa_p2q_form - """ - __module__ = 'dendritex.channels' - - def __init__( - self, - size: bst.typing.Size, - T: bst.typing.ArrayLike = 36., - T_base_p: bst.typing.ArrayLike = 3.55, - T_base_q: bst.typing.ArrayLike = 3., - g_max: Union[bst.typing.ArrayLike, Callable] = 2. * (u.mS / u.cm ** 2), - V_sh: Union[bst.typing.ArrayLike, Callable] = -3. * u.mV, - phi_p: Union[bst.typing.ArrayLike, Callable] = None, - phi_q: Union[bst.typing.ArrayLike, Callable] = None, - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): - phi_p = T_base_p ** ((T - 24) / 10) if phi_p is None else phi_p - phi_q = T_base_q ** ((T - 24) / 10) if phi_q is None else phi_q - super().__init__( - size, - name=name, - g_max=g_max, - phi_p=phi_p, - phi_q=phi_q, - mode=mode - ) - - # parameters - self.T = bst.init.param(T, self.varshape, allow_none=False) - self.T_base_p = bst.init.param(T_base_p, self.varshape, allow_none=False) - self.T_base_q = bst.init.param(T_base_q, self.varshape, allow_none=False) - self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) - - def f_p_inf(self, V): - V = (V - self.V_sh).to_decimal(u.mV) - return 1. / (1 + u.math.exp(-(V + 59.) / 6.2)) - - def f_p_tau(self, V): - V = (V - self.V_sh).to_decimal(u.mV) - return 1. / (u.math.exp(-(V + 132.) / 16.7) + - u.math.exp((V + 16.8) / 18.2)) + 0.612 - - def f_q_inf(self, V): - V = (V - self.V_sh).to_decimal(u.mV) - return 1. / (1. + u.math.exp((V + 83.) / 4.0)) - - def f_q_tau(self, V): - V = (V - self.V_sh).to_decimal(u.mV) - return u.math.where(V >= -80., - u.math.exp(-(V + 22.) / 10.5) + 28., - u.math.exp((V + 467.) / 66.6)) + r""" + The low-threshold T-type calcium current model proposed by (Huguenard & McCormick, 1992) [1]_. + + The dynamics of the low-threshold T-type calcium current model [1]_ is given by: + + .. math:: + + I_{CaT} &= g_{max} p^2 q(V-E_{Ca}) \\ + {dp \over dt} &= {\phi_p \cdot (p_{\infty}-p)\over \tau_p} \\ + &p_{\infty} = {1 \over 1+\exp [-(V+59-V_{sh}) / 6.2]} \\ + &\tau_{p} = 0.612 + {1 \over \exp [-(V+132.-V_{sh}) / 16.7]+\exp [(V+16.8-V_{sh}) / 18.2]} \\ + {dq \over dt} &= {\phi_q \cdot (q_{\infty}-q) \over \tau_q} \\ + &q_{\infty} = {1 \over 1+\exp [(V+83-V_{sh}) / 4]} \\ + & \begin{array}{l} \tau_{q} = \exp \left(\frac{V+467-V_{sh}}{66.6}\right) \quad V< (-80 +V_{sh})\, mV \\ + \tau_{q} = \exp \left(\frac{V+22-V_{sh}}{-10.5}\right)+28 \quad V \geq (-80 + V_{sh})\, mV \end{array} + + where :math:`\phi_p = 3.55^{\frac{T-24}{10}}` and :math:`\phi_q = 3^{\frac{T-24}{10}}` + are temperature-dependent factors (:math:`T` is the temperature in Celsius), + :math:`E_{Ca}` is the reversal potential of Calcium channel. + + Parameters + ---------- + T : float, ArrayType + The temperature. + T_base_p : float, ArrayType + The brainpy_object temperature factor of :math:`p` channel. + T_base_q : float, ArrayType + The brainpy_object temperature factor of :math:`q` channel. + g_max : float, ArrayType, Callable, Initializer + The maximum conductance. + V_sh : float, ArrayType, Callable, Initializer + The membrane potential shift. + phi_p : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`p`. + phi_q : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`q`. + + References + ---------- + + .. [1] Huguenard JR, McCormick DA (1992) Simulation of the currents involved in + rhythmic oscillations in thalamic relay neurons. J Neurophysiol 68:1373–1383. + + See Also + -------- + ICa_p2q_form + """ + __module__ = 'dendritex.channels' + + def __init__( + self, + size: bst.typing.Size, + T: bst.typing.ArrayLike = 36., + T_base_p: bst.typing.ArrayLike = 3.55, + T_base_q: bst.typing.ArrayLike = 3., + g_max: Union[bst.typing.ArrayLike, Callable] = 2. * (u.mS / u.cm ** 2), + V_sh: Union[bst.typing.ArrayLike, Callable] = -3. * u.mV, + phi_p: Union[bst.typing.ArrayLike, Callable] = None, + phi_q: Union[bst.typing.ArrayLike, Callable] = None, + name: Optional[str] = None, + ): + phi_p = T_base_p ** ((T - 24) / 10) if phi_p is None else phi_p + phi_q = T_base_q ** ((T - 24) / 10) if phi_q is None else phi_q + super().__init__( + size, + name=name, + g_max=g_max, + phi_p=phi_p, + phi_q=phi_q, + ) + + # parameters + self.T = bst.init.param(T, self.varshape, allow_none=False) + self.T_base_p = bst.init.param(T_base_p, self.varshape, allow_none=False) + self.T_base_q = bst.init.param(T_base_q, self.varshape, allow_none=False) + self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) + + def f_p_inf(self, V): + V = (V - self.V_sh).to_decimal(u.mV) + return 1. / (1 + u.math.exp(-(V + 59.) / 6.2)) + + def f_p_tau(self, V): + V = (V - self.V_sh).to_decimal(u.mV) + return 1. / (u.math.exp(-(V + 132.) / 16.7) + + u.math.exp((V + 16.8) / 18.2)) + 0.612 + + def f_q_inf(self, V): + V = (V - self.V_sh).to_decimal(u.mV) + return 1. / (1. + u.math.exp((V + 83.) / 4.0)) + + def f_q_tau(self, V): + V = (V - self.V_sh).to_decimal(u.mV) + return u.math.where(V >= -80., + u.math.exp(-(V + 22.) / 10.5) + 28., + u.math.exp((V + 467.) / 66.6)) class ICaT_HP1992(_ICa_p2q_ss): - r"""The low-threshold T-type calcium current model for thalamic - reticular nucleus proposed by (Huguenard & Prince, 1992) [1]_. - - The dynamics of the low-threshold T-type calcium current model in thalamic - reticular nucleus neurons [1]_ is given by: - - .. math:: - - I_{CaT} &= g_{max} p^2 q(V-E_{Ca}) \\ - {dp \over dt} &= {\phi_p \cdot (p_{\infty}-p)\over \tau_p} \\ - &p_{\infty} = {1 \over 1+\exp [-(V+52-V_{sh}) / 7.4]} \\ - &\tau_{p} = 3+{1 \over \exp [(V+27-V_{sh}) / 10]+\exp [-(V+102-V_{sh}) / 15]} \\ - {dq \over dt} &= {\phi_q \cdot (q_{\infty}-q) \over \tau_q} \\ - &q_{\infty} = {1 \over 1+\exp [(V+80-V_{sh}) / 5]} \\ - & \tau_q = 85+ {1 \over \exp [(V+48-V_{sh}) / 4]+\exp [-(V+407-V_{sh}) / 50]} - - where :math:`\phi_p = 5^{\frac{T-24}{10}}` and :math:`\phi_q = 3^{\frac{T-24}{10}}` - are temperature-dependent factors (:math:`T` is the temperature in Celsius), - :math:`E_{Ca}` is the reversal potential of Calcium channel. - - Parameters - ---------- - T : float, ArrayType - The temperature. - T_base_p : float, ArrayType - The brainpy_object temperature factor of :math:`p` channel. - T_base_q : float, ArrayType - The brainpy_object temperature factor of :math:`q` channel. - g_max : float, ArrayType, Callable, Initializer - The maximum conductance. - V_sh : float, ArrayType, Callable, Initializer - The membrane potential shift. - phi_p : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`p`. - phi_q : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`q`. - - References - ---------- - - .. [1] Huguenard JR, Prince DA (1992) A novel T-type current underlies - prolonged Ca2+- dependent burst firing in GABAergic neurons of rat - thalamic reticular nucleus. J Neurosci 12: 3804–3817. - - See Also - -------- - ICa_p2q_form - """ - __module__ = 'dendritex.channels' - - def __init__( - self, - size: bst.typing.Size, - T: bst.typing.ArrayLike = 36., - T_base_p: bst.typing.ArrayLike = 5., - T_base_q: bst.typing.ArrayLike = 3., - g_max: Union[bst.typing.ArrayLike, Callable] = 1.75 * (u.mS / u.cm ** 2), - V_sh: Union[bst.typing.ArrayLike, Callable] = -3. * u.mV, - phi_p: Union[bst.typing.ArrayLike, Callable] = None, - phi_q: Union[bst.typing.ArrayLike, Callable] = None, - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): - phi_p = T_base_p ** ((T - 24) / 10) if phi_p is None else phi_p - phi_q = T_base_q ** ((T - 24) / 10) if phi_q is None else phi_q - super().__init__( - size, - name=name, - g_max=g_max, - phi_p=phi_p, - phi_q=phi_q, - mode=mode - ) - - # parameters - self.T = bst.init.param(T, self.varshape, allow_none=False) - self.T_base_p = bst.init.param(T_base_p, self.varshape, allow_none=False) - self.T_base_q = bst.init.param(T_base_q, self.varshape, allow_none=False) - self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) - - def f_p_inf(self, V): - V = (V - self.V_sh).to_decimal(u.mV) - return 1. / (1. + u.math.exp(-(V + 52.) / 7.4)) - - def f_p_tau(self, V): - V = (V - self.V_sh).to_decimal(u.mV) - return 3. + 1. / (u.math.exp((V + 27.) / 10.) + - u.math.exp(-(V + 102.) / 15.)) - - def f_q_inf(self, V): - V = (V - self.V_sh).to_decimal(u.mV) - return 1. / (1. + u.math.exp((V + 80.) / 5.)) - - def f_q_tau(self, V): - V = (V - self.V_sh).to_decimal(u.mV) - return 85. + 1. / (u.math.exp((V + 48.) / 4.) + - u.math.exp(-(V + 407.) / 50.)) + r"""The low-threshold T-type calcium current model for thalamic + reticular nucleus proposed by (Huguenard & Prince, 1992) [1]_. + + The dynamics of the low-threshold T-type calcium current model in thalamic + reticular nucleus neurons [1]_ is given by: + + .. math:: + + I_{CaT} &= g_{max} p^2 q(V-E_{Ca}) \\ + {dp \over dt} &= {\phi_p \cdot (p_{\infty}-p)\over \tau_p} \\ + &p_{\infty} = {1 \over 1+\exp [-(V+52-V_{sh}) / 7.4]} \\ + &\tau_{p} = 3+{1 \over \exp [(V+27-V_{sh}) / 10]+\exp [-(V+102-V_{sh}) / 15]} \\ + {dq \over dt} &= {\phi_q \cdot (q_{\infty}-q) \over \tau_q} \\ + &q_{\infty} = {1 \over 1+\exp [(V+80-V_{sh}) / 5]} \\ + & \tau_q = 85+ {1 \over \exp [(V+48-V_{sh}) / 4]+\exp [-(V+407-V_{sh}) / 50]} + + where :math:`\phi_p = 5^{\frac{T-24}{10}}` and :math:`\phi_q = 3^{\frac{T-24}{10}}` + are temperature-dependent factors (:math:`T` is the temperature in Celsius), + :math:`E_{Ca}` is the reversal potential of Calcium channel. + + Parameters + ---------- + T : float, ArrayType + The temperature. + T_base_p : float, ArrayType + The brainpy_object temperature factor of :math:`p` channel. + T_base_q : float, ArrayType + The brainpy_object temperature factor of :math:`q` channel. + g_max : float, ArrayType, Callable, Initializer + The maximum conductance. + V_sh : float, ArrayType, Callable, Initializer + The membrane potential shift. + phi_p : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`p`. + phi_q : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`q`. + + References + ---------- + + .. [1] Huguenard JR, Prince DA (1992) A novel T-type current underlies + prolonged Ca2+- dependent burst firing in GABAergic neurons of rat + thalamic reticular nucleus. J Neurosci 12: 3804–3817. + + See Also + -------- + ICa_p2q_form + """ + __module__ = 'dendritex.channels' + + def __init__( + self, + size: bst.typing.Size, + T: bst.typing.ArrayLike = 36., + T_base_p: bst.typing.ArrayLike = 5., + T_base_q: bst.typing.ArrayLike = 3., + g_max: Union[bst.typing.ArrayLike, Callable] = 1.75 * (u.mS / u.cm ** 2), + V_sh: Union[bst.typing.ArrayLike, Callable] = -3. * u.mV, + phi_p: Union[bst.typing.ArrayLike, Callable] = None, + phi_q: Union[bst.typing.ArrayLike, Callable] = None, + name: Optional[str] = None, + ): + phi_p = T_base_p ** ((T - 24) / 10) if phi_p is None else phi_p + phi_q = T_base_q ** ((T - 24) / 10) if phi_q is None else phi_q + super().__init__( + size, + name=name, + g_max=g_max, + phi_p=phi_p, + phi_q=phi_q, + ) + + # parameters + self.T = bst.init.param(T, self.varshape, allow_none=False) + self.T_base_p = bst.init.param(T_base_p, self.varshape, allow_none=False) + self.T_base_q = bst.init.param(T_base_q, self.varshape, allow_none=False) + self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) + + def f_p_inf(self, V): + V = (V - self.V_sh).to_decimal(u.mV) + return 1. / (1. + u.math.exp(-(V + 52.) / 7.4)) + + def f_p_tau(self, V): + V = (V - self.V_sh).to_decimal(u.mV) + return 3. + 1. / (u.math.exp((V + 27.) / 10.) + + u.math.exp(-(V + 102.) / 15.)) + + def f_q_inf(self, V): + V = (V - self.V_sh).to_decimal(u.mV) + return 1. / (1. + u.math.exp((V + 80.) / 5.)) + + def f_q_tau(self, V): + V = (V - self.V_sh).to_decimal(u.mV) + return 85. + 1. / (u.math.exp((V + 48.) / 4.) + + u.math.exp(-(V + 407.) / 50.)) class ICaHT_HM1992(_ICa_p2q_ss): - r"""The high-threshold T-type calcium current model proposed by (Huguenard & McCormick, 1992) [1]_. - - The high-threshold T-type calcium current model is adopted from [1]_. - Its dynamics is given by - - .. math:: - - \begin{aligned} - I_{\mathrm{Ca/HT}} &= g_{\mathrm{max}} p^2 q (V-E_{Ca}) - \\ - {dp \over dt} &= {\phi_{p} \cdot (p_{\infty} - p) \over \tau_{p}} \\ - &\tau_{p} =\frac{1}{\exp \left(\frac{V+132-V_{sh}}{-16.7}\right)+\exp \left(\frac{V+16.8-V_{sh}}{18.2}\right)}+0.612 \\ - & p_{\infty} = {1 \over 1+exp[-(V+59-V_{sh}) / 6.2]} - \\ - {dq \over dt} &= {\phi_{q} \cdot (q_{\infty} - h) \over \tau_{q}} \\ - & \begin{array}{l} \tau_q = \exp \left(\frac{V+467-V_{sh}}{66.6}\right) \quad V< (-80 +V_{sh})\, mV \\ - \tau_q = \exp \left(\frac{V+22-V_{sh}}{-10.5}\right)+28 \quad V \geq (-80 + V_{sh})\, mV \end{array} \\ - &q_{\infty} = {1 \over 1+exp[(V+83 -V_{shift})/4]} - \end{aligned} - - where :math:`phi_p = 3.55^{\frac{T-24}{10}}` and :math:`phi_q = 3^{\frac{T-24}{10}}` - are temperature-dependent factors (:math:`T` is the temperature in Celsius), - :math:`E_{Ca}` is the reversal potential of Calcium channel. - - Parameters - ---------- - T : float, ArrayType - The temperature. - T_base_p : float, ArrayType - The brainpy_object temperature factor of :math:`p` channel. - T_base_q : float, ArrayType - The brainpy_object temperature factor of :math:`q` channel. - g_max : bst.typing.ArrayLike, Callable - The maximum conductance. - V_sh : bst.typing.ArrayLike, Callable - The membrane potential shift. - - References - ---------- - .. [1] Huguenard JR, McCormick DA (1992) Simulation of the currents involved in - rhythmic oscillations in thalamic relay neurons. J Neurophysiol 68:1373–1383. - - See Also - -------- - ICa_p2q_form - """ - __module__ = 'dendritex.channels' - - def __init__( - self, - size: bst.typing.Size, - T: bst.typing.ArrayLike = 36., - T_base_p: bst.typing.ArrayLike = 3.55, - T_base_q: bst.typing.ArrayLike = 3., - g_max: Union[bst.typing.ArrayLike, Callable] = 2. * (u.mS / u.cm ** 2), - V_sh: Union[bst.typing.ArrayLike, Callable] = 25. * u.mV, - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): - super().__init__( - size, - name=name, - g_max=g_max, - phi_p=T_base_p ** ((T - 24) / 10), - phi_q=T_base_q ** ((T - 24) / 10), - mode=mode - ) - - # parameters - self.T = bst.init.param(T, self.varshape, allow_none=False) - self.T_base_p = bst.init.param(T_base_p, self.varshape, allow_none=False) - self.T_base_q = bst.init.param(T_base_q, self.varshape, allow_none=False) - self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) - - def f_p_inf(self, V): - V = (V - self.V_sh).to_decimal(u.mV) - return 1. / (1. + u.math.exp(-(V + 59.) / 6.2)) - - def f_p_tau(self, V): - V = (V - self.V_sh).to_decimal(u.mV) - return 1. / (u.math.exp(-(V + 132.) / 16.7) + - u.math.exp((V + 16.8) / 18.2)) + 0.612 - - def f_q_inf(self, V): - V = (V - self.V_sh).to_decimal(u.mV) - return 1. / (1. + u.math.exp((V + 83.) / 4.)) - - def f_q_tau(self, V): - V = (V - self.V_sh).to_decimal(u.mV) - return u.math.where(V >= -80., - u.math.exp(-(V + 22.) / 10.5) + 28., - u.math.exp((V + 467.) / 66.6)) + r"""The high-threshold T-type calcium current model proposed by (Huguenard & McCormick, 1992) [1]_. + + The high-threshold T-type calcium current model is adopted from [1]_. + Its dynamics is given by + + .. math:: + + \begin{aligned} + I_{\mathrm{Ca/HT}} &= g_{\mathrm{max}} p^2 q (V-E_{Ca}) + \\ + {dp \over dt} &= {\phi_{p} \cdot (p_{\infty} - p) \over \tau_{p}} \\ + &\tau_{p} =\frac{1}{\exp \left(\frac{V+132-V_{sh}}{-16.7}\right)+\exp \left(\frac{V+16.8-V_{sh}}{18.2}\right)}+0.612 \\ + & p_{\infty} = {1 \over 1+exp[-(V+59-V_{sh}) / 6.2]} + \\ + {dq \over dt} &= {\phi_{q} \cdot (q_{\infty} - h) \over \tau_{q}} \\ + & \begin{array}{l} \tau_q = \exp \left(\frac{V+467-V_{sh}}{66.6}\right) \quad V< (-80 +V_{sh})\, mV \\ + \tau_q = \exp \left(\frac{V+22-V_{sh}}{-10.5}\right)+28 \quad V \geq (-80 + V_{sh})\, mV \end{array} \\ + &q_{\infty} = {1 \over 1+exp[(V+83 -V_{shift})/4]} + \end{aligned} + + where :math:`phi_p = 3.55^{\frac{T-24}{10}}` and :math:`phi_q = 3^{\frac{T-24}{10}}` + are temperature-dependent factors (:math:`T` is the temperature in Celsius), + :math:`E_{Ca}` is the reversal potential of Calcium channel. + + Parameters + ---------- + T : float, ArrayType + The temperature. + T_base_p : float, ArrayType + The brainpy_object temperature factor of :math:`p` channel. + T_base_q : float, ArrayType + The brainpy_object temperature factor of :math:`q` channel. + g_max : bst.typing.ArrayLike, Callable + The maximum conductance. + V_sh : bst.typing.ArrayLike, Callable + The membrane potential shift. + + References + ---------- + .. [1] Huguenard JR, McCormick DA (1992) Simulation of the currents involved in + rhythmic oscillations in thalamic relay neurons. J Neurophysiol 68:1373–1383. + + See Also + -------- + ICa_p2q_form + """ + __module__ = 'dendritex.channels' + + def __init__( + self, + size: bst.typing.Size, + T: bst.typing.ArrayLike = 36., + T_base_p: bst.typing.ArrayLike = 3.55, + T_base_q: bst.typing.ArrayLike = 3., + g_max: Union[bst.typing.ArrayLike, Callable] = 2. * (u.mS / u.cm ** 2), + V_sh: Union[bst.typing.ArrayLike, Callable] = 25. * u.mV, + name: Optional[str] = None, + ): + super().__init__( + size, + name=name, + g_max=g_max, + phi_p=T_base_p ** ((T - 24) / 10), + phi_q=T_base_q ** ((T - 24) / 10), + ) + + # parameters + self.T = bst.init.param(T, self.varshape, allow_none=False) + self.T_base_p = bst.init.param(T_base_p, self.varshape, allow_none=False) + self.T_base_q = bst.init.param(T_base_q, self.varshape, allow_none=False) + self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) + + def f_p_inf(self, V): + V = (V - self.V_sh).to_decimal(u.mV) + return 1. / (1. + u.math.exp(-(V + 59.) / 6.2)) + + def f_p_tau(self, V): + V = (V - self.V_sh).to_decimal(u.mV) + return 1. / (u.math.exp(-(V + 132.) / 16.7) + + u.math.exp((V + 16.8) / 18.2)) + 0.612 + + def f_q_inf(self, V): + V = (V - self.V_sh).to_decimal(u.mV) + return 1. / (1. + u.math.exp((V + 83.) / 4.)) + + def f_q_tau(self, V): + V = (V - self.V_sh).to_decimal(u.mV) + return u.math.where(V >= -80., + u.math.exp(-(V + 22.) / 10.5) + 28., + u.math.exp((V + 467.) / 66.6)) class ICaHT_Re1993(_ICa_p2q_markov): - r"""The high-threshold T-type calcium current model proposed by (Reuveni, et al., 1993) [1]_. - - HVA Calcium current was described for neocortical neurons by Sayer et al. (1990). - Its dynamics is given by (the rate functions are measured under 36 Celsius): - - .. math:: - - \begin{aligned} - I_{L} &=\bar{g}_{L} q^{2} r\left(V-E_{\mathrm{Ca}}\right) \\ - \frac{\mathrm{d} q}{\mathrm{~d} t} &= \phi_p (\alpha_{q}(V)(1-q)-\beta_{q}(V) q) \\ - \frac{\mathrm{d} r}{\mathrm{~d} t} &= \phi_q (\alpha_{r}(V)(1-r)-\beta_{r}(V) r) \\ - \alpha_{q} &=\frac{0.055(-27-V+V_{sh})}{\exp [(-27-V+V_{sh}) / 3.8]-1} \\ - \beta_{q} &=0.94 \exp [(-75-V+V_{sh}) / 17] \\ - \alpha_{r} &=0.000457 \exp [(-13-V+V_{sh}) / 50] \\ - \beta_{r} &=\frac{0.0065}{\exp [(-15-V+V_{sh}) / 28]+1}, - \end{aligned} - - Parameters - ---------- - size: int, tuple of int - The size of the simulation target. - name: str - The name of the object. - g_max : float, ArrayType, Callable, Initializer - The maximum conductance. - V_sh : float, ArrayType, Callable, Initializer - The membrane potential shift. - T : float, ArrayType - The temperature. - T_base_p : float, ArrayType - The brainpy_object temperature factor of :math:`p` channel. - T_base_q : float, ArrayType - The brainpy_object temperature factor of :math:`q` channel. - phi_p : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`p`. - If `None`, :math:`\phi_p = \mathrm{T_base_p}^{\frac{T-23}{10}}`. - phi_q : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`q`. - If `None`, :math:`\phi_q = \mathrm{T_base_q}^{\frac{T-23}{10}}`. - - References - ---------- - .. [1] Reuveni, I., et al. "Stepwise repolarization from Ca2+ plateaus - in neocortical pyramidal cells: evidence for nonhomogeneous - distribution of HVA Ca2+ channels in dendrites." Journal of - Neuroscience 13.11 (1993): 4609-4621. - - """ - __module__ = 'dendritex.channels' - - def __init__( - self, - size: bst.typing.Size, - T: bst.typing.ArrayLike = 36., - T_base_p: bst.typing.ArrayLike = 2.3, - T_base_q: bst.typing.ArrayLike = 2.3, - phi_p: Union[bst.typing.ArrayLike, Callable] = None, - phi_q: Union[bst.typing.ArrayLike, Callable] = None, - g_max: Union[bst.typing.ArrayLike, Callable] = 1. * (u.mS / u.cm ** 2), - V_sh: Union[bst.typing.ArrayLike, Callable] = 0. * u.mV, - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): - phi_p = T_base_p ** ((T - 23.) / 10.) if phi_p is None else phi_p - phi_q = T_base_q ** ((T - 23.) / 10.) if phi_q is None else phi_q - super().__init__( - size, - name=name, - g_max=g_max, - phi_p=phi_p, - phi_q=phi_q, - mode=mode - ) - self.T = bst.init.param(T, self.varshape, allow_none=False) - self.T_base_p = bst.init.param(T_base_p, self.varshape, allow_none=False) - self.T_base_q = bst.init.param(T_base_q, self.varshape, allow_none=False) - self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) - - def f_p_alpha(self, V): - V = (- V + self.V_sh).to_decimal(u.mV) - temp = -27 + V - return 0.055 * temp / (u.math.exp(temp / 3.8) - 1) - - def f_p_beta(self, V): - V = (- V + self.V_sh).to_decimal(u.mV) - return 0.94 * u.math.exp((-75. + V) / 17.) - - def f_q_alpha(self, V): - V = (- V + self.V_sh).to_decimal(u.mV) - return 0.000457 * u.math.exp((-13. + V) / 50.) - - def f_q_beta(self, V): - V = (- V + self.V_sh).to_decimal(u.mV) - return 0.0065 / (u.math.exp((-15. + V) / 28.) + 1.) + r"""The high-threshold T-type calcium current model proposed by (Reuveni, et al., 1993) [1]_. + + HVA Calcium current was described for neocortical neurons by Sayer et al. (1990). + Its dynamics is given by (the rate functions are measured under 36 Celsius): + + .. math:: + + \begin{aligned} + I_{L} &=\bar{g}_{L} q^{2} r\left(V-E_{\mathrm{Ca}}\right) \\ + \frac{\mathrm{d} q}{\mathrm{~d} t} &= \phi_p (\alpha_{q}(V)(1-q)-\beta_{q}(V) q) \\ + \frac{\mathrm{d} r}{\mathrm{~d} t} &= \phi_q (\alpha_{r}(V)(1-r)-\beta_{r}(V) r) \\ + \alpha_{q} &=\frac{0.055(-27-V+V_{sh})}{\exp [(-27-V+V_{sh}) / 3.8]-1} \\ + \beta_{q} &=0.94 \exp [(-75-V+V_{sh}) / 17] \\ + \alpha_{r} &=0.000457 \exp [(-13-V+V_{sh}) / 50] \\ + \beta_{r} &=\frac{0.0065}{\exp [(-15-V+V_{sh}) / 28]+1}, + \end{aligned} + + Parameters + ---------- + size: int, tuple of int + The size of the simulation target. + name: str + The name of the object. + g_max : float, ArrayType, Callable, Initializer + The maximum conductance. + V_sh : float, ArrayType, Callable, Initializer + The membrane potential shift. + T : float, ArrayType + The temperature. + T_base_p : float, ArrayType + The brainpy_object temperature factor of :math:`p` channel. + T_base_q : float, ArrayType + The brainpy_object temperature factor of :math:`q` channel. + phi_p : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`p`. + If `None`, :math:`\phi_p = \mathrm{T_base_p}^{\frac{T-23}{10}}`. + phi_q : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`q`. + If `None`, :math:`\phi_q = \mathrm{T_base_q}^{\frac{T-23}{10}}`. + + References + ---------- + .. [1] Reuveni, I., et al. "Stepwise repolarization from Ca2+ plateaus + in neocortical pyramidal cells: evidence for nonhomogeneous + distribution of HVA Ca2+ channels in dendrites." Journal of + Neuroscience 13.11 (1993): 4609-4621. + + """ + __module__ = 'dendritex.channels' + + def __init__( + self, + size: bst.typing.Size, + T: bst.typing.ArrayLike = 36., + T_base_p: bst.typing.ArrayLike = 2.3, + T_base_q: bst.typing.ArrayLike = 2.3, + phi_p: Union[bst.typing.ArrayLike, Callable] = None, + phi_q: Union[bst.typing.ArrayLike, Callable] = None, + g_max: Union[bst.typing.ArrayLike, Callable] = 1. * (u.mS / u.cm ** 2), + V_sh: Union[bst.typing.ArrayLike, Callable] = 0. * u.mV, + name: Optional[str] = None, + ): + phi_p = T_base_p ** ((T - 23.) / 10.) if phi_p is None else phi_p + phi_q = T_base_q ** ((T - 23.) / 10.) if phi_q is None else phi_q + super().__init__( + size, + name=name, + g_max=g_max, + phi_p=phi_p, + phi_q=phi_q, + ) + self.T = bst.init.param(T, self.varshape, allow_none=False) + self.T_base_p = bst.init.param(T_base_p, self.varshape, allow_none=False) + self.T_base_q = bst.init.param(T_base_q, self.varshape, allow_none=False) + self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) + + def f_p_alpha(self, V): + V = (- V + self.V_sh).to_decimal(u.mV) + temp = -27 + V + return 0.055 * temp / (u.math.exp(temp / 3.8) - 1) + + def f_p_beta(self, V): + V = (- V + self.V_sh).to_decimal(u.mV) + return 0.94 * u.math.exp((-75. + V) / 17.) + + def f_q_alpha(self, V): + V = (- V + self.V_sh).to_decimal(u.mV) + return 0.000457 * u.math.exp((-13. + V) / 50.) + + def f_q_beta(self, V): + V = (- V + self.V_sh).to_decimal(u.mV) + return 0.0065 / (u.math.exp((-15. + V) / 28.) + 1.) class ICaL_IS2008(_ICa_p2q_ss): - r"""The L-type calcium channel model proposed by (Inoue & Strowbridge, 2008) [1]_. - - The L-type calcium channel model is adopted from (Inoue, et, al., 2008) [1]_. - Its dynamics is given by: - - .. math:: - - I_{CaL} &= g_{max} p^2 q(V-E_{Ca}) \\ - {dp \over dt} &= {\phi_p \cdot (p_{\infty}-p)\over \tau_p} \\ - & p_{\infty} = {1 \over 1+\exp [-(V+10-V_{sh}) / 4.]} \\ - & \tau_{p} = 0.4+{0.7 \over \exp [(V+5-V_{sh}) / 15]+\exp [-(V+5-V_{sh}) / 15]} \\ - {dq \over dt} &= {\phi_q \cdot (q_{\infty}-q) \over \tau_q} \\ - & q_{\infty} = {1 \over 1+\exp [(V+25-V_{sh}) / 2]} \\ - & \tau_q = 300 + {100 \over \exp [(V+40-V_{sh}) / 9.5]+\exp [-(V+40-V_{sh}) / 9.5]} - - where :math:`phi_p = 3.55^{\frac{T-24}{10}}` and :math:`phi_q = 3^{\frac{T-24}{10}}` - are temperature-dependent factors (:math:`T` is the temperature in Celsius), - :math:`E_{Ca}` is the reversal potential of Calcium channel. - - Parameters - ---------- - T : float - The temperature. - T_base_p : float - The brainpy_object temperature factor of :math:`p` channel. - T_base_q : float - The brainpy_object temperature factor of :math:`q` channel. - g_max : float - The maximum conductance. - V_sh : float - The membrane potential shift. - - References - ---------- - - .. [1] Inoue, Tsuyoshi, and Ben W. Strowbridge. "Transient activity induces a long-lasting - increase in the excitability of olfactory bulb interneurons." Journal of - neurophysiology 99, no. 1 (2008): 187-199. - - See Also - -------- - ICa_p2q_form - """ - __module__ = 'dendritex.channels' - - def __init__( - self, - size: bst.typing.Size, - T: Union[bst.typing.ArrayLike, Callable] = 36., - T_base_p: Union[bst.typing.ArrayLike, Callable] = 3.55, - T_base_q: Union[bst.typing.ArrayLike, Callable] = 3., - g_max: Union[bst.typing.ArrayLike, Callable] = 1. * (u.mS / u.cm ** 2), - V_sh: Union[bst.typing.ArrayLike, Callable] = 0. * u.mV, - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): - super().__init__( - size, - name=name, - g_max=g_max, - phi_p=T_base_p ** ((T - 24) / 10), - phi_q=T_base_q ** ((T - 24) / 10), - mode=mode - ) - - # parameters - self.T = bst.init.param(T, self.varshape, allow_none=False) - self.T_base_p = bst.init.param(T_base_p, self.varshape, allow_none=False) - self.T_base_q = bst.init.param(T_base_q, self.varshape, allow_none=False) - self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) - - def f_p_inf(self, V): - V = (V - self.V_sh).to_decimal(u.mV) - return 1. / (1 + u.math.exp(-(V + 10.) / 4.)) - - def f_p_tau(self, V): - V = (V - self.V_sh).to_decimal(u.mV) - return 0.4 + .7 / (u.math.exp(-(V + 5.) / 15.) + u.math.exp((V + 5.) / 15.)) - - def f_q_inf(self, V): - V = (V - self.V_sh).to_decimal(u.mV) - return 1. / (1. + u.math.exp((V + 25.) / 2.)) - - def f_q_tau(self, V): - V = (V - self.V_sh).to_decimal(u.mV) - return 300. + 100. / (u.math.exp((V + 40) / 9.5) + u.math.exp(-(V + 40) / 9.5)) + r"""The L-type calcium channel model proposed by (Inoue & Strowbridge, 2008) [1]_. + + The L-type calcium channel model is adopted from (Inoue, et, al., 2008) [1]_. + Its dynamics is given by: + + .. math:: + + I_{CaL} &= g_{max} p^2 q(V-E_{Ca}) \\ + {dp \over dt} &= {\phi_p \cdot (p_{\infty}-p)\over \tau_p} \\ + & p_{\infty} = {1 \over 1+\exp [-(V+10-V_{sh}) / 4.]} \\ + & \tau_{p} = 0.4+{0.7 \over \exp [(V+5-V_{sh}) / 15]+\exp [-(V+5-V_{sh}) / 15]} \\ + {dq \over dt} &= {\phi_q \cdot (q_{\infty}-q) \over \tau_q} \\ + & q_{\infty} = {1 \over 1+\exp [(V+25-V_{sh}) / 2]} \\ + & \tau_q = 300 + {100 \over \exp [(V+40-V_{sh}) / 9.5]+\exp [-(V+40-V_{sh}) / 9.5]} + + where :math:`phi_p = 3.55^{\frac{T-24}{10}}` and :math:`phi_q = 3^{\frac{T-24}{10}}` + are temperature-dependent factors (:math:`T` is the temperature in Celsius), + :math:`E_{Ca}` is the reversal potential of Calcium channel. + + Parameters + ---------- + T : float + The temperature. + T_base_p : float + The brainpy_object temperature factor of :math:`p` channel. + T_base_q : float + The brainpy_object temperature factor of :math:`q` channel. + g_max : float + The maximum conductance. + V_sh : float + The membrane potential shift. + + References + ---------- + + .. [1] Inoue, Tsuyoshi, and Ben W. Strowbridge. "Transient activity induces a long-lasting + increase in the excitability of olfactory bulb interneurons." Journal of + neurophysiology 99, no. 1 (2008): 187-199. + + See Also + -------- + ICa_p2q_form + """ + __module__ = 'dendritex.channels' + + def __init__( + self, + size: bst.typing.Size, + T: Union[bst.typing.ArrayLike, Callable] = 36., + T_base_p: Union[bst.typing.ArrayLike, Callable] = 3.55, + T_base_q: Union[bst.typing.ArrayLike, Callable] = 3., + g_max: Union[bst.typing.ArrayLike, Callable] = 1. * (u.mS / u.cm ** 2), + V_sh: Union[bst.typing.ArrayLike, Callable] = 0. * u.mV, + name: Optional[str] = None, + ): + super().__init__( + size, + name=name, + g_max=g_max, + phi_p=T_base_p ** ((T - 24) / 10), + phi_q=T_base_q ** ((T - 24) / 10), + ) + + # parameters + self.T = bst.init.param(T, self.varshape, allow_none=False) + self.T_base_p = bst.init.param(T_base_p, self.varshape, allow_none=False) + self.T_base_q = bst.init.param(T_base_q, self.varshape, allow_none=False) + self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) + + def f_p_inf(self, V): + V = (V - self.V_sh).to_decimal(u.mV) + return 1. / (1 + u.math.exp(-(V + 10.) / 4.)) + + def f_p_tau(self, V): + V = (V - self.V_sh).to_decimal(u.mV) + return 0.4 + .7 / (u.math.exp(-(V + 5.) / 15.) + u.math.exp((V + 5.) / 15.)) + + def f_q_inf(self, V): + V = (V - self.V_sh).to_decimal(u.mV) + return 1. / (1. + u.math.exp((V + 25.) / 2.)) + + def f_q_tau(self, V): + V = (V - self.V_sh).to_decimal(u.mV) + return 300. + 100. / (u.math.exp((V + 40) / 9.5) + u.math.exp(-(V + 40) / 9.5)) class ICav12_Ma2020(CalciumChannel): - r""" - : model from Evans et al 2013, transferred from GENESIS to NEURON by Beining et al (2016), "A novel comprehensive and consistent electrophysiologcal model of dentate granule cells" - : also added Calcium dependent inactivation - """ - - __module__ = 'dendritex.channels' - - root_type = Calcium - - def __init__( - self, - size: bst.typing.Size, - g_max: Union[bst.typing.ArrayLike, Callable] = 0 * (u.mS / u.cm ** 2), - V_sh: Union[bst.typing.ArrayLike, Callable] = 0 * u.mV, - T_base: bst.typing.ArrayLike = 3, - T: bst.typing.ArrayLike = 22., - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): - super().__init__( - size=size, - name=name, - mode=mode - ) - - # parameters - self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) - self.T = bst.init.param(T, self.varshape, allow_none=False) - self.T_base = bst.init.param(T_base, self.varshape, allow_none=False) - self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) - self.phi = bst.init.param(1., self.varshape, allow_none=False) - - self.kf = 0.0005 - self.VDI = 0.17 - - def init_state(self, V, Ca: IonInfo, batch_size: int = None): - self.m = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) - self.h = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) - self.n = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) - - def reset_state(self, V, Ca, batch_size=None): - self.m.value = self.f_m_inf(V) - self.h.value = self.f_h_inf(V) - self.n.value = self.f_n_inf(V, Ca) - - def compute_derivative(self, V, Ca): - self.m.derivative = self.phi * (self.f_m_inf(V) - self.m.value) / self.f_m_tau(V) / u.ms - self.h.derivative = self.phi * (self.f_h_inf(V) - self.h.value) / self.f_h_tau(V) / u.ms - self.n.derivative = self.phi * (self.f_n_inf(V, Ca) - self.n.value) / self.f_n_tau(V) / u.ms - - def f_m_inf(self, V): - V = V.to_decimal(u.mV) - return 1 / (1 + u.math.exp((V + 8.9) / (-6.7))) - - def f_h_inf(self, V): - V = V.to_decimal(u.mV) - return self.VDI / (1 + u.math.exp((V + 55) / 8)) + (1 - self.VDI) - - def f_n_inf(self, V, Ca): - V = V.to_decimal(u.mV) - return u.math.ones_like(V) * self.kf / (self.kf + Ca.C / u.mM) - - def f_m_tau(self, V): - V = V.to_decimal(u.mV) - mA = 39800 * (V + 8.124) / (u.math.exp((V + 8.124) / 9.005) - 1) - mB = 990 * u.math.exp(V / 31.4) - return 1 / (mA + mB) - - def f_h_tau(self, V): - return 44.3 - - def f_n_tau(self, V): - return 0.5 - - def current(self, V, Ca: IonInfo): - return self.g_max * self.m.value * self.h.value * self.n.value * (Ca.E - V) + r""" + : model from Evans et al 2013, transferred from GENESIS to NEURON by Beining et al (2016), "A novel comprehensive and consistent electrophysiologcal model of dentate granule cells" + : also added Calcium dependent inactivation + """ + + __module__ = 'dendritex.channels' + + root_type = Calcium + + def __init__( + self, + size: bst.typing.Size, + g_max: Union[bst.typing.ArrayLike, Callable] = 0 * (u.mS / u.cm ** 2), + V_sh: Union[bst.typing.ArrayLike, Callable] = 0 * u.mV, + T_base: bst.typing.ArrayLike = 3, + T: bst.typing.ArrayLike = 22., + name: Optional[str] = None, + ): + super().__init__(size=size, name=name, ) + + # parameters + self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) + self.T = bst.init.param(T, self.varshape, allow_none=False) + self.T_base = bst.init.param(T_base, self.varshape, allow_none=False) + self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) + self.phi = bst.init.param(1., self.varshape, allow_none=False) + + self.kf = 0.0005 + self.VDI = 0.17 + + def init_state(self, V, Ca: IonInfo, batch_size: int = None): + self.m = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) + self.h = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) + self.n = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) + + def reset_state(self, V, Ca, batch_size=None): + self.m.value = self.f_m_inf(V) + self.h.value = self.f_h_inf(V) + self.n.value = self.f_n_inf(V, Ca) + + def compute_derivative(self, V, Ca): + self.m.derivative = self.phi * (self.f_m_inf(V) - self.m.value) / self.f_m_tau(V) / u.ms + self.h.derivative = self.phi * (self.f_h_inf(V) - self.h.value) / self.f_h_tau(V) / u.ms + self.n.derivative = self.phi * (self.f_n_inf(V, Ca) - self.n.value) / self.f_n_tau(V) / u.ms + + def f_m_inf(self, V): + V = V.to_decimal(u.mV) + return 1 / (1 + u.math.exp((V + 8.9) / (-6.7))) + + def f_h_inf(self, V): + V = V.to_decimal(u.mV) + return self.VDI / (1 + u.math.exp((V + 55) / 8)) + (1 - self.VDI) + + def f_n_inf(self, V, Ca): + V = V.to_decimal(u.mV) + return u.math.ones_like(V) * self.kf / (self.kf + Ca.C / u.mM) + + def f_m_tau(self, V): + V = V.to_decimal(u.mV) + mA = 39800 * (V + 8.124) / (u.math.exp((V + 8.124) / 9.005) - 1) + mB = 990 * u.math.exp(V / 31.4) + return 1 / (mA + mB) + + def f_h_tau(self, V): + return 44.3 + + def f_n_tau(self, V): + return 0.5 + + def current(self, V, Ca: IonInfo): + return self.g_max * self.m.value * self.h.value * self.n.value * (Ca.E - V) class ICav13_Ma2020(CalciumChannel): - r""" - : model from Evans et al 2013, transferred from GENESIS to NEURON by Beining et al (2016), "A novel comprehensive and consistent electrophysiologcal model of dentate granule cells" - : also added Calcium dependent inactivation - """ - __module__ = 'dendritex.channels' - - root_type = Calcium - - def __init__( - self, - size: bst.typing.Size, - g_max: Union[bst.typing.ArrayLike, Callable] = 0 * (u.mS / u.cm ** 2), - V_sh: Union[bst.typing.ArrayLike, Callable] = 0 * u.mV, - T_base: bst.typing.ArrayLike = 3, - T: bst.typing.ArrayLike = 22., - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): - super().__init__( - size=size, - name=name, - mode=mode - ) - - # parameters - self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) - self.T = bst.init.param(T, self.varshape, allow_none=False) - self.T_base = bst.init.param(T_base, self.varshape, allow_none=False) - self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) - self.phi = bst.init.param(1., self.varshape, allow_none=False) - - self.kf = 0.0005 - self.VDI = 1 - - def init_state(self, V, Ca: IonInfo, batch_size: int = None): - self.m = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) - self.h = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) - self.n = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) - - def reset_state(self, V, Ca, batch_size=None): - self.m.value = self.f_m_inf(V) - self.h.value = self.f_h_inf(V) - self.n.value = self.f_n_inf(V, Ca) - - def compute_derivative(self, V, Ca): - self.m.derivative = self.phi * (self.f_m_inf(V) - self.m.value) / self.f_m_tau(V) / u.ms - self.h.derivative = self.phi * (self.f_h_inf(V) - self.h.value) / self.f_h_tau(V) / u.ms - self.n.derivative = self.phi * (self.f_n_inf(V, Ca) - self.n.value) / self.f_n_tau(V) / u.ms - - def f_m_inf(self, V): - V = V.to_decimal(u.mV) - return 1.0 / ((u.math.exp((V - (-40.0)) / (-5))) + 1.0) - - def f_h_inf(self, V): - V = V.to_decimal(u.mV) - return self.VDI / ((u.math.exp((V - (-37)) / 5)) + 1.0) + (1 - self.VDI) - - def f_n_inf(self, V, Ca): - V = V.to_decimal(u.mV) - return u.math.ones_like(V) * self.kf / (self.kf + Ca.C / u.mM) - - def f_m_tau(self, V): - V = V.to_decimal(u.mV) - # mA = (39800 * (V + 67.24)) / (u.math.exp((V + 67.24) / 15.005) - 1.0) - mA = 39800 * 15.005 / u.math.exprel((V + 67.24) / 15.005) - mB = 3500 * u.math.exp(V / 31.4) - return 1 / (mA + mB) - - def f_h_tau(self, V): - return 44.3 - - def f_n_tau(self, V): - return 0.5 - - def current(self, V, Ca: IonInfo): - return self.g_max * self.m.value * self.h.value * self.n.value * (Ca.E - V) + r""" + : model from Evans et al 2013, transferred from GENESIS to NEURON by Beining et al (2016), "A novel comprehensive and consistent electrophysiologcal model of dentate granule cells" + : also added Calcium dependent inactivation + """ + __module__ = 'dendritex.channels' + + root_type = Calcium + + def __init__( + self, + size: bst.typing.Size, + g_max: Union[bst.typing.ArrayLike, Callable] = 0 * (u.mS / u.cm ** 2), + V_sh: Union[bst.typing.ArrayLike, Callable] = 0 * u.mV, + T_base: bst.typing.ArrayLike = 3, + T: bst.typing.ArrayLike = 22., + name: Optional[str] = None, + ): + super().__init__(size=size, name=name, ) + + # parameters + self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) + self.T = bst.init.param(T, self.varshape, allow_none=False) + self.T_base = bst.init.param(T_base, self.varshape, allow_none=False) + self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) + self.phi = bst.init.param(1., self.varshape, allow_none=False) + + self.kf = 0.0005 + self.VDI = 1 + + def init_state(self, V, Ca: IonInfo, batch_size: int = None): + self.m = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) + self.h = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) + self.n = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) + + def reset_state(self, V, Ca, batch_size=None): + self.m.value = self.f_m_inf(V) + self.h.value = self.f_h_inf(V) + self.n.value = self.f_n_inf(V, Ca) + + def compute_derivative(self, V, Ca): + self.m.derivative = self.phi * (self.f_m_inf(V) - self.m.value) / self.f_m_tau(V) / u.ms + self.h.derivative = self.phi * (self.f_h_inf(V) - self.h.value) / self.f_h_tau(V) / u.ms + self.n.derivative = self.phi * (self.f_n_inf(V, Ca) - self.n.value) / self.f_n_tau(V) / u.ms + + def f_m_inf(self, V): + V = V.to_decimal(u.mV) + return 1.0 / ((u.math.exp((V - (-40.0)) / (-5))) + 1.0) + + def f_h_inf(self, V): + V = V.to_decimal(u.mV) + return self.VDI / ((u.math.exp((V - (-37)) / 5)) + 1.0) + (1 - self.VDI) + + def f_n_inf(self, V, Ca): + V = V.to_decimal(u.mV) + return u.math.ones_like(V) * self.kf / (self.kf + Ca.C / u.mM) + + def f_m_tau(self, V): + V = V.to_decimal(u.mV) + # mA = (39800 * (V + 67.24)) / (u.math.exp((V + 67.24) / 15.005) - 1.0) + mA = 39800 * 15.005 / u.math.exprel((V + 67.24) / 15.005) + mB = 3500 * u.math.exp(V / 31.4) + return 1 / (mA + mB) + + def f_h_tau(self, V): + return 44.3 + + def f_n_tau(self, V): + return 0.5 + + def current(self, V, Ca: IonInfo): + return self.g_max * self.m.value * self.h.value * self.n.value * (Ca.E - V) class ICav23_Ma2020(CalciumChannel): - r""" - Ca R-type channel with medium threshold for activation. - - : used in distal dendritic regions, together with calH.mod, to help - : the generation of Ca++ spikes in these regions - : uses channel conductance (not permeability) - : written by Yiota Poirazi on 11/13/00 poirazi@LNC.usc.edu - : From car to Cav2_3 - """ - - __module__ = 'dendritex.channels' - - root_type = Calcium - - def __init__( - self, - size: bst.typing.Size, - g_max: Union[bst.typing.ArrayLike, Callable] = 0 * (u.mS / u.cm ** 2), - V_sh: Union[bst.typing.ArrayLike, Callable] = 0 * u.mV, - T_base: bst.typing.ArrayLike = 3, - T: bst.typing.ArrayLike = 22., - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): - super().__init__( - size=size, - name=name, - mode=mode - ) - - # parameters - self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) - self.T = bst.init.param(T, self.varshape, allow_none=False) - self.T_base = bst.init.param(T_base, self.varshape, allow_none=False) - self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) - self.phi = bst.init.param(1., self.varshape, allow_none=False) - - self.eca = 140 * u.mV - - def init_state(self, V, Ca: IonInfo, batch_size: int = None): - self.m = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) - self.h = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) - - def reset_state(self, V, Ca, batch_size=None): - self.m.value = self.f_m_inf(V) - self.h.value = self.f_h_inf(V) - - def compute_derivative(self, V, Ca): - self.m.derivative = self.phi * (self.f_m_inf(V) - self.m.value) / self.f_m_tau(V) / u.ms - self.h.derivative = self.phi * (self.f_h_inf(V) - self.h.value) / self.f_h_tau(V) / u.ms - - def current(self, V, Ca: IonInfo): - return self.g_max * self.m.value ** 3 * self.h.value * (self.eca - V) - - def f_m_inf(self, V): - V = V.to_decimal(u.mV) - return 1 / (1 + u.math.exp((V + 48.5) / (-3))) - - def f_h_inf(self, V): - V = V.to_decimal(u.mV) - return 1 / (1 + u.math.exp((V + 53) / 1.)) - - def f_m_tau(self, V): - return 50. - - def f_h_tau(self, V): - return 5. + r""" + Ca R-type channel with medium threshold for activation. + + : used in distal dendritic regions, together with calH.mod, to help + : the generation of Ca++ spikes in these regions + : uses channel conductance (not permeability) + : written by Yiota Poirazi on 11/13/00 poirazi@LNC.usc.edu + : From car to Cav2_3 + """ + + __module__ = 'dendritex.channels' + + root_type = Calcium + + def __init__( + self, + size: bst.typing.Size, + g_max: Union[bst.typing.ArrayLike, Callable] = 0 * (u.mS / u.cm ** 2), + V_sh: Union[bst.typing.ArrayLike, Callable] = 0 * u.mV, + T_base: bst.typing.ArrayLike = 3, + T: bst.typing.ArrayLike = 22., + name: Optional[str] = None, + ): + super().__init__(size=size, name=name, ) + + # parameters + self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) + self.T = bst.init.param(T, self.varshape, allow_none=False) + self.T_base = bst.init.param(T_base, self.varshape, allow_none=False) + self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) + self.phi = bst.init.param(1., self.varshape, allow_none=False) + + self.eca = 140 * u.mV + + def init_state(self, V, Ca: IonInfo, batch_size: int = None): + self.m = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) + self.h = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) + + def reset_state(self, V, Ca, batch_size=None): + self.m.value = self.f_m_inf(V) + self.h.value = self.f_h_inf(V) + + def compute_derivative(self, V, Ca): + self.m.derivative = self.phi * (self.f_m_inf(V) - self.m.value) / self.f_m_tau(V) / u.ms + self.h.derivative = self.phi * (self.f_h_inf(V) - self.h.value) / self.f_h_tau(V) / u.ms + + def current(self, V, Ca: IonInfo): + return self.g_max * self.m.value ** 3 * self.h.value * (self.eca - V) + + def f_m_inf(self, V): + V = V.to_decimal(u.mV) + return 1 / (1 + u.math.exp((V + 48.5) / (-3))) + + def f_h_inf(self, V): + V = V.to_decimal(u.mV) + return 1 / (1 + u.math.exp((V + 53) / 1.)) + + def f_m_tau(self, V): + return 50. + + def f_h_tau(self, V): + return 5. class ICav31_Ma2020(CalciumChannel): - r""" - Low threshold calcium current Cerebellum Purkinje Cell Model. - - Kinetics adapted to fit the Cav3.1 Iftinca et al 2006, Temperature dependence of T-type Calcium channel gating, NEUROSCIENCE - - Reference: Anwar H, Hong S, De Schutter E (2010) Controlling Ca2+-activated K+ channels with models of Ca2+ buffering in Purkinje cell. Cerebellum - - Article available as Open Access - - PubMed link: http://www.ncbi.nlm.nih.gov/pubmed/20981513 - - Written by Haroon Anwar, Computational Neuroscience Unit, Okinawa Institute of Science and Technology, 2010. - Contact: Haroon Anwar (anwar@oist.jp) - - """ - __module__ = 'dendritex.channels' - root_type = Calcium - - def __init__( - self, - size: bst.typing.Size, - g_max: Union[bst.typing.ArrayLike, Callable] = 2.5e-4 * (u.cm / u.second), - V_sh: Union[bst.typing.ArrayLike, Callable] = 0 * u.mV, - T_base: bst.typing.ArrayLike = 3, - T: bst.typing.ArrayLike = 22., - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): - super().__init__(size=size, name=name, mode=mode) - - # parameters - self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) - self.T = bst.init.param(T, self.varshape, allow_none=False) - self.T_base = bst.init.param(T_base, self.varshape, allow_none=False) - self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) - self.phi = bst.init.param(T_base ** ((T - 37) / 10), self.varshape, allow_none=False) - - self.v0_m_inf = -52 * u.mV - self.v0_h_inf = -72 * u.mV - self.k_m_inf = -5 * u.mV - self.k_h_inf = 7 * u.mV - - self.C_tau_m = 1 - self.A_tau_m = 1.0 - self.v0_tau_m1 = -40 * u.mV - self.v0_tau_m2 = -102 * u.mV - self.k_tau_m1 = 9 * u.mV - self.k_tau_m2 = -18 * u.mV - - self.C_tau_h = 15 - self.A_tau_h = 1.0 - self.v0_tau_h1 = -32 * u.mV - self.k_tau_h1 = 7 * u.mV - self.z = 2 - - def init_state(self, V, Ca: IonInfo, batch_size: int = None): - self.p = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) - self.q = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) - - def reset_state(self, V, Ca, batch_size=None): - self.p.value = self.f_p_inf(V) - self.q.value = self.f_q_inf(V) - - def compute_derivative(self, V, Ca): - self.p.derivative = self.phi * (self.f_p_inf(V) - self.p.value) / self.f_p_tau(V) / u.ms - self.q.derivative = self.phi * (self.f_q_inf(V) - self.q.value) / self.f_q_tau(V) / u.ms - - def f_p_inf(self, V): - return 1.0 / (1 + u.math.exp((V - self.v0_m_inf) / self.k_m_inf)) - - def f_q_inf(self, V): - return 1.0 / (1 + u.math.exp((V - self.v0_h_inf) / self.k_h_inf)) - - def f_p_tau(self, V): - return u.math.where( - V <= -90 * u.mV, - 1., - (self.C_tau_m + - self.A_tau_m / (u.math.exp((V - self.v0_tau_m1) / self.k_tau_m1) + - u.math.exp((V - self.v0_tau_m2) / self.k_tau_m2))) - ) - - def f_q_tau(self, V): - return self.C_tau_h + self.A_tau_h / u.math.exp((V - self.v0_tau_h1) / self.k_tau_h1) - - def ghk(self, V, Ca: IonInfo): - E = (1e-3) * V - zeta = (self.z * u.faraday_constant * E) / (u.gas_constant * (273.15 + self.T) * u.kelvin) - ci = Ca.C - co = 2 * u.mM # co = Ca.C0 for Calciumdetailed - g_1 = 1e-6 * (self.z * u.faraday_constant) * (ci - co * u.math.exp(-zeta)) * (1 + zeta / 2) - g_2 = 1e-6 * (self.z * zeta * u.faraday_constant) * (ci - co * u.math.exp(-zeta)) / (1 - u.math.exp(-zeta)) - return u.math.where(u.math.abs((1 - u.math.exp(-zeta))) <= 1e-6, g_1, g_2) - - def current(self, V, Ca: IonInfo): - return -1e3 * self.g_max * self.p.value ** 2 * self.q.value * self.ghk(V, Ca) + r""" + Low threshold calcium current Cerebellum Purkinje Cell Model. + + Kinetics adapted to fit the Cav3.1 Iftinca et al 2006, Temperature dependence of T-type Calcium channel gating, NEUROSCIENCE + + Reference: Anwar H, Hong S, De Schutter E (2010) Controlling Ca2+-activated K+ channels with models of Ca2+ buffering in Purkinje cell. Cerebellum + + Article available as Open Access + + PubMed link: http://www.ncbi.nlm.nih.gov/pubmed/20981513 + + Written by Haroon Anwar, Computational Neuroscience Unit, Okinawa Institute of Science and Technology, 2010. + Contact: Haroon Anwar (anwar@oist.jp) + + """ + __module__ = 'dendritex.channels' + root_type = Calcium + + def __init__( + self, + size: bst.typing.Size, + g_max: Union[bst.typing.ArrayLike, Callable] = 2.5e-4 * (u.cm / u.second), + V_sh: Union[bst.typing.ArrayLike, Callable] = 0 * u.mV, + T_base: bst.typing.ArrayLike = 3, + T: bst.typing.ArrayLike = 22., + name: Optional[str] = None, + ): + super().__init__(size=size, name=name, ) + + # parameters + self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) + self.T = bst.init.param(T, self.varshape, allow_none=False) + self.T_base = bst.init.param(T_base, self.varshape, allow_none=False) + self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) + self.phi = bst.init.param(T_base ** ((T - 37) / 10), self.varshape, allow_none=False) + + self.v0_m_inf = -52 * u.mV + self.v0_h_inf = -72 * u.mV + self.k_m_inf = -5 * u.mV + self.k_h_inf = 7 * u.mV + + self.C_tau_m = 1 + self.A_tau_m = 1.0 + self.v0_tau_m1 = -40 * u.mV + self.v0_tau_m2 = -102 * u.mV + self.k_tau_m1 = 9 * u.mV + self.k_tau_m2 = -18 * u.mV + + self.C_tau_h = 15 + self.A_tau_h = 1.0 + self.v0_tau_h1 = -32 * u.mV + self.k_tau_h1 = 7 * u.mV + self.z = 2 + + def init_state(self, V, Ca: IonInfo, batch_size: int = None): + self.p = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) + self.q = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) + + def reset_state(self, V, Ca, batch_size=None): + self.p.value = self.f_p_inf(V) + self.q.value = self.f_q_inf(V) + + def compute_derivative(self, V, Ca): + self.p.derivative = self.phi * (self.f_p_inf(V) - self.p.value) / self.f_p_tau(V) / u.ms + self.q.derivative = self.phi * (self.f_q_inf(V) - self.q.value) / self.f_q_tau(V) / u.ms + + def f_p_inf(self, V): + return 1.0 / (1 + u.math.exp((V - self.v0_m_inf) / self.k_m_inf)) + + def f_q_inf(self, V): + return 1.0 / (1 + u.math.exp((V - self.v0_h_inf) / self.k_h_inf)) + + def f_p_tau(self, V): + return u.math.where( + V <= -90 * u.mV, + 1., + (self.C_tau_m + + self.A_tau_m / (u.math.exp((V - self.v0_tau_m1) / self.k_tau_m1) + + u.math.exp((V - self.v0_tau_m2) / self.k_tau_m2))) + ) + + def f_q_tau(self, V): + return self.C_tau_h + self.A_tau_h / u.math.exp((V - self.v0_tau_h1) / self.k_tau_h1) + + def ghk(self, V, Ca: IonInfo): + E = (1e-3) * V + zeta = (self.z * u.faraday_constant * E) / (u.gas_constant * (273.15 + self.T) * u.kelvin) + ci = Ca.C + co = 2 * u.mM # co = Ca.C0 for Calciumdetailed + g_1 = 1e-6 * (self.z * u.faraday_constant) * (ci - co * u.math.exp(-zeta)) * (1 + zeta / 2) + g_2 = 1e-6 * (self.z * zeta * u.faraday_constant) * (ci - co * u.math.exp(-zeta)) / (1 - u.math.exp(-zeta)) + return u.math.where(u.math.abs((1 - u.math.exp(-zeta))) <= 1e-6, g_1, g_2) + + def current(self, V, Ca: IonInfo): + return -1e3 * self.g_max * self.p.value ** 2 * self.q.value * self.ghk(V, Ca) class ICaGrc_Ma2020(CalciumChannel): - r""" - Cerebellum Granule Cell Model. - - COMMENT - CaHVA channel - - Author: E.D'Angelo, T.Nieus, A. Fontana - Last revised: 8.5.2000 - """ - - __module__ = 'dendritex.channels' - - root_type = Calcium - - def __init__( - self, - size: bst.typing.Size, - g_max: Union[bst.typing.ArrayLike, Callable] = 0.46 * (u.mS / u.cm ** 2), - V_sh: Union[bst.typing.ArrayLike, Callable] = 0 * u.mV, - T_base: bst.typing.ArrayLike = 3, - T: bst.typing.ArrayLike = 22., - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): - super().__init__( - size=size, - name=name, - mode=mode - ) - - # parameters - self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) - self.T = bst.init.param(T, self.varshape, allow_none=False) - self.T_base = bst.init.param(T_base, self.varshape, allow_none=False) - self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) - self.phi = bst.init.param(T_base ** ((T - 20) / 10), self.varshape, allow_none=False) - - self.eca = 129.33 * u.mV - - self.Aalpha_s = 0.04944 - self.Kalpha_s = 15.87301587302 - self.V0alpha_s = -29.06 - - self.Abeta_s = 0.08298 - self.Kbeta_s = -25.641 - self.V0beta_s = -18.66 - - self.Aalpha_u = 0.0013 - self.Kalpha_u = -18.183 - self.V0alpha_u = -48 - - self.Abeta_u = 0.0013 - self.Kbeta_u = 83.33 - self.V0beta_u = -48 - - def init_state(self, V, Ca: IonInfo, batch_size: int = None): - self.m = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) - self.h = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) - - def reset_state(self, V, Ca, batch_size=None): - self.m.value = self.f_m_inf(V) - self.h.value = self.f_h_inf(V) - - def compute_derivative(self, V, Ca): - self.m.derivative = self.phi * (self.f_m_inf(V) - self.m.value) / self.f_m_tau(V) / u.ms - self.h.derivative = self.phi * (self.f_h_inf(V) - self.h.value) / self.f_h_tau(V) / u.ms - - def current(self, V, Ca: IonInfo): - return self.g_max * self.m.value ** 2 * self.h.value * (self.eca - V) - - def f_m_inf(self, V): - return self.alpha_m(V) / (self.alpha_m(V) + self.beta_m(V)) - - def f_h_inf(self, V): - return self.alpha_h(V) / (self.alpha_h(V) + self.beta_h(V)) - - def f_m_tau(self, V): - return 1. / (self.alpha_m(V) + self.beta_m(V)) - - def f_h_tau(self, V): - return 1. / (self.alpha_h(V) + self.beta_h(V)) - - def alpha_m(self, V): - V = (V - self.V_sh).to_decimal(u.mV) - return self.Aalpha_s * u.math.exp((V - self.V0alpha_s) / self.Kalpha_s) - - def beta_m(self, V): - V = (V - self.V_sh).to_decimal(u.mV) - return self.Abeta_s * u.math.exp((V - self.V0beta_s) / self.Kbeta_s) - - def alpha_h(self, V): - V = (V - self.V_sh).to_decimal(u.mV) - return self.Aalpha_u * u.math.exp((V - self.V0alpha_u) / self.Kalpha_u) - - def beta_h(self, V): - V = (V - self.V_sh).to_decimal(u.mV) - return self.Abeta_u * u.math.exp((V - self.V0beta_u) / self.Kbeta_u) + r""" + Cerebellum Granule Cell Model. + + COMMENT + CaHVA channel + + Author: E.D'Angelo, T.Nieus, A. Fontana + Last revised: 8.5.2000 + """ + + __module__ = 'dendritex.channels' + + root_type = Calcium + + def __init__( + self, + size: bst.typing.Size, + g_max: Union[bst.typing.ArrayLike, Callable] = 0.46 * (u.mS / u.cm ** 2), + V_sh: Union[bst.typing.ArrayLike, Callable] = 0 * u.mV, + T_base: bst.typing.ArrayLike = 3, + T: bst.typing.ArrayLike = 22., + name: Optional[str] = None, + ): + super().__init__(size=size, name=name, ) + + # parameters + self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) + self.T = bst.init.param(T, self.varshape, allow_none=False) + self.T_base = bst.init.param(T_base, self.varshape, allow_none=False) + self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) + self.phi = bst.init.param(T_base ** ((T - 20) / 10), self.varshape, allow_none=False) + + self.eca = 129.33 * u.mV + + self.Aalpha_s = 0.04944 + self.Kalpha_s = 15.87301587302 + self.V0alpha_s = -29.06 + + self.Abeta_s = 0.08298 + self.Kbeta_s = -25.641 + self.V0beta_s = -18.66 + + self.Aalpha_u = 0.0013 + self.Kalpha_u = -18.183 + self.V0alpha_u = -48 + + self.Abeta_u = 0.0013 + self.Kbeta_u = 83.33 + self.V0beta_u = -48 + + def init_state(self, V, Ca: IonInfo, batch_size: int = None): + self.m = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) + self.h = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) + + def reset_state(self, V, Ca, batch_size=None): + self.m.value = self.f_m_inf(V) + self.h.value = self.f_h_inf(V) + + def compute_derivative(self, V, Ca): + self.m.derivative = self.phi * (self.f_m_inf(V) - self.m.value) / self.f_m_tau(V) / u.ms + self.h.derivative = self.phi * (self.f_h_inf(V) - self.h.value) / self.f_h_tau(V) / u.ms + + def current(self, V, Ca: IonInfo): + return self.g_max * self.m.value ** 2 * self.h.value * (self.eca - V) + + def f_m_inf(self, V): + return self.alpha_m(V) / (self.alpha_m(V) + self.beta_m(V)) + + def f_h_inf(self, V): + return self.alpha_h(V) / (self.alpha_h(V) + self.beta_h(V)) + + def f_m_tau(self, V): + return 1. / (self.alpha_m(V) + self.beta_m(V)) + + def f_h_tau(self, V): + return 1. / (self.alpha_h(V) + self.beta_h(V)) + + def alpha_m(self, V): + V = (V - self.V_sh).to_decimal(u.mV) + return self.Aalpha_s * u.math.exp((V - self.V0alpha_s) / self.Kalpha_s) + + def beta_m(self, V): + V = (V - self.V_sh).to_decimal(u.mV) + return self.Abeta_s * u.math.exp((V - self.V0beta_s) / self.Kbeta_s) + + def alpha_h(self, V): + V = (V - self.V_sh).to_decimal(u.mV) + return self.Aalpha_u * u.math.exp((V - self.V0alpha_u) / self.Kalpha_u) + + def beta_h(self, V): + V = (V - self.V_sh).to_decimal(u.mV) + return self.Abeta_u * u.math.exp((V - self.V0beta_u) / self.Kbeta_u) diff --git a/dendritex/channels/hyperpolarization_activated.py b/dendritex/channels/hyperpolarization_activated.py index 242cfe2..3f1777d 100644 --- a/dendritex/channels/hyperpolarization_activated.py +++ b/dendritex/channels/hyperpolarization_activated.py @@ -11,98 +11,94 @@ import brainstate as bst import brainunit as bu -from .._base import Channel, HHTypedNeuron, State4Integral +from dendritex._base import Channel, HHTypedNeuron, State4Integral __all__ = [ - 'Ih_HM1992', - 'Ih1_Ma2020', - 'Ih2_Ma2020', + 'Ih_HM1992', + 'Ih1_Ma2020', + 'Ih2_Ma2020', ] class Ih_HM1992(Channel): - r""" - The hyperpolarization-activated cation current model propsoed by (Huguenard & McCormick, 1992) [1]_. - - The hyperpolarization-activated cation current model is adopted from - (Huguenard, et, al., 1992) [1]_. Its dynamics is given by: - - .. math:: - - \begin{aligned} - I_h &= g_{\mathrm{max}} p \\ - \frac{dp}{dt} &= \phi \frac{p_{\infty} - p}{\tau_p} \\ - p_{\infty} &=\frac{1}{1+\exp ((V+75) / 5.5)} \\ - \tau_{p} &=\frac{1}{\exp (-0.086 V-14.59)+\exp (0.0701 V-1.87)} - \end{aligned} - - where :math:`\phi=1` is a temperature-dependent factor. - - Parameters - ---------- - g_max : float - The maximal conductance density (:math:`mS/cm^2`). - E : float - The reversal potential (mV). - phi : float - The temperature-dependent factor. - - References - ---------- - .. [1] Huguenard, John R., and David A. McCormick. "Simulation of the currents - involved in rhythmic oscillations in thalamic relay neurons." Journal - of neurophysiology 68, no. 4 (1992): 1373-1383. - - """ - __module__ = 'dendritex.channels' - - root_type = HHTypedNeuron - - def __init__( - self, - size: bst.typing.Size, - g_max: Union[bst.typing.ArrayLike, Callable] = 10. * (bu.mS / bu.cm ** 2), - E: Union[bst.typing.ArrayLike, Callable] = 43. * bu.mV, - phi: Union[bst.typing.ArrayLike, Callable] = 1., - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): - super().__init__( - size, - name=name, - mode=mode - ) - - # parameters - self.phi = bst.init.param(phi, self.varshape, allow_none=False) - self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) - self.E = bst.init.param(E, self.varshape, allow_none=False) - - def init_state(self, V, batch_size=None): - self.p = State4Integral(bst.init.param(bu.math.zeros, self.varshape, batch_size)) - - def reset_state(self, V, batch_size=None): - self.p.value = self.f_p_inf(V) - - def before_integral(self, V): - pass - - def compute_derivative(self, V): - self.p.derivative = self.phi * (self.f_p_inf(V) - self.p.value) / self.f_p_tau(V) / bu.ms - - def after_integral(self, V): - pass - - def current(self, V): - return self.g_max * self.p.value * (self.E - V) - - def f_p_inf(self, V): - V = V.to_decimal(bu.mV) - return 1. / (1. + bu.math.exp((V + 75.) / 5.5)) - - def f_p_tau(self, V): - V = V.to_decimal(bu.mV) - return 1. / (bu.math.exp(-0.086 * V - 14.59) + bu.math.exp(0.0701 * V - 1.87)) + r""" + The hyperpolarization-activated cation current model propsoed by (Huguenard & McCormick, 1992) [1]_. + + The hyperpolarization-activated cation current model is adopted from + (Huguenard, et, al., 1992) [1]_. Its dynamics is given by: + + .. math:: + + \begin{aligned} + I_h &= g_{\mathrm{max}} p \\ + \frac{dp}{dt} &= \phi \frac{p_{\infty} - p}{\tau_p} \\ + p_{\infty} &=\frac{1}{1+\exp ((V+75) / 5.5)} \\ + \tau_{p} &=\frac{1}{\exp (-0.086 V-14.59)+\exp (0.0701 V-1.87)} + \end{aligned} + + where :math:`\phi=1` is a temperature-dependent factor. + + Parameters + ---------- + g_max : float + The maximal conductance density (:math:`mS/cm^2`). + E : float + The reversal potential (mV). + phi : float + The temperature-dependent factor. + + References + ---------- + .. [1] Huguenard, John R., and David A. McCormick. "Simulation of the currents + involved in rhythmic oscillations in thalamic relay neurons." Journal + of neurophysiology 68, no. 4 (1992): 1373-1383. + + """ + __module__ = 'dendritex.channels' + + root_type = HHTypedNeuron + + def __init__( + self, + size: bst.typing.Size, + g_max: Union[bst.typing.ArrayLike, Callable] = 10. * (bu.mS / bu.cm ** 2), + E: Union[bst.typing.ArrayLike, Callable] = 43. * bu.mV, + phi: Union[bst.typing.ArrayLike, Callable] = 1., + name: Optional[str] = None, + ): + super().__init__(size=size, name=name, ) + + # parameters + self.phi = bst.init.param(phi, self.varshape, allow_none=False) + self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) + self.E = bst.init.param(E, self.varshape, allow_none=False) + + def init_state(self, V, batch_size=None): + self.p = State4Integral(bst.init.param(bu.math.zeros, self.varshape, batch_size)) + + def reset_state(self, V, batch_size=None): + self.p.value = self.f_p_inf(V) + + def before_integral(self, V): + pass + + def compute_derivative(self, V): + self.p.derivative = self.phi * (self.f_p_inf(V) - self.p.value) / self.f_p_tau(V) / bu.ms + + def post_derivative(self, V): + pass + + def current(self, V): + return self.g_max * self.p.value * (self.E - V) + + def f_p_inf(self, V): + V = V.to_decimal(bu.mV) + return 1. / (1. + bu.math.exp((V + 75.) / 5.5)) + + def f_p_tau(self, V): + V = V.to_decimal(bu.mV) + return 1. / (bu.math.exp(-0.086 * V - 14.59) + bu.math.exp(0.0701 * V - 1.87)) + # class Ih_De1996(Channel): # r""" @@ -248,200 +244,185 @@ def f_p_tau(self, V): class Ih1_Ma2020(Channel): - r""" - TITLE Cerebellum Golgi Cell Model - - COMMENT - - Author:L. Forti & S. Solinas - Data from: Santoro et al. J Neurosci. 2000 - Last revised: April 2006 - - From Golgi_hcn1 to HCN1 - - """ - __module__ = 'dendritex.channels' - - root_type = HHTypedNeuron - - def __init__( - self, - size: bst.typing.Size, - g_max: Union[bst.typing.ArrayLike, Callable] = 5e-2 * (bu.mS / bu.cm ** 2), - E: Union[bst.typing.ArrayLike, Callable] = -20 * bu.mV, - V_sh: Union[bst.typing.ArrayLike, Callable] = 0. * bu.mV, - T_base_g: bst.typing.ArrayLike = 1.5, - T_base_channel: bst.typing.ArrayLike = 3., - T: bst.typing.ArrayLike = 22, - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): - super().__init__( - size, - name=name, - mode=mode - ) - - # parameters - self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) - self.E = bst.init.param(E, self.varshape, allow_none=False) - self.T = bst.init.param(T, self.varshape, allow_none=False) - self.T_base_g = bst.init.param(T_base_g, self.varshape, allow_none=False) - self.T_base_channel = bst.init.param(T_base_channel, self.varshape, allow_none=False) - self.phi_g = bst.init.param(T_base_g ** ((T - 23) / 10), self.varshape, allow_none=False) - self.phi_channel = bst.init.param(T_base_channel ** ((T - 23) / 10), self.varshape, allow_none=False) - self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) - - - self.Ehalf = -72.49 - self.c = 0.11305 - - self.rA = 0.002096 - self.rB = 0.97596 - self.tCf = 0.01371 - self.tDf = -3.368 - self.tEf = 2.302585092 - self.tCs = 0.01451 - self.tDs = -4.056 - self.tEs = 2.302585092 - - def init_state(self, V, batch_size=None): - self.p = State4Integral(bst.init.param(bu.math.zeros, self.varshape, batch_size)) - self.q = State4Integral(bst.init.param(bu.math.zeros, self.varshape, batch_size)) - - def reset_state(self, V, batch_size=None): - self.p.value = self.f_p_inf(V) - self.q.value = self.f_q_inf(V) - - def before_integral(self, V): - pass - - def compute_derivative(self, V): - self.p.derivative = self.phi_channel * (self.f_p_inf(V) - self.p.value) / self.f_p_tau(V) / bu.ms - self.q.derivative = self.phi_channel * (self.f_q_inf(V) - self.q.value) / self.f_q_tau(V) / bu.ms - - def after_integral(self, V): - pass - - def current(self, V): - return self.phi_g * self.g_max * (self.p.value + self.q.value) * (self.E - V) - - def f_p_inf(self, V): - V = (V - self.V_sh) / bu.mV - return self.r(V) / (1 + bu.math.exp((V - self.Ehalf) * self.c)) - - def f_q_inf(self, V): - V = (V - self.V_sh) / bu.mV - return (1 - self.r(V)) / (1 + bu.math.exp((V - self.Ehalf) * self.c)) - def f_p_tau(self, V): - V = (V - self.V_sh) / bu.mV - return bu.math.exp(((self.tCf * V) - self.tDf) * self.tEf) - - def f_q_tau(self, V): - V = (V - self.V_sh) / bu.mV - return bu.math.exp(((self.tCs * V) - self.tDs) * self.tEs) - - def r(self, V): - return self.rA * V + self.rB + r""" + TITLE Cerebellum Golgi Cell Model + + COMMENT + + Author:L. Forti & S. Solinas + Data from: Santoro et al. J Neurosci. 2000 + Last revised: April 2006 + + From Golgi_hcn1 to HCN1 + + """ + __module__ = 'dendritex.channels' + + root_type = HHTypedNeuron + + def __init__( + self, + size: bst.typing.Size, + g_max: Union[bst.typing.ArrayLike, Callable] = 5e-2 * (bu.mS / bu.cm ** 2), + E: Union[bst.typing.ArrayLike, Callable] = -20 * bu.mV, + V_sh: Union[bst.typing.ArrayLike, Callable] = 0. * bu.mV, + T_base_g: bst.typing.ArrayLike = 1.5, + T_base_channel: bst.typing.ArrayLike = 3., + T: bst.typing.ArrayLike = 22, + name: Optional[str] = None, + ): + super().__init__(size=size, name=name, ) + # parameters + self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) + self.E = bst.init.param(E, self.varshape, allow_none=False) + self.T = bst.init.param(T, self.varshape, allow_none=False) + self.T_base_g = bst.init.param(T_base_g, self.varshape, allow_none=False) + self.T_base_channel = bst.init.param(T_base_channel, self.varshape, allow_none=False) + self.phi_g = bst.init.param(T_base_g ** ((T - 23) / 10), self.varshape, allow_none=False) + self.phi_channel = bst.init.param(T_base_channel ** ((T - 23) / 10), self.varshape, allow_none=False) + self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) + + self.Ehalf = -72.49 + self.c = 0.11305 + + self.rA = 0.002096 + self.rB = 0.97596 + self.tCf = 0.01371 + self.tDf = -3.368 + self.tEf = 2.302585092 + self.tCs = 0.01451 + self.tDs = -4.056 + self.tEs = 2.302585092 + + def init_state(self, V, batch_size=None): + self.p = State4Integral(bst.init.param(bu.math.zeros, self.varshape, batch_size)) + self.q = State4Integral(bst.init.param(bu.math.zeros, self.varshape, batch_size)) + + def reset_state(self, V, batch_size=None): + self.p.value = self.f_p_inf(V) + self.q.value = self.f_q_inf(V) + + def before_integral(self, V): + pass + + def compute_derivative(self, V): + self.p.derivative = self.phi_channel * (self.f_p_inf(V) - self.p.value) / self.f_p_tau(V) / bu.ms + self.q.derivative = self.phi_channel * (self.f_q_inf(V) - self.q.value) / self.f_q_tau(V) / bu.ms + + def post_derivative(self, V): + pass + + def current(self, V): + return self.phi_g * self.g_max * (self.p.value + self.q.value) * (self.E - V) + + def f_p_inf(self, V): + V = (V - self.V_sh) / bu.mV + return self.r(V) / (1 + bu.math.exp((V - self.Ehalf) * self.c)) + + def f_q_inf(self, V): + V = (V - self.V_sh) / bu.mV + return (1 - self.r(V)) / (1 + bu.math.exp((V - self.Ehalf) * self.c)) + + def f_p_tau(self, V): + V = (V - self.V_sh) / bu.mV + return bu.math.exp(((self.tCf * V) - self.tDf) * self.tEf) + + def f_q_tau(self, V): + V = (V - self.V_sh) / bu.mV + return bu.math.exp(((self.tCs * V) - self.tDs) * self.tEs) + + def r(self, V): + return self.rA * V + self.rB class Ih2_Ma2020(Channel): - r""" - TITLE Cerebellum Golgi Cell Model - - COMMENT - - Author:L. Forti & S. Solinas - Data from: Santoro et al. J Neurosci. 2000 - Last revised: April 2006 - """ - __module__ = 'dendritex.channels' - - root_type = HHTypedNeuron - - def __init__( - self, - size: bst.typing.Size, - g_max: Union[bst.typing.ArrayLike, Callable] = 8e-2 * (bu.mS / bu.cm ** 2), - E: Union[bst.typing.ArrayLike, Callable] = -20 * bu.mV, - V_sh: Union[bst.typing.ArrayLike, Callable] = 0. * bu.mV, - T_base_g: bst.typing.ArrayLike = 1.5, - T_base_channel: bst.typing.ArrayLike = 3., - T: bst.typing.ArrayLike = 22, - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): - super().__init__( - size, - name=name, - mode=mode - ) - - # parameters - self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) - self.E = bst.init.param(E, self.varshape, allow_none=False) - self.T = bst.init.param(T, self.varshape, allow_none=False) - self.T_base_g = bst.init.param(T_base_g, self.varshape, allow_none=False) - self.T_base_channel = bst.init.param(T_base_channel, self.varshape, allow_none=False) - self.phi_g = bst.init.param(T_base_g ** ((T - 23) / 10), self.varshape, allow_none=False) - self.phi_channel = bst.init.param(T_base_channel ** ((T - 23) / 10), self.varshape, allow_none=False) - self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) - - - self.Ehalf = -81.95 - self.c = 0.1661 - - self.rA = -0.0227 - self.rB = -1.4694 - self.tCf = 0.0269 - self.tDf = -5.6111 - self.tEf = 2.3026 - self.tCs = 0.0152 - self.tDs = -5.2944 - self.tEs = 2.3026 - - - def init_state(self, V, batch_size=None): - self.p = State4Integral(bst.init.param(bu.math.zeros, self.varshape, batch_size)) - self.q = State4Integral(bst.init.param(bu.math.zeros, self.varshape, batch_size)) - - def reset_state(self, V, batch_size=None): - self.p.value = self.f_p_inf(V) - self.q.value = self.f_q_inf(V) - - def before_integral(self, V): - pass - - def compute_derivative(self, V): - self.p.derivative = self.phi_channel * (self.f_p_inf(V) - self.p.value) / self.f_p_tau(V) / bu.ms - self.q.derivative = self.phi_channel * (self.f_q_inf(V) - self.q.value) / self.f_q_tau(V) / bu.ms - - def after_integral(self, V): - pass - - def current(self, V): - return self.phi_g * self.g_max * (self.p.value + self.q.value) * (self.E - V) - - def f_p_inf(self, V): - V = (V - self.V_sh) / bu.mV - return self.r(V,self.rA,self.rB) / (1 + bu.math.exp((V - self.Ehalf) * self.c)) - - def f_q_inf(self, V): - V = (V - self.V_sh) / bu.mV - return (1 - self.r(V,self.rA,self.rB)) / (1 + bu.math.exp((V - self.Ehalf) * self.c)) - def f_p_tau(self, V): - V = (V - self.V_sh) / bu.mV - return bu.math.exp(((self.tCf * V) - self.tDf) * self.tEf) - - def f_q_tau(self, V): - V = (V - self.V_sh) / bu.mV - return bu.math.exp(((self.tCs * V) - self.tDs) * self.tEs) - - def r(self, V, r1, r2): - - return bu.math.where(V >= -64.70, - 0, - bu.math.where(V <= -108.70, - 1, - r1 * V + r2)) - + r""" + TITLE Cerebellum Golgi Cell Model + + COMMENT + + Author:L. Forti & S. Solinas + Data from: Santoro et al. J Neurosci. 2000 + Last revised: April 2006 + """ + __module__ = 'dendritex.channels' + + root_type = HHTypedNeuron + + def __init__( + self, + size: bst.typing.Size, + g_max: Union[bst.typing.ArrayLike, Callable] = 8e-2 * (bu.mS / bu.cm ** 2), + E: Union[bst.typing.ArrayLike, Callable] = -20 * bu.mV, + V_sh: Union[bst.typing.ArrayLike, Callable] = 0. * bu.mV, + T_base_g: bst.typing.ArrayLike = 1.5, + T_base_channel: bst.typing.ArrayLike = 3., + T: bst.typing.ArrayLike = 22, + name: Optional[str] = None, + ): + super().__init__(size=size, name=name, ) + # parameters + self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) + self.E = bst.init.param(E, self.varshape, allow_none=False) + self.T = bst.init.param(T, self.varshape, allow_none=False) + self.T_base_g = bst.init.param(T_base_g, self.varshape, allow_none=False) + self.T_base_channel = bst.init.param(T_base_channel, self.varshape, allow_none=False) + self.phi_g = bst.init.param(T_base_g ** ((T - 23) / 10), self.varshape, allow_none=False) + self.phi_channel = bst.init.param(T_base_channel ** ((T - 23) / 10), self.varshape, allow_none=False) + self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) + + self.Ehalf = -81.95 + self.c = 0.1661 + + self.rA = -0.0227 + self.rB = -1.4694 + self.tCf = 0.0269 + self.tDf = -5.6111 + self.tEf = 2.3026 + self.tCs = 0.0152 + self.tDs = -5.2944 + self.tEs = 2.3026 + + def init_state(self, V, batch_size=None): + self.p = State4Integral(bst.init.param(bu.math.zeros, self.varshape, batch_size)) + self.q = State4Integral(bst.init.param(bu.math.zeros, self.varshape, batch_size)) + + def reset_state(self, V, batch_size=None): + self.p.value = self.f_p_inf(V) + self.q.value = self.f_q_inf(V) + + def before_integral(self, V): + pass + + def compute_derivative(self, V): + self.p.derivative = self.phi_channel * (self.f_p_inf(V) - self.p.value) / self.f_p_tau(V) / bu.ms + self.q.derivative = self.phi_channel * (self.f_q_inf(V) - self.q.value) / self.f_q_tau(V) / bu.ms + + def post_derivative(self, V): + pass + + def current(self, V): + return self.phi_g * self.g_max * (self.p.value + self.q.value) * (self.E - V) + + def f_p_inf(self, V): + V = (V - self.V_sh) / bu.mV + return self.r(V, self.rA, self.rB) / (1 + bu.math.exp((V - self.Ehalf) * self.c)) + + def f_q_inf(self, V): + V = (V - self.V_sh) / bu.mV + return (1 - self.r(V, self.rA, self.rB)) / (1 + bu.math.exp((V - self.Ehalf) * self.c)) + + def f_p_tau(self, V): + V = (V - self.V_sh) / bu.mV + return bu.math.exp(((self.tCf * V) - self.tDf) * self.tEf) + + def f_q_tau(self, V): + V = (V - self.V_sh) / bu.mV + return bu.math.exp(((self.tCs * V) - self.tDs) * self.tEs) + + def r(self, V, r1, r2): + return bu.math.where(V >= -64.70, + 0, + bu.math.where(V <= -108.70, + 1, + r1 * V + r2)) diff --git a/dendritex/channels/leaky.py b/dendritex/channels/leaky.py index 7009d0f..672ed70 100644 --- a/dendritex/channels/leaky.py +++ b/dendritex/channels/leaky.py @@ -12,69 +12,64 @@ import brainstate as bst import brainunit as bu -from .._base import HHTypedNeuron, Channel +from dendritex._base import HHTypedNeuron, Channel __all__ = [ - 'LeakageChannel', - 'IL', + 'LeakageChannel', + 'IL', ] class LeakageChannel(Channel): - """ - Base class for leakage channel dynamics. - """ - __module__ = 'dendritex.channels' + """ + Base class for leakage channel dynamics. + """ + __module__ = 'dendritex.channels' - root_type = HHTypedNeuron + root_type = HHTypedNeuron - def before_integral(self, V): - pass + def before_integral(self, V): + pass - def after_integral(self, V): - pass + def post_derivative(self, V): + pass - def compute_derivative(self, V): - pass + def compute_derivative(self, V): + pass - def current(self, V): - raise NotImplementedError + def current(self, V): + raise NotImplementedError - def init_state(self, V, batch_size: int = None): - pass + def init_state(self, V, batch_size: int = None): + pass - def reset_state(self, V, batch_size: int = None): - pass + def reset_state(self, V, batch_size: int = None): + pass class IL(LeakageChannel): - """The leakage channel current. - - Parameters - ---------- - g_max : float - The leakage conductance. - E : float - The reversal potential. - """ - __module__ = 'dendritex.channels' - - def __init__( - self, - size: Union[int, Sequence[int]], - g_max: Union[bst.typing.ArrayLike, Callable] = 0.1 * (bu.mS / bu.cm ** 2), - E: Union[bst.typing.ArrayLike, Callable] = -70. * bu.mV, - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): - super().__init__( - size, - name=name, - mode=mode - ) - - self.E = bst.init.param(E, self.varshape, allow_none=False) - self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) - - def current(self, V): - return self.g_max * (self.E - V) + """The leakage channel current. + + Parameters + ---------- + g_max : float + The leakage conductance. + E : float + The reversal potential. + """ + __module__ = 'dendritex.channels' + + def __init__( + self, + size: Union[int, Sequence[int]], + g_max: Union[bst.typing.ArrayLike, Callable] = 0.1 * (bu.mS / bu.cm ** 2), + E: Union[bst.typing.ArrayLike, Callable] = -70. * bu.mV, + name: Optional[str] = None, + ): + super().__init__(size=size, name=name, ) + + self.E = bst.init.param(E, self.varshape, allow_none=False) + self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) + + def current(self, V): + return self.g_max * (self.E - V) diff --git a/dendritex/channels/potassium.py b/dendritex/channels/potassium.py index a8c4e39..d943407 100644 --- a/dendritex/channels/potassium.py +++ b/dendritex/channels/potassium.py @@ -12,1373 +12,1315 @@ import brainstate as bst import brainunit as bu -from .._base import Channel, IonInfo, State4Integral -from ..ions import Potassium +from dendritex._base import Channel, IonInfo, State4Integral +from dendritex.ions import Potassium __all__ = [ - 'PotassiumChannel', - 'IKDR_Ba2002', - 'IK_TM1991', - 'IK_HH1952', - 'IKA1_HM1992', - 'IKA2_HM1992', - 'IKK2A_HM1992', - 'IKK2B_HM1992', - 'IKNI_Ya1989', - 'IK_Leak', - 'IKv11_Ak2007', - 'IKv34_Ma2020', - 'IKv43_Ma2020', - 'IKM_Grc_Ma2020', - + 'PotassiumChannel', + 'IKDR_Ba2002', + 'IK_TM1991', + 'IK_HH1952', + 'IKA1_HM1992', + 'IKA2_HM1992', + 'IKK2A_HM1992', + 'IKK2B_HM1992', + 'IKNI_Ya1989', + 'IK_Leak', + 'IKv11_Ak2007', + 'IKv34_Ma2020', + 'IKv43_Ma2020', + 'IKM_Grc_Ma2020', ] class PotassiumChannel(Channel): - """Base class for sodium channel dynamics.""" - __module__ = 'dendritex.channels' + """Base class for sodium channel dynamics.""" + __module__ = 'dendritex.channels' - root_type = Potassium + root_type = Potassium - def before_integral(self, V, K: IonInfo): - pass + def before_integral(self, V, K: IonInfo): + pass - def after_integral(self, V, K: IonInfo): - pass + def post_derivative(self, V, K: IonInfo): + pass - def compute_derivative(self, V, K: IonInfo): - pass + def compute_derivative(self, V, K: IonInfo): + pass - def current(self, V, K: IonInfo): - raise NotImplementedError + def current(self, V, K: IonInfo): + raise NotImplementedError - def init_state(self, V, K: IonInfo, batch_size: int = None): - pass + def init_state(self, V, K: IonInfo, batch_size: int = None): + pass - def reset_state(self, V, K: IonInfo, batch_size: int = None): - pass + def reset_state(self, V, K: IonInfo, batch_size: int = None): + pass class _IK_p4_markov(PotassiumChannel): - r"""The delayed rectifier potassium channel of :math:`p^4` - current which described with first-order Markov chain. + r"""The delayed rectifier potassium channel of :math:`p^4` + current which described with first-order Markov chain. - This general potassium current model should have the form of + This general potassium current model should have the form of - .. math:: + .. math:: - \begin{aligned} - I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\ - \frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) - \end{aligned} + \begin{aligned} + I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\ + \frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) + \end{aligned} - where :math:`\phi` is a temperature-dependent factor. - - Parameters - ---------- - size: int, sequence of int - The object size. - g_max : bst.typing.ArrayLike, Callable - The maximal conductance density (:math:`mS/cm^2`). - phi : bst.typing.ArrayLike, Callable - The temperature-dependent factor. - name: Optional[str] - The object name. - - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - g_max: Union[bst.typing.ArrayLike, Callable] = 10. * (bu.mS / bu.cm ** 2), - phi: Union[bst.typing.ArrayLike, Callable] = 1., - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): - super().__init__( - size, - name=name, - mode=mode - ) - - self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) - self.phi = bst.init.param(phi, self.varshape, allow_none=False) - - def init_state(self, V, K: IonInfo, batch_size=None): - self.p = State4Integral(bst.init.param(bu.math.zeros, self.varshape, batch_size)) - - def reset_state(self, V, K: IonInfo, batch_size: int = None): - alpha = self.f_p_alpha(V) - beta = self.f_p_beta(V) - self.p.value = alpha / (alpha + beta) - if isinstance(batch_size, int): - assert self.p.value.shape[0] == batch_size - - def compute_derivative(self, V, K: IonInfo): - p = self.p.value - dp = self.phi * (self.f_p_alpha(V) * (1. - p) - self.f_p_beta(V) * p) / bu.ms - self.p.derivative = dp - - def current(self, V, K: IonInfo): - return self.g_max * self.p.value ** 4 * (K.E - V) - - def f_p_alpha(self, V): - raise NotImplementedError - - def f_p_beta(self, V): - raise NotImplementedError + where :math:`\phi` is a temperature-dependent factor. + Parameters + ---------- + size: int, sequence of int + The object size. + g_max : bst.typing.ArrayLike, Callable + The maximal conductance density (:math:`mS/cm^2`). + phi : bst.typing.ArrayLike, Callable + The temperature-dependent factor. + name: Optional[str] + The object name. -class IKDR_Ba2002(_IK_p4_markov): - r"""The delayed rectifier potassium channel current. - - The potassium current model is adopted from (Bazhenov, et, al. 2002) [1]_. - It's dynamics is given by: + """ - .. math:: + def __init__( + self, + size: Union[int, Sequence[int]], + g_max: Union[bst.typing.ArrayLike, Callable] = 10. * (bu.mS / bu.cm ** 2), + phi: Union[bst.typing.ArrayLike, Callable] = 1., + name: Optional[str] = None, + ): + super().__init__(size=size, name=name, ) - \begin{aligned} - I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\ - \frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) \\ - \alpha_{p} &=\frac{0.032\left(V-V_{sh}-15\right)}{1-\exp \left(-\left(V-V_{sh}-15\right) / 5\right)} \\ - \beta_p &= 0.5 \exp \left(-\left(V-V_{sh}-10\right) / 40\right) - \end{aligned} + self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) + self.phi = bst.init.param(phi, self.varshape, allow_none=False) - where :math:`\phi` is a temperature-dependent factor, which is given by - :math:`\phi=3^{\frac{T-36}{10}}` (:math:`T` is the temperature in Celsius). - - Parameters - ---------- - size: int, sequence of int - The object size. - g_max : bst.typing.ArrayLike, Callable - The maximal conductance density (:math:`mS/cm^2`). - T_base : float, ArrayType - The brainpy_object of temperature factor. - T : bst.typing.ArrayLike, Callable - The temperature (Celsius, :math:`^{\circ}C`). - V_sh : bst.typing.ArrayLike, Callable - The shift of the membrane potential to spike. - name: Optional[str] - The object name. - - References - ---------- - .. [1] Bazhenov, Maxim, et al. "Model of thalamocortical slow-wave sleep oscillations - and transitions to activated states." Journal of neuroscience 22.19 (2002): 8691-8704. - - """ - __module__ = 'dendritex.channels' - - def __init__( - self, - size: Union[int, Sequence[int]], - g_max: Union[bst.typing.ArrayLike, Callable] = 10. * (bu.mS / bu.cm ** 2), - V_sh: Union[bst.typing.ArrayLike, Callable] = -50. * bu.mV, - T_base: bst.typing.ArrayLike = 3., - T: bst.typing.ArrayLike = 36., - phi: Optional[Union[bst.typing.ArrayLike, Callable]] = None, - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): - phi = T_base ** ((T - 36) / 10) if phi is None else phi - super().__init__( - size, - name=name, - g_max=g_max, - phi=phi, - mode=mode - ) - - # parameters - self.T = bst.init.param(T, self.varshape, allow_none=False) - self.T_base = bst.init.param(T_base, self.varshape, allow_none=False) - self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) - - def f_p_alpha(self, V): - V = (V - self.V_sh).to_decimal(bu.mV) - tmp = V - 15. - return 0.032 * tmp / (1. - bu.math.exp(-tmp / 5.)) - - def f_p_beta(self, V): - V = (V - self.V_sh).to_decimal(bu.mV) - return 0.5 * bu.math.exp(-(V - 10.) / 40.) + def init_state(self, V, K: IonInfo, batch_size=None): + self.p = State4Integral(bst.init.param(bu.math.zeros, self.varshape, batch_size)) + def reset_state(self, V, K: IonInfo, batch_size: int = None): + alpha = self.f_p_alpha(V) + beta = self.f_p_beta(V) + self.p.value = alpha / (alpha + beta) + if isinstance(batch_size, int): + assert self.p.value.shape[0] == batch_size -class IK_TM1991(_IK_p4_markov): - r"""The potassium channel described by (Traub and Miles, 1991) [1]_. + def compute_derivative(self, V, K: IonInfo): + p = self.p.value + dp = self.phi * (self.f_p_alpha(V) * (1. - p) - self.f_p_beta(V) * p) / bu.ms + self.p.derivative = dp - The dynamics of this channel is given by: + def current(self, V, K: IonInfo): + return self.g_max * self.p.value ** 4 * (K.E - V) - .. math:: + def f_p_alpha(self, V): + raise NotImplementedError - \begin{aligned} - I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\ - \frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) \\ - \alpha_{p} &= 0.032 \frac{(15 - V + V_{sh})}{(\exp((15 - V + V_{sh}) / 5) - 1.)} \\ - \beta_p &= 0.5 * \exp((10 - V + V_{sh}) / 40) - \end{aligned} + def f_p_beta(self, V): + raise NotImplementedError - where :math:`V_{sh}` is the membrane shift (default -63 mV), and - :math:`\phi` is the temperature-dependent factor (default 1.). - - Parameters - ---------- - size: int, sequence of int - The geometry size. - g_max : bst.typing.ArrayLike, Callable - The maximal conductance density (:math:`mS/cm^2`). - name: Optional[str] - The object name. - - References - ---------- - .. [1] Traub, Roger D., and Richard Miles. Neuronal networks of the hippocampus. - Vol. 777. Cambridge University Press, 1991. - - See Also - -------- - INa_TM1991 - """ - __module__ = 'dendritex.channels' - - def __init__( - self, - size: Union[int, Sequence[int]], - g_max: Union[bst.typing.ArrayLike, Callable] = 10. * (bu.mS / bu.cm ** 2), - phi: Union[bst.typing.ArrayLike, Callable] = 1., - V_sh: Union[int, bst.typing.ArrayLike, Callable] = -60. * bu.mV, - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): - super().__init__( - size, - name=name, - phi=phi, - g_max=g_max, - mode=mode - ) - self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) - - def f_p_alpha(self, V): - c = 15 + (- V + self.V_sh).to_decimal(bu.mV) - return 0.032 * c / (bu.math.exp(c / 5) - 1.) - - def f_p_beta(self, V): - V = (self.V_sh - V).to_decimal(bu.mV) - return 0.5 * bu.math.exp((10 + V) / 40) +class IKDR_Ba2002(_IK_p4_markov): + r"""The delayed rectifier potassium channel current. + + The potassium current model is adopted from (Bazhenov, et, al. 2002) [1]_. + It's dynamics is given by: + + .. math:: + + \begin{aligned} + I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\ + \frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) \\ + \alpha_{p} &=\frac{0.032\left(V-V_{sh}-15\right)}{1-\exp \left(-\left(V-V_{sh}-15\right) / 5\right)} \\ + \beta_p &= 0.5 \exp \left(-\left(V-V_{sh}-10\right) / 40\right) + \end{aligned} + + where :math:`\phi` is a temperature-dependent factor, which is given by + :math:`\phi=3^{\frac{T-36}{10}}` (:math:`T` is the temperature in Celsius). + + Parameters + ---------- + size: int, sequence of int + The object size. + g_max : bst.typing.ArrayLike, Callable + The maximal conductance density (:math:`mS/cm^2`). + T_base : float, ArrayType + The brainpy_object of temperature factor. + T : bst.typing.ArrayLike, Callable + The temperature (Celsius, :math:`^{\circ}C`). + V_sh : bst.typing.ArrayLike, Callable + The shift of the membrane potential to spike. + name: Optional[str] + The object name. + + References + ---------- + .. [1] Bazhenov, Maxim, et al. "Model of thalamocortical slow-wave sleep oscillations + and transitions to activated states." Journal of neuroscience 22.19 (2002): 8691-8704. + + """ + __module__ = 'dendritex.channels' + + def __init__( + self, + size: Union[int, Sequence[int]], + g_max: Union[bst.typing.ArrayLike, Callable] = 10. * (bu.mS / bu.cm ** 2), + V_sh: Union[bst.typing.ArrayLike, Callable] = -50. * bu.mV, + T_base: bst.typing.ArrayLike = 3., + T: bst.typing.ArrayLike = 36., + phi: Optional[Union[bst.typing.ArrayLike, Callable]] = None, + name: Optional[str] = None, + ): + phi = T_base ** ((T - 36) / 10) if phi is None else phi + super().__init__( + size, + name=name, + g_max=g_max, + phi=phi, + ) + + # parameters + self.T = bst.init.param(T, self.varshape, allow_none=False) + self.T_base = bst.init.param(T_base, self.varshape, allow_none=False) + self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) + + def f_p_alpha(self, V): + V = (V - self.V_sh).to_decimal(bu.mV) + tmp = V - 15. + return 0.032 * tmp / (1. - bu.math.exp(-tmp / 5.)) + + def f_p_beta(self, V): + V = (V - self.V_sh).to_decimal(bu.mV) + return 0.5 * bu.math.exp(-(V - 10.) / 40.) -class IK_HH1952(_IK_p4_markov): - r"""The potassium channel described by Hodgkin–Huxley model [1]_. - The dynamics of this channel is given by: +class IK_TM1991(_IK_p4_markov): + r"""The potassium channel described by (Traub and Miles, 1991) [1]_. + + The dynamics of this channel is given by: + + .. math:: + + \begin{aligned} + I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\ + \frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) \\ + \alpha_{p} &= 0.032 \frac{(15 - V + V_{sh})}{(\exp((15 - V + V_{sh}) / 5) - 1.)} \\ + \beta_p &= 0.5 * \exp((10 - V + V_{sh}) / 40) + \end{aligned} + + where :math:`V_{sh}` is the membrane shift (default -63 mV), and + :math:`\phi` is the temperature-dependent factor (default 1.). + + Parameters + ---------- + size: int, sequence of int + The geometry size. + g_max : bst.typing.ArrayLike, Callable + The maximal conductance density (:math:`mS/cm^2`). + name: Optional[str] + The object name. + + References + ---------- + .. [1] Traub, Roger D., and Richard Miles. Neuronal networks of the hippocampus. + Vol. 777. Cambridge University Press, 1991. + + See Also + -------- + INa_TM1991 + """ + __module__ = 'dendritex.channels' + + def __init__( + self, + size: Union[int, Sequence[int]], + g_max: Union[bst.typing.ArrayLike, Callable] = 10. * (bu.mS / bu.cm ** 2), + phi: Union[bst.typing.ArrayLike, Callable] = 1., + V_sh: Union[int, bst.typing.ArrayLike, Callable] = -60. * bu.mV, + name: Optional[str] = None, + ): + super().__init__( + size, + name=name, + phi=phi, + g_max=g_max, + ) + self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) + + def f_p_alpha(self, V): + c = 15 + (- V + self.V_sh).to_decimal(bu.mV) + return 0.032 * c / (bu.math.exp(c / 5) - 1.) + + def f_p_beta(self, V): + V = (self.V_sh - V).to_decimal(bu.mV) + return 0.5 * bu.math.exp((10 + V) / 40) - .. math:: - \begin{aligned} - I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\ - \frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) \\ - \alpha_{p} &= \frac{0.01 (V -V_{sh} + 10)}{1-\exp \left(-\left(V-V_{sh}+ 10\right) / 10\right)} \\ - \beta_p &= 0.125 \exp \left(-\left(V-V_{sh}+20\right) / 80\right) - \end{aligned} - - where :math:`V_{sh}` is the membrane shift (default -45 mV), and - :math:`\phi` is the temperature-dependent factor (default 1.). - - Parameters - ---------- - size: int, sequence of int - The geometry size. - g_max : bst.typing.ArrayLike, Callable - The maximal conductance density (:math:`mS/cm^2`). - name: Optional[str] - The object name. - - References - ---------- - .. [1] Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description of - membrane current and its application to conduction and excitation in - nerve." The Journal of physiology 117.4 (1952): 500. - - See Also - -------- - INa_HH1952 - """ - __module__ = 'dendritex.channels' - - def __init__( - self, - size: Union[int, Sequence[int]], - g_max: Union[bst.typing.ArrayLike, Callable] = 10. * (bu.mS / bu.cm ** 2), - phi: Union[bst.typing.ArrayLike, Callable] = 1., - V_sh: Union[int, bst.typing.ArrayLike, Callable] = -45. * bu.mV, - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): - super().__init__( - size, - name=name, - phi=phi, - g_max=g_max, - mode=mode - ) - self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) - - def f_p_alpha(self, V): - V = (V - self.V_sh).to_decimal(bu.mV) - temp = V + 10 - return 0.01 * temp / (1 - bu.math.exp(-temp / 10)) - - def f_p_beta(self, V): - V = (V - self.V_sh).to_decimal(bu.mV) - return 0.125 * bu.math.exp(-(V + 20) / 80) +class IK_HH1952(_IK_p4_markov): + r"""The potassium channel described by Hodgkin–Huxley model [1]_. + + The dynamics of this channel is given by: + + .. math:: + + \begin{aligned} + I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\ + \frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) \\ + \alpha_{p} &= \frac{0.01 (V -V_{sh} + 10)}{1-\exp \left(-\left(V-V_{sh}+ 10\right) / 10\right)} \\ + \beta_p &= 0.125 \exp \left(-\left(V-V_{sh}+20\right) / 80\right) + \end{aligned} + + where :math:`V_{sh}` is the membrane shift (default -45 mV), and + :math:`\phi` is the temperature-dependent factor (default 1.). + + Parameters + ---------- + size: int, sequence of int + The geometry size. + g_max : bst.typing.ArrayLike, Callable + The maximal conductance density (:math:`mS/cm^2`). + name: Optional[str] + The object name. + + References + ---------- + .. [1] Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description of + membrane current and its application to conduction and excitation in + nerve." The Journal of physiology 117.4 (1952): 500. + + See Also + -------- + INa_HH1952 + """ + __module__ = 'dendritex.channels' + + def __init__( + self, + size: Union[int, Sequence[int]], + g_max: Union[bst.typing.ArrayLike, Callable] = 10. * (bu.mS / bu.cm ** 2), + phi: Union[bst.typing.ArrayLike, Callable] = 1., + V_sh: Union[int, bst.typing.ArrayLike, Callable] = -45. * bu.mV, + name: Optional[str] = None, + ): + super().__init__( + size, + name=name, + phi=phi, + g_max=g_max, + ) + self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) + + def f_p_alpha(self, V): + V = (V - self.V_sh).to_decimal(bu.mV) + temp = V + 10 + return 0.01 * temp / (1 - bu.math.exp(-temp / 10)) + + def f_p_beta(self, V): + V = (V - self.V_sh).to_decimal(bu.mV) + return 0.125 * bu.math.exp(-(V + 20) / 80) class _IKA_p4q_ss(PotassiumChannel): - r""" - The rapidly inactivating Potassium channel of :math:`p^4q` - current which described with steady-state format. - - This model is developed according to the average behavior of - rapidly inactivating Potassium channel in Thalamus relay neurons [2]_ [3]_. - - .. math:: - - &IA = g_{\mathrm{max}} p^4 q (E-V) \\ - &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ - &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ - - where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - - Parameters - ---------- - size: int, sequence of int - The geometry size. - name: Optional[str] - The object name. - g_max : bst.typing.ArrayLike, Callable - The maximal conductance density (:math:`mS/cm^2`). - phi_p : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`p`. - phi_q : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`q`. - - References - ---------- - .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the - currents involved in rhythmic oscillations in thalamic relay - neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. - .. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a - TEA-sensitive K current in acutely isolated rat thalamic relay - neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - g_max: Union[bst.typing.ArrayLike, Callable] = 10. * (bu.mS / bu.cm ** 2), - phi_p: Union[bst.typing.ArrayLike, Callable] = 1., - phi_q: Union[bst.typing.ArrayLike, Callable] = 1., - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): - super().__init__( - size, - name=name, - mode=mode - ) - - # parameters - self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) - self.phi_p = bst.init.param(phi_p, self.varshape, allow_none=False) - self.phi_q = bst.init.param(phi_q, self.varshape, allow_none=False) - - def compute_derivative(self, V, K: IonInfo): - self.p.derivative = self.phi_p * (self.f_p_inf(V) - self.p.value) / self.f_p_tau(V) / bu.ms - self.q.derivative = self.phi_q * (self.f_q_inf(V) - self.q.value) / self.f_q_tau(V) / bu.ms - - def current(self, V, K: IonInfo): - return self.g_max * self.p.value ** 4 * self.q.value * (K.E - V) - - def init_state(self, V, K: IonInfo, batch_size: int = None): - self.p = State4Integral(bst.init.param(bu.math.zeros, self.varshape, batch_size)) - self.q = State4Integral(bst.init.param(bu.math.zeros, self.varshape, batch_size)) - - def reset_state(self, V, K: IonInfo, batch_size=None): - self.p.value = self.f_p_inf(V) - self.q.value = self.f_q_inf(V) - if isinstance(batch_size, int): - assert self.p.value.shape[0] == batch_size - assert self.q.value.shape[0] == batch_size - - def f_p_inf(self, V): - raise NotImplementedError - - def f_p_tau(self, V): - raise NotImplementedError - - def f_q_inf(self, V): - raise NotImplementedError - - def f_q_tau(self, V): - raise NotImplementedError + r""" + The rapidly inactivating Potassium channel of :math:`p^4q` + current which described with steady-state format. + + This model is developed according to the average behavior of + rapidly inactivating Potassium channel in Thalamus relay neurons [2]_ [3]_. + + .. math:: + + &IA = g_{\mathrm{max}} p^4 q (E-V) \\ + &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ + &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ + + where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). + + Parameters + ---------- + size: int, sequence of int + The geometry size. + name: Optional[str] + The object name. + g_max : bst.typing.ArrayLike, Callable + The maximal conductance density (:math:`mS/cm^2`). + phi_p : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`p`. + phi_q : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`q`. + + References + ---------- + .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the + currents involved in rhythmic oscillations in thalamic relay + neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. + .. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a + TEA-sensitive K current in acutely isolated rat thalamic relay + neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + g_max: Union[bst.typing.ArrayLike, Callable] = 10. * (bu.mS / bu.cm ** 2), + phi_p: Union[bst.typing.ArrayLike, Callable] = 1., + phi_q: Union[bst.typing.ArrayLike, Callable] = 1., + name: Optional[str] = None, + ): + super().__init__(size=size, name=name, ) + + # parameters + self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) + self.phi_p = bst.init.param(phi_p, self.varshape, allow_none=False) + self.phi_q = bst.init.param(phi_q, self.varshape, allow_none=False) + + def compute_derivative(self, V, K: IonInfo): + self.p.derivative = self.phi_p * (self.f_p_inf(V) - self.p.value) / self.f_p_tau(V) / bu.ms + self.q.derivative = self.phi_q * (self.f_q_inf(V) - self.q.value) / self.f_q_tau(V) / bu.ms + + def current(self, V, K: IonInfo): + return self.g_max * self.p.value ** 4 * self.q.value * (K.E - V) + + def init_state(self, V, K: IonInfo, batch_size: int = None): + self.p = State4Integral(bst.init.param(bu.math.zeros, self.varshape, batch_size)) + self.q = State4Integral(bst.init.param(bu.math.zeros, self.varshape, batch_size)) + + def reset_state(self, V, K: IonInfo, batch_size=None): + self.p.value = self.f_p_inf(V) + self.q.value = self.f_q_inf(V) + if isinstance(batch_size, int): + assert self.p.value.shape[0] == batch_size + assert self.q.value.shape[0] == batch_size + + def f_p_inf(self, V): + raise NotImplementedError + + def f_p_tau(self, V): + raise NotImplementedError + + def f_q_inf(self, V): + raise NotImplementedError + + def f_q_tau(self, V): + raise NotImplementedError class IKA1_HM1992(_IKA_p4q_ss): - r"""The rapidly inactivating Potassium channel (IA1) model proposed by (Huguenard & McCormick, 1992) [2]_. - - This model is developed according to the average behavior of - rapidly inactivating Potassium channel in Thalamus relay neurons [2]_ [1]_. - - .. math:: - - &IA = g_{\mathrm{max}} p^4 q (E-V) \\ - &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ - &p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 60)/8.5]} \\ - &\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}+35.8}{19.7}\right)+ \exp \left(\frac{V -V_{sh}+79.7}{-12.7}\right)}+0.37 \\ - &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ - &q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 78)/6]} \\ - &\begin{array}{l} \tau_{q} = \frac{1}{\exp((V -V_{sh}+46)/5.) + \exp((V -V_{sh}+238)/-37.5)} \quad V<(-63+V_{sh})\, mV \\ - \tau_{q} = 19 \quad V \geq (-63 + V_{sh})\, mV \end{array} - - where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - - Parameters - ---------- - size: int, sequence of int - The geometry size. - name: Optional[str] - The object name. - g_max : bst.typing.ArrayLike, Callable - The maximal conductance density (:math:`mS/cm^2`). - V_sh : float, ArrayType, Callable, Initializer - The membrane potential shift. - phi_p : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`p`. - phi_q : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`q`. - - References - ---------- - .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the - currents involved in rhythmic oscillations in thalamic relay - neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. - .. [1] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a - TEA-sensitive K current in acutely isolated rat thalamic relay - neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. - - See Also - -------- - IKA2_HM1992 - """ - __module__ = 'dendritex.channels' - - def __init__( - self, - size: Union[int, Sequence[int]], - g_max: Union[bst.typing.ArrayLike, Callable] = 30. * (bu.mS / bu.cm ** 2), - V_sh: Union[bst.typing.ArrayLike, Callable] = 0. * bu.mV, - phi_p: Union[bst.typing.ArrayLike, Callable] = 1., - phi_q: Union[bst.typing.ArrayLike, Callable] = 1., - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): - super().__init__( - size, - name=name, - g_max=g_max, - phi_p=phi_p, - phi_q=phi_q, - mode=mode - ) - - # parameters - self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) - - def f_p_inf(self, V): - V = (V - self.V_sh).to_decimal(bu.mV) - return 1. / (1. + bu.math.exp(-(V + 60.) / 8.5)) - - def f_p_tau(self, V): - V = (V - self.V_sh).to_decimal(bu.mV) - return 1. / (bu.math.exp((V + 35.8) / 19.7) + - bu.math.exp(-(V + 79.7) / 12.7)) + 0.37 - - def f_q_inf(self, V): - V = (V - self.V_sh).to_decimal(bu.mV) - return 1. / (1. + bu.math.exp((V + 78.) / 6.)) - - def f_q_tau(self, V): - V = (V - self.V_sh).to_decimal(bu.mV) - return bu.math.where( - V < -63, - 1. / (bu.math.exp((V + 46.) / 5.) + - bu.math.exp(-(V + 238.) / 37.5)), - 19. - ) + r"""The rapidly inactivating Potassium channel (IA1) model proposed by (Huguenard & McCormick, 1992) [2]_. + + This model is developed according to the average behavior of + rapidly inactivating Potassium channel in Thalamus relay neurons [2]_ [1]_. + + .. math:: + + &IA = g_{\mathrm{max}} p^4 q (E-V) \\ + &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ + &p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 60)/8.5]} \\ + &\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}+35.8}{19.7}\right)+ \exp \left(\frac{V -V_{sh}+79.7}{-12.7}\right)}+0.37 \\ + &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ + &q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 78)/6]} \\ + &\begin{array}{l} \tau_{q} = \frac{1}{\exp((V -V_{sh}+46)/5.) + \exp((V -V_{sh}+238)/-37.5)} \quad V<(-63+V_{sh})\, mV \\ + \tau_{q} = 19 \quad V \geq (-63 + V_{sh})\, mV \end{array} + + where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). + + Parameters + ---------- + size: int, sequence of int + The geometry size. + name: Optional[str] + The object name. + g_max : bst.typing.ArrayLike, Callable + The maximal conductance density (:math:`mS/cm^2`). + V_sh : float, ArrayType, Callable, Initializer + The membrane potential shift. + phi_p : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`p`. + phi_q : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`q`. + + References + ---------- + .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the + currents involved in rhythmic oscillations in thalamic relay + neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. + .. [1] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a + TEA-sensitive K current in acutely isolated rat thalamic relay + neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. + + See Also + -------- + IKA2_HM1992 + """ + __module__ = 'dendritex.channels' + + def __init__( + self, + size: Union[int, Sequence[int]], + g_max: Union[bst.typing.ArrayLike, Callable] = 30. * (bu.mS / bu.cm ** 2), + V_sh: Union[bst.typing.ArrayLike, Callable] = 0. * bu.mV, + phi_p: Union[bst.typing.ArrayLike, Callable] = 1., + phi_q: Union[bst.typing.ArrayLike, Callable] = 1., + name: Optional[str] = None, + ): + super().__init__( + size, + name=name, + g_max=g_max, + phi_p=phi_p, + phi_q=phi_q, + ) + + # parameters + self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) + + def f_p_inf(self, V): + V = (V - self.V_sh).to_decimal(bu.mV) + return 1. / (1. + bu.math.exp(-(V + 60.) / 8.5)) + + def f_p_tau(self, V): + V = (V - self.V_sh).to_decimal(bu.mV) + return 1. / (bu.math.exp((V + 35.8) / 19.7) + + bu.math.exp(-(V + 79.7) / 12.7)) + 0.37 + + def f_q_inf(self, V): + V = (V - self.V_sh).to_decimal(bu.mV) + return 1. / (1. + bu.math.exp((V + 78.) / 6.)) + + def f_q_tau(self, V): + V = (V - self.V_sh).to_decimal(bu.mV) + return bu.math.where( + V < -63, + 1. / (bu.math.exp((V + 46.) / 5.) + + bu.math.exp(-(V + 238.) / 37.5)), + 19. + ) class IKA2_HM1992(_IKA_p4q_ss): - r"""The rapidly inactivating Potassium channel (IA2) model proposed by (Huguenard & McCormick, 1992) [2]_. - - This model is developed according to the average behavior of - rapidly inactivating Potassium channel in Thalamus relay neurons [2]_ [1]_. - - .. math:: - - &IA = g_{\mathrm{max}} p^4 q (E-V) \\ - &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ - &p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 36)/20.]} \\ - &\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}+35.8}{19.7}\right)+ \exp \left(\frac{V -V_{sh}+79.7}{-12.7}\right)}+0.37 \\ - &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ - &q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 78)/6]} \\ - &\begin{array}{l} \tau_{q} = \frac{1}{\exp((V -V_{sh}+46)/5.) + \exp((V -V_{sh}+238)/-37.5)} \quad V<(-63+V_{sh})\, mV \\ - \tau_{q} = 19 \quad V \geq (-63 + V_{sh})\, mV \end{array} - - where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - - Parameters - ---------- - size: int, sequence of int - The geometry size. - name: Optional[str] - The object name. - g_max : bst.typing.ArrayLike, Callable - The maximal conductance density (:math:`mS/cm^2`). - V_sh : float, ArrayType, Callable, Initializer - The membrane potential shift. - phi_p : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`p`. - phi_q : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`q`. - - References - ---------- - .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the - currents involved in rhythmic oscillations in thalamic relay - neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. - .. [1] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a - TEA-sensitive K current in acutely isolated rat thalamic relay - neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. - - See Also - -------- - IKA1_HM1992 - """ - __module__ = 'dendritex.channels' - - def __init__( - self, - size: Union[int, Sequence[int]], - g_max: Union[bst.typing.ArrayLike, Callable] = 20. * (bu.mS / bu.cm ** 2), - V_sh: Union[bst.typing.ArrayLike, Callable] = 0. * bu.mV, - phi_p: Union[bst.typing.ArrayLike, Callable] = 1., - phi_q: Union[bst.typing.ArrayLike, Callable] = 1., - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): - super().__init__( - size, - name=name, - g_max=g_max, - phi_q=phi_q, - phi_p=phi_p, - mode=mode - ) - - # parameters - self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) - - def f_p_inf(self, V): - V = (V - self.V_sh).to_decimal(bu.mV) - return 1. / (1. + bu.math.exp(-(V + 36.) / 20.)) - - def f_p_tau(self, V): - V = (V - self.V_sh).to_decimal(bu.mV) - return 1. / (bu.math.exp((V + 35.8) / 19.7) + - bu.math.exp(-(V + 79.7) / 12.7)) + 0.37 - - def f_q_inf(self, V): - V = (V - self.V_sh).to_decimal(bu.mV) - return 1. / (1. + bu.math.exp((V + 78.) / 6.)) - - def f_q_tau(self, V): - V = (V - self.V_sh).to_decimal(bu.mV) - return bu.math.where( - V < -63, - 1. / (bu.math.exp((V + 46.) / 5.) + - bu.math.exp(-(V + 238.) / 37.5)), - 19. - ) + r"""The rapidly inactivating Potassium channel (IA2) model proposed by (Huguenard & McCormick, 1992) [2]_. + + This model is developed according to the average behavior of + rapidly inactivating Potassium channel in Thalamus relay neurons [2]_ [1]_. + + .. math:: + + &IA = g_{\mathrm{max}} p^4 q (E-V) \\ + &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ + &p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 36)/20.]} \\ + &\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}+35.8}{19.7}\right)+ \exp \left(\frac{V -V_{sh}+79.7}{-12.7}\right)}+0.37 \\ + &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ + &q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 78)/6]} \\ + &\begin{array}{l} \tau_{q} = \frac{1}{\exp((V -V_{sh}+46)/5.) + \exp((V -V_{sh}+238)/-37.5)} \quad V<(-63+V_{sh})\, mV \\ + \tau_{q} = 19 \quad V \geq (-63 + V_{sh})\, mV \end{array} + + where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). + + Parameters + ---------- + size: int, sequence of int + The geometry size. + name: Optional[str] + The object name. + g_max : bst.typing.ArrayLike, Callable + The maximal conductance density (:math:`mS/cm^2`). + V_sh : float, ArrayType, Callable, Initializer + The membrane potential shift. + phi_p : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`p`. + phi_q : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`q`. + + References + ---------- + .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the + currents involved in rhythmic oscillations in thalamic relay + neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. + .. [1] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a + TEA-sensitive K current in acutely isolated rat thalamic relay + neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. + + See Also + -------- + IKA1_HM1992 + """ + __module__ = 'dendritex.channels' + + def __init__( + self, + size: Union[int, Sequence[int]], + g_max: Union[bst.typing.ArrayLike, Callable] = 20. * (bu.mS / bu.cm ** 2), + V_sh: Union[bst.typing.ArrayLike, Callable] = 0. * bu.mV, + phi_p: Union[bst.typing.ArrayLike, Callable] = 1., + phi_q: Union[bst.typing.ArrayLike, Callable] = 1., + name: Optional[str] = None, + ): + super().__init__( + size, + name=name, + g_max=g_max, + phi_q=phi_q, + phi_p=phi_p, + ) + + # parameters + self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) + + def f_p_inf(self, V): + V = (V - self.V_sh).to_decimal(bu.mV) + return 1. / (1. + bu.math.exp(-(V + 36.) / 20.)) + + def f_p_tau(self, V): + V = (V - self.V_sh).to_decimal(bu.mV) + return 1. / (bu.math.exp((V + 35.8) / 19.7) + + bu.math.exp(-(V + 79.7) / 12.7)) + 0.37 + + def f_q_inf(self, V): + V = (V - self.V_sh).to_decimal(bu.mV) + return 1. / (1. + bu.math.exp((V + 78.) / 6.)) + + def f_q_tau(self, V): + V = (V - self.V_sh).to_decimal(bu.mV) + return bu.math.where( + V < -63, + 1. / (bu.math.exp((V + 46.) / 5.) + + bu.math.exp(-(V + 238.) / 37.5)), + 19. + ) class _IKK2_pq_ss(PotassiumChannel): - r"""The slowly inactivating Potassium channel of :math:`pq` - current which described with steady-state format. - - The dynamics of the model is given as [2]_ [3]_. - - .. math:: - - &IK2 = g_{\mathrm{max}} p q (E-V) \\ - &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ - &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ - - where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - - Parameters - ---------- - size: int, sequence of int - The geometry size. - name: Optional[str] - The object name. - g_max : bst.typing.ArrayLike, Callable - The maximal conductance density (:math:`mS/cm^2`). - phi_p : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`p`. - phi_q : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`q`. - - References - ---------- - .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the - currents involved in rhythmic oscillations in thalamic relay - neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. - .. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a - TEA-sensitive K current in acutely isolated rat thalamic relay - neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. - - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - g_max: Union[bst.typing.ArrayLike, Callable] = 10. * (bu.mS / bu.cm ** 2), - phi_p: Union[bst.typing.ArrayLike, Callable] = 1., - phi_q: Union[bst.typing.ArrayLike, Callable] = 1., - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): - super().__init__( - size, - name=name, - mode=mode - ) - - # parameters - self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) - self.phi_p = bst.init.param(phi_p, self.varshape, allow_none=False) - self.phi_q = bst.init.param(phi_q, self.varshape, allow_none=False) - - def compute_derivative(self, V, K: IonInfo): - self.p.derivative = self.phi_p * (self.f_p_inf(V) - self.p.value) / self.f_p_tau(V) - self.q.derivative = self.phi_q * (self.f_q_inf(V) - self.q.value) / self.f_q_tau(V) - - def current(self, V, K: IonInfo): - return self.g_max * self.p.value * self.q.value * (K.E - V) - - def init_state(self, V, Ca: IonInfo, batch_size: int = None): - self.p = State4Integral(bst.init.param(bu.math.zeros, self.varshape, batch_size)) - self.q = State4Integral(bst.init.param(bu.math.zeros, self.varshape, batch_size)) - - def reset_state(self, V, K: IonInfo, batch_size=None): - self.p.value = self.f_p_inf(V) - self.q.value = self.f_q_inf(V) - if isinstance(batch_size, int): - assert self.p.value.shape[0] == batch_size - assert self.q.value.shape[0] == batch_size - - def f_p_inf(self, V): - raise NotImplementedError - - def f_p_tau(self, V): - raise NotImplementedError - - def f_q_inf(self, V): - raise NotImplementedError - - def f_q_tau(self, V): - raise NotImplementedError + r"""The slowly inactivating Potassium channel of :math:`pq` + current which described with steady-state format. + + The dynamics of the model is given as [2]_ [3]_. + + .. math:: + + &IK2 = g_{\mathrm{max}} p q (E-V) \\ + &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ + &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ + + where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). + + Parameters + ---------- + size: int, sequence of int + The geometry size. + name: Optional[str] + The object name. + g_max : bst.typing.ArrayLike, Callable + The maximal conductance density (:math:`mS/cm^2`). + phi_p : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`p`. + phi_q : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`q`. + + References + ---------- + .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the + currents involved in rhythmic oscillations in thalamic relay + neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. + .. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a + TEA-sensitive K current in acutely isolated rat thalamic relay + neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. + + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + g_max: Union[bst.typing.ArrayLike, Callable] = 10. * (bu.mS / bu.cm ** 2), + phi_p: Union[bst.typing.ArrayLike, Callable] = 1., + phi_q: Union[bst.typing.ArrayLike, Callable] = 1., + name: Optional[str] = None, + ): + super().__init__(size=size, name=name, ) + # parameters + self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) + self.phi_p = bst.init.param(phi_p, self.varshape, allow_none=False) + self.phi_q = bst.init.param(phi_q, self.varshape, allow_none=False) + + def compute_derivative(self, V, K: IonInfo): + self.p.derivative = self.phi_p * (self.f_p_inf(V) - self.p.value) / self.f_p_tau(V) + self.q.derivative = self.phi_q * (self.f_q_inf(V) - self.q.value) / self.f_q_tau(V) + + def current(self, V, K: IonInfo): + return self.g_max * self.p.value * self.q.value * (K.E - V) + + def init_state(self, V, Ca: IonInfo, batch_size: int = None): + self.p = State4Integral(bst.init.param(bu.math.zeros, self.varshape, batch_size)) + self.q = State4Integral(bst.init.param(bu.math.zeros, self.varshape, batch_size)) + + def reset_state(self, V, K: IonInfo, batch_size=None): + self.p.value = self.f_p_inf(V) + self.q.value = self.f_q_inf(V) + if isinstance(batch_size, int): + assert self.p.value.shape[0] == batch_size + assert self.q.value.shape[0] == batch_size + + def f_p_inf(self, V): + raise NotImplementedError + + def f_p_tau(self, V): + raise NotImplementedError + + def f_q_inf(self, V): + raise NotImplementedError + + def f_q_tau(self, V): + raise NotImplementedError class IKK2A_HM1992(_IKK2_pq_ss): - r"""The slowly inactivating Potassium channel (IK2a) model proposed by (Huguenard & McCormick, 1992) [2]_. - - The dynamics of the model is given as [2]_ [3]_. - - .. math:: - - &IK2 = g_{\mathrm{max}} p q (E-V) \\ - &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ - &p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 43)/17]} \\ - &\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}-81.}{25.6}\right)+ - \exp \left(\frac{V -V_{sh}+132}{-18}\right)}+9.9 \\ - &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ - &q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 59)/10.6]} \\ - & \tau_{q} = \frac{1}{\exp((V -V_{sh}+1329)/200.) + \exp((V -V_{sh}+130)/-7.1)} + 120 \\ - - where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - - Parameters - ---------- - size: int, sequence of int - The geometry size. - name: Optional[str] - The object name. - g_max : bst.typing.ArrayLike, Callable - The maximal conductance density (:math:`mS/cm^2`). - V_sh : float, ArrayType, Callable, Initializer - The membrane potential shift. - phi_p : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`p`. - phi_q : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`q`. - - References - ---------- - .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the - currents involved in rhythmic oscillations in thalamic relay - neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. - .. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a - TEA-sensitive K current in acutely isolated rat thalamic relay - neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. - - """ - __module__ = 'dendritex.channels' - - def __init__( - self, - size: Union[int, Sequence[int]], - g_max: Union[bst.typing.ArrayLike, Callable] = 10. * (bu.mS * bu.cm ** -2), - V_sh: Union[bst.typing.ArrayLike, Callable] = 0. * bu.mV, - phi_p: Union[bst.typing.ArrayLike, Callable] = 1., - phi_q: Union[bst.typing.ArrayLike, Callable] = 1., - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): - super().__init__( - size, - name=name, - phi_p=phi_p, - phi_q=phi_q, - g_max=g_max, - mode=mode - ) - - # parameters - self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) - - def f_p_inf(self, V): - V = (V - self.V_sh).to_decimal(bu.mV) - return 1. / (1. + bu.math.exp(-(V + 43.) / 17.)) - - def f_p_tau(self, V): - V = (V - self.V_sh).to_decimal(bu.mV) - return 1. / (bu.math.exp((V - 81.) / 25.6) + - bu.math.exp(-(V + 132) / 18.)) + 9.9 - - def f_q_inf(self, V): - V = (V - self.V_sh).to_decimal(bu.mV) - return 1. / (1. + bu.math.exp((V + 58.) / 10.6)) - - def f_q_tau(self, V): - V = (V - self.V_sh).to_decimal(bu.mV) - return 1. / (bu.math.exp((V - 1329.) / 200.) + - bu.math.exp(-(V + 130.) / 7.1)) + r"""The slowly inactivating Potassium channel (IK2a) model proposed by (Huguenard & McCormick, 1992) [2]_. + + The dynamics of the model is given as [2]_ [3]_. + + .. math:: + + &IK2 = g_{\mathrm{max}} p q (E-V) \\ + &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ + &p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 43)/17]} \\ + &\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}-81.}{25.6}\right)+ + \exp \left(\frac{V -V_{sh}+132}{-18}\right)}+9.9 \\ + &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ + &q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 59)/10.6]} \\ + & \tau_{q} = \frac{1}{\exp((V -V_{sh}+1329)/200.) + \exp((V -V_{sh}+130)/-7.1)} + 120 \\ + + where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). + + Parameters + ---------- + size: int, sequence of int + The geometry size. + name: Optional[str] + The object name. + g_max : bst.typing.ArrayLike, Callable + The maximal conductance density (:math:`mS/cm^2`). + V_sh : float, ArrayType, Callable, Initializer + The membrane potential shift. + phi_p : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`p`. + phi_q : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`q`. + + References + ---------- + .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the + currents involved in rhythmic oscillations in thalamic relay + neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. + .. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a + TEA-sensitive K current in acutely isolated rat thalamic relay + neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. + + """ + __module__ = 'dendritex.channels' + + def __init__( + self, + size: Union[int, Sequence[int]], + g_max: Union[bst.typing.ArrayLike, Callable] = 10. * (bu.mS * bu.cm ** -2), + V_sh: Union[bst.typing.ArrayLike, Callable] = 0. * bu.mV, + phi_p: Union[bst.typing.ArrayLike, Callable] = 1., + phi_q: Union[bst.typing.ArrayLike, Callable] = 1., + name: Optional[str] = None, + ): + super().__init__( + size, + name=name, + phi_p=phi_p, + phi_q=phi_q, + g_max=g_max, + ) + + # parameters + self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) + + def f_p_inf(self, V): + V = (V - self.V_sh).to_decimal(bu.mV) + return 1. / (1. + bu.math.exp(-(V + 43.) / 17.)) + + def f_p_tau(self, V): + V = (V - self.V_sh).to_decimal(bu.mV) + return 1. / (bu.math.exp((V - 81.) / 25.6) + + bu.math.exp(-(V + 132) / 18.)) + 9.9 + + def f_q_inf(self, V): + V = (V - self.V_sh).to_decimal(bu.mV) + return 1. / (1. + bu.math.exp((V + 58.) / 10.6)) + + def f_q_tau(self, V): + V = (V - self.V_sh).to_decimal(bu.mV) + return 1. / (bu.math.exp((V - 1329.) / 200.) + + bu.math.exp(-(V + 130.) / 7.1)) class IKK2B_HM1992(_IKK2_pq_ss): - r"""The slowly inactivating Potassium channel (IK2b) model proposed by (Huguenard & McCormick, 1992) [2]_. - - The dynamics of the model is given as [2]_ [3]_. - - .. math:: - - &IK2 = g_{\mathrm{max}} p q (E-V) \\ - &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ - &p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 43)/17]} \\ - &\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}-81.}{25.6}\right)+ - \exp \left(\frac{V -V_{sh}+132}{-18}\right)}+9.9 \\ - &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ - &q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 59)/10.6]} \\ - &\begin{array}{l} \tau_{q} = \frac{1}{\exp((V -V_{sh}+1329)/200.) + - \exp((V -V_{sh}+130)/-7.1)} + 120 \quad V<(-70+V_{sh})\, mV \\ - \tau_{q} = 8.9 \quad V \geq (-70 + V_{sh})\, mV \end{array} - - where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - - Parameters - ---------- - size: int, sequence of int - The geometry size. - name: Optional[str] - The object name. - g_max : bst.typing.ArrayLike, Callable - The maximal conductance density (:math:`mS/cm^2`). - V_sh : float, ArrayType, Callable, Initializer - The membrane potential shift. - phi_p : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`p`. - phi_q : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`q`. - - References - ---------- - .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the - currents involved in rhythmic oscillations in thalamic relay - neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. - .. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a - TEA-sensitive K current in acutely isolated rat thalamic relay - neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. - - """ - __module__ = 'dendritex.channels' - - def __init__( - self, - size: Union[int, Sequence[int]], - g_max: Union[bst.typing.ArrayLike, Callable] = 10. * (bu.mS * bu.cm ** -2), - V_sh: Union[bst.typing.ArrayLike, Callable] = 0. * bu.mV, - phi_p: Union[bst.typing.ArrayLike, Callable] = 1., - phi_q: Union[bst.typing.ArrayLike, Callable] = 1., - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): - super().__init__( - size, - name=name, - phi_p=phi_p, - phi_q=phi_q, - g_max=g_max, - mode=mode - ) - - # parameters - self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) - - def f_p_inf(self, V): - V = (V - self.V_sh).to_decimal(bu.mV) - return 1. / (1. + bu.math.exp(-(V + 43.) / 17.)) - - def f_p_tau(self, V): - V = (V - self.V_sh).to_decimal(bu.mV) - return 1. / (bu.math.exp((V - 81.) / 25.6) + - bu.math.exp(-(V + 132) / 18.)) + 9.9 - - def f_q_inf(self, V): - V = (V - self.V_sh).to_decimal(bu.mV) - return 1. / (1. + bu.math.exp((V + 58.) / 10.6)) - - def f_q_tau(self, V): - V = (V - self.V_sh).to_decimal(bu.mV) - return bu.math.where( - V < -70, - 1. / (bu.math.exp((V - 1329.) / 200.) + - bu.math.exp(-(V + 130.) / 7.1)), - 8.9 - ) + r"""The slowly inactivating Potassium channel (IK2b) model proposed by (Huguenard & McCormick, 1992) [2]_. + + The dynamics of the model is given as [2]_ [3]_. + + .. math:: + + &IK2 = g_{\mathrm{max}} p q (E-V) \\ + &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ + &p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 43)/17]} \\ + &\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}-81.}{25.6}\right)+ + \exp \left(\frac{V -V_{sh}+132}{-18}\right)}+9.9 \\ + &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ + &q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 59)/10.6]} \\ + &\begin{array}{l} \tau_{q} = \frac{1}{\exp((V -V_{sh}+1329)/200.) + + \exp((V -V_{sh}+130)/-7.1)} + 120 \quad V<(-70+V_{sh})\, mV \\ + \tau_{q} = 8.9 \quad V \geq (-70 + V_{sh})\, mV \end{array} + + where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). + + Parameters + ---------- + size: int, sequence of int + The geometry size. + name: Optional[str] + The object name. + g_max : bst.typing.ArrayLike, Callable + The maximal conductance density (:math:`mS/cm^2`). + V_sh : float, ArrayType, Callable, Initializer + The membrane potential shift. + phi_p : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`p`. + phi_q : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`q`. + + References + ---------- + .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the + currents involved in rhythmic oscillations in thalamic relay + neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. + .. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a + TEA-sensitive K current in acutely isolated rat thalamic relay + neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. + + """ + __module__ = 'dendritex.channels' + + def __init__( + self, + size: Union[int, Sequence[int]], + g_max: Union[bst.typing.ArrayLike, Callable] = 10. * (bu.mS * bu.cm ** -2), + V_sh: Union[bst.typing.ArrayLike, Callable] = 0. * bu.mV, + phi_p: Union[bst.typing.ArrayLike, Callable] = 1., + phi_q: Union[bst.typing.ArrayLike, Callable] = 1., + name: Optional[str] = None, + ): + super().__init__( + size, + name=name, + phi_p=phi_p, + phi_q=phi_q, + g_max=g_max, + ) + + # parameters + self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) + + def f_p_inf(self, V): + V = (V - self.V_sh).to_decimal(bu.mV) + return 1. / (1. + bu.math.exp(-(V + 43.) / 17.)) + + def f_p_tau(self, V): + V = (V - self.V_sh).to_decimal(bu.mV) + return 1. / (bu.math.exp((V - 81.) / 25.6) + + bu.math.exp(-(V + 132) / 18.)) + 9.9 + + def f_q_inf(self, V): + V = (V - self.V_sh).to_decimal(bu.mV) + return 1. / (1. + bu.math.exp((V + 58.) / 10.6)) + + def f_q_tau(self, V): + V = (V - self.V_sh).to_decimal(bu.mV) + return bu.math.where( + V < -70, + 1. / (bu.math.exp((V - 1329.) / 200.) + + bu.math.exp(-(V + 130.) / 7.1)), + 8.9 + ) class IKNI_Ya1989(PotassiumChannel): - r"""A slow non-inactivating K+ current described by Yamada et al. (1989) [1]_. - - This slow potassium current can effectively account for spike-frequency adaptation. - - .. math:: - - \begin{aligned} - &I_{M}=\bar{g}_{M} p\left(V-E_{K}\right) \\ - &\frac{\mathrm{d} p}{\mathrm{~d} t}=\left(p_{\infty}(V)-p\right) / \tau_{p}(V) \\ - &p_{\infty}(V)=\frac{1}{1+\exp [-(V-V_{sh}+35) / 10]} \\ - &\tau_{p}(V)=\frac{\tau_{\max }}{3.3 \exp [(V-V_{sh}+35) / 20]+\exp [-(V-V_{sh}+35) / 20]} - \end{aligned} - - where :math:`\bar{g}_{M}` was :math:`0.004 \mathrm{mS} / \mathrm{cm}^{2}` and - :math:`\tau_{\max }=4 \mathrm{~s}`, unless stated otherwise. - - Parameters - ---------- - size: int, sequence of int - The geometry size. - name: Optional[str] - The object name. - g_max : bst.typing.ArrayLike, Callable - The maximal conductance density (:math:`mS/cm^2`). - V_sh : float, ArrayType, Callable, Initializer - The membrane potential shift. - phi_p : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`p`. - tau_max: float, ArrayType, Callable, Initializer - The :math:`tau_{\max}` parameter. - - References - ---------- - .. [1] Yamada, Walter M. "Multiple channels and calcium dynamics." Methods in neuronal modeling (1989): 97-133. - - """ - __module__ = 'dendritex.channels' - - def __init__( - self, - size: Union[int, Sequence[int]], - g_max: Union[bst.typing.ArrayLike, Callable] = 0.004 * (bu.mS * bu.cm ** -2), - phi_p: Union[bst.typing.ArrayLike, Callable] = 1., - phi_q: Union[bst.typing.ArrayLike, Callable] = 1., - tau_max: Union[bst.typing.ArrayLike, Callable] = 4e3 * bu.ms, - V_sh: Union[bst.typing.ArrayLike, Callable] = 0. * bu.mV, - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): - super().__init__(size, name=name, mode=mode) - - # parameters - self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) - self.tau_max = bst.init.param(tau_max, self.varshape, allow_none=False) - self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) - self.phi_p = bst.init.param(phi_p, self.varshape, allow_none=False) - self.phi_q = bst.init.param(phi_q, self.varshape, allow_none=False) - - def compute_derivative(self, V, K: IonInfo): - self.p.derivative = self.phi_p * (self.f_p_inf(V) - self.p.value) / self.f_p_tau(V) - - def current(self, V, K: IonInfo): - return self.g_max * self.p.value * (K.E - V) - - def init_state(self, V, Ca: IonInfo, batch_size: int = None): - self.p = State4Integral(bst.init.param(bu.math.zeros, self.varshape, batch_size)) - - def reset_state(self, V, K: IonInfo, batch_size=None): - self.p.value = self.f_p_inf(V) - if isinstance(batch_size, int): - assert self.p.value.shape[0] == batch_size - - def f_p_inf(self, V): - V = (V - self.V_sh).to_decimal(bu.mV) - return 1. / (1. + bu.math.exp(-(V + 35.) / 10.)) - - def f_p_tau(self, V): - V = (V - self.V_sh).to_decimal(bu.mV) - temp = V + 35. - return self.tau_max / (3.3 * bu.math.exp(temp / 20.) + bu.math.exp(-temp / 20.)) + r"""A slow non-inactivating K+ current described by Yamada et al. (1989) [1]_. + This slow potassium current can effectively account for spike-frequency adaptation. -class IK_Leak(PotassiumChannel): - """The potassium leak channel current. - - Parameters - ---------- - g_max : float - The potassium leakage conductance which is modulated by both - acetylcholine and norepinephrine. - """ - __module__ = 'dendritex.channels' - - root_type = Potassium - - def __init__( - self, - size: Union[int, Sequence[int]], - g_max: Union[int, bst.typing.ArrayLike, Callable] = 0.005 * (bu.mS * bu.cm ** -2), - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): - super().__init__( - size=size, - name=name, - mode=mode - ) - self.g_max = bst.init.param(g_max, self.varshape) - - def reset_state(self, V, K: IonInfo, batch_size: int = None): - pass - - def compute_derivative(self, V, K: IonInfo): - pass - - def current(self, V, K: IonInfo): - return self.g_max * (K.E - V) + .. math:: + \begin{aligned} + &I_{M}=\bar{g}_{M} p\left(V-E_{K}\right) \\ + &\frac{\mathrm{d} p}{\mathrm{~d} t}=\left(p_{\infty}(V)-p\right) / \tau_{p}(V) \\ + &p_{\infty}(V)=\frac{1}{1+\exp [-(V-V_{sh}+35) / 10]} \\ + &\tau_{p}(V)=\frac{\tau_{\max }}{3.3 \exp [(V-V_{sh}+35) / 20]+\exp [-(V-V_{sh}+35) / 20]} + \end{aligned} -class IKv11_Ak2007(PotassiumChannel): - r""" - TITLE Voltage-gated low threshold potassium current from Kv1 subunits + where :math:`\bar{g}_{M}` was :math:`0.004 \mathrm{mS} / \mathrm{cm}^{2}` and + :math:`\tau_{\max }=4 \mathrm{~s}`, unless stated otherwise. + + Parameters + ---------- + size: int, sequence of int + The geometry size. + name: Optional[str] + The object name. + g_max : bst.typing.ArrayLike, Callable + The maximal conductance density (:math:`mS/cm^2`). + V_sh : float, ArrayType, Callable, Initializer + The membrane potential shift. + phi_p : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`p`. + tau_max: float, ArrayType, Callable, Initializer + The :math:`tau_{\max}` parameter. + + References + ---------- + .. [1] Yamada, Walter M. "Multiple channels and calcium dynamics." Methods in neuronal modeling (1989): 97-133. + + """ + __module__ = 'dendritex.channels' + + def __init__( + self, + size: Union[int, Sequence[int]], + g_max: Union[bst.typing.ArrayLike, Callable] = 0.004 * (bu.mS * bu.cm ** -2), + phi_p: Union[bst.typing.ArrayLike, Callable] = 1., + phi_q: Union[bst.typing.ArrayLike, Callable] = 1., + tau_max: Union[bst.typing.ArrayLike, Callable] = 4e3 * bu.ms, + V_sh: Union[bst.typing.ArrayLike, Callable] = 0. * bu.mV, + name: Optional[str] = None, + ): + super().__init__(size=size, name=name, ) + + # parameters + self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) + self.tau_max = bst.init.param(tau_max, self.varshape, allow_none=False) + self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) + self.phi_p = bst.init.param(phi_p, self.varshape, allow_none=False) + self.phi_q = bst.init.param(phi_q, self.varshape, allow_none=False) + + def compute_derivative(self, V, K: IonInfo): + self.p.derivative = self.phi_p * (self.f_p_inf(V) - self.p.value) / self.f_p_tau(V) + + def current(self, V, K: IonInfo): + return self.g_max * self.p.value * (K.E - V) + + def init_state(self, V, Ca: IonInfo, batch_size: int = None): + self.p = State4Integral(bst.init.param(bu.math.zeros, self.varshape, batch_size)) + + def reset_state(self, V, K: IonInfo, batch_size=None): + self.p.value = self.f_p_inf(V) + if isinstance(batch_size, int): + assert self.p.value.shape[0] == batch_size + + def f_p_inf(self, V): + V = (V - self.V_sh).to_decimal(bu.mV) + return 1. / (1. + bu.math.exp(-(V + 35.) / 10.)) + + def f_p_tau(self, V): + V = (V - self.V_sh).to_decimal(bu.mV) + temp = V + 35. + return self.tau_max / (3.3 * bu.math.exp(temp / 20.) + bu.math.exp(-temp / 20.)) + + +class IK_Leak(PotassiumChannel): + """The potassium leak channel current. + + Parameters + ---------- + g_max : float + The potassium leakage conductance which is modulated by both + acetylcholine and norepinephrine. + """ + __module__ = 'dendritex.channels' - COMMENT - - NEURON implementation of a potassium channel from Kv1.1 subunits - Kinetical scheme: Hodgkin-Huxley m^4, no inactivation + root_type = Potassium - Experimental data taken from: - Human Kv1.1 expressed in xenopus oocytes: Zerr et al., J Neurosci 18, 2842, 2848, 1998 - Vhalf = -28.8 +- 2.3 mV; k = 8.1+- 0.9 mV + def __init__( + self, + size: Union[int, Sequence[int]], + g_max: Union[int, bst.typing.ArrayLike, Callable] = 0.005 * (bu.mS * bu.cm ** -2), + name: Optional[str] = None, + ): + super().__init__(size=size, name=name, ) + self.g_max = bst.init.param(g_max, self.varshape) - The voltage dependency of the rate constants was approximated by: + def reset_state(self, V, K: IonInfo, batch_size: int = None): + pass - alpha = ca * exp(-(v+cva)/cka) - beta = cb * exp(-(v+cvb)/ckb) + def compute_derivative(self, V, K: IonInfo): + pass - Parameters ca, cva, cka, cb, cvb, ckb - were determined from least square-fits to experimental data of G/Gmax(v) and tau(v). - Values are defined in the CONSTANT block. - Model includes calculation of Kv gating current + def current(self, V, K: IonInfo): + return self.g_max * (K.E - V) - Reference: Akemann et al., Biophys. J. (2009) 96: 3959-3976 - Laboratory for Neuronal Circuit Dynamics - RIKEN Brain Science Institute, Wako City, Japan - http://www.neurodynamics.brain.riken.jp +class IKv11_Ak2007(PotassiumChannel): + r""" + TITLE Voltage-gated low threshold potassium current from Kv1 subunits + + COMMENT + + NEURON implementation of a potassium channel from Kv1.1 subunits + Kinetical scheme: Hodgkin-Huxley m^4, no inactivation + + Experimental data taken from: + Human Kv1.1 expressed in xenopus oocytes: Zerr et al., J Neurosci 18, 2842, 2848, 1998 + Vhalf = -28.8 +- 2.3 mV; k = 8.1+- 0.9 mV + + The voltage dependency of the rate constants was approximated by: + + alpha = ca * exp(-(v+cva)/cka) + beta = cb * exp(-(v+cvb)/ckb) + + Parameters ca, cva, cka, cb, cvb, ckb + were determined from least square-fits to experimental data of G/Gmax(v) and tau(v). + Values are defined in the CONSTANT block. + Model includes calculation of Kv gating current + + Reference: Akemann et al., Biophys. J. (2009) 96: 3959-3976 + + Laboratory for Neuronal Circuit Dynamics + RIKEN Brain Science Institute, Wako City, Japan + http://www.neurodynamics.brain.riken.jp + + Date of Implementation: April 2007 + Contact: akemann@brain.riken.jp + """ + __module__ = 'dendritex.channels' + + def __init__( + self, + size: Union[int, Sequence[int]], + g_max: Union[bst.typing.ArrayLike, Callable] = 4. * (bu.mS / bu.cm ** 2), + gateCurrent: Union[bst.typing.ArrayLike, Callable] = 0., + gunit: Union[bst.typing.ArrayLike, Callable] = 16. * 1e-9 * bu.mS, + V_sh: Union[bst.typing.ArrayLike, Callable] = 0. * bu.mV, + T_base: bst.typing.ArrayLike = 2.7, + T: bst.typing.ArrayLike = 22, + name: Optional[str] = None, + ): + + super().__init__(size=size, name=name, ) + + # parameters + self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) + self.gateCurrent = bst.init.param(gateCurrent, self.varshape, allow_none=False) + self.gunit = bst.init.param(gunit, self.varshape, allow_none=False) + self.T = bst.init.param(T, self.varshape, allow_none=False) + self.T_base = bst.init.param(T_base, self.varshape, allow_none=False) + self.phi = bst.init.param(T_base ** ((T - 22) / 10), self.varshape, allow_none=False) + self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) + + self.e0 = 1.60217646e-19 * bu.coulomb + self.q10 = 2.7 + self.ca = 0.12889 + self.cva = 45 + self.cka = -33.90877 + self.cb = 0.12889 + self.cvb = 45 + self.ckb = 12.42101 + self.zn = 2.7978 + + def init_state(self, V, K: IonInfo, batch_size=None): + self.p = State4Integral(bst.init.param(bu.math.zeros, self.varshape, batch_size)) + + def reset_state(self, V, K: IonInfo, batch_size: int = None): + alpha = self.f_p_alpha(V) + beta = self.f_p_beta(V) + self.p.value = alpha / (alpha + beta) + if isinstance(batch_size, int): + assert self.p.value.shape[0] == batch_size + + def compute_derivative(self, V, K: IonInfo): + self.p.derivative = self.phi * ( + self.f_p_alpha(V) * (1. - self.p.value) - self.f_p_beta(V) * self.p.value) / bu.ms + + def current(self, V, K: IonInfo): + if self.gateCurrent == 0: + ik = self.g_max * self.p.value ** 4 * (K.E - V) + else: + ngateFlip = self.phi * (self.f_p_alpha(V) * (1. - self.p.value) - self.f_p_beta(V) * self.p.value) / bu.ms + igate = ( + 1e12) * self.g_max / self.gunit * 1e6 * self.e0 * 4 * self.zn * ngateFlip # NONSPECIFIC_CURRENT igate + + ik = -igate + self.g_max * self.p.value ** 4 * (K.E - V) + return ik + + def f_p_alpha(self, V): + V = (V - self.V_sh).to_decimal(bu.mV) + return self.ca * bu.math.exp(- (V + self.cva) / self.cka) + + def f_p_beta(self, V): + V = (V - self.V_sh).to_decimal(bu.mV) + return self.cb * bu.math.exp(-(V + self.cvb) / self.ckb) - Date of Implementation: April 2007 - Contact: akemann@brain.riken.jp - """ - __module__ = 'dendritex.channels' - def __init__( - self, - size: Union[int, Sequence[int]], - g_max: Union[bst.typing.ArrayLike, Callable] = 4. * (bu.mS / bu.cm ** 2), - gateCurrent: Union[bst.typing.ArrayLike, Callable] = 0., - gunit: Union[bst.typing.ArrayLike, Callable] = 16. * 1e-9 * bu.mS, - V_sh: Union[bst.typing.ArrayLike, Callable] = 0. * bu.mV, - T_base: bst.typing.ArrayLike = 2.7, - T: bst.typing.ArrayLike = 22, - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): +class IKv34_Ma2020(PotassiumChannel): + r""" + : HH TEA-sensitive Purkinje potassium current + : Created 8/5/02 - nwg + """ + __module__ = 'dendritex.channels' + + def __init__( + self, + size: Union[int, Sequence[int]], + g_max: Union[bst.typing.ArrayLike, Callable] = 4. * (bu.mS / bu.cm ** 2), + V_sh: Union[bst.typing.ArrayLike, Callable] = -11. * bu.mV, + T_base: bst.typing.ArrayLike = 3., + T: bst.typing.ArrayLike = 22, + name: Optional[str] = None, + ): + super().__init__(size=size, name=name, ) + # parameters + self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) + + self.T = bst.init.param(T, self.varshape, allow_none=False) + self.T_base = bst.init.param(T_base, self.varshape, allow_none=False) + self.phi = bst.init.param(T_base ** ((T - 37) / 10), self.varshape, allow_none=False) + self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) + + self.mivh = -24 + self.mik = 15.4 + self.mty0 = .00012851 + self.mtvh1 = 100.7 + self.mtk1 = 12.9 + self.mtvh2 = -56.0 + self.mtk2 = -23.1 + + self.hiy0 = .31 + self.hiA = .69 + self.hivh = -5.802 + self.hik = 11.2 + + def compute_derivative(self, V, K: IonInfo): + self.p.derivative = self.phi * (self.f_p_inf(V) - self.p.value) / self.f_p_tau(V) / bu.ms + self.q.derivative = self.phi * (self.f_q_inf(V) - self.q.value) / self.f_q_tau(V) / bu.ms + + def current(self, V, K: IonInfo): + return self.g_max * self.p.value ** 3 * self.q.value * (K.E - V) + + def init_state(self, V, K: IonInfo, batch_size: int = None): + self.p = State4Integral(bst.init.param(bu.math.zeros, self.varshape, batch_size)) + self.q = State4Integral(bst.init.param(bu.math.zeros, self.varshape, batch_size)) + + def reset_state(self, V, K: IonInfo, batch_size=None): + self.p.value = self.f_p_inf(V) + self.q.value = self.f_q_inf(V) + if isinstance(batch_size, int): + assert self.p.value.shape[0] == batch_size + assert self.q.value.shape[0] == batch_size + + def f_p_inf(self, V): + V = (V - self.V_sh).to_decimal(bu.mV) + return 1. / (1. + bu.math.exp(-(V - self.mivh) / self.mik)) + + def f_p_tau(self, V): + V = (V - self.V_sh).to_decimal(bu.mV) + mtau_func = bu.math.where( + V < -35, + (3.4225e-5 + .00498 * bu.math.exp(-V / -28.29)) * 3, + (self.mty0 + 1. / (bu.math.exp((V + self.mtvh1) / self.mtk1) + + bu.math.exp((V + self.mtvh2) / self.mtk2))) + ) + return 1000 * mtau_func + + def f_q_inf(self, V): + V = (V - self.V_sh).to_decimal(bu.mV) + return self.hiy0 + self.hiA / (1 + bu.math.exp((V - self.hivh) / self.hik)) + + def f_q_tau(self, V): + V = (V - self.V_sh).to_decimal(bu.mV) + htau_func = bu.math.where( + V > 0, + (.0012 + .0023 * bu.math.exp(-.141 * V)), + (1.2202e-05 + .012 * bu.math.exp(-((V - (-56.3)) / 49.6) ** 2)) + ) + return 1000 * htau_func +class IKv43_Ma2020(PotassiumChannel): + r""" + TITLE Cerebellum Granule Cell Model + + COMMENT + KA channel + + Author: E.D'Angelo, T.Nieus, A. Fontana + Last revised: Egidio 3.12.2003 + """ + __module__ = 'dendritex.channels' + + def __init__( + self, + size: Union[int, Sequence[int]], + g_max: Union[bst.typing.ArrayLike, Callable] = 3.2 * (bu.mS / bu.cm ** 2), + V_sh: Union[bst.typing.ArrayLike, Callable] = 0. * bu.mV, + T_base: bst.typing.ArrayLike = 3., + T: bst.typing.ArrayLike = 22, + name: Optional[str] = None, + ): + super().__init__(size=size, name=name, ) + + # parameters + self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) + + self.T = bst.init.param(T, self.varshape, allow_none=False) + self.T_base = bst.init.param(T_base, self.varshape, allow_none=False) + self.phi = bst.init.param(T_base ** ((T - 25.5) / 10), self.varshape, allow_none=False) + self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) + + self.Aalpha_a = 0.8147 + self.Kalpha_a = -23.32708 + self.V0alpha_a = -9.17203 + self.Abeta_a = 0.1655 + self.Kbeta_a = 19.47175 + self.V0beta_a = -18.27914 + + self.Aalpha_b = 0.0368 + self.Kalpha_b = 12.8433 + self.V0alpha_b = -111.33209 + self.Abeta_b = 0.0345 + self.Kbeta_b = -8.90123 + self.V0beta_b = -49.9537 + + self.V0_ainf = -38 + self.K_ainf = -17 + + self.V0_binf = -78.8 + self.K_binf = 8.4 + + def compute_derivative(self, V, K: IonInfo): + self.p.derivative = self.phi * (self.f_p_inf(V) - self.p.value) / self.f_p_tau(V) / bu.ms + self.q.derivative = self.phi * (self.f_q_inf(V) - self.q.value) / self.f_q_tau(V) / bu.ms + + def current(self, V, K: IonInfo): + return self.g_max * self.p.value ** 3 * self.q.value * (K.E - V) + + def init_state(self, V, K: IonInfo, batch_size: int = None): + self.p = State4Integral(bst.init.param(bu.math.zeros, self.varshape, batch_size)) + self.q = State4Integral(bst.init.param(bu.math.zeros, self.varshape, batch_size)) + + def reset_state(self, V, K: IonInfo, batch_size=None): + self.p.value = self.f_p_inf(V) + self.q.value = self.f_q_inf(V) + if isinstance(batch_size, int): + assert self.p.value.shape[0] == batch_size + assert self.q.value.shape[0] == batch_size + + def sigm(self, x, y): + return 1 / (bu.math.exp(x / y) + 1) + + def f_p_alpha(self, V): + V = (V - self.V_sh).to_decimal(bu.mV) + return self.Aalpha_a * self.sigm(V - self.V0alpha_a, self.Kalpha_a) + + def f_p_beta(self, V): + V = (V - self.V_sh).to_decimal(bu.mV) + return self.Abeta_a / (bu.math.exp((V - self.V0beta_a) / self.Kbeta_a)) + + def f_q_alpha(self, V): + V = (V - self.V_sh).to_decimal(bu.mV) + return self.Aalpha_b * self.sigm(V - self.V0alpha_b, self.Kalpha_b) + + def f_q_beta(self, V): + V = (V - self.V_sh).to_decimal(bu.mV) + return self.Abeta_b * self.sigm(V - self.V0beta_b, self.Kbeta_b) + + def f_p_inf(self, V): + V = (V - self.V_sh).to_decimal(bu.mV) + return 1 / (1 + bu.math.exp((V - self.V0_ainf) / self.K_ainf)) + + def f_p_tau(self, V): + return 1. / (self.f_p_alpha(V) + self.f_p_beta(V)) + + def f_q_inf(self, V): + V = (V - self.V_sh).to_decimal(bu.mV) + return 1 / (1 + bu.math.exp((V - self.V0_binf) / self.K_binf)) + + def f_q_tau(self, V): + return 1. / (self.f_q_alpha(V) + self.f_q_beta(V)) - super().__init__( - size, - name=name, - mode=mode - ) - # parameters - self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) - self.gateCurrent = bst.init.param(gateCurrent, self.varshape, allow_none=False) - self.gunit = bst.init.param(gunit, self.varshape, allow_none=False) - self.T = bst.init.param(T, self.varshape, allow_none=False) - self.T_base = bst.init.param(T_base, self.varshape, allow_none=False) - self.phi = bst.init.param(T_base ** ((T - 22) / 10), self.varshape, allow_none=False) - self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) +class IKM_Grc_Ma2020(PotassiumChannel): + r""" + TITLE Cerebellum Granule Cell Model - self.e0 = 1.60217646e-19 * bu.coulomb - self.q10 = 2.7 - self.ca = 0.12889 - self.cva = 45 - self.cka = -33.90877 - self.cb = 0.12889 - self.cvb = 45 - self.ckb = 12.42101 - self.zn = 2.7978 + COMMENT + KM channel + Author: A. Fontana + CoAuthor: T.Nieus Last revised: 20.11.99 + """ + __module__ = 'dendritex.channels' - def init_state(self, V, K: IonInfo, batch_size=None): - self.p = State4Integral(bst.init.param(bu.math.zeros, self.varshape, batch_size)) + def __init__( + self, + size: Union[int, Sequence[int]], + g_max: Union[bst.typing.ArrayLike, Callable] = 0.25 * (bu.mS / bu.cm ** 2), + V_sh: Union[bst.typing.ArrayLike, Callable] = 0. * bu.mV, + T_base: bst.typing.ArrayLike = 3., + T: bst.typing.ArrayLike = 22, + name: Optional[str] = None, + ): + super().__init__(size=size, name=name, ) - def reset_state(self, V, K: IonInfo, batch_size: int = None): - alpha = self.f_p_alpha(V) - beta = self.f_p_beta(V) - self.p.value = alpha / (alpha + beta) - if isinstance(batch_size, int): - assert self.p.value.shape[0] == batch_size + # parameters + self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) - def compute_derivative(self, V, K: IonInfo): - self.p.derivative = self.phi * (self.f_p_alpha(V) * (1. - self.p.value) - self.f_p_beta(V) * self.p.value) / bu.ms + self.T = bst.init.param(T, self.varshape, allow_none=False) + self.T_base = bst.init.param(T_base, self.varshape, allow_none=False) + self.phi = bst.init.param(T_base ** ((T - 22) / 10), self.varshape, allow_none=False) + self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) + self.ek = -84.69 * bu.mV + self.Aalpha_n = 0.0033 + self.Kalpha_n = 40 - def current(self, V, K: IonInfo): - if self.gateCurrent == 0: - ik = self.g_max * self.p.value ** 4 * (K.E - V) - else: - ngateFlip = self.phi * (self.f_p_alpha(V) * (1. - self.p.value) - self.f_p_beta(V) * self.p.value) / bu.ms - igate = (1e12) * self.g_max / self.gunit * 1e6 * self.e0 * 4 * self.zn * ngateFlip # NONSPECIFIC_CURRENT igate + self.V0alpha_n = -30 + self.Abeta_n = 0.0033 + self.Kbeta_n = -20 - ik = -igate + self.g_max * self.p.value ** 4 * (K.E - V) - return ik + self.V0beta_n = -30 + self.V0_ninf = -35 + self.B_ninf = 6 - def f_p_alpha(self, V): - V = (V - self.V_sh).to_decimal(bu.mV) - return self.ca * bu.math.exp(- (V + self.cva) / self.cka) + def compute_derivative(self, V, K: IonInfo): + self.p.derivative = self.phi * (self.f_p_inf(V) - self.p.value) / self.f_p_tau(V) / bu.ms - def f_p_beta(self, V): - V = (V - self.V_sh).to_decimal(bu.mV) - return self.cb * bu.math.exp(-(V + self.cvb) / self.ckb) + def current(self, V, K: IonInfo): + return self.g_max * self.p.value * (self.ek - V) + def init_state(self, V, K: IonInfo, batch_size: int = None): + self.p = State4Integral(bst.init.param(bu.math.zeros, self.varshape, batch_size)) -class IKv34_Ma2020(PotassiumChannel): - r""" - : HH TEA-sensitive Purkinje potassium current - : Created 8/5/02 - nwg - """ - __module__ = 'dendritex.channels' - - def __init__( - self, - size: Union[int, Sequence[int]], - g_max: Union[bst.typing.ArrayLike, Callable] = 4. * (bu.mS / bu.cm ** 2), - V_sh: Union[bst.typing.ArrayLike, Callable] = -11. * bu.mV, - T_base: bst.typing.ArrayLike = 3., - T: bst.typing.ArrayLike = 22, - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): - - super().__init__( - size, - name=name, - mode=mode - ) - - # parameters - self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) - - self.T = bst.init.param(T, self.varshape, allow_none=False) - self.T_base = bst.init.param(T_base, self.varshape, allow_none=False) - self.phi = bst.init.param(T_base ** ((T - 37) / 10), self.varshape, allow_none=False) - self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) - - self.mivh = -24 - self.mik = 15.4 - self.mty0 = .00012851 - self.mtvh1 = 100.7 - self.mtk1 = 12.9 - self.mtvh2 = -56.0 - self.mtk2 = -23.1 - - self.hiy0 = .31 - self.hiA = .69 - self.hivh = -5.802 - self.hik = 11.2 - - def compute_derivative(self, V, K: IonInfo): - self.p.derivative = self.phi * (self.f_p_inf(V) - self.p.value) / self.f_p_tau(V) / bu.ms - self.q.derivative = self.phi * (self.f_q_inf(V) - self.q.value) / self.f_q_tau(V) / bu.ms - - def current(self, V, K: IonInfo): - return self.g_max * self.p.value ** 3 * self.q.value * (K.E - V) - - def init_state(self, V, K: IonInfo, batch_size: int = None): - self.p = State4Integral(bst.init.param(bu.math.zeros, self.varshape, batch_size)) - self.q = State4Integral(bst.init.param(bu.math.zeros, self.varshape, batch_size)) - - def reset_state(self, V, K: IonInfo, batch_size=None): - self.p.value = self.f_p_inf(V) - self.q.value = self.f_q_inf(V) - if isinstance(batch_size, int): - assert self.p.value.shape[0] == batch_size - assert self.q.value.shape[0] == batch_size - - def f_p_inf(self, V): - V = (V - self.V_sh).to_decimal(bu.mV) - return 1. / (1. + bu.math.exp(-(V - self.mivh) / self.mik)) - - def f_p_tau(self, V): - V = (V - self.V_sh).to_decimal(bu.mV) - mtau_func = bu.math.where( - V < -35, - (3.4225e-5 + .00498 * bu.math.exp(-V / -28.29)) * 3, - (self.mty0 + 1. / (bu.math.exp((V + self.mtvh1) / self.mtk1) + - bu.math.exp((V + self.mtvh2) / self.mtk2))) - ) - return 1000 * mtau_func - - def f_q_inf(self, V): - V = (V - self.V_sh).to_decimal(bu.mV) - return self.hiy0 + self.hiA / (1 + bu.math.exp((V - self.hivh) / self.hik)) - - def f_q_tau(self, V): - V = (V - self.V_sh).to_decimal(bu.mV) - htau_func = bu.math.where( - V > 0, - (.0012 + .0023 * bu.math.exp(-.141 * V)), - (1.2202e-05 + .012 * bu.math.exp(-((V - (-56.3)) / 49.6) ** 2)) - ) - return 1000 * htau_func + def reset_state(self, V, K: IonInfo, batch_size=None): + self.p.value = self.f_p_inf(V) + if isinstance(batch_size, int): + assert self.p.value.shape[0] == batch_size -class IKv43_Ma2020(PotassiumChannel): - r""" - TITLE Cerebellum Granule Cell Model - - COMMENT - KA channel - - Author: E.D'Angelo, T.Nieus, A. Fontana - Last revised: Egidio 3.12.2003 - """ - __module__ = 'dendritex.channels' - - def __init__( - self, - size: Union[int, Sequence[int]], - g_max: Union[bst.typing.ArrayLike, Callable] = 3.2 * (bu.mS / bu.cm ** 2), - V_sh: Union[bst.typing.ArrayLike, Callable] = 0. * bu.mV, - T_base: bst.typing.ArrayLike = 3., - T: bst.typing.ArrayLike = 22, - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): - super().__init__(size, name=name, mode=mode) - - # parameters - self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) - - self.T = bst.init.param(T, self.varshape, allow_none=False) - self.T_base = bst.init.param(T_base, self.varshape, allow_none=False) - self.phi = bst.init.param(T_base ** ((T - 25.5) / 10), self.varshape, allow_none=False) - self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) - - self.Aalpha_a = 0.8147 - self.Kalpha_a = -23.32708 - self.V0alpha_a = -9.17203 - self.Abeta_a = 0.1655 - self.Kbeta_a = 19.47175 - self.V0beta_a = -18.27914 - - self.Aalpha_b = 0.0368 - self.Kalpha_b = 12.8433 - self.V0alpha_b = -111.33209 - self.Abeta_b = 0.0345 - self.Kbeta_b = -8.90123 - self.V0beta_b = -49.9537 - - self.V0_ainf = -38 - self.K_ainf = -17 - - self.V0_binf = -78.8 - self.K_binf = 8.4 - - def compute_derivative(self, V, K: IonInfo): - self.p.derivative = self.phi * (self.f_p_inf(V) - self.p.value) / self.f_p_tau(V) / bu.ms - self.q.derivative = self.phi * (self.f_q_inf(V) - self.q.value) / self.f_q_tau(V) / bu.ms - - def current(self, V, K: IonInfo): - return self.g_max * self.p.value ** 3 * self.q.value * (K.E - V) - - def init_state(self, V, K: IonInfo, batch_size: int = None): - self.p = State4Integral(bst.init.param(bu.math.zeros, self.varshape, batch_size)) - self.q = State4Integral(bst.init.param(bu.math.zeros, self.varshape, batch_size)) - - def reset_state(self, V, K: IonInfo, batch_size=None): - self.p.value = self.f_p_inf(V) - self.q.value = self.f_q_inf(V) - if isinstance(batch_size, int): - assert self.p.value.shape[0] == batch_size - assert self.q.value.shape[0] == batch_size - - def sigm(self, x, y): - return 1 / (bu.math.exp(x / y) + 1) - - def f_p_alpha(self, V): - V = (V - self.V_sh).to_decimal(bu.mV) - return self.Aalpha_a * self.sigm(V - self.V0alpha_a, self.Kalpha_a) - - def f_p_beta(self, V): - V = (V - self.V_sh).to_decimal(bu.mV) - return self.Abeta_a / (bu.math.exp((V - self.V0beta_a) / self.Kbeta_a)) - - def f_q_alpha(self, V): - V = (V - self.V_sh).to_decimal(bu.mV) - return self.Aalpha_b * self.sigm(V - self.V0alpha_b, self.Kalpha_b) - - def f_q_beta(self, V): - V = (V - self.V_sh).to_decimal(bu.mV) - return self.Abeta_b * self.sigm(V - self.V0beta_b, self.Kbeta_b) - - def f_p_inf(self, V): - V = (V - self.V_sh).to_decimal(bu.mV) - return 1 / (1 + bu.math.exp((V - self.V0_ainf) / self.K_ainf)) - - def f_p_tau(self, V): - return 1. / (self.f_p_alpha(V) + self.f_p_beta(V)) - - def f_q_inf(self, V): - V = (V - self.V_sh).to_decimal(bu.mV) - return 1 / (1 + bu.math.exp((V - self.V0_binf) / self.K_binf)) - - def f_q_tau(self, V): - return 1. / (self.f_q_alpha(V) + self.f_q_beta(V)) + def f_p_alpha(self, V): + V = (V - self.V_sh).to_decimal(bu.mV) + return self.Aalpha_n * bu.math.exp((V - self.V0alpha_n) / self.Kalpha_n) + def f_p_beta(self, V): + V = (V - self.V_sh).to_decimal(bu.mV) + return self.Abeta_n * bu.math.exp((V - self.V0beta_n) / self.Kbeta_n) + def f_p_inf(self, V): + V = (V - self.V_sh).to_decimal(bu.mV) + return 1 / (1 + bu.math.exp(-(V - self.V0_ninf) / self.B_ninf)) -class IKM_Grc_Ma2020(PotassiumChannel): - r""" - TITLE Cerebellum Granule Cell Model - - COMMENT - KM channel - - Author: A. Fontana - CoAuthor: T.Nieus Last revised: 20.11.99 - """ - __module__ = 'dendritex.channels' - - def __init__( - self, - size: Union[int, Sequence[int]], - g_max: Union[bst.typing.ArrayLike, Callable] = 0.25 * (bu.mS / bu.cm ** 2), - V_sh: Union[bst.typing.ArrayLike, Callable] = 0. * bu.mV, - T_base: bst.typing.ArrayLike = 3., - T: bst.typing.ArrayLike = 22, - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): - super().__init__(size, name=name, mode=mode) - - # parameters - self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) - - self.T = bst.init.param(T, self.varshape, allow_none=False) - self.T_base = bst.init.param(T_base, self.varshape, allow_none=False) - self.phi = bst.init.param(T_base ** ((T - 22) / 10), self.varshape, allow_none=False) - self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) - - self.ek = -84.69 * bu.mV - - self.Aalpha_n = 0.0033 - self.Kalpha_n = 40 - - self.V0alpha_n = -30 - self.Abeta_n = 0.0033 - self.Kbeta_n = -20 - - self.V0beta_n = -30 - self.V0_ninf = -35 - self.B_ninf = 6 - - - - def compute_derivative(self, V, K: IonInfo): - self.p.derivative = self.phi * (self.f_p_inf(V) - self.p.value) / self.f_p_tau(V) / bu.ms - - - def current(self, V, K: IonInfo): - return self.g_max * self.p.value * (self.ek - V) - - def init_state(self, V, K: IonInfo, batch_size: int = None): - self.p = State4Integral(bst.init.param(bu.math.zeros, self.varshape, batch_size)) - - def reset_state(self, V, K: IonInfo, batch_size=None): - self.p.value = self.f_p_inf(V) - if isinstance(batch_size, int): - assert self.p.value.shape[0] == batch_size - - - def f_p_alpha(self, V): - V = (V - self.V_sh).to_decimal(bu.mV) - return self.Aalpha_n*bu.math.exp((V-self.V0alpha_n)/self.Kalpha_n) - - def f_p_beta(self, V): - V = (V - self.V_sh).to_decimal(bu.mV) - return self.Abeta_n*bu.math.exp((V-self.V0beta_n)/self.Kbeta_n) - - def f_p_inf(self, V): - V = (V - self.V_sh).to_decimal(bu.mV) - return 1/(1+bu.math.exp(-(V-self.V0_ninf)/self.B_ninf)) - - def f_p_tau(self, V): - return 1. / (self.f_p_alpha(V) + self.f_p_beta(V)) + def f_p_tau(self, V): + return 1. / (self.f_p_alpha(V) + self.f_p_beta(V)) diff --git a/dendritex/channels/potassium_calcium.py b/dendritex/channels/potassium_calcium.py index 810f3e0..666e256 100644 --- a/dendritex/channels/potassium_calcium.py +++ b/dendritex/channels/potassium_calcium.py @@ -13,463 +13,447 @@ import brainunit as bu import jax -from .._base import IonInfo, Channel, State4Integral -from ..ions import Calcium, Potassium +from dendritex._base import IonInfo, Channel, State4Integral +from dendritex.ions import Calcium, Potassium __all__ = [ - 'IAHP_De1994', - 'IKca3_1_Ma2020', - 'IKca2_2_Ma2020', - 'IKca1_1_Ma2020', + 'IAHP_De1994', + 'IKca3_1_Ma2020', + 'IKca2_2_Ma2020', + 'IKca1_1_Ma2020', ] class KCaChannel(Channel): - __module__ = 'dendritex.channels' + __module__ = 'dendritex.channels' - root_type = bst.mixin.JointTypes[Potassium, Calcium] + root_type = bst.mixin.JointTypes[Potassium, Calcium] - def before_integral(self, V, K: IonInfo, Ca: IonInfo): - pass + def before_integral(self, V, K: IonInfo, Ca: IonInfo): + pass - def after_integral(self, V, K: IonInfo, Ca: IonInfo): - pass + def post_derivative(self, V, K: IonInfo, Ca: IonInfo): + pass - def compute_derivative(self, V, K: IonInfo, Ca: IonInfo): - pass + def compute_derivative(self, V, K: IonInfo, Ca: IonInfo): + pass - def current(self, V, K: IonInfo, Ca: IonInfo): - raise NotImplementedError + def current(self, V, K: IonInfo, Ca: IonInfo): + raise NotImplementedError - def init_state(self, V, K: IonInfo, Ca: IonInfo, batch_size: int = None): - pass + def init_state(self, V, K: IonInfo, Ca: IonInfo, batch_size: int = None): + pass - def reset_state(self, V, K: IonInfo, Ca: IonInfo, batch_size: int = None): - pass + def reset_state(self, V, K: IonInfo, Ca: IonInfo, batch_size: int = None): + pass class IAHP_De1994(KCaChannel): - r"""The calcium-dependent potassium current model proposed by (Destexhe, et al., 1994) [1]_. - - Both in vivo (Contreras et al. 1993; Mulle et al. 1986) and in - vitro recordings (Avanzini et al. 1989) show the presence of a - marked after-hyper-polarization (AHP) after each burst of the RE - cell. This slow AHP is mediated by a slow :math:`Ca^{2+}`-dependent K+ - current (Bal and McCormick 1993). (Destexhe, et al., 1994) adopted a - modified version of a model of :math:`I_{KCa}` introduced previously (Yamada et al. - 1989) that requires the binding of :math:`nCa^{2+}` to open the channel - - .. math:: - - (\text { closed })+n \mathrm{Ca}_{i}^{2+} \underset{\beta}{\stackrel{\alpha}{\rightleftharpoons}(\text { open }) - - where :math:`Ca_i^{2+}` is the intracellular calcium and :math:`\alpha` and - :math:`\beta` are rate constants. The ionic current is then given by - - .. math:: - - \begin{aligned} - I_{AHP} &= g_{\mathrm{max}} p^2 (V - E_K) \\ - {dp \over dt} &= \phi {p_{\infty}(V, [Ca^{2+}]_i) - p \over \tau_p(V, [Ca^{2+}]_i)} \\ - p_{\infty} &=\frac{\alpha[Ca^{2+}]_i^n}{\left(\alpha[Ca^{2+}]_i^n + \beta\right)} \\ - \tau_p &=\frac{1}{\left(\alpha[Ca^{2+}]_i +\beta\right)} - \end{aligned} - - where :math:`E` is the reversal potential, :math:`g_{max}` is the maximum conductance, - :math:`[Ca^{2+}]_i` is the intracellular Calcium concentration. - The values :math:`n=2, \alpha=48 \mathrm{~ms}^{-1} \mathrm{mM}^{-2}` and - :math:`\beta=0.03 \mathrm{~ms}^{-1}` yielded AHPs very similar to those RE cells - recorded in vivo and in vitro. - - Parameters - ---------- - g_max : float - The maximal conductance density (:math:`mS/cm^2`). - - References - ---------- - - .. [1] Destexhe, Alain, et al. "A model of spindle rhythmicity in the isolated - thalamic reticular nucleus." Journal of neurophysiology 72.2 (1994): 803-818. - - """ - __module__ = 'dendritex.channels' - - root_type = bst.mixin.JointTypes[Potassium, Calcium] - - def __init__( - self, - size: bst.typing.Size, - n: Union[bst.typing.ArrayLike, Callable] = 2, - g_max: Union[bst.typing.ArrayLike, Callable] = 10. * (bu.mS / bu.cm ** 2), - alpha: Union[bst.typing.ArrayLike, Callable] = 48., - beta: Union[bst.typing.ArrayLike, Callable] = 0.09, - phi: Union[bst.typing.ArrayLike, Callable] = 1., - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): - super().__init__( - size=size, - name=name, - mode=mode - ) - - # parameters - self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) - self.n = bst.init.param(n, self.varshape, allow_none=False) - self.alpha = bst.init.param(alpha, self.varshape, allow_none=False) - self.beta = bst.init.param(beta, self.varshape, allow_none=False) - self.phi = bst.init.param(phi, self.varshape, allow_none=False) - - def compute_derivative(self, V, K: IonInfo, Ca: IonInfo): - C2 = self.alpha * bu.math.power(Ca.C / bu.mM, self.n) - C3 = C2 + self.beta - self.p.derivative = self.phi * (C2 / C3 - self.p.value) * C3 / bu.ms - - def current(self, V, K: IonInfo, Ca: IonInfo): - return self.g_max * self.p.value * self.p.value * (K.E - V) - - def init_state(self, V, K: IonInfo, Ca: IonInfo, batch_size=None): - self.p = State4Integral(bst.init.param(bu.math.zeros, self.varshape, batch_size)) - - def reset_state(self, V, K: IonInfo, Ca: IonInfo, batch_size=None): - C2 = self.alpha * bu.math.power(Ca.C / bu.mM, self.n) - C3 = C2 + self.beta - if batch_size is None: - self.p.value = bu.math.broadcast_to(C2 / C3, self.varshape) - else: - self.p.value = bu.math.broadcast_to(C2 / C3, (batch_size,) + self.varshape) - assert self.p.value.shape[0] == batch_size + r"""The calcium-dependent potassium current model proposed by (Destexhe, et al., 1994) [1]_. + + Both in vivo (Contreras et al. 1993; Mulle et al. 1986) and in + vitro recordings (Avanzini et al. 1989) show the presence of a + marked after-hyper-polarization (AHP) after each burst of the RE + cell. This slow AHP is mediated by a slow :math:`Ca^{2+}`-dependent K+ + current (Bal and McCormick 1993). (Destexhe, et al., 1994) adopted a + modified version of a model of :math:`I_{KCa}` introduced previously (Yamada et al. + 1989) that requires the binding of :math:`nCa^{2+}` to open the channel + + .. math:: + + (\text { closed })+n \mathrm{Ca}_{i}^{2+} \underset{\beta}{\stackrel{\alpha}{\rightleftharpoons}(\text { open }) + + where :math:`Ca_i^{2+}` is the intracellular calcium and :math:`\alpha` and + :math:`\beta` are rate constants. The ionic current is then given by + + .. math:: + + \begin{aligned} + I_{AHP} &= g_{\mathrm{max}} p^2 (V - E_K) \\ + {dp \over dt} &= \phi {p_{\infty}(V, [Ca^{2+}]_i) - p \over \tau_p(V, [Ca^{2+}]_i)} \\ + p_{\infty} &=\frac{\alpha[Ca^{2+}]_i^n}{\left(\alpha[Ca^{2+}]_i^n + \beta\right)} \\ + \tau_p &=\frac{1}{\left(\alpha[Ca^{2+}]_i +\beta\right)} + \end{aligned} + + where :math:`E` is the reversal potential, :math:`g_{max}` is the maximum conductance, + :math:`[Ca^{2+}]_i` is the intracellular Calcium concentration. + The values :math:`n=2, \alpha=48 \mathrm{~ms}^{-1} \mathrm{mM}^{-2}` and + :math:`\beta=0.03 \mathrm{~ms}^{-1}` yielded AHPs very similar to those RE cells + recorded in vivo and in vitro. + + Parameters + ---------- + g_max : float + The maximal conductance density (:math:`mS/cm^2`). + + References + ---------- + + .. [1] Destexhe, Alain, et al. "A model of spindle rhythmicity in the isolated + thalamic reticular nucleus." Journal of neurophysiology 72.2 (1994): 803-818. + + """ + __module__ = 'dendritex.channels' + + root_type = bst.mixin.JointTypes[Potassium, Calcium] + + def __init__( + self, + size: bst.typing.Size, + n: Union[bst.typing.ArrayLike, Callable] = 2, + g_max: Union[bst.typing.ArrayLike, Callable] = 10. * (bu.mS / bu.cm ** 2), + alpha: Union[bst.typing.ArrayLike, Callable] = 48., + beta: Union[bst.typing.ArrayLike, Callable] = 0.09, + phi: Union[bst.typing.ArrayLike, Callable] = 1., + name: Optional[str] = None, + ): + super().__init__(size=size, name=name, ) + + # parameters + self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) + self.n = bst.init.param(n, self.varshape, allow_none=False) + self.alpha = bst.init.param(alpha, self.varshape, allow_none=False) + self.beta = bst.init.param(beta, self.varshape, allow_none=False) + self.phi = bst.init.param(phi, self.varshape, allow_none=False) + + def compute_derivative(self, V, K: IonInfo, Ca: IonInfo): + C2 = self.alpha * bu.math.power(Ca.C / bu.mM, self.n) + C3 = C2 + self.beta + self.p.derivative = self.phi * (C2 / C3 - self.p.value) * C3 / bu.ms + + def current(self, V, K: IonInfo, Ca: IonInfo): + return self.g_max * self.p.value * self.p.value * (K.E - V) + + def init_state(self, V, K: IonInfo, Ca: IonInfo, batch_size=None): + self.p = State4Integral(bst.init.param(bu.math.zeros, self.varshape, batch_size)) + + def reset_state(self, V, K: IonInfo, Ca: IonInfo, batch_size=None): + C2 = self.alpha * bu.math.power(Ca.C / bu.mM, self.n) + C3 = C2 + self.beta + if batch_size is None: + self.p.value = bu.math.broadcast_to(C2 / C3, self.varshape) + else: + self.p.value = bu.math.broadcast_to(C2 / C3, (batch_size,) + self.varshape) + assert self.p.value.shape[0] == batch_size class IKca3_1_Ma2020(KCaChannel): - r''' - TITLE Calcium dependent potassium channel - : Implemented in Rubin and Cleland (2006) J Neurophysiology - : Parameters from Bhalla and Bower (1993) J Neurophysiology - : Adapted from /usr/local/neuron/demo/release/nachan.mod - squid - : by Andrew Davison, The Babraham Institute [Brain Res Bulletin, 2000] - ''' - __module__ = 'dendritex.channels' - - root_type = bst.mixin.JointTypes[Potassium, Calcium] - - def __init__( - self, - size: bst.typing.Size, - g_max: Union[bst.typing.ArrayLike, Callable] = 120. * (bu.mS / bu.cm ** 2), - T_base: bst.typing.ArrayLike = 3., - T: bst.typing.ArrayLike = 22, - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): - super().__init__( - size=size, - name=name, - mode=mode - ) - - # parameters - self.T = bst.init.param(T, self.varshape, allow_none=False) - self.T_base = bst.init.param(T_base, self.varshape, allow_none=False) - self.phi = bst.init.param(T_base ** ((T - 37) / 10), self.varshape, allow_none=False) - self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) - - self.p_beta = 0.05 - - def current(self, V, K: IonInfo, Ca: IonInfo): - return self.g_max * self.p.value * (K.E - V) - - def p_tau(self, V, Ca): - return 1 / (self.p_alpha(V, Ca) + self.p_beta) - - def p_inf(self, V, Ca): - return self.p_alpha(V, Ca) / (self.p_alpha(V, Ca) + self.p_beta) - - def p_alpha(self, V, Ca): - V = V / bu.mV - return self.p_vdep(V) * self.p_concdep(Ca) - - def p_vdep(self, V): - return bu.math.exp((V + 70.) / 27.) - - def p_concdep(self, Ca): - # concdep_1 = 500 * (0.015 - Ca.C / u.mM) / (u.math.exp((0.015 - Ca.C / u.mM) / 0.0013) - 1) - concdep_1 = 500 * 0.0013 / bu.math.exprel((0.015 - Ca.C / bu.mM) / 0.0013) - with jax.ensure_compile_time_eval(): - concdep_2 = 500 * 0.005 / (bu.math.exp(0.005 / 0.0013) - 1) - return bu.math.where(Ca.C / bu.mM < 0.01, concdep_1, concdep_2) - - def init_state(self, V, K: IonInfo, Ca: IonInfo, batch_size=None): - self.p = State4Integral(bst.init.param(bu.math.zeros, self.varshape, batch_size)) - self.reset_state(V, K, Ca) - - def reset_state(self, V, K: IonInfo, Ca: IonInfo, batch_size=None): - self.p.value = self.p_inf(V, Ca) - - def compute_derivative(self, V, K: IonInfo, Ca: IonInfo): - self.p.derivative = self.phi * (self.p_inf(V, Ca) - self.p.value) / self.p_tau(V, Ca) / bu.ms + r''' + TITLE Calcium dependent potassium channel + : Implemented in Rubin and Cleland (2006) J Neurophysiology + : Parameters from Bhalla and Bower (1993) J Neurophysiology + : Adapted from /usr/local/neuron/demo/release/nachan.mod - squid + : by Andrew Davison, The Babraham Institute [Brain Res Bulletin, 2000] + ''' + __module__ = 'dendritex.channels' + + root_type = bst.mixin.JointTypes[Potassium, Calcium] + + def __init__( + self, + size: bst.typing.Size, + g_max: Union[bst.typing.ArrayLike, Callable] = 120. * (bu.mS / bu.cm ** 2), + T_base: bst.typing.ArrayLike = 3., + T: bst.typing.ArrayLike = 22, + name: Optional[str] = None, + ): + super().__init__(size=size, name=name, ) + + # parameters + self.T = bst.init.param(T, self.varshape, allow_none=False) + self.T_base = bst.init.param(T_base, self.varshape, allow_none=False) + self.phi = bst.init.param(T_base ** ((T - 37) / 10), self.varshape, allow_none=False) + self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) + + self.p_beta = 0.05 + + def current(self, V, K: IonInfo, Ca: IonInfo): + return self.g_max * self.p.value * (K.E - V) + + def p_tau(self, V, Ca): + return 1 / (self.p_alpha(V, Ca) + self.p_beta) + + def p_inf(self, V, Ca): + return self.p_alpha(V, Ca) / (self.p_alpha(V, Ca) + self.p_beta) + + def p_alpha(self, V, Ca): + V = V / bu.mV + return self.p_vdep(V) * self.p_concdep(Ca) + + def p_vdep(self, V): + return bu.math.exp((V + 70.) / 27.) + + def p_concdep(self, Ca): + # concdep_1 = 500 * (0.015 - Ca.C / u.mM) / (u.math.exp((0.015 - Ca.C / u.mM) / 0.0013) - 1) + concdep_1 = 500 * 0.0013 / bu.math.exprel((0.015 - Ca.C / bu.mM) / 0.0013) + with jax.ensure_compile_time_eval(): + concdep_2 = 500 * 0.005 / (bu.math.exp(0.005 / 0.0013) - 1) + return bu.math.where(Ca.C / bu.mM < 0.01, concdep_1, concdep_2) + + def init_state(self, V, K: IonInfo, Ca: IonInfo, batch_size=None): + self.p = State4Integral(bst.init.param(bu.math.zeros, self.varshape, batch_size)) + self.reset_state(V, K, Ca) + + def reset_state(self, V, K: IonInfo, Ca: IonInfo, batch_size=None): + self.p.value = self.p_inf(V, Ca) + + def compute_derivative(self, V, K: IonInfo, Ca: IonInfo): + self.p.derivative = self.phi * (self.p_inf(V, Ca) - self.p.value) / self.p_tau(V, Ca) / bu.ms class IKca2_2_Ma2020(KCaChannel): - r''' - TITLE SK2 multi-state model Cerebellum Golgi Cell Model - - COMMENT - - Author:Sergio Solinas, Lia Forti, Egidio DAngelo - Based on data from: Hirschberg, Maylie, Adelman, Marrion J Gen Physiol 1998 - Last revised: May 2007 - - Published in: - Sergio M. Solinas, Lia Forti, Elisabetta Cesana, - Jonathan Mapelli, Erik De Schutter and Egidio D`Angelo (2008) - Computational reconstruction of pacemaking and intrinsic - electroresponsiveness in cerebellar golgi cells - Frontiers in Cellular Neuroscience 2:2 - ''' - __module__ = 'dendritex.channels' - - root_type = bst.mixin.JointTypes[Potassium, Calcium] - - def __init__( - self, - size: bst.typing.Size, - g_max: Union[bst.typing.ArrayLike, Callable] = 38. * (bu.mS / bu.cm ** 2), - T_base: bst.typing.ArrayLike = 3., - diff: bst.typing.ArrayLike = 3., - T: bst.typing.ArrayLike = 22, - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): - super().__init__( - size=size, - name=name, - mode=mode - ) - - # parameters - self.T = bst.init.param(T, self.varshape, allow_none=False) - self.T_base = bst.init.param(T_base, self.varshape, allow_none=False) - self.phi = bst.init.param(T_base ** ((T - 23) / 10), self.varshape, allow_none=False) - self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) - self.diff = bst.init.param(diff, self.varshape, allow_none=False) - - self.invc1 = 80e-3 # (/ms) - self.invc2 = 80e-3 # (/ms) - self.invc3 = 200e-3 # (/ms) - - self.invo1 = 1 # (/ms) - self.invo2 = 100e-3 # (/ms) - self.diro1 = 160e-3 # (/ms) - self.diro2 = 1.2 # (/ms) - - self.dirc2 = 200 # (/ms-mM) - self.dirc3 = 160 # (/ms-mM) - self.dirc4 = 80 # (/ms-mM) - - def init_state(self, V, K: IonInfo, Ca: IonInfo, batch_size=None): - - self.C1 = State4Integral(bst.init.param(bu.math.ones, self.varshape, batch_size)) - self.C2 = State4Integral(bst.init.param(bu.math.ones, self.varshape, batch_size)) - self.C3 = State4Integral(bst.init.param(bu.math.ones, self.varshape, batch_size)) - self.C4 = State4Integral(bst.init.param(bu.math.ones, self.varshape, batch_size)) - self.O1 = State4Integral(bst.init.param(bu.math.ones, self.varshape, batch_size)) - self.O2 = State4Integral(bst.init.param(bu.math.ones, self.varshape, batch_size)) - self.normalize_states([self.C1, self.C2, self.C3, self.C4, self.O1, self.O2]) - - def reset_state(self, V, K: IonInfo, Ca: IonInfo, batch_size=None): - self.normalize_states([self.C1, self.C2, self.C3, self.C4, self.O1, self.O2]) - - def current(self, V, K: IonInfo, Ca: IonInfo): - return self.g_max * (self.O1.value + self.O2.value) * (K.E - V) - - def before_integral(self, V, K: IonInfo, Ca: IonInfo): - self.normalize_states([self.C1, self.C2, self.C3, self.C4, self.O1, self.O2]) - - def normalize_states(self, states): - total = 0. - for state in states: - state.value = bu.math.maximum(state.value, 0) - total = total + state.value - for state in states: - state.value = state.value / total - - def compute_derivative(self, V, K: IonInfo, Ca: IonInfo): - - self.C1.derivative = (self.C2.value * self.invc1_t(Ca) - self.C1.value * self.dirc2_t_ca(Ca)) / bu.ms - self.C2.derivative = (self.C3.value * self.invc2_t(Ca) + self.C1.value * self.dirc2_t_ca(Ca) - self.C2.value * ( - self.invc1_t(Ca) + self.dirc3_t_ca(Ca))) / bu.ms - self.C3.derivative = (self.C4.value * self.invc3_t(Ca) + self.O1.value * self.invo1_t(Ca) - self.C3.value * ( - self.dirc4_t_ca(Ca) + self.diro1_t(Ca))) / bu.ms - self.C4.derivative = (self.C3.value * self.dirc4_t_ca(Ca) + self.O2.value * self.invo2_t(Ca) - self.C4.value * ( - self.invc3_t(Ca) + self.diro2_t(Ca))) / bu.ms - self.O1.derivative = (self.C3.value * self.diro1_t(Ca) - self.O1.value * self.invo1_t(Ca)) / bu.ms - self.O2.derivative = (self.C4.value * self.diro2_t(Ca) - self.O2.value * self.invo2_t(Ca)) / bu.ms - - dirc2_t_ca = lambda self, Ca: self.dirc2_t * (Ca.C / bu.mM) / self.diff - dirc3_t_ca = lambda self, Ca: self.dirc3_t * (Ca.C / bu.mM) / self.diff - dirc4_t_ca = lambda self, Ca: self.dirc4_t * (Ca.C / bu.mM) / self.diff - - invc1_t = lambda self, Ca: self.invc1 * self.phi - invc2_t = lambda self, Ca: self.invc2 * self.phi - invc3_t = lambda self, Ca: self.invc3 * self.phi - invo1_t = lambda self, Ca: self.invo1 * self.phi - invo2_t = lambda self, Ca: self.invo2 * self.phi - diro1_t = lambda self, Ca: self.diro1 * self.phi - diro2_t = lambda self, Ca: self.diro2 * self.phi - dirc2_t = lambda self, Ca: self.dirc2 * self.phi - dirc3_t = lambda self, Ca: self.dirc3 * self.phi - dirc4_t = lambda self, Ca: self.dirc4 * self.phi + r''' + TITLE SK2 multi-state model Cerebellum Golgi Cell Model + + COMMENT + + Author:Sergio Solinas, Lia Forti, Egidio DAngelo + Based on data from: Hirschberg, Maylie, Adelman, Marrion J Gen Physiol 1998 + Last revised: May 2007 + + Published in: + Sergio M. Solinas, Lia Forti, Elisabetta Cesana, + Jonathan Mapelli, Erik De Schutter and Egidio D`Angelo (2008) + Computational reconstruction of pacemaking and intrinsic + electroresponsiveness in cerebellar golgi cells + Frontiers in Cellular Neuroscience 2:2 + ''' + __module__ = 'dendritex.channels' + + root_type = bst.mixin.JointTypes[Potassium, Calcium] + + def __init__( + self, + size: bst.typing.Size, + g_max: Union[bst.typing.ArrayLike, Callable] = 38. * (bu.mS / bu.cm ** 2), + T_base: bst.typing.ArrayLike = 3., + diff: bst.typing.ArrayLike = 3., + T: bst.typing.ArrayLike = 22, + name: Optional[str] = None, + ): + super().__init__(size=size, name=name, ) + + # parameters + self.T = bst.init.param(T, self.varshape, allow_none=False) + self.T_base = bst.init.param(T_base, self.varshape, allow_none=False) + self.phi = bst.init.param(T_base ** ((T - 23) / 10), self.varshape, allow_none=False) + self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) + self.diff = bst.init.param(diff, self.varshape, allow_none=False) + + self.invc1 = 80e-3 # (/ms) + self.invc2 = 80e-3 # (/ms) + self.invc3 = 200e-3 # (/ms) + + self.invo1 = 1 # (/ms) + self.invo2 = 100e-3 # (/ms) + self.diro1 = 160e-3 # (/ms) + self.diro2 = 1.2 # (/ms) + + self.dirc2 = 200 # (/ms-mM) + self.dirc3 = 160 # (/ms-mM) + self.dirc4 = 80 # (/ms-mM) + + def init_state(self, V, K: IonInfo, Ca: IonInfo, batch_size=None): + + self.C1 = State4Integral(bst.init.param(bu.math.ones, self.varshape, batch_size)) + self.C2 = State4Integral(bst.init.param(bu.math.ones, self.varshape, batch_size)) + self.C3 = State4Integral(bst.init.param(bu.math.ones, self.varshape, batch_size)) + self.C4 = State4Integral(bst.init.param(bu.math.ones, self.varshape, batch_size)) + self.O1 = State4Integral(bst.init.param(bu.math.ones, self.varshape, batch_size)) + self.O2 = State4Integral(bst.init.param(bu.math.ones, self.varshape, batch_size)) + self.normalize_states([self.C1, self.C2, self.C3, self.C4, self.O1, self.O2]) + + def reset_state(self, V, K: IonInfo, Ca: IonInfo, batch_size=None): + self.normalize_states([self.C1, self.C2, self.C3, self.C4, self.O1, self.O2]) + + def current(self, V, K: IonInfo, Ca: IonInfo): + return self.g_max * (self.O1.value + self.O2.value) * (K.E - V) + + def before_integral(self, V, K: IonInfo, Ca: IonInfo): + self.normalize_states([self.C1, self.C2, self.C3, self.C4, self.O1, self.O2]) + + def normalize_states(self, states): + total = 0. + for state in states: + state.value = bu.math.maximum(state.value, 0) + total = total + state.value + for state in states: + state.value = state.value / total + + def compute_derivative(self, V, K: IonInfo, Ca: IonInfo): + + self.C1.derivative = (self.C2.value * self.invc1_t(Ca) - self.C1.value * self.dirc2_t_ca(Ca)) / bu.ms + self.C2.derivative = (self.C3.value * self.invc2_t(Ca) + self.C1.value * self.dirc2_t_ca(Ca) - self.C2.value * ( + self.invc1_t(Ca) + self.dirc3_t_ca(Ca))) / bu.ms + self.C3.derivative = (self.C4.value * self.invc3_t(Ca) + self.O1.value * self.invo1_t(Ca) - self.C3.value * ( + self.dirc4_t_ca(Ca) + self.diro1_t(Ca))) / bu.ms + self.C4.derivative = (self.C3.value * self.dirc4_t_ca(Ca) + self.O2.value * self.invo2_t(Ca) - self.C4.value * ( + self.invc3_t(Ca) + self.diro2_t(Ca))) / bu.ms + self.O1.derivative = (self.C3.value * self.diro1_t(Ca) - self.O1.value * self.invo1_t(Ca)) / bu.ms + self.O2.derivative = (self.C4.value * self.diro2_t(Ca) - self.O2.value * self.invo2_t(Ca)) / bu.ms + + dirc2_t_ca = lambda self, Ca: self.dirc2_t * (Ca.C / bu.mM) / self.diff + dirc3_t_ca = lambda self, Ca: self.dirc3_t * (Ca.C / bu.mM) / self.diff + dirc4_t_ca = lambda self, Ca: self.dirc4_t * (Ca.C / bu.mM) / self.diff + + invc1_t = lambda self, Ca: self.invc1 * self.phi + invc2_t = lambda self, Ca: self.invc2 * self.phi + invc3_t = lambda self, Ca: self.invc3 * self.phi + invo1_t = lambda self, Ca: self.invo1 * self.phi + invo2_t = lambda self, Ca: self.invo2 * self.phi + diro1_t = lambda self, Ca: self.diro1 * self.phi + diro2_t = lambda self, Ca: self.diro2 * self.phi + dirc2_t = lambda self, Ca: self.dirc2 * self.phi + dirc3_t = lambda self, Ca: self.dirc3 * self.phi + dirc4_t = lambda self, Ca: self.dirc4 * self.phi class IKca1_1_Ma2020(KCaChannel): - r''' - TITLE Large conductance Ca2+ activated K+ channel mslo - - COMMENT - - Parameters from Cox et al. (1987) J Gen Physiol 110:257-81 (patch 1). - - Current Model Reference: Anwar H, Hong S, De Schutter E (2010) Controlling Ca2+-activated K+ channels with models of Ca2+ buffering in Purkinje cell. Cerebellum* - - *Article available as Open Access - - PubMed link: http://www.ncbi.nlm.nih.gov/pubmed/20981513 - - - Written by Sungho Hong, Okinawa Institute of Science and Technology, March 2009. - Contact: Sungho Hong (shhong@oist.jp) - ''' - __module__ = 'dendritex.channels' - - root_type = bst.mixin.JointTypes[Potassium, Calcium] - - def __init__( - self, - size: bst.typing.Size, - g_max: Union[bst.typing.ArrayLike, Callable] = 10. * (bu.mS / bu.cm ** 2), - T_base: bst.typing.ArrayLike = 3., - T: bst.typing.ArrayLike = 22., - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): - super().__init__( - size=size, - name=name, - mode=mode - ) - - # parameters - self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) - self.T = bst.init.param(T, self.varshape, allow_none=False) - self.T_base = bst.init.param(T_base, self.varshape, allow_none=False) - self.phi = bst.init.param(T_base ** ((T - 23) / 10), self.varshape, allow_none=False) - - self.Qo = 0.73 - self.Qc = -0.67 - self.k1 = 1.0e3 - self.onoffrate = 1. - self.L0 = 1806 - self.Kc = 11.0e-3 - self.Ko = 1.1e-3 - - self.pf0 = 2.39e-3 - self.pf1 = 7.0e-3 - self.pf2 = 40e-3 - self.pf3 = 295e-3 - self.pf4 = 557e-3 - - self.pb0 = 3936e-3 - self.pb1 = 1152e-3 - self.pb2 = 659e-3 - self.pb3 = 486e-3 - self.pb4 = 92e-3 - - def init_state(self, V, K: IonInfo, Ca: IonInfo, batch_size=None): - - for i in range(5): - setattr(self, f'C{i}', State4Integral(bst.init.param(bu.math.ones, self.varshape, batch_size))) - - for i in range(5): - setattr(self, f'O{i}', State4Integral(bst.init.param(bu.math.ones, self.varshape, batch_size))) - - self.normalize_states([getattr(self, f'C{i}') for i in range(5)] + [getattr(self, f'O{i}') for i in range(5)]) - - def reset_state(self, V, K: IonInfo, Ca: IonInfo, batch_size=None): - self.normalize_states([getattr(self, f'C{i}') for i in range(5)] + [getattr(self, f'O{i}') for i in range(5)]) - - def current(self, V, K: IonInfo, Ca: IonInfo): - return self.g_max * (self.O1.value + self.O2.value) * (K.E - V) - - def before_integral(self, V, K: IonInfo, Ca: IonInfo): - self.normalize_states([getattr(self, f'C{i}') for i in range(5)] + [getattr(self, f'O{i}') for i in range(5)]) - - def normalize_states(self, states): - total = 0. - for state in states: - state.value = bu.math.maximum(state.value, 0) - total = total + state.value - for state in states: - state.value = state.value / total - - def compute_derivative(self, V, K: IonInfo, Ca: IonInfo): - - self.C0.derivative = (self.C1 * self.c10(Ca) + self.O0 * self.b0(V) - self.C0 * (self.c01(Ca) + self.f0(V))) / bu.ms - self.C1.derivative = (self.C0 * self.c01(Ca) + self.C2 * self.c21(Ca) + self.O1 * self.b1(V) - self.C1 * ( - self.c10(Ca) + self.c12(Ca) + self.f1(V))) / bu.ms - self.C2.derivative = (self.C1 * self.c12(Ca) + self.C3 * self.c32(Ca) + self.O2 * self.b2(V) - self.C2 * ( - self.c21(Ca) + self.c23(Ca) + self.f2(V))) / bu.ms - self.C3.derivative = (self.C2 * self.c23(Ca) + self.C4 * self.c43(Ca) + self.O3 * self.b3(V) - self.C3 * ( - self.c32(Ca) + self.c34(Ca) + self.f3(V))) / bu.ms - self.C4.derivative = (self.C3 * self.c34(Ca) + self.O4 * self.b4(V) - self.C4 * (self.c43(Ca) + self.f4(V))) / bu.ms - - self.O0.derivative = (self.O1 * self.o10(Ca) + self.C0 * self.f0(V) - self.O0 * (self.o01(Ca) + self.b0(V))) / bu.ms - self.O1.derivative = (self.O0 * self.o01(Ca) + self.O2 * self.o21(Ca) + self.C1 * self.f1(V) - self.O1 * ( - self.o10(Ca) + self.o12(Ca) + self.b1(V))) / bu.ms - self.O2.derivative = (self.O1 * self.o12(Ca) + self.O3 * self.o32(Ca) + self.C2 * self.f2(V) - self.O2 * ( - self.o21(Ca) + self.o23(Ca) + self.b2(V))) / bu.ms - self.O3.derivative = (self.O2 * self.o23(Ca) + self.O4 * self.o43(Ca) + self.C3 * self.f3(V) - self.O3 * ( - self.o32(Ca) + self.o34(Ca) + self.b3(V))) / bu.ms - self.O4.derivative = (self.O3 * self.o34(Ca) + self.C4 * self.f4(V) - self.O4 * (self.o43(Ca) + self.b4(V))) / bu.ms - - def current(self, V, K: IonInfo, Ca: IonInfo): - return self.g_max * (self.O0.value + self.O1.value + self.O2.value + self.O3.value + self.O4.value) * (K.E - V) - - c01 = lambda self, Ca: 4 * (Ca.C / bu.mM) * self.k1 * self.onoffrate * self.phi - c12 = lambda self, Ca: 3 * (Ca.C / bu.mM) * self.k1 * self.onoffrate * self.phi - c23 = lambda self, Ca: 2 * (Ca.C / bu.mM) * self.k1 * self.onoffrate * self.phi - c34 = lambda self, Ca: 1 * (Ca.C / bu.mM) * self.k1 * self.onoffrate * self.phi - - o01 = lambda self, Ca: 4 * (Ca.C / bu.mM) * self.k1 * self.onoffrate * self.phi - o12 = lambda self, Ca: 3 * (Ca.C / bu.mM) * self.k1 * self.onoffrate * self.phi - o23 = lambda self, Ca: 2 * (Ca.C / bu.mM) * self.k1 * self.onoffrate * self.phi - o34 = lambda self, Ca: 1 * (Ca.C / bu.mM) * self.k1 * self.onoffrate * self.phi - - c10 = lambda self, Ca: 1 * self.Kc * self.k1 * self.onoffrate * self.phi - c21 = lambda self, Ca: 2 * self.Kc * self.k1 * self.onoffrate * self.phi - c32 = lambda self, Ca: 3 * self.Kc * self.k1 * self.onoffrate * self.phi - c43 = lambda self, Ca: 4 * self.Kc * self.k1 * self.onoffrate * self.phi - - o10 = lambda self, Ca: 1 * self.Ko * self.k1 * self.onoffrate * self.phi - o21 = lambda self, Ca: 2 * self.Ko * self.k1 * self.onoffrate * self.phi - o32 = lambda self, Ca: 3 * self.Ko * self.k1 * self.onoffrate * self.phi - o43 = lambda self, Ca: 4 * self.Ko * self.k1 * self.onoffrate * self.phi - - alpha = lambda self, V: bu.math.exp( - (self.Qo * bu.faraday_constant * V) / (bu.gas_constant * (273.15 + self.T) * bu.kelvin)) - beta = lambda self, V: bu.math.exp( - (self.Qc * bu.faraday_constant * V) / (bu.gas_constant * (273.15 + self.T) * bu.kelvin)) - - f0 = lambda self, V: self.pf0 * self.alpha(V) * self.phi - f1 = lambda self, V: self.pf1 * self.alpha(V) * self.phi - f2 = lambda self, V: self.pf2 * self.alpha(V) * self.phi - f3 = lambda self, V: self.pf3 * self.alpha(V) * self.phi - f4 = lambda self, V: self.pf4 * self.alpha(V) * self.phi - - b0 = lambda self, V: self.pb0 * self.beta(V) * self.phi - b1 = lambda self, V: self.pb1 * self.beta(V) * self.phi - b2 = lambda self, V: self.pb2 * self.beta(V) * self.phi - b3 = lambda self, V: self.pb3 * self.beta(V) * self.phi - b4 = lambda self, V: self.pb4 * self.beta(V) * self.phi + r''' + TITLE Large conductance Ca2+ activated K+ channel mslo + + COMMENT + + Parameters from Cox et al. (1987) J Gen Physiol 110:257-81 (patch 1). + + Current Model Reference: Anwar H, Hong S, De Schutter E (2010) Controlling Ca2+-activated K+ channels with models of Ca2+ buffering in Purkinje cell. Cerebellum* + + *Article available as Open Access + + PubMed link: http://www.ncbi.nlm.nih.gov/pubmed/20981513 + + + Written by Sungho Hong, Okinawa Institute of Science and Technology, March 2009. + Contact: Sungho Hong (shhong@oist.jp) + ''' + __module__ = 'dendritex.channels' + + root_type = bst.mixin.JointTypes[Potassium, Calcium] + + def __init__( + self, + size: bst.typing.Size, + g_max: Union[bst.typing.ArrayLike, Callable] = 10. * (bu.mS / bu.cm ** 2), + T_base: bst.typing.ArrayLike = 3., + T: bst.typing.ArrayLike = 22., + name: Optional[str] = None, + ): + super().__init__(size=size, name=name, ) + + # parameters + self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) + self.T = bst.init.param(T, self.varshape, allow_none=False) + self.T_base = bst.init.param(T_base, self.varshape, allow_none=False) + self.phi = bst.init.param(T_base ** ((T - 23) / 10), self.varshape, allow_none=False) + + self.Qo = 0.73 + self.Qc = -0.67 + self.k1 = 1.0e3 + self.onoffrate = 1. + self.L0 = 1806 + self.Kc = 11.0e-3 + self.Ko = 1.1e-3 + + self.pf0 = 2.39e-3 + self.pf1 = 7.0e-3 + self.pf2 = 40e-3 + self.pf3 = 295e-3 + self.pf4 = 557e-3 + + self.pb0 = 3936e-3 + self.pb1 = 1152e-3 + self.pb2 = 659e-3 + self.pb3 = 486e-3 + self.pb4 = 92e-3 + + def init_state(self, V, K: IonInfo, Ca: IonInfo, batch_size=None): + + for i in range(5): + setattr(self, f'C{i}', State4Integral(bst.init.param(bu.math.ones, self.varshape, batch_size))) + + for i in range(5): + setattr(self, f'O{i}', State4Integral(bst.init.param(bu.math.ones, self.varshape, batch_size))) + + self.normalize_states([getattr(self, f'C{i}') for i in range(5)] + [getattr(self, f'O{i}') for i in range(5)]) + + def reset_state(self, V, K: IonInfo, Ca: IonInfo, batch_size=None): + self.normalize_states([getattr(self, f'C{i}') for i in range(5)] + [getattr(self, f'O{i}') for i in range(5)]) + + def current(self, V, K: IonInfo, Ca: IonInfo): + return self.g_max * (self.O1.value + self.O2.value) * (K.E - V) + + def before_integral(self, V, K: IonInfo, Ca: IonInfo): + self.normalize_states([getattr(self, f'C{i}') for i in range(5)] + [getattr(self, f'O{i}') for i in range(5)]) + + def normalize_states(self, states): + total = 0. + for state in states: + state.value = bu.math.maximum(state.value, 0) + total = total + state.value + for state in states: + state.value = state.value / total + + def compute_derivative(self, V, K: IonInfo, Ca: IonInfo): + + self.C0.derivative = (self.C1 * self.c10(Ca) + self.O0 * self.b0(V) - self.C0 * ( + self.c01(Ca) + self.f0(V))) / bu.ms + self.C1.derivative = (self.C0 * self.c01(Ca) + self.C2 * self.c21(Ca) + self.O1 * self.b1(V) - self.C1 * ( + self.c10(Ca) + self.c12(Ca) + self.f1(V))) / bu.ms + self.C2.derivative = (self.C1 * self.c12(Ca) + self.C3 * self.c32(Ca) + self.O2 * self.b2(V) - self.C2 * ( + self.c21(Ca) + self.c23(Ca) + self.f2(V))) / bu.ms + self.C3.derivative = (self.C2 * self.c23(Ca) + self.C4 * self.c43(Ca) + self.O3 * self.b3(V) - self.C3 * ( + self.c32(Ca) + self.c34(Ca) + self.f3(V))) / bu.ms + self.C4.derivative = (self.C3 * self.c34(Ca) + self.O4 * self.b4(V) - self.C4 * ( + self.c43(Ca) + self.f4(V))) / bu.ms + + self.O0.derivative = (self.O1 * self.o10(Ca) + self.C0 * self.f0(V) - self.O0 * ( + self.o01(Ca) + self.b0(V))) / bu.ms + self.O1.derivative = (self.O0 * self.o01(Ca) + self.O2 * self.o21(Ca) + self.C1 * self.f1(V) - self.O1 * ( + self.o10(Ca) + self.o12(Ca) + self.b1(V))) / bu.ms + self.O2.derivative = (self.O1 * self.o12(Ca) + self.O3 * self.o32(Ca) + self.C2 * self.f2(V) - self.O2 * ( + self.o21(Ca) + self.o23(Ca) + self.b2(V))) / bu.ms + self.O3.derivative = (self.O2 * self.o23(Ca) + self.O4 * self.o43(Ca) + self.C3 * self.f3(V) - self.O3 * ( + self.o32(Ca) + self.o34(Ca) + self.b3(V))) / bu.ms + self.O4.derivative = (self.O3 * self.o34(Ca) + self.C4 * self.f4(V) - self.O4 * ( + self.o43(Ca) + self.b4(V))) / bu.ms + + def current(self, V, K: IonInfo, Ca: IonInfo): + return self.g_max * (self.O0.value + self.O1.value + self.O2.value + self.O3.value + self.O4.value) * (K.E - V) + + c01 = lambda self, Ca: 4 * (Ca.C / bu.mM) * self.k1 * self.onoffrate * self.phi + c12 = lambda self, Ca: 3 * (Ca.C / bu.mM) * self.k1 * self.onoffrate * self.phi + c23 = lambda self, Ca: 2 * (Ca.C / bu.mM) * self.k1 * self.onoffrate * self.phi + c34 = lambda self, Ca: 1 * (Ca.C / bu.mM) * self.k1 * self.onoffrate * self.phi + + o01 = lambda self, Ca: 4 * (Ca.C / bu.mM) * self.k1 * self.onoffrate * self.phi + o12 = lambda self, Ca: 3 * (Ca.C / bu.mM) * self.k1 * self.onoffrate * self.phi + o23 = lambda self, Ca: 2 * (Ca.C / bu.mM) * self.k1 * self.onoffrate * self.phi + o34 = lambda self, Ca: 1 * (Ca.C / bu.mM) * self.k1 * self.onoffrate * self.phi + + c10 = lambda self, Ca: 1 * self.Kc * self.k1 * self.onoffrate * self.phi + c21 = lambda self, Ca: 2 * self.Kc * self.k1 * self.onoffrate * self.phi + c32 = lambda self, Ca: 3 * self.Kc * self.k1 * self.onoffrate * self.phi + c43 = lambda self, Ca: 4 * self.Kc * self.k1 * self.onoffrate * self.phi + + o10 = lambda self, Ca: 1 * self.Ko * self.k1 * self.onoffrate * self.phi + o21 = lambda self, Ca: 2 * self.Ko * self.k1 * self.onoffrate * self.phi + o32 = lambda self, Ca: 3 * self.Ko * self.k1 * self.onoffrate * self.phi + o43 = lambda self, Ca: 4 * self.Ko * self.k1 * self.onoffrate * self.phi + + alpha = lambda self, V: bu.math.exp( + (self.Qo * bu.faraday_constant * V) / (bu.gas_constant * (273.15 + self.T) * bu.kelvin)) + beta = lambda self, V: bu.math.exp( + (self.Qc * bu.faraday_constant * V) / (bu.gas_constant * (273.15 + self.T) * bu.kelvin)) + + f0 = lambda self, V: self.pf0 * self.alpha(V) * self.phi + f1 = lambda self, V: self.pf1 * self.alpha(V) * self.phi + f2 = lambda self, V: self.pf2 * self.alpha(V) * self.phi + f3 = lambda self, V: self.pf3 * self.alpha(V) * self.phi + f4 = lambda self, V: self.pf4 * self.alpha(V) * self.phi + + b0 = lambda self, V: self.pb0 * self.beta(V) * self.phi + b1 = lambda self, V: self.pb1 * self.beta(V) * self.phi + b2 = lambda self, V: self.pb2 * self.beta(V) * self.phi + b3 = lambda self, V: self.pb3 * self.beta(V) * self.phi + b4 = lambda self, V: self.pb4 * self.beta(V) * self.phi diff --git a/dendritex/channels/sodium.py b/dendritex/channels/sodium.py index 8269f90..f4ae466 100644 --- a/dendritex/channels/sodium.py +++ b/dendritex/channels/sodium.py @@ -12,492 +12,494 @@ import brainstate as bst import brainunit as u -from .._base import Channel, IonInfo, State4Integral -from ..ions import Sodium +from dendritex._base import Channel, IonInfo, State4Integral +from dendritex.ions import Sodium __all__ = [ - 'SodiumChannel', - 'INa_Ba2002', - 'INa_TM1991', - 'INa_HH1952', - 'INa_Rsg', + 'SodiumChannel', + 'INa_Ba2002', + 'INa_TM1991', + 'INa_HH1952', + 'INa_Rsg', ] class SodiumChannel(Channel): - """Base class for sodium channel dynamics.""" - __module__ = 'dendritex.channels' + """Base class for sodium channel dynamics.""" + __module__ = 'dendritex.channels' - root_type = Sodium + root_type = Sodium - def before_integral(self, V, Na: IonInfo): - pass + def before_integral(self, V, Na: IonInfo): + pass - def after_integral(self, V, Na: IonInfo): - pass + def post_derivative(self, V, Na: IonInfo): + pass - def compute_derivative(self, V, Na: IonInfo): - pass + def compute_derivative(self, V, Na: IonInfo): + pass - def current(self, V, Na: IonInfo): - raise NotImplementedError + def current(self, V, Na: IonInfo): + raise NotImplementedError - def init_state(self, V, Na: IonInfo, batch_size: int = None): - pass + def init_state(self, V, Na: IonInfo, batch_size: int = None): + pass - def reset_state(self, V, Na: IonInfo, batch_size: int = None): - pass + def reset_state(self, V, Na: IonInfo, batch_size: int = None): + pass class INa_p3q_markov(SodiumChannel): - r""" - The sodium current model of :math:`p^3q` current which described with first-order Markov chain. - - The general model can be used to model the dynamics with: - - .. math:: - - \begin{aligned} - I_{\mathrm{Na}} &= g_{\mathrm{max}} * p^3 * q \\ - \frac{dp}{dt} &= \phi ( \alpha_p (1-p) - \beta_p p) \\ - \frac{dq}{dt} & = \phi ( \alpha_q (1-h) - \beta_q h) \\ - \end{aligned} - - where :math:`\phi` is a temperature-dependent factor. - - Parameters - ---------- - g_max : float, ArrayType, Callable, Initializer - The maximal conductance density (:math:`mS/cm^2`). - phi : float, ArrayType, Callable, Initializer - The temperature-dependent factor. - name: str - The name of the object. - - """ - - def __init__( - self, - size: bst.typing.Size, - g_max: Union[bst.typing.ArrayLike, Callable] = 90. * (u.mS / u.cm ** 2), - phi: Union[bst.typing.ArrayLike, Callable] = 1., - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): - super().__init__( - size=size, - name=name, - mode=mode - ) - - # parameters - self.phi = bst.init.param(phi, self.varshape, allow_none=False) - self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) - - def init_state(self, V, Na: IonInfo, batch_size=None): - self.p = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) - self.q = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) - - def reset_state(self, V, Na: IonInfo, batch_size=None): - alpha = self.f_p_alpha(V) - beta = self.f_p_beta(V) - self.p.value = alpha / (alpha + beta) - alpha = self.f_q_alpha(V) - beta = self.f_q_beta(V) - self.q.value = alpha / (alpha + beta) - - def compute_derivative(self, V, Na: IonInfo): - p = self.p.value - q = self.q.value - self.p.derivative = self.phi * (self.f_p_alpha(V) * (1. - p) - self.f_p_beta(V) * p) / u.ms - self.q.derivative = self.phi * (self.f_q_alpha(V) * (1. - q) - self.f_q_beta(V) * q) / u.ms - - def current(self, V, Na: IonInfo): - return self.g_max * self.p.value ** 3 * self.q.value * (Na.E - V) - - def f_p_alpha(self, V): - raise NotImplementedError - - def f_p_beta(self, V): - raise NotImplementedError - - def f_q_alpha(self, V): - raise NotImplementedError - - def f_q_beta(self, V): - raise NotImplementedError + r""" + The sodium current model of :math:`p^3q` current which described with first-order Markov chain. + + The general model can be used to model the dynamics with: + + .. math:: + + \begin{aligned} + I_{\mathrm{Na}} &= g_{\mathrm{max}} * p^3 * q \\ + \frac{dp}{dt} &= \phi ( \alpha_p (1-p) - \beta_p p) \\ + \frac{dq}{dt} & = \phi ( \alpha_q (1-h) - \beta_q h) \\ + \end{aligned} + + where :math:`\phi` is a temperature-dependent factor. + + Parameters + ---------- + g_max : float, ArrayType, Callable, Initializer + The maximal conductance density (:math:`mS/cm^2`). + phi : float, ArrayType, Callable, Initializer + The temperature-dependent factor. + name: str + The name of the object. + + """ + + def __init__( + self, + size: bst.typing.Size, + g_max: Union[bst.typing.ArrayLike, Callable] = 90. * (u.mS / u.cm ** 2), + phi: Union[bst.typing.ArrayLike, Callable] = 1., + name: Optional[str] = None, + ): + super().__init__(size=size, name=name, ) + + # parameters + self.phi = bst.init.param(phi, self.varshape, allow_none=False) + self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) + + def init_state(self, V, Na: IonInfo, batch_size=None): + self.p = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) + self.q = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) + + def reset_state(self, V, Na: IonInfo, batch_size=None): + alpha = self.f_p_alpha(V) + beta = self.f_p_beta(V) + self.p.value = alpha / (alpha + beta) + alpha = self.f_q_alpha(V) + beta = self.f_q_beta(V) + self.q.value = alpha / (alpha + beta) + + def compute_derivative(self, V, Na: IonInfo): + p = self.p.value + q = self.q.value + self.p.derivative = self.phi * (self.f_p_alpha(V) * (1. - p) - self.f_p_beta(V) * p) / u.ms + self.q.derivative = self.phi * (self.f_q_alpha(V) * (1. - q) - self.f_q_beta(V) * q) / u.ms + + def current(self, V, Na: IonInfo): + return self.g_max * self.p.value ** 3 * self.q.value * (Na.E - V) + + def f_p_alpha(self, V): + raise NotImplementedError + + def f_p_beta(self, V): + raise NotImplementedError + + def f_q_alpha(self, V): + raise NotImplementedError + + def f_q_beta(self, V): + raise NotImplementedError class INa_Ba2002(INa_p3q_markov): - r""" - The sodium current model. - - The sodium current model is adopted from (Bazhenov, et, al. 2002) [1]_. - It's dynamics is given by: - - .. math:: - - \begin{aligned} - I_{\mathrm{Na}} &= g_{\mathrm{max}} * p^3 * q \\ - \frac{dp}{dt} &= \phi ( \alpha_p (1-p) - \beta_p p) \\ - \alpha_{p} &=\frac{0.32\left(V-V_{sh}-13\right)}{1-\exp \left(-\left(V-V_{sh}-13\right) / 4\right)} \\ - \beta_{p} &=\frac{-0.28\left(V-V_{sh}-40\right)}{1-\exp \left(\left(V-V_{sh}-40\right) / 5\right)} \\ - \frac{dq}{dt} & = \phi ( \alpha_q (1-h) - \beta_q h) \\ - \alpha_q &=0.128 \exp \left(-\left(V-V_{sh}-17\right) / 18\right) \\ - \beta_q &= \frac{4}{1+\exp \left(-\left(V-V_{sh}-40\right) / 5\right)} - \end{aligned} - - where :math:`\phi` is a temperature-dependent factor, which is given by - :math:`\phi=3^{\frac{T-36}{10}}` (:math:`T` is the temperature in Celsius). - - Parameters - ---------- - g_max : float, ArrayType, Callable, Initializer - The maximal conductance density (:math:`mS/cm^2`). - T : float, ArrayType - The temperature (Celsius, :math:`^{\circ}C`). - V_sh : float, ArrayType, Callable, Initializer - The shift of the membrane potential to spike. - - References - ---------- - - .. [1] Bazhenov, Maxim, et al. "Model of thalamocortical slow-wave sleep oscillations - and transitions to activated states." Journal of neuroscience 22.19 (2002): 8691-8704. - - See Also - -------- - INa_TM1991 - """ - __module__ = 'dendritex.channels' - - def __init__( - self, - size: bst.typing.Size, - T: bst.typing.ArrayLike = 36., - g_max: Union[bst.typing.ArrayLike, Callable] = 90. * (u.mS / u.cm ** 2), - V_sh: Union[bst.typing.ArrayLike, Callable] = -50. * u.mV, - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): - super().__init__( - size, - name=name, - phi=3 ** ((T - 36) / 10), - g_max=g_max, - mode=mode - ) - self.T = bst.init.param(T, self.varshape, allow_none=False) - self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) - - def f_p_alpha(self, V): - V = (V - self.V_sh).to_decimal(u.mV) - temp = V - 13. - return 0.32 * temp / (1. - u.math.exp(-temp / 4.)) - - def f_p_beta(self, V): - V = (V - self.V_sh).to_decimal(u.mV) - temp = V - 40. - return -0.28 * temp / (1. - u.math.exp(temp / 5.)) - - def f_q_alpha(self, V): - V = (V - self.V_sh).to_decimal(u.mV) - return 0.128 * u.math.exp(-(V - 17.) / 18.) - - def f_q_beta(self, V): - V = (V - self.V_sh).to_decimal(u.mV) - return 4. / (1. + u.math.exp(-(V - 40.) / 5.)) + r""" + The sodium current model. + + The sodium current model is adopted from (Bazhenov, et, al. 2002) [1]_. + It's dynamics is given by: + + .. math:: + + \begin{aligned} + I_{\mathrm{Na}} &= g_{\mathrm{max}} * p^3 * q \\ + \frac{dp}{dt} &= \phi ( \alpha_p (1-p) - \beta_p p) \\ + \alpha_{p} &=\frac{0.32\left(V-V_{sh}-13\right)}{1-\exp \left(-\left(V-V_{sh}-13\right) / 4\right)} \\ + \beta_{p} &=\frac{-0.28\left(V-V_{sh}-40\right)}{1-\exp \left(\left(V-V_{sh}-40\right) / 5\right)} \\ + \frac{dq}{dt} & = \phi ( \alpha_q (1-h) - \beta_q h) \\ + \alpha_q &=0.128 \exp \left(-\left(V-V_{sh}-17\right) / 18\right) \\ + \beta_q &= \frac{4}{1+\exp \left(-\left(V-V_{sh}-40\right) / 5\right)} + \end{aligned} + + where :math:`\phi` is a temperature-dependent factor, which is given by + :math:`\phi=3^{\frac{T-36}{10}}` (:math:`T` is the temperature in Celsius). + + Parameters + ---------- + g_max : float, ArrayType, Callable, Initializer + The maximal conductance density (:math:`mS/cm^2`). + T : float, ArrayType + The temperature (Celsius, :math:`^{\circ}C`). + V_sh : float, ArrayType, Callable, Initializer + The shift of the membrane potential to spike. + + References + ---------- + + .. [1] Bazhenov, Maxim, et al. "Model of thalamocortical slow-wave sleep oscillations + and transitions to activated states." Journal of neuroscience 22.19 (2002): 8691-8704. + + See Also + -------- + INa_TM1991 + """ + __module__ = 'dendritex.channels' + + def __init__( + self, + size: bst.typing.Size, + T: bst.typing.ArrayLike = 36., + g_max: Union[bst.typing.ArrayLike, Callable] = 90. * (u.mS / u.cm ** 2), + V_sh: Union[bst.typing.ArrayLike, Callable] = -50. * u.mV, + name: Optional[str] = None, + ): + super().__init__( + size, + name=name, + phi=3 ** ((T - 36) / 10), + g_max=g_max, + ) + self.T = bst.init.param(T, self.varshape, allow_none=False) + self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) + + def f_p_alpha(self, V): + V = (V - self.V_sh).to_decimal(u.mV) + temp = V - 13. + return 0.32 * temp / (1. - u.math.exp(-temp / 4.)) + + def f_p_beta(self, V): + V = (V - self.V_sh).to_decimal(u.mV) + temp = V - 40. + return -0.28 * temp / (1. - u.math.exp(temp / 5.)) + + def f_q_alpha(self, V): + V = (V - self.V_sh).to_decimal(u.mV) + return 0.128 * u.math.exp(-(V - 17.) / 18.) + + def f_q_beta(self, V): + V = (V - self.V_sh).to_decimal(u.mV) + return 4. / (1. + u.math.exp(-(V - 40.) / 5.)) class INa_TM1991(INa_p3q_markov): - r""" - The sodium current model described by (Traub and Miles, 1991) [1]_. - - The dynamics of this sodium current model is given by: - - .. math:: - - \begin{split} - \begin{aligned} - I_{\mathrm{Na}} &= g_{\mathrm{max}} m^3 h \\ - \frac {dm} {dt} &= \phi(\alpha_m (1-x) - \beta_m) \\ - &\alpha_m(V) = 0.32 \frac{(13 - V + V_{sh})}{\exp((13 - V +V_{sh}) / 4) - 1.} \\ - &\beta_m(V) = 0.28 \frac{(V - V_{sh} - 40)}{(\exp((V - V_{sh} - 40) / 5) - 1)} \\ - \frac {dh} {dt} &= \phi(\alpha_h (1-x) - \beta_h) \\ - &\alpha_h(V) = 0.128 * \exp((17 - V + V_{sh}) / 18) \\ - &\beta_h(V) = 4. / (1 + \exp(-(V - V_{sh} - 40) / 5)) \\ - \end{aligned} - \end{split} - - where :math:`V_{sh}` is the membrane shift (default -63 mV), and - :math:`\phi` is the temperature-dependent factor (default 1.). - - Parameters - ---------- - size: int, tuple of int - The size of the simulation target. - name: str - The name of the object. - g_max : float, ArrayType, Callable, Initializer - The maximal conductance density (:math:`mS/cm^2`). - V_sh: float, ArrayType, Callable, Initializer - The membrane shift. - - References - ---------- - .. [1] Traub, Roger D., and Richard Miles. Neuronal networks of the hippocampus. - Vol. 777. Cambridge University Press, 1991. - - See Also - -------- - INa_Ba2002 - """ - __module__ = 'dendritex.channels' - - def __init__( - self, - size: bst.typing.Size, - g_max: Union[bst.typing.ArrayLike, Callable] = 120. * (u.mS / u.cm ** 2), - phi: Union[bst.typing.ArrayLike, Callable] = 1., - V_sh: Union[bst.typing.ArrayLike, Callable] = -63. * u.mV, - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): - super().__init__( - size, - name=name, - phi=phi, - g_max=g_max, - mode=mode - ) - self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) - - def f_p_alpha(self, V): - V = (self.V_sh - V).to_decimal(u.mV) - temp = 13 + V - return 0.32 * temp / (u.math.exp(temp / 4) - 1.) - - def f_p_beta(self, V): - V = (V - self.V_sh).to_decimal(u.mV) - temp = V - 40 - return 0.28 * temp / (u.math.exp(temp / 5) - 1) - - def f_q_alpha(self, V): - V = (- V + self.V_sh).to_decimal(u.mV) - return 0.128 * u.math.exp((17 + V) / 18) - - def f_q_beta(self, V): - V = (V - self.V_sh).to_decimal(u.mV) - return 4. / (1 + u.math.exp(-(V - 40) / 5)) + r""" + The sodium current model described by (Traub and Miles, 1991) [1]_. + + The dynamics of this sodium current model is given by: + + .. math:: + + \begin{split} + \begin{aligned} + I_{\mathrm{Na}} &= g_{\mathrm{max}} m^3 h \\ + \frac {dm} {dt} &= \phi(\alpha_m (1-x) - \beta_m) \\ + &\alpha_m(V) = 0.32 \frac{(13 - V + V_{sh})}{\exp((13 - V +V_{sh}) / 4) - 1.} \\ + &\beta_m(V) = 0.28 \frac{(V - V_{sh} - 40)}{(\exp((V - V_{sh} - 40) / 5) - 1)} \\ + \frac {dh} {dt} &= \phi(\alpha_h (1-x) - \beta_h) \\ + &\alpha_h(V) = 0.128 * \exp((17 - V + V_{sh}) / 18) \\ + &\beta_h(V) = 4. / (1 + \exp(-(V - V_{sh} - 40) / 5)) \\ + \end{aligned} + \end{split} + + where :math:`V_{sh}` is the membrane shift (default -63 mV), and + :math:`\phi` is the temperature-dependent factor (default 1.). + + Parameters + ---------- + size: int, tuple of int + The size of the simulation target. + name: str + The name of the object. + g_max : float, ArrayType, Callable, Initializer + The maximal conductance density (:math:`mS/cm^2`). + V_sh: float, ArrayType, Callable, Initializer + The membrane shift. + + References + ---------- + .. [1] Traub, Roger D., and Richard Miles. Neuronal networks of the hippocampus. + Vol. 777. Cambridge University Press, 1991. + + See Also + -------- + INa_Ba2002 + """ + __module__ = 'dendritex.channels' + + def __init__( + self, + size: bst.typing.Size, + g_max: Union[bst.typing.ArrayLike, Callable] = 120. * (u.mS / u.cm ** 2), + phi: Union[bst.typing.ArrayLike, Callable] = 1., + V_sh: Union[bst.typing.ArrayLike, Callable] = -63. * u.mV, + name: Optional[str] = None, + ): + super().__init__( + size, + name=name, + phi=phi, + g_max=g_max, + ) + self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) + + def f_p_alpha(self, V): + V = (self.V_sh - V).to_decimal(u.mV) + temp = 13 + V + return 0.32 * temp / (u.math.exp(temp / 4) - 1.) + + def f_p_beta(self, V): + V = (V - self.V_sh).to_decimal(u.mV) + temp = V - 40 + return 0.28 * temp / (u.math.exp(temp / 5) - 1) + + def f_q_alpha(self, V): + V = (- V + self.V_sh).to_decimal(u.mV) + return 0.128 * u.math.exp((17 + V) / 18) + + def f_q_beta(self, V): + V = (V - self.V_sh).to_decimal(u.mV) + return 4. / (1 + u.math.exp(-(V - 40) / 5)) class INa_HH1952(INa_p3q_markov): - r""" - The sodium current model described by Hodgkin–Huxley model [1]_. - - The dynamics of this sodium current model is given by: - - .. math:: - - \begin{split} - \begin{aligned} - I_{\mathrm{Na}} &= g_{\mathrm{max}} m^3 h \\ - \frac {dm} {dt} &= \phi (\alpha_m (1-x) - \beta_m) \\ - &\alpha_m(V) = \frac {0.1(V-V_{sh}-5)}{1-\exp(\frac{-(V -V_{sh} -5)} {10})} \\ - &\beta_m(V) = 4.0 \exp(\frac{-(V -V_{sh}+ 20)} {18}) \\ - \frac {dh} {dt} &= \phi (\alpha_h (1-x) - \beta_h) \\ - &\alpha_h(V) = 0.07 \exp(\frac{-(V-V_{sh}+20)}{20}) \\ - &\beta_h(V) = \frac 1 {1 + \exp(\frac{-(V -V_{sh}-10)} {10})} \\ - \end{aligned} - \end{split} - - where :math:`V_{sh}` is the membrane shift (default -45 mV), and - :math:`\phi` is the temperature-dependent factor (default 1.). - - Parameters - ---------- - size: int, tuple of int - The size of the simulation target. - name: str - The name of the object. - g_max : float, ArrayType, Callable, Initializer - The maximal conductance density (:math:`mS/cm^2`). - V_sh: float, ArrayType, Callable, Initializer - The membrane shift. - - References - ---------- - .. [1] Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description of - membrane current and its application to conduction and excitation in - nerve." The Journal of physiology 117.4 (1952): 500. - - See Also - -------- - IK_HH1952 - """ - __module__ = 'dendritex.channels' - - def __init__( - self, - size: bst.typing.Size, - g_max: Union[bst.typing.ArrayLike, Callable] = 120. * (u.mS / u.cm ** 2), - phi: Union[bst.typing.ArrayLike, Callable] = 1., - V_sh: Union[bst.typing.ArrayLike, Callable] = -45. * u.mV, - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): - super().__init__( - size, - name=name, - phi=phi, - g_max=g_max, - mode=mode - ) - self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) - - def f_p_alpha(self, V): - temp = (V - self.V_sh).to_decimal(u.mV) - 5 - return 0.1 * temp / (1 - u.math.exp(-temp / 10)) - - def f_p_beta(self, V): - V = (V - self.V_sh).to_decimal(u.mV) - return 4.0 * u.math.exp(-(V + 20) / 18) - - def f_q_alpha(self, V): - V = (V - self.V_sh).to_decimal(u.mV) - return 0.07 * u.math.exp(-(V + 20) / 20.) - - def f_q_beta(self, V): - V = (V - self.V_sh).to_decimal(u.mV) - return 1 / (1 + u.math.exp(-(V - 10) / 10)) + r""" + The sodium current model described by Hodgkin–Huxley model [1]_. + + The dynamics of this sodium current model is given by: + + .. math:: + + \begin{split} + \begin{aligned} + I_{\mathrm{Na}} &= g_{\mathrm{max}} m^3 h \\ + \frac {dm} {dt} &= \phi (\alpha_m (1-x) - \beta_m) \\ + &\alpha_m(V) = \frac {0.1(V-V_{sh}-5)}{1-\exp(\frac{-(V -V_{sh} -5)} {10})} \\ + &\beta_m(V) = 4.0 \exp(\frac{-(V -V_{sh}+ 20)} {18}) \\ + \frac {dh} {dt} &= \phi (\alpha_h (1-x) - \beta_h) \\ + &\alpha_h(V) = 0.07 \exp(\frac{-(V-V_{sh}+20)}{20}) \\ + &\beta_h(V) = \frac 1 {1 + \exp(\frac{-(V -V_{sh}-10)} {10})} \\ + \end{aligned} + \end{split} + + where :math:`V_{sh}` is the membrane shift (default -45 mV), and + :math:`\phi` is the temperature-dependent factor (default 1.). + + Parameters + ---------- + size: int, tuple of int + The size of the simulation target. + name: str + The name of the object. + g_max : float, ArrayType, Callable, Initializer + The maximal conductance density (:math:`mS/cm^2`). + V_sh: float, ArrayType, Callable, Initializer + The membrane shift. + + References + ---------- + .. [1] Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description of + membrane current and its application to conduction and excitation in + nerve." The Journal of physiology 117.4 (1952): 500. + + See Also + -------- + IK_HH1952 + """ + __module__ = 'dendritex.channels' + + def __init__( + self, + size: bst.typing.Size, + g_max: Union[bst.typing.ArrayLike, Callable] = 120. * (u.mS / u.cm ** 2), + phi: Union[bst.typing.ArrayLike, Callable] = 1., + V_sh: Union[bst.typing.ArrayLike, Callable] = -45. * u.mV, + name: Optional[str] = None, + ): + super().__init__( + size, + name=name, + phi=phi, + g_max=g_max, + ) + self.V_sh = bst.init.param(V_sh, self.varshape, allow_none=False) + + def f_p_alpha(self, V): + temp = (V - self.V_sh).to_decimal(u.mV) - 5 + return 0.1 * temp / (1 - u.math.exp(-temp / 10)) + + def f_p_beta(self, V): + V = (V - self.V_sh).to_decimal(u.mV) + return 4.0 * u.math.exp(-(V + 20) / 18) + + def f_q_alpha(self, V): + V = (V - self.V_sh).to_decimal(u.mV) + return 0.07 * u.math.exp(-(V + 20) / 20.) + + def f_q_beta(self, V): + V = (V - self.V_sh).to_decimal(u.mV) + return 1 / (1 + u.math.exp(-(V - 10) / 10)) class INa_Rsg(SodiumChannel): - - __module__ = 'dendritex.channels' - - def __init__( - self, - size: bst.typing.Size, - T: bst.typing.ArrayLike = 22., - g_max: Union[bst.typing.ArrayLike, Callable] = 15. * (u.mS / u.cm ** 2), - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - ): - super().__init__( - size=size, - name=name, - mode=mode - ) - - self.phi = bst.init.param(2.7 ** ((T - 22) / 10), self.varshape, allow_none=False) - self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) - - self.Con = 0.005 - self.Coff = 0.5 - self.Oon = 0.75 - self.Ooff = 0.005 - self.alpha = 150. - self.beta = 3. - self.gamma = 150. - self.delta = 40. - self.epsilon = 1.75 - self.zeta = 0.03 - - self.x1 = 20. - self.x2 = -20. - self.x3 = 1e12 - self.x4 = -1e12 - self.x5 = 1e12 - self.x6 = -25. - self.vshifta = 0. - self.vshifti = 0. - self.vshiftk = 0. - - self.alfac = (self.Oon / self.Con) ** (1 / 4) - self.btfac = (self.Ooff / self.Coff) ** (1 / 4) - - def init_state(self, V, Na: IonInfo, batch_size=None): - - self.C1 = State4Integral(bst.init.param(u.math.ones, self.varshape, batch_size)) - self.C2 = State4Integral(bst.init.param(u.math.ones, self.varshape, batch_size)) - self.C3 = State4Integral(bst.init.param(u.math.ones, self.varshape, batch_size)) - self.C4 = State4Integral(bst.init.param(u.math.ones, self.varshape, batch_size)) - self.C5 = State4Integral(bst.init.param(u.math.ones, self.varshape, batch_size)) - self.I1 = State4Integral(bst.init.param(u.math.ones, self.varshape, batch_size)) - self.I2 = State4Integral(bst.init.param(u.math.ones, self.varshape, batch_size)) - self.I3 = State4Integral(bst.init.param(u.math.ones, self.varshape, batch_size)) - self.I4 = State4Integral(bst.init.param(u.math.ones, self.varshape, batch_size)) - self.I5 = State4Integral(bst.init.param(u.math.ones, self.varshape, batch_size)) - self.O = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) - self.B = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) - self.I6 = State4Integral(bst.init.param(u.math.ones, self.varshape, batch_size)) - self.normalize_states([self.C1, self.C2, self.C3, self.C4, self.C5, self.I1, self.I2, self.I3, self.I4, self.I5, self.O, self.B, self.I6]) - def normalize_states(self, states): - total = 0. - for state in states: - state.value = u.math.maximum(state.value, 0) - total = total + state.value - for state in states: - state.value = state.value/total - - def before_integral(self, V, Na: IonInfo): - self.normalize_states([self.C1, self.C2, self.C3, self.C4, self.C5, self.I1, self.I2, self.I3, self.I4, self.I5, self.O, self.B, self.I6]) - - def compute_derivative(self, V, Na: IonInfo): - - self.C1.derivative = (self.I1.value * self.bi1(V) + self.C2.value * self.b01(V) - self.C1.value * (self.fi1(V) + self.f01(V))) / u.ms - self.C2.derivative = (self.C1.value * self.f01(V) + self.I2.value * self.bi2(V) + self.C3.value * self.b02(V) - self.C2.value * (self.b01(V) + self.fi2(V) + self.f02(V))) / u.ms - self.C3.derivative = (self.C2.value * self.f02(V) + self.I3.value * self.bi3(V) + self.C4.value * self.b03(V) - self.C3.value * (self.b02(V) + self.fi3(V) + self.f03(V))) / u.ms - self.C4.derivative = (self.C3.value * self.f03(V) + self.I4.value * self.bi4(V) + self.C5.value * self.b04(V) - self.C4.value * (self.b03(V) + self.fi4(V) + self.f04(V))) / u.ms - self.C5.derivative = (self.C4.value * self.f04(V) + self.I5.value * self.bi5(V) + self.O.value * self.b0O(V) - self.C5.value * (self.b04(V) + self.fi5(V) + self.f0O(V))) / u.ms - self.O.derivative = (self.C5.value * self.f0O(V) + self.B.value * self.bip(V) + self.I6.value * self.bin(V) - self.O.value * (self.b0O(V) + self.fip(V) + self.fin(V))) / u.ms - self.B.derivative = (self.O.value * self.fip(V) - self.B.value * self.bip(V)) / u.ms - self.I1.derivative = (self.C1.value * self.fi1(V) + self.I2.value * self.b11(V) - self.I1.value * (self.bi1(V) + self.f11(V))) / u.ms - self.I2.derivative = (self.I1.value * self.f11(V) + self.C2.value * self.fi2(V) + self.I3.value * self.b12(V) - self.I2.value * (self.b11(V) + self.bi2(V) + self.f12(V))) / u.ms - self.I3.derivative = (self.I2.value * self.f12(V) + self.C3.value * self.fi3(V) + self.I4.value * self.b13(V) - self.I3.value * (self.b12(V) + self.bi3(V) + self.f13(V))) / u.ms - self.I4.derivative = (self.I3.value * self.f13(V) + self.C4.value * self.fi4(V) + self.I5.value * self.b14(V) - self.I4.value * (self.b13(V) + self.bi4(V) + self.f14(V))) / u.ms - self.I5.derivative = (self.I4.value * self.f14(V) + self.C5.value * self.fi5(V) + self.I6.value * self.b1n(V) - self.I5.value * (self.b14(V) + self.bi5(V) + self.f1n(V))) / u.ms - self.I6.derivative = (self.I5.value * self.f1n(V) + self.O.value * self.fin(V) - self.I6.value * (self.b1n(V) + self.bin(V))) / u.ms - - def reset_state(self, V, Na: IonInfo, batch_size=None): - self.normalize_states([self.C1, self.C2, self.C3, self.C4, self.C5, self.I1, self.I2, self.I3, self.I4, self.I5, self.O, self.B, self.I6]) - - def current(self, V, Na: IonInfo): - return self.g_max * self.O.value * (Na.E - V) - - f01 = lambda self, V: 4 * self.alpha * u.math.exp((V / u.mV) / self.x1) * self.phi - f02 = lambda self, V: 3 * self.alpha * u.math.exp((V / u.mV) / self.x1) * self.phi - f03 = lambda self, V: 2 * self.alpha * u.math.exp((V / u.mV) / self.x1) * self.phi - f04 = lambda self, V: 1 * self.alpha * u.math.exp((V / u.mV) / self.x1) * self.phi - f0O = lambda self, V: self.gamma * self.phi - fip = lambda self, V: self.epsilon * self.phi - f11 = lambda self, V: 4 * self.alpha * self.alfac * u.math.exp((V / u.mV + self.vshifti) / self.x1) * self.phi - f12 = lambda self, V: 3 * self.alpha * self.alfac * u.math.exp((V / u.mV + self.vshifti) / self.x1) * self.phi - f13 = lambda self, V: 2 * self.alpha * self.alfac * u.math.exp((V / u.mV + self.vshifti) / self.x1) * self.phi - f14 = lambda self, V: 1 * self.alpha * self.alfac * u.math.exp((V / u.mV + self.vshifti) / self.x1) * self.phi - f1n = lambda self, V: self.gamma * self.phi - fi1 = lambda self, V: self.Con * self.phi - fi2 = lambda self, V: self.Con * self.alfac * self.phi - fi3 = lambda self, V: self.Con * self.alfac ** 2 * self.phi - fi4 = lambda self, V: self.Con * self.alfac ** 3 * self.phi - fi5 = lambda self, V: self.Con * self.alfac ** 4 * self.phi - fin = lambda self, V: self.Oon * self.phi - - b01 = lambda self, V: 1 * self.beta * u.math.exp((V / u.mV + self.vshifta) / (self.x2 + self.vshiftk)) * self.phi - b02 = lambda self, V: 2 * self.beta * u.math.exp((V / u.mV + self.vshifta) / (self.x2 + self.vshiftk)) * self.phi - b03 = lambda self, V: 3 * self.beta * u.math.exp((V / u.mV + self.vshifta) / (self.x2 + self.vshiftk)) * self.phi - b04 = lambda self, V: 4 * self.beta * u.math.exp((V / u.mV + self.vshifta) / (self.x2 + self.vshiftk)) * self.phi - b0O = lambda self, V: self.delta * self.phi - bip = lambda self, V: self.zeta * u.math.exp(V / u.mV / self.x6) * self.phi - b11 = lambda self, V: 1 * self.beta * self.btfac * u.math.exp((V / u.mV + self.vshifti) / self.x2) * self.phi - b12 = lambda self, V: 2 * self.beta * self.btfac * u.math.exp((V / u.mV + self.vshifti) / self.x2) * self.phi - b13 = lambda self, V: 3 * self.beta * self.btfac * u.math.exp((V / u.mV + self.vshifti) / self.x2) * self.phi - b14 = lambda self, V: 4 * self.beta * self.btfac * u.math.exp((V / u.mV + self.vshifti) / self.x2) * self.phi - b1n = lambda self, V: self.delta * self.phi - bi1 = lambda self, V: self.Coff * self.phi - bi2 = lambda self, V: self.Coff * self.btfac * self.phi - bi3 = lambda self, V: self.Coff * self.btfac ** 2 * self.phi - bi4 = lambda self, V: self.Coff * self.btfac ** 3 * self.phi - bi5 = lambda self, V: self.Coff * self.btfac ** 4 * self.phi - bin = lambda self, V: self.Ooff * self.phi \ No newline at end of file + __module__ = 'dendritex.channels' + + def __init__( + self, + size: bst.typing.Size, + T: bst.typing.ArrayLike = 22., + g_max: Union[bst.typing.ArrayLike, Callable] = 15. * (u.mS / u.cm ** 2), + name: Optional[str] = None, + ): + super().__init__(size=size, name=name, ) + + self.phi = bst.init.param(2.7 ** ((T - 22) / 10), self.varshape, allow_none=False) + self.g_max = bst.init.param(g_max, self.varshape, allow_none=False) + + self.Con = 0.005 + self.Coff = 0.5 + self.Oon = 0.75 + self.Ooff = 0.005 + self.alpha = 150. + self.beta = 3. + self.gamma = 150. + self.delta = 40. + self.epsilon = 1.75 + self.zeta = 0.03 + + self.x1 = 20. + self.x2 = -20. + self.x3 = 1e12 + self.x4 = -1e12 + self.x5 = 1e12 + self.x6 = -25. + self.vshifta = 0. + self.vshifti = 0. + self.vshiftk = 0. + + self.alfac = (self.Oon / self.Con) ** (1 / 4) + self.btfac = (self.Ooff / self.Coff) ** (1 / 4) + + def init_state(self, V, Na: IonInfo, batch_size=None): + + self.C1 = State4Integral(bst.init.param(u.math.ones, self.varshape, batch_size)) + self.C2 = State4Integral(bst.init.param(u.math.ones, self.varshape, batch_size)) + self.C3 = State4Integral(bst.init.param(u.math.ones, self.varshape, batch_size)) + self.C4 = State4Integral(bst.init.param(u.math.ones, self.varshape, batch_size)) + self.C5 = State4Integral(bst.init.param(u.math.ones, self.varshape, batch_size)) + self.I1 = State4Integral(bst.init.param(u.math.ones, self.varshape, batch_size)) + self.I2 = State4Integral(bst.init.param(u.math.ones, self.varshape, batch_size)) + self.I3 = State4Integral(bst.init.param(u.math.ones, self.varshape, batch_size)) + self.I4 = State4Integral(bst.init.param(u.math.ones, self.varshape, batch_size)) + self.I5 = State4Integral(bst.init.param(u.math.ones, self.varshape, batch_size)) + self.O = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) + self.B = State4Integral(bst.init.param(u.math.zeros, self.varshape, batch_size)) + self.I6 = State4Integral(bst.init.param(u.math.ones, self.varshape, batch_size)) + self.normalize_states( + [self.C1, self.C2, self.C3, self.C4, self.C5, self.I1, self.I2, self.I3, self.I4, self.I5, self.O, self.B, + self.I6]) + + def normalize_states(self, states): + total = 0. + for state in states: + state.value = u.math.maximum(state.value, 0) + total = total + state.value + for state in states: + state.value = state.value / total + + def before_integral(self, V, Na: IonInfo): + self.normalize_states( + [self.C1, self.C2, self.C3, self.C4, self.C5, self.I1, self.I2, self.I3, self.I4, self.I5, self.O, self.B, + self.I6]) + + def compute_derivative(self, V, Na: IonInfo): + + self.C1.derivative = (self.I1.value * self.bi1(V) + self.C2.value * self.b01(V) - self.C1.value * ( + self.fi1(V) + self.f01(V))) / u.ms + self.C2.derivative = (self.C1.value * self.f01(V) + self.I2.value * self.bi2(V) + self.C3.value * self.b02( + V) - self.C2.value * (self.b01(V) + self.fi2(V) + self.f02(V))) / u.ms + self.C3.derivative = (self.C2.value * self.f02(V) + self.I3.value * self.bi3(V) + self.C4.value * self.b03( + V) - self.C3.value * (self.b02(V) + self.fi3(V) + self.f03(V))) / u.ms + self.C4.derivative = (self.C3.value * self.f03(V) + self.I4.value * self.bi4(V) + self.C5.value * self.b04( + V) - self.C4.value * (self.b03(V) + self.fi4(V) + self.f04(V))) / u.ms + self.C5.derivative = (self.C4.value * self.f04(V) + self.I5.value * self.bi5(V) + self.O.value * self.b0O( + V) - self.C5.value * (self.b04(V) + self.fi5(V) + self.f0O(V))) / u.ms + self.O.derivative = (self.C5.value * self.f0O(V) + self.B.value * self.bip(V) + self.I6.value * self.bin( + V) - self.O.value * (self.b0O(V) + self.fip(V) + self.fin(V))) / u.ms + self.B.derivative = (self.O.value * self.fip(V) - self.B.value * self.bip(V)) / u.ms + self.I1.derivative = (self.C1.value * self.fi1(V) + self.I2.value * self.b11(V) - self.I1.value * ( + self.bi1(V) + self.f11(V))) / u.ms + self.I2.derivative = (self.I1.value * self.f11(V) + self.C2.value * self.fi2(V) + self.I3.value * self.b12( + V) - self.I2.value * (self.b11(V) + self.bi2(V) + self.f12(V))) / u.ms + self.I3.derivative = (self.I2.value * self.f12(V) + self.C3.value * self.fi3(V) + self.I4.value * self.b13( + V) - self.I3.value * (self.b12(V) + self.bi3(V) + self.f13(V))) / u.ms + self.I4.derivative = (self.I3.value * self.f13(V) + self.C4.value * self.fi4(V) + self.I5.value * self.b14( + V) - self.I4.value * (self.b13(V) + self.bi4(V) + self.f14(V))) / u.ms + self.I5.derivative = (self.I4.value * self.f14(V) + self.C5.value * self.fi5(V) + self.I6.value * self.b1n( + V) - self.I5.value * (self.b14(V) + self.bi5(V) + self.f1n(V))) / u.ms + self.I6.derivative = (self.I5.value * self.f1n(V) + self.O.value * self.fin(V) - self.I6.value * ( + self.b1n(V) + self.bin(V))) / u.ms + + def reset_state(self, V, Na: IonInfo, batch_size=None): + self.normalize_states( + [self.C1, self.C2, self.C3, self.C4, self.C5, self.I1, self.I2, self.I3, self.I4, self.I5, self.O, self.B, + self.I6]) + + def current(self, V, Na: IonInfo): + return self.g_max * self.O.value * (Na.E - V) + + f01 = lambda self, V: 4 * self.alpha * u.math.exp((V / u.mV) / self.x1) * self.phi + f02 = lambda self, V: 3 * self.alpha * u.math.exp((V / u.mV) / self.x1) * self.phi + f03 = lambda self, V: 2 * self.alpha * u.math.exp((V / u.mV) / self.x1) * self.phi + f04 = lambda self, V: 1 * self.alpha * u.math.exp((V / u.mV) / self.x1) * self.phi + f0O = lambda self, V: self.gamma * self.phi + fip = lambda self, V: self.epsilon * self.phi + f11 = lambda self, V: 4 * self.alpha * self.alfac * u.math.exp((V / u.mV + self.vshifti) / self.x1) * self.phi + f12 = lambda self, V: 3 * self.alpha * self.alfac * u.math.exp((V / u.mV + self.vshifti) / self.x1) * self.phi + f13 = lambda self, V: 2 * self.alpha * self.alfac * u.math.exp((V / u.mV + self.vshifti) / self.x1) * self.phi + f14 = lambda self, V: 1 * self.alpha * self.alfac * u.math.exp((V / u.mV + self.vshifti) / self.x1) * self.phi + f1n = lambda self, V: self.gamma * self.phi + fi1 = lambda self, V: self.Con * self.phi + fi2 = lambda self, V: self.Con * self.alfac * self.phi + fi3 = lambda self, V: self.Con * self.alfac ** 2 * self.phi + fi4 = lambda self, V: self.Con * self.alfac ** 3 * self.phi + fi5 = lambda self, V: self.Con * self.alfac ** 4 * self.phi + fin = lambda self, V: self.Oon * self.phi + + b01 = lambda self, V: 1 * self.beta * u.math.exp((V / u.mV + self.vshifta) / (self.x2 + self.vshiftk)) * self.phi + b02 = lambda self, V: 2 * self.beta * u.math.exp((V / u.mV + self.vshifta) / (self.x2 + self.vshiftk)) * self.phi + b03 = lambda self, V: 3 * self.beta * u.math.exp((V / u.mV + self.vshifta) / (self.x2 + self.vshiftk)) * self.phi + b04 = lambda self, V: 4 * self.beta * u.math.exp((V / u.mV + self.vshifta) / (self.x2 + self.vshiftk)) * self.phi + b0O = lambda self, V: self.delta * self.phi + bip = lambda self, V: self.zeta * u.math.exp(V / u.mV / self.x6) * self.phi + b11 = lambda self, V: 1 * self.beta * self.btfac * u.math.exp((V / u.mV + self.vshifti) / self.x2) * self.phi + b12 = lambda self, V: 2 * self.beta * self.btfac * u.math.exp((V / u.mV + self.vshifti) / self.x2) * self.phi + b13 = lambda self, V: 3 * self.beta * self.btfac * u.math.exp((V / u.mV + self.vshifti) / self.x2) * self.phi + b14 = lambda self, V: 4 * self.beta * self.btfac * u.math.exp((V / u.mV + self.vshifti) / self.x2) * self.phi + b1n = lambda self, V: self.delta * self.phi + bi1 = lambda self, V: self.Coff * self.phi + bi2 = lambda self, V: self.Coff * self.btfac * self.phi + bi3 = lambda self, V: self.Coff * self.btfac ** 2 * self.phi + bi4 = lambda self, V: self.Coff * self.btfac ** 3 * self.phi + bi5 = lambda self, V: self.Coff * self.btfac ** 4 * self.phi + bin = lambda self, V: self.Ooff * self.phi diff --git a/dendritex/ions/calcium.py b/dendritex/ions/calcium.py index ad2b135..7798003 100644 --- a/dendritex/ions/calcium.py +++ b/dendritex/ions/calcium.py @@ -22,312 +22,289 @@ import brainstate as bst import brainunit as u -from .._base import Ion, Channel, HHTypedNeuron, State4Integral +from dendritex._base import Ion, Channel, HHTypedNeuron, State4Integral __all__ = [ - 'Calcium', - 'CalciumFixed', - 'CalciumDetailed', - 'CalciumFirstOrder', + 'Calcium', + 'CalciumFixed', + 'CalciumDetailed', + 'CalciumFirstOrder', ] class Calcium(Ion): - """Base class for modeling Calcium ion.""" - __module__ = 'dendritex.ions' + """Base class for modeling Calcium ion.""" + __module__ = 'dendritex.ions' - root_type = HHTypedNeuron + root_type = HHTypedNeuron class CalciumFixed(Calcium): - """Fixed Calcium dynamics. - - This calcium model has no dynamics. It holds fixed reversal - potential :math:`E` and concentration :math:`C`. - """ - __module__ = 'dendritex.ions' - - def __init__( - self, - size: bst.typing.Size, - E: Union[bst.typing.ArrayLike, Callable] = 120. * u.mV, - C: Union[bst.typing.ArrayLike, Callable] = 2.4e-4 * u.mM, - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - **channels - ): - super().__init__( - size, - name=name, - mode=mode, - **channels - ) - self.E = bst.init.param(E, self.varshape, allow_none=False) - self.C = bst.init.param(C, self.varshape, allow_none=False) - - def reset_state(self, V, batch_size=None): - ca_info = self.pack_info() - nodes = self.nodes(level=1, include_self=False).subset(Channel).values() - self.check_hierarchies(type(self), *tuple(nodes)) - for node in nodes: - node.reset_state(V, ca_info, batch_size=batch_size) + """Fixed Calcium dynamics. + + This calcium model has no dynamics. It holds fixed reversal + potential :math:`E` and concentration :math:`C`. + """ + __module__ = 'dendritex.ions' + + def __init__( + self, + size: bst.typing.Size, + E: Union[bst.typing.ArrayLike, Callable] = 120. * u.mV, + C: Union[bst.typing.ArrayLike, Callable] = 2.4e-4 * u.mM, + name: Optional[str] = None, + **channels + ): + super().__init__(size, name=name, **channels) + self.E = bst.init.param(E, self.varshape, allow_none=False) + self.C = bst.init.param(C, self.varshape, allow_none=False) + + def reset_state(self, V, batch_size=None): + ca_info = self.pack_info() + nodes = bst.graph.nodes(self, Channel, allowed_hierarchy=(1, 1)).values() + self.check_hierarchies(type(self), *tuple(nodes)) + for node in nodes: + node.reset_state(V, ca_info, batch_size=batch_size) class _CalciumDynamics(Calcium): - """Calcium ion flow with dynamics. - - Parameters - ---------- - size: int, tuple of int - The ion size. - C0: bst.typing.ArrayLike, Callable - The Calcium concentration outside of membrane. - T: bst.typing.ArrayLike, Callable - The temperature. - C_initializer: bst.typing.ArrayLike, Callable - The initializer for Calcium concentration. - name: str - The ion name. - """ - - def __init__( - self, - size: bst.typing.Size, - C0: Union[bst.typing.ArrayLike, Callable] = 2. * u.mM, - T: Union[bst.typing.ArrayLike, Callable] = u.celsius2kelvin(36.), - C_initializer: Union[bst.typing.ArrayLike, Callable] = bst.init.Constant(2.4e-4 * u.mM), - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - **channels - ): - super().__init__( - size, - name=name, - mode=mode, - **channels - ) - - # parameters - self.C0 = bst.init.param(C0, self.varshape, allow_none=False) - self.T = bst.init.param(T, self.varshape, allow_none=False) # temperature - self._constant = u.gas_constant * self.T / (2 * u.faraday_constant) - self._C_initializer = C_initializer - - def derivative(self, C, t, V): - raise NotImplementedError - - def init_state(self, V, batch_size=None): - # Calcium concentration - self.C = State4Integral(bst.init.param(self._C_initializer, self.varshape, batch_size)) - super().init_state(V, batch_size) - - def reset_state(self, V, batch_size=None): - self.C.value = bst.init.param(self._C_initializer, self.varshape, batch_size) - super().reset_state(V, batch_size) - - def compute_derivative(self, V): - ca_info = self.pack_info() - nodes = self.nodes(level=1, include_self=False).subset(Channel).values() - self.check_hierarchies(type(self), *tuple(nodes)) - for node in nodes: - node.compute_derivative(V, ca_info) - self.C.derivative = self.derivative(self.C.value, bst.environ.get('t'), V) - - @property - def E(self): - return self._reversal_potential(self.C.value) - - def _reversal_potential(self, C): - # The Nernst relation: - # - # E_{\mathrm{Ca}}=\frac{RT}{2F}\log\frac{[\mathrm{Ca}]_{\mathrm{o}}}{[\mathrm{Ca}]_{\mathrm{i}}} - return self._constant * u.math.log(self.C0 / C) + """Calcium ion flow with dynamics. + + Parameters + ---------- + size: int, tuple of int + The ion size. + C0: bst.typing.ArrayLike, Callable + The Calcium concentration outside of membrane. + T: bst.typing.ArrayLike, Callable + The temperature. + C_initializer: bst.typing.ArrayLike, Callable + The initializer for Calcium concentration. + name: str + The ion name. + """ + + def __init__( + self, + size: bst.typing.Size, + C0: Union[bst.typing.ArrayLike, Callable] = 2. * u.mM, + T: Union[bst.typing.ArrayLike, Callable] = u.celsius2kelvin(36.), + C_initializer: Union[bst.typing.ArrayLike, Callable] = bst.init.Constant(2.4e-4 * u.mM), + name: Optional[str] = None, + **channels + ): + super().__init__(size, name=name, **channels) + + # parameters + self.C0 = bst.init.param(C0, self.varshape, allow_none=False) + self.T = bst.init.param(T, self.varshape, allow_none=False) # temperature + self._constant = u.gas_constant * self.T / (2 * u.faraday_constant) + self._C_initializer = C_initializer + + def derivative(self, C, t, V): + raise NotImplementedError + + def init_state(self, V, batch_size=None): + # Calcium concentration + self.C = State4Integral(bst.init.param(self._C_initializer, self.varshape, batch_size)) + super().init_state(V, batch_size) + + def reset_state(self, V, batch_size=None): + self.C.value = bst.init.param(self._C_initializer, self.varshape, batch_size) + super().reset_state(V, batch_size) + + def compute_derivative(self, V): + ca_info = self.pack_info() + nodes = bst.graph.nodes(self, Channel, allowed_hierarchy=(1, 1)).values() + self.check_hierarchies(type(self), *tuple(nodes)) + for node in nodes: + node.compute_derivative(V, ca_info) + self.C.derivative = self.derivative(self.C.value, bst.environ.get('t'), V) + + @property + def E(self): + return self._reversal_potential(self.C.value) + + def _reversal_potential(self, C): + # The Nernst relation: + # + # E_{\mathrm{Ca}}=\frac{RT}{2F}\log\frac{[\mathrm{Ca}]_{\mathrm{o}}}{[\mathrm{Ca}]_{\mathrm{i}}} + return self._constant * u.math.log(self.C0 / C) class CalciumDetailed(_CalciumDynamics): - r"""Dynamical Calcium model proposed. + r"""Dynamical Calcium model proposed. - **1. The dynamics of intracellular** :math:`Ca^{2+}` + **1. The dynamics of intracellular** :math:`Ca^{2+}` - The dynamics of intracellular :math:`Ca^{2+}` were determined by two contributions [1]_ : + The dynamics of intracellular :math:`Ca^{2+}` were determined by two contributions [1]_ : - *(i) Influx of* :math:`Ca^{2+}` *due to Calcium currents* + *(i) Influx of* :math:`Ca^{2+}` *due to Calcium currents* - :math:`Ca^{2+}` ions enter through :math:`Ca^{2+}` channels and diffuse into the - interior of the cell. Only the :math:`Ca^{2+}` concentration in a thin shell beneath - the membrane was modeled. The influx of :math:`Ca^{2+}` into such a thin shell followed: + :math:`Ca^{2+}` ions enter through :math:`Ca^{2+}` channels and diffuse into the + interior of the cell. Only the :math:`Ca^{2+}` concentration in a thin shell beneath + the membrane was modeled. The influx of :math:`Ca^{2+}` into such a thin shell followed: - .. math:: + .. math:: - [Ca]_{i}=-\frac{I_{Ca}}{2 F d} + [Ca]_{i}=-\frac{I_{Ca}}{2 F d} - where :math:`F=96489\, \mathrm{C\, mol^{-1}}` is the Faraday constant, - :math:`d=1\, \mathrm{\mu m}` is the depth of the shell beneath the membrane, - :math:`I_T` in :math:`\mathrm{\mu A/cm^{2}}` and :math:`[Ca]_{i}` in millimolar, - and :math:`I_{Ca}` is the summation of all :math:`Ca^{2+}` currents. + where :math:`F=96489\, \mathrm{C\, mol^{-1}}` is the Faraday constant, + :math:`d=1\, \mathrm{\mu m}` is the depth of the shell beneath the membrane, + :math:`I_T` in :math:`\mathrm{\mu A/cm^{2}}` and :math:`[Ca]_{i}` in millimolar, + and :math:`I_{Ca}` is the summation of all :math:`Ca^{2+}` currents. - *(ii) Efflux of* :math:`Ca^{2+}` *due to an active pump* + *(ii) Efflux of* :math:`Ca^{2+}` *due to an active pump* - In a thin shell beneath the membrane, :math:`Ca^{2+}` retrieval usually consists of a - combination of several processes, such as binding to :math:`Ca^{2+}` buffers, calcium - efflux due to :math:`Ca^{2+}` ATPase pump activity and diffusion to neighboring shells. - Only the :math:`Ca^{2+}` pump was modeled here. We adopted the following kinetic scheme: - - .. math:: - - Ca _{i}^{2+}+ P \overset{c_1}{\underset{c_2}{\rightleftharpoons}} CaP \xrightarrow{c_3} P+ Ca _{0}^{2+} - - where P represents the :math:`Ca^{2+}` pump, CaP is an intermediate state, - :math:`Ca _{ o }^{2+}` is the extracellular :math:`Ca^{2+}` concentration, - and :math:`c_{1}, c_{2}` and :math:`c_{3}` are rate constants. :math:`Ca^{2+}` - ions have a high affinity for the pump :math:`P`, whereas extrusion of - :math:`Ca^{2+}` follows a slower process (Blaustein, 1988 ). Therefore, - :math:`c_{3}` is low compared to :math:`c_{1}` and :math:`c_{2}` and the - Michaelis-Menten approximation can be used for describing the kinetics of the pump. - According to such a scheme, the kinetic equation for the :math:`Ca^{2+}` pump is: - - .. math:: - - \frac{[Ca^{2+}]_{i}}{dt}=-\frac{K_{T}[Ca]_{i}}{[Ca]_{i}+K_{d}} - - where :math:`K_{T}=10^{-4}\, \mathrm{mM\, ms^{-1}}` is the product of :math:`c_{3}` - with the total concentration of :math:`P` and :math:`K_{d}=c_{2} / c_{1}=10^{-4}\, \mathrm{mM}` - is the dissociation constant, which can be interpreted here as the value of - :math:`[Ca]_{i}` at which the pump is half activated (if :math:`[Ca]_{i} \ll K_{d}` - then the efflux is negligible). - - **2.A simple first-order model** - - While, in (Bazhenov, et al., 1998) [2]_, the :math:`Ca^{2+}` dynamics is - described by a simple first-order model, - - .. math:: - - \frac{d\left[Ca^{2+}\right]_{i}}{d t}=-\frac{I_{Ca}}{z F d}+\frac{\left[Ca^{2+}\right]_{rest}-\left[C a^{2+}\right]_{i}}{\tau_{Ca}} - - where :math:`I_{Ca}` is the summation of all :math:`Ca ^{2+}` currents, :math:`d` - is the thickness of the perimembrane "shell" in which calcium is able to affect - membrane properties :math:`(1.\, \mathrm{\mu M})`, :math:`z=2` is the valence of the - :math:`Ca ^{2+}` ion, :math:`F` is the Faraday constant, and :math:`\tau_{C a}` is - the :math:`Ca ^{2+}` removal rate. The resting :math:`Ca ^{2+}` concentration was - set to be :math:`\left[ Ca ^{2+}\right]_{\text {rest}}=.05\, \mathrm{\mu M}` . - - **3. The reversal potential** - - The reversal potential of calcium :math:`Ca ^{2+}` is calculated according to the - Nernst equation: - - .. math:: - - E = k'{RT \over 2F} log{[Ca^{2+}]_0 \over [Ca^{2+}]_i} - - where :math:`R=8.31441 \, \mathrm{J} /(\mathrm{mol}^{\circ} \mathrm{K})`, - :math:`T=309.15^{\circ} \mathrm{K}`, - :math:`F=96,489 \mathrm{C} / \mathrm{mol}`, - and :math:`\left[\mathrm{Ca}^{2+}\right]_{0}=2 \mathrm{mM}`. - - Parameters - ---------- - d : float - The thickness of the peri-membrane "shell". - F : float - The Faraday constant. (:math:`C*mmol^{-1}`) - tau : float - The time constant of the :math:`Ca ^{2+}` removal rate. (ms) - C_rest : float - The resting :math:`Ca ^{2+}` concentration. - C0 : float - The :math:`Ca ^{2+}` concentration outside of the membrane. - R : float - The gas constant. (:math:` J*mol^{-1}*K^{-1}`) - - References - ---------- - - .. [1] Destexhe, Alain, Agnessa Babloyantz, and Terrence J. Sejnowski. - "Ionic mechanisms for intrinsic slow oscillations in thalamic - relay neurons." Biophysical journal 65, no. 4 (1993): 1538-1552. - .. [2] Bazhenov, Maxim, Igor Timofeev, Mircea Steriade, and Terrence J. - Sejnowski. "Cellular and network models for intrathalamic augmenting - responses during 10-Hz stimulation." Journal of neurophysiology 79, - no. 5 (1998): 2730-2748. - - """ - __module__ = 'dendritex.ions' - - def __init__( - self, - size: bst.typing.Size, - T: Union[bst.typing.ArrayLike, Callable] = u.celsius2kelvin(36.), - d: Union[bst.typing.ArrayLike, Callable] = 1. * u.um, - tau: Union[bst.typing.ArrayLike, Callable] = 5. * u.ms, - C_rest: Union[bst.typing.ArrayLike, Callable] = 2.4e-4 * u.mM, - C0: Union[bst.typing.ArrayLike, Callable] = 2. * u.mM, - C_initializer: Union[bst.typing.ArrayLike, Callable] = bst.init.Constant(2.4e-4 * u.mM), - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - **channels - ): - super().__init__( - size, - name=name, - T=T, - C0=C0, - C_initializer=C_initializer, - mode=mode, - **channels - ) - - # parameters - self.d = bst.init.param(d, self.varshape, allow_none=False) - self.tau = bst.init.param(tau, self.varshape, allow_none=False) - self.C_rest = bst.init.param(C_rest, self.varshape, allow_none=False) - - def derivative(self, C, t, V): - ICa = self.current(V, include_external=True) - drive = ICa / (2 * u.faraday_constant * self.d) - drive = u.math.maximum(drive, u.math.zeros_like(drive)) - return drive + (self.C_rest - C) / self.tau + In a thin shell beneath the membrane, :math:`Ca^{2+}` retrieval usually consists of a + combination of several processes, such as binding to :math:`Ca^{2+}` buffers, calcium + efflux due to :math:`Ca^{2+}` ATPase pump activity and diffusion to neighboring shells. + Only the :math:`Ca^{2+}` pump was modeled here. We adopted the following kinetic scheme: + + .. math:: + + Ca _{i}^{2+}+ P \overset{c_1}{\underset{c_2}{\rightleftharpoons}} CaP \xrightarrow{c_3} P+ Ca _{0}^{2+} + + where P represents the :math:`Ca^{2+}` pump, CaP is an intermediate state, + :math:`Ca _{ o }^{2+}` is the extracellular :math:`Ca^{2+}` concentration, + and :math:`c_{1}, c_{2}` and :math:`c_{3}` are rate constants. :math:`Ca^{2+}` + ions have a high affinity for the pump :math:`P`, whereas extrusion of + :math:`Ca^{2+}` follows a slower process (Blaustein, 1988 ). Therefore, + :math:`c_{3}` is low compared to :math:`c_{1}` and :math:`c_{2}` and the + Michaelis-Menten approximation can be used for describing the kinetics of the pump. + According to such a scheme, the kinetic equation for the :math:`Ca^{2+}` pump is: + + .. math:: + + \frac{[Ca^{2+}]_{i}}{dt}=-\frac{K_{T}[Ca]_{i}}{[Ca]_{i}+K_{d}} + + where :math:`K_{T}=10^{-4}\, \mathrm{mM\, ms^{-1}}` is the product of :math:`c_{3}` + with the total concentration of :math:`P` and :math:`K_{d}=c_{2} / c_{1}=10^{-4}\, \mathrm{mM}` + is the dissociation constant, which can be interpreted here as the value of + :math:`[Ca]_{i}` at which the pump is half activated (if :math:`[Ca]_{i} \ll K_{d}` + then the efflux is negligible). + + **2.A simple first-order model** + + While, in (Bazhenov, et al., 1998) [2]_, the :math:`Ca^{2+}` dynamics is + described by a simple first-order model, + + .. math:: + + \frac{d\left[Ca^{2+}\right]_{i}}{d t}=-\frac{I_{Ca}}{z F d}+\frac{\left[Ca^{2+}\right]_{rest}-\left[C a^{2+}\right]_{i}}{\tau_{Ca}} + + where :math:`I_{Ca}` is the summation of all :math:`Ca ^{2+}` currents, :math:`d` + is the thickness of the perimembrane "shell" in which calcium is able to affect + membrane properties :math:`(1.\, \mathrm{\mu M})`, :math:`z=2` is the valence of the + :math:`Ca ^{2+}` ion, :math:`F` is the Faraday constant, and :math:`\tau_{C a}` is + the :math:`Ca ^{2+}` removal rate. The resting :math:`Ca ^{2+}` concentration was + set to be :math:`\left[ Ca ^{2+}\right]_{\text {rest}}=.05\, \mathrm{\mu M}` . + + **3. The reversal potential** + + The reversal potential of calcium :math:`Ca ^{2+}` is calculated according to the + Nernst equation: + + .. math:: + + E = k'{RT \over 2F} log{[Ca^{2+}]_0 \over [Ca^{2+}]_i} + + where :math:`R=8.31441 \, \mathrm{J} /(\mathrm{mol}^{\circ} \mathrm{K})`, + :math:`T=309.15^{\circ} \mathrm{K}`, + :math:`F=96,489 \mathrm{C} / \mathrm{mol}`, + and :math:`\left[\mathrm{Ca}^{2+}\right]_{0}=2 \mathrm{mM}`. + + Parameters + ---------- + d : float + The thickness of the peri-membrane "shell". + F : float + The Faraday constant. (:math:`C*mmol^{-1}`) + tau : float + The time constant of the :math:`Ca ^{2+}` removal rate. (ms) + C_rest : float + The resting :math:`Ca ^{2+}` concentration. + C0 : float + The :math:`Ca ^{2+}` concentration outside of the membrane. + R : float + The gas constant. (:math:` J*mol^{-1}*K^{-1}`) + + References + ---------- + + .. [1] Destexhe, Alain, Agnessa Babloyantz, and Terrence J. Sejnowski. + "Ionic mechanisms for intrinsic slow oscillations in thalamic + relay neurons." Biophysical journal 65, no. 4 (1993): 1538-1552. + .. [2] Bazhenov, Maxim, Igor Timofeev, Mircea Steriade, and Terrence J. + Sejnowski. "Cellular and network models for intrathalamic augmenting + responses during 10-Hz stimulation." Journal of neurophysiology 79, + no. 5 (1998): 2730-2748. + + """ + __module__ = 'dendritex.ions' + + def __init__( + self, + size: bst.typing.Size, + T: Union[bst.typing.ArrayLike, Callable] = u.celsius2kelvin(36.), + d: Union[bst.typing.ArrayLike, Callable] = 1. * u.um, + tau: Union[bst.typing.ArrayLike, Callable] = 5. * u.ms, + C_rest: Union[bst.typing.ArrayLike, Callable] = 2.4e-4 * u.mM, + C0: Union[bst.typing.ArrayLike, Callable] = 2. * u.mM, + C_initializer: Union[bst.typing.ArrayLike, Callable] = bst.init.Constant(2.4e-4 * u.mM), + name: Optional[str] = None, + **channels + ): + super().__init__(size, name=name, T=T, C0=C0, C_initializer=C_initializer, **channels) + + # parameters + self.d = bst.init.param(d, self.varshape, allow_none=False) + self.tau = bst.init.param(tau, self.varshape, allow_none=False) + self.C_rest = bst.init.param(C_rest, self.varshape, allow_none=False) + + def derivative(self, C, t, V): + ICa = self.current(V, include_external=True) + drive = ICa / (2 * u.faraday_constant * self.d) + drive = u.math.maximum(drive, u.math.zeros_like(drive)) + return drive + (self.C_rest - C) / self.tau class CalciumFirstOrder(_CalciumDynamics): - r""" - The first-order calcium concentration model. - - .. math:: - - Ca' = -\alpha I_{Ca} + -\beta Ca - - """ - __module__ = 'dendritex.ions' - - def __init__( - self, - size: bst.typing.Size, - T: Union[bst.typing.ArrayLike, Callable] = u.celsius2kelvin(36.), - alpha: Union[bst.typing.ArrayLike, Callable] = 0.13, - beta: Union[bst.typing.ArrayLike, Callable] = 0.075, - C0: Union[bst.typing.ArrayLike, Callable] = 2. * u.mM, - C_initializer: Union[bst.typing.ArrayLike, Callable] = bst.init.Constant(2.4e-4 * u.mM), - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - **channels - ): - super().__init__( - size, - name=name, - T=T, - C0=C0, - C_initializer=C_initializer, - mode=mode, - **channels - ) - - # parameters - self.alpha = bst.init.param(alpha, self.varshape, allow_none=False) - self.beta = bst.init.param(beta, self.varshape, allow_none=False) - - def derivative(self, C, t, V): - ICa = self.current(V, include_external=True) - drive = u.math.maximum(self.alpha * ICa, 0. * u.mM) - return drive - self.beta * C + r""" + The first-order calcium concentration model. + + .. math:: + + Ca' = -\alpha I_{Ca} + -\beta Ca + + """ + __module__ = 'dendritex.ions' + + def __init__( + self, + size: bst.typing.Size, + T: Union[bst.typing.ArrayLike, Callable] = u.celsius2kelvin(36.), + alpha: Union[bst.typing.ArrayLike, Callable] = 0.13, + beta: Union[bst.typing.ArrayLike, Callable] = 0.075, + C0: Union[bst.typing.ArrayLike, Callable] = 2. * u.mM, + C_initializer: Union[bst.typing.ArrayLike, Callable] = bst.init.Constant(2.4e-4 * u.mM), + name: Optional[str] = None, + **channels + ): + super().__init__( + size, + name=name, + T=T, + C0=C0, + C_initializer=C_initializer, + **channels + ) + + # parameters + self.alpha = bst.init.param(alpha, self.varshape, allow_none=False) + self.beta = bst.init.param(beta, self.varshape, allow_none=False) + + def derivative(self, C, t, V): + ICa = self.current(V, include_external=True) + drive = u.math.maximum(self.alpha * ICa, 0. * u.mM) + return drive - self.beta * C diff --git a/dendritex/ions/potassium.py b/dendritex/ions/potassium.py index db49ed7..f52656f 100644 --- a/dendritex/ions/potassium.py +++ b/dendritex/ions/potassium.py @@ -21,48 +21,42 @@ import brainstate as bst import brainunit as bu -from .._base import Ion, Channel +from dendritex._base import Ion, Channel __all__ = [ - 'Potassium', - 'PotassiumFixed', + 'Potassium', + 'PotassiumFixed', ] class Potassium(Ion): - """Base class for modeling Potassium ion.""" - __module__ = 'dendritex.ions' + """Base class for modeling Potassium ion.""" + __module__ = 'dendritex.ions' class PotassiumFixed(Potassium): - """Fixed Sodium dynamics. - - This calcium model has no dynamics. It holds fixed reversal - potential :math:`E` and concentration :math:`C`. - """ - __module__ = 'dendritex.ions' - - def __init__( - self, - size: bst.typing.Size, - E: Union[bst.typing.ArrayLike, Callable] = -95. * bu.mV, - C: Union[bst.typing.ArrayLike, Callable] = 0.0400811 * bu.mM, - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - **channels - ): - super().__init__( - size, - name=name, - mode=mode, - **channels - ) - self.E = bst.init.param(E, self.varshape) - self.C = bst.init.param(C, self.varshape) - - def reset_state(self, V, batch_size=None): - nodes = self.nodes(level=1, include_self=False).subset(Channel).values() - self.check_hierarchies(type(self), *tuple(nodes)) - ion_info = self.pack_info() - for node in nodes: - node.reset_state(V, ion_info, batch_size) + """Fixed Sodium dynamics. + + This calcium model has no dynamics. It holds fixed reversal + potential :math:`E` and concentration :math:`C`. + """ + __module__ = 'dendritex.ions' + + def __init__( + self, + size: bst.typing.Size, + E: Union[bst.typing.ArrayLike, Callable] = -95. * bu.mV, + C: Union[bst.typing.ArrayLike, Callable] = 0.0400811 * bu.mM, + name: Optional[str] = None, + **channels + ): + super().__init__(size, name=name, **channels) + self.E = bst.init.param(E, self.varshape) + self.C = bst.init.param(C, self.varshape) + + def reset_state(self, V, batch_size=None): + nodes = bst.graph.nodes(self, Channel, allowed_hierarchy=(1, 1)).values() + self.check_hierarchies(type(self), *tuple(nodes)) + ion_info = self.pack_info() + for node in nodes: + node.reset_state(V, ion_info, batch_size) diff --git a/dendritex/ions/sodium.py b/dendritex/ions/sodium.py index e7c0b09..151dbe0 100644 --- a/dendritex/ions/sodium.py +++ b/dendritex/ions/sodium.py @@ -20,42 +20,36 @@ import brainstate as bst import brainunit as bu -from .._base import Ion +from dendritex._base import Ion __all__ = [ - 'Sodium', - 'SodiumFixed', + 'Sodium', + 'SodiumFixed', ] class Sodium(Ion): - """Base class for modeling Sodium ion.""" - __module__ = 'dendritex.ions' + """Base class for modeling Sodium ion.""" + __module__ = 'dendritex.ions' class SodiumFixed(Sodium): - """ - Fixed Sodium dynamics. - - This calcium model has no dynamics. It holds fixed reversal - potential :math:`E` and concentration :math:`C`. - """ - __module__ = 'dendritex.ions' - - def __init__( - self, - size: bst.typing.Size, - E: Union[bst.typing.ArrayLike, Callable] = 50. * bu.mV, - C: Union[bst.typing.ArrayLike, Callable] = 0.0400811 * bu.mM, - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - **channels - ): - super().__init__( - size, - name=name, - mode=mode, - **channels - ) - self.E = bst.init.param(E, self.varshape, allow_none=False) - self.C = bst.init.param(C, self.varshape, allow_none=False) + """ + Fixed Sodium dynamics. + + This calcium model has no dynamics. It holds fixed reversal + potential :math:`E` and concentration :math:`C`. + """ + __module__ = 'dendritex.ions' + + def __init__( + self, + size: bst.typing.Size, + E: Union[bst.typing.ArrayLike, Callable] = 50. * bu.mV, + C: Union[bst.typing.ArrayLike, Callable] = 0.0400811 * bu.mM, + name: Optional[str] = None, + **channels + ): + super().__init__(size, name=name, **channels) + self.E = bst.init.param(E, self.varshape, allow_none=False) + self.C = bst.init.param(C, self.varshape, allow_none=False) diff --git a/dendritex/neurons/multi_compartment.py b/dendritex/neurons/multi_compartment.py index 495061c..e45d58b 100644 --- a/dendritex/neurons/multi_compartment.py +++ b/dendritex/neurons/multi_compartment.py @@ -22,191 +22,203 @@ import jax import numpy as np -from .._base import HHTypedNeuron, State4Integral, IonChannel +from dendritex._base import HHTypedNeuron, State4Integral, IonChannel __all__ = [ - 'MultiCompartment', + 'MultiCompartment', ] def diffusive_coupling(potentials, coo_ids, resistances): - """ - Compute the diffusive coupling currents between neurons. - - :param potentials: The membrane potential of neurons. - :param coo_ids: The COO format of the adjacency matrix. - :param resistances: The weight/resistances of each connection. - :return: The output of the operator, which computes the diffusive coupling currents. - """ - # potential: [n,] - # The membrane potential of neurons. - # Should be a 1D array. - # coo_ids: [m, 2] - # The COO format of the adjacency matrix. - # Should be a 2D array. Each row is a pair of (i, j). - # Note that (i, j) indicates the connection from neuron i to neuron j, - # and also the connection from neuron j to i. - # resistances: [m] - # The weight of each connection. - # resistances[i] is the weight of the connection from coo_ids[i, 0] to coo_ids[i, 1], - # and also the connection from coo_ids[i, 1] to coo_ids[i, 0]. - # outs: [n] - # The output of the operator, which computes the summation of all differences of potentials. - # outs[i] = sum((potentials[i] - potentials[j]) / resistances[j] for j in neighbors of i) - - assert isinstance(potentials, bu.Quantity), 'The potentials should be a Quantity.' - assert isinstance(resistances, bu.Quantity), 'The conductance should be a Quantity.' - # assert potentials.ndim == 1, f'The potentials should be a 1D array. Got {potentials.shape}.' - assert resistances.shape[-1] == coo_ids.shape[0], ('The length of conductance should be equal ' - 'to the number of connections.') - assert coo_ids.ndim == 2, f'The coo_ids should be a 2D array. Got {coo_ids.shape}.' - assert resistances.ndim == 1, f'The conductance should be a 1D array. Got {resistances.shape}.' - - outs = bu.Quantity(bu.math.zeros(potentials.shape), unit=potentials.unit / resistances.unit) - pre_ids = coo_ids[:, 0] - post_ids = coo_ids[:, 1] - diff = (potentials[..., pre_ids] - potentials[..., post_ids]) / resistances - outs = outs.at[..., pre_ids].add(-diff) - outs = outs.at[..., post_ids].add(diff) - return outs + """ + Compute the diffusive coupling currents between neurons. + + :param potentials: The membrane potential of neurons. + :param coo_ids: The COO format of the adjacency matrix. + :param resistances: The weight/resistances of each connection. + :return: The output of the operator, which computes the diffusive coupling currents. + """ + # potential: [n,] + # The membrane potential of neurons. + # Should be a 1D array. + # coo_ids: [m, 2] + # The COO format of the adjacency matrix. + # Should be a 2D array. Each row is a pair of (i, j). + # Note that (i, j) indicates the connection from neuron i to neuron j, + # and also the connection from neuron j to i. + # resistances: [m] + # The weight of each connection. + # resistances[i] is the weight of the connection from coo_ids[i, 0] to coo_ids[i, 1], + # and also the connection from coo_ids[i, 1] to coo_ids[i, 0]. + # outs: [n] + # The output of the operator, which computes the summation of all differences of potentials. + # outs[i] = sum((potentials[i] - potentials[j]) / resistances[j] for j in neighbors of i) + + assert isinstance(potentials, bu.Quantity), 'The potentials should be a Quantity.' + assert isinstance(resistances, bu.Quantity), 'The conductance should be a Quantity.' + # assert potentials.ndim == 1, f'The potentials should be a 1D array. Got {potentials.shape}.' + assert resistances.shape[-1] == coo_ids.shape[0], ('The length of conductance should be equal ' + 'to the number of connections.') + assert coo_ids.ndim == 2, f'The coo_ids should be a 2D array. Got {coo_ids.shape}.' + assert resistances.ndim == 1, f'The conductance should be a 1D array. Got {resistances.shape}.' + + outs = bu.Quantity(bu.math.zeros(potentials.shape), unit=potentials.unit / resistances.unit) + pre_ids = coo_ids[:, 0] + post_ids = coo_ids[:, 1] + diff = (potentials[..., pre_ids] - potentials[..., post_ids]) / resistances + outs = outs.at[..., pre_ids].add(-diff) + outs = outs.at[..., post_ids].add(diff) + return outs def init_coupling_weight(n_compartment, connection, diam, L, Ra): - # weights = [] - # for i, j in connection: - # # R_{i,j}=\frac{R_{i}+R_{j}}{2} - # # =\frac{1}{2}(\frac{4R_{a}\cdot L_{i}}{\pi\cdot diam_{j}^{2}}+ - # # \frac{4R_{a}\cdot L_{j}}{\pi\cdot diam_{j}^{2}}) - # R_ij = 0.5 * (4 * Ra[i] * L[i] / (np.pi * diam[i] ** 2) + 4 * Ra[j] * L[j] / (np.pi * diam[j] ** 2)) - # weights.append(R_ij) - # return u.Quantity(weights) - - assert isinstance(connection, (np.ndarray, jax.Array)), 'The connection should be a numpy/jax array.' - pre_ids = connection[:, 0] - post_ids = connection[:, 1] - if Ra.size == 1: - Ra_pre = Ra - Ra_post = Ra - else: - assert Ra.shape[-1] == n_compartment, (f'The length of Ra should be equal to ' - f'the number of compartments. Got {Ra.shape}.') - Ra_pre = Ra[..., pre_ids] - Ra_post = Ra[..., post_ids] - if L.size == 1: - L_pre = L - L_post = L - else: - assert L.shape[-1] == n_compartment, (f'The length of L should be equal to ' - f'the number of compartments. Got {L.shape}.') - L_pre = L[..., pre_ids] - L_post = L[..., post_ids] - if diam.size == 1: - diam_pre = diam - diam_post = diam - else: - assert diam.shape[-1] == n_compartment, (f'The length of diam should be equal to the ' - f'number of compartments. Got {diam.shape}.') - diam_pre = diam[..., pre_ids] - diam_post = diam[..., post_ids] - - weights = 0.5 * ( - 4 * Ra_pre * L_pre / (np.pi * diam_pre ** 2) + - 4 * Ra_post * L_post / (np.pi * diam_post ** 2) - ) - return weights + # weights = [] + # for i, j in connection: + # # R_{i,j}=\frac{R_{i}+R_{j}}{2} + # # =\frac{1}{2}(\frac{4R_{a}\cdot L_{i}}{\pi\cdot diam_{j}^{2}}+ + # # \frac{4R_{a}\cdot L_{j}}{\pi\cdot diam_{j}^{2}}) + # R_ij = 0.5 * (4 * Ra[i] * L[i] / (np.pi * diam[i] ** 2) + 4 * Ra[j] * L[j] / (np.pi * diam[j] ** 2)) + # weights.append(R_ij) + # return u.Quantity(weights) + + assert isinstance(connection, (np.ndarray, jax.Array)), 'The connection should be a numpy/jax array.' + pre_ids = connection[:, 0] + post_ids = connection[:, 1] + if Ra.size == 1: + Ra_pre = Ra + Ra_post = Ra + else: + assert Ra.shape[-1] == n_compartment, (f'The length of Ra should be equal to ' + f'the number of compartments. Got {Ra.shape}.') + Ra_pre = Ra[..., pre_ids] + Ra_post = Ra[..., post_ids] + if L.size == 1: + L_pre = L + L_post = L + else: + assert L.shape[-1] == n_compartment, (f'The length of L should be equal to ' + f'the number of compartments. Got {L.shape}.') + L_pre = L[..., pre_ids] + L_post = L[..., post_ids] + if diam.size == 1: + diam_pre = diam + diam_post = diam + else: + assert diam.shape[-1] == n_compartment, (f'The length of diam should be equal to the ' + f'number of compartments. Got {diam.shape}.') + diam_pre = diam[..., pre_ids] + diam_post = diam[..., post_ids] + + weights = 0.5 * ( + 4 * Ra_pre * L_pre / (np.pi * diam_pre ** 2) + + 4 * Ra_post * L_post / (np.pi * diam_post ** 2) + ) + return weights class MultiCompartment(HHTypedNeuron): - __module__ = 'dendritex.neurons' - - def __init__( - self, - size: bst.typing.Size, - - # morphology parameters - connection: Sequence[Tuple[int, int]] | np.ndarray, - - # neuron parameters - Ra: bst.typing.ArrayLike = 100. * (bu.ohm * bu.cm), - cm: bst.typing.ArrayLike = 1.0 * (bu.uF / bu.cm ** 2), - diam: bst.typing.ArrayLike = 1. * bu.um, - L: bst.typing.ArrayLike = 10. * bu.um, - - # membrane potentials - V_th: Union[bst.typing.ArrayLike, Callable] = 0. * bu.mV, - V_initializer: Union[bst.typing.ArrayLike, Callable] = bst.init.Uniform(-70 * bu.mV, -60. * bu.mV), - spk_fun: Callable = bst.surrogate.ReluGrad(), - - # others - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - **ion_channels - ): - super().__init__(size, mode=mode, name=name, **ion_channels) - - # neuronal parameters - self.Ra = bst.init.param(Ra, self.varshape) - self.cm = bst.init.param(cm, self.varshape) - self.diam = bst.init.param(diam, self.varshape) - self.L = bst.init.param(L, self.varshape) - self.A = np.pi * self.diam * self.L # surface area - - # parameters for morphology - connection = np.asarray(connection) - assert connection.shape[1] == 2, 'The connection should be a sequence of tuples with two elements.' - self.connection = np.unique( - np.sort( - connection, - axis=1, # avoid off duplicated connections, for example (1, 2) vs (2, 1) - ), - axis=0 # avoid of duplicated connections, for example (1, 2) vs (1, 2) - ) - if self.connection.max() >= self.n_compartment: - raise ValueError('The connection should be within the range of compartments. ' - f'But we got {self.connection.max()} >= {self.n_compartment}.') - self.resistances = init_coupling_weight(self.n_compartment, connection, self.diam, self.L, self.Ra) - - # parameters for membrane potentials - self.V_th = V_th - self._V_initializer = V_initializer - self.spk_fun = spk_fun - - def init_state(self, batch_size=None): - self.V = State4Integral(bst.init.param(self._V_initializer, self.varshape, batch_size)) - super().init_state(batch_size) - - def reset_state(self, batch_size=None): - self.V.value = bst.init.param(self._V_initializer, self.varshape, batch_size) - super().reset_state(batch_size) - - def before_integral(self, *args): - channels = self.nodes(level=1, include_self=False).subset(IonChannel) - for node in channels.values(): - node.before_integral(self.V.value) - - def compute_derivative(self, I_ext=0. * bu.nA): - # [ Compute the derivative of membrane potential ] - # 1. external currents - I_ext = I_ext / self.A - # 1.axial currents - I_axial = diffusive_coupling(self.V.value, self.connection, self.resistances) / self.A - # 2. synapse currents - I_syn = self.sum_current_inputs(self.V.value, init=0. * bu.nA / bu.cm ** 2) - # 3. channel currents - I_channel = None - for ch in self.nodes(level=1, include_self=False).subset(IonChannel).values(): - I_channel = ch.current(self.V.value) if I_channel is None else (I_channel + ch.current(self.V.value)) - # 4. derivatives - self.V.derivative = (I_ext + I_axial + I_syn + I_channel) / self.cm - - # [ integrate dynamics of ion and ion channels ] - # check whether the children channels have the correct parents. - channels = self.nodes(level=1, include_self=False).subset(IonChannel) - for node in channels.values(): - node.compute_derivative(self.V.value) - - def after_integral(self, *args): - self.V.value = self.sum_delta_inputs(init=self.V.value) - channels = self.nodes(level=1, include_self=False).subset(IonChannel) - for node in channels.values(): - node.after_integral(self.V.value) + __module__ = 'dendritex.neurons' + + def __init__( + self, + size: bst.typing.Size, + + # morphology parameters + connection: Sequence[Tuple[int, int]] | np.ndarray, + + # neuron parameters + Ra: bst.typing.ArrayLike = 100. * (bu.ohm * bu.cm), + cm: bst.typing.ArrayLike = 1.0 * (bu.uF / bu.cm ** 2), + diam: bst.typing.ArrayLike = 1. * bu.um, + L: bst.typing.ArrayLike = 10. * bu.um, + + # membrane potentials + V_th: Union[bst.typing.ArrayLike, Callable] = 0. * bu.mV, + V_initializer: Union[bst.typing.ArrayLike, Callable] = bst.init.Uniform(-70 * bu.mV, -60. * bu.mV), + spk_fun: Callable = bst.surrogate.ReluGrad(), + + # others + name: Optional[str] = None, + **ion_channels + ): + super().__init__(size, name=name, **ion_channels) + + # neuronal parameters + self.Ra = bst.init.param(Ra, self.varshape) + self.cm = bst.init.param(cm, self.varshape) + self.diam = bst.init.param(diam, self.varshape) + self.L = bst.init.param(L, self.varshape) + self.A = np.pi * self.diam * self.L # surface area + + # parameters for morphology + connection = np.asarray(connection) + assert connection.shape[1] == 2, 'The connection should be a sequence of tuples with two elements.' + self.connection = np.unique( + np.sort( + connection, + axis=1, # avoid off duplicated connections, for example (1, 2) vs (2, 1) + ), + axis=0 # avoid of duplicated connections, for example (1, 2) vs (1, 2) + ) + if self.connection.max() >= self.n_compartment: + raise ValueError('The connection should be within the range of compartments. ' + f'But we got {self.connection.max()} >= {self.n_compartment}.') + self.resistances = init_coupling_weight(self.n_compartment, connection, self.diam, self.L, self.Ra) + + # parameters for membrane potentials + self.V_th = V_th + self.V_initializer = V_initializer + self.spk_fun = spk_fun + + def init_state(self, batch_size=None): + self.V = State4Integral(bst.init.param(self.V_initializer, self.varshape, batch_size)) + self._v_last_time = None + super().init_state(batch_size) + + def reset_state(self, batch_size=None): + self.V.value = bst.init.param(self.V_initializer, self.varshape, batch_size) + self._v_last_time = None + super().reset_state(batch_size) + + def before_integral(self, *args): + self._v_last_time = self.V.value + channels = self.nodes(allowed_hierarchy=(1, 1)).filter(IonChannel) + for node in channels.values(): + node.before_integral(self.V.value) + + def compute_derivative(self, I_ext=0. * bu.nA): + # [ Compute the derivative of membrane potential ] + # 1. external currents + I_ext = I_ext / self.A + # 1.axial currents + I_axial = diffusive_coupling(self.V.value, self.connection, self.resistances) / self.A + # 2. synapse currents + I_syn = self.sum_current_inputs(0. * bu.nA / bu.cm ** 2, self.V.value) + # 3. channel currents + I_channel = None + for ch in self.nodes(allowed_hierarchy=(1, 1)).filter(IonChannel).values(): + I_channel = ch.current(self.V.value) if I_channel is None else (I_channel + ch.current(self.V.value)) + # 4. derivatives + self.V.derivative = (I_ext + I_axial + I_syn + I_channel) / self.cm + + # [ integrate dynamics of ion and ion channels ] + # check whether the children channels have the correct parents. + channels = self.nodes(allowed_hierarchy=(1, 1)).filter(IonChannel) + for node in channels.values(): + node.compute_derivative(self.V.value) + + def post_derivative(self, *args): + self.V.value = self.sum_delta_inputs(init=self.V.value) + channels = self.nodes(allowed_hierarchy=(1, 1)).filter(IonChannel) + for node in channels.values(): + node.post_derivative(self.V.value) + + def update(self, *args): + return self.get_spike() + + def get_spike(self): + if not hasattr(self, '_v_last_time'): + raise ValueError("The membrane potential is not initialized.") + if self._v_last_time is None: + raise ValueError("The membrane potential is not initialized.") + return self.spk_fun(self.V.value - self.V_th) * self.spk_fun(self.V_th - self._v_last_time) diff --git a/dendritex/neurons/single_compartment.py b/dendritex/neurons/single_compartment.py index 7f95420..9b0601b 100644 --- a/dendritex/neurons/single_compartment.py +++ b/dendritex/neurons/single_compartment.py @@ -20,105 +20,112 @@ import brainstate as bst import brainunit as u -from .._base import HHTypedNeuron, IonChannel, State4Integral +from dendritex._base import HHTypedNeuron, IonChannel, State4Integral __all__ = [ - 'SingleCompartment', + 'SingleCompartment', ] class SingleCompartment(HHTypedNeuron): - r""" - Base class to model conductance-based neuron group. - - The standard formulation for a conductance-based model is given as - - .. math:: - - C_m {dV \over dt} = \sum_jg_j(E - V) + I_{ext} - - where :math:`g_j=\bar{g}_{j} M^x N^y` is the channel conductance, :math:`E` is the - reversal potential, :math:`M` is the activation variable, and :math:`N` is the - inactivation variable. - - :math:`M` and :math:`N` have the dynamics of - - .. math:: - - {dx \over dt} = \phi_x {x_\infty (V) - x \over \tau_x(V)} - - where :math:`x \in [M, N]`, :math:`\phi_x` is a temperature-dependent factor, - :math:`x_\infty` is the steady state, and :math:`\tau_x` is the time constant. - Equivalently, the above equation can be written as: - - .. math:: - - \frac{d x}{d t}=\phi_{x}\left(\alpha_{x}(1-x)-\beta_{x} x\right) - - where :math:`\alpha_{x}` and :math:`\beta_{x}` are rate constants. - - - Parameters - ---------- - size : int, sequence of int - The network size of this neuron group. - name : optional, str - The neuron group name. - """ - __module__ = 'dendritex.neurons' - - def __init__( - self, - size: bst.typing.Size, - C: Union[bst.typing.ArrayLike, Callable] = 1. * u.uF / u.cm ** 2, - V_th: Union[bst.typing.ArrayLike, Callable] = 0. * u.mV, - V_initializer: Union[bst.typing.ArrayLike, Callable] = bst.init.Uniform(-70 * u.mV, -60. * u.mV), - spk_fun: Callable = bst.surrogate.ReluGrad(), - name: Optional[str] = None, - mode: Optional[bst.mixin.Mode] = None, - **ion_channels - ): - super().__init__(size, mode=mode, name=name, **ion_channels) - - # parameters for neurons - assert self.n_compartment == 1, (f'Point-based neuron only supports single compartment. ' - f'But got {self.n_compartment} compartments.') - self.C = C - self.V_th = V_th - self._V_initializer = V_initializer - self.spk_fun = spk_fun - - def init_state(self, batch_size=None): - self.V = State4Integral(bst.init.param(self._V_initializer, self.varshape, batch_size)) - super().init_state(batch_size) - - def reset_state(self, batch_size=None): - self.V.value = bst.init.param(self._V_initializer, self.varshape, batch_size) - super().init_state(batch_size) - - def before_integral(self, *args): - channels = self.nodes(level=1, include_self=False).subset(IonChannel) - for node in channels.values(): - node.before_integral(self.V.value) - - def compute_derivative(self, x=0. * u.nA / u.cm ** 2): - # [ Compute the derivative of membrane potential ] - # 1. inputs + 2. synapses - x = self.sum_current_inputs(self.V.value, init=x) - # 3. channels - for ch in self.nodes(level=1, include_self=False).subset(IonChannel).values(): - x = x + ch.current(self.V.value) - # 4. derivatives - self.V.derivative = x / self.C - - # [ integrate dynamics of ion and ion channels ] - # check whether the children channels have the correct parents. - channels = self.nodes(level=1, include_self=False).subset(IonChannel) - for node in channels.values(): - node.compute_derivative(self.V.value) - - def after_integral(self, *args): - self.V.value = self.sum_delta_inputs(init=self.V.value) - channels = self.nodes(level=1, include_self=False).subset(IonChannel) - for node in channels.values(): - node.after_integral(self.V.value) + r""" + Base class to model conductance-based neuron group. + + The standard formulation for a conductance-based model is given as + + .. math:: + + C_m {dV \over dt} = \sum_jg_j(E - V) + I_{ext} + + where :math:`g_j=\bar{g}_{j} M^x N^y` is the channel conductance, :math:`E` is the + reversal potential, :math:`M` is the activation variable, and :math:`N` is the + inactivation variable. + + :math:`M` and :math:`N` have the dynamics of + + .. math:: + + {dx \over dt} = \phi_x {x_\infty (V) - x \over \tau_x(V)} + + where :math:`x \in [M, N]`, :math:`\phi_x` is a temperature-dependent factor, + :math:`x_\infty` is the steady state, and :math:`\tau_x` is the time constant. + Equivalently, the above equation can be written as: + + .. math:: + + \frac{d x}{d t}=\phi_{x}\left(\alpha_{x}(1-x)-\beta_{x} x\right) + + where :math:`\alpha_{x}` and :math:`\beta_{x}` are rate constants. + + + Parameters + ---------- + size : int, sequence of int + The network size of this neuron group. + name : optional, str + The neuron group name. + """ + __module__ = 'dendritex.neurons' + + def __init__( + self, + size: bst.typing.Size, + C: Union[bst.typing.ArrayLike, Callable] = 1. * u.uF / u.cm ** 2, + V_th: Union[bst.typing.ArrayLike, Callable] = 0. * u.mV, + V_initializer: Union[bst.typing.ArrayLike, Callable] = bst.init.Uniform(-70 * u.mV, -60. * u.mV), + spk_fun: Callable = bst.surrogate.ReluGrad(), + name: Optional[str] = None, + **ion_channels + ): + super().__init__(size, name=name, **ion_channels) + self.C = bst.init.param(C, self.varshape) + self.V_th = bst.init.param(V_th, self.varshape) + self.V_initializer = V_initializer + self.spk_fun = spk_fun + + def init_state(self, batch_size=None): + self.V = State4Integral(bst.init.param(self.V_initializer, self.varshape, batch_size)) + self._v_last_time = None + super().init_state(batch_size) + + def reset_state(self, batch_size=None): + self.V.value = bst.init.param(self.V_initializer, self.varshape, batch_size) + self._v_last_time = None + super().init_state(batch_size) + + def before_integral(self, *args): + self._v_last_time = self.V.value + for node in self.nodes(allowed_hierarchy=(1, 1)).filter(IonChannel).values(): + node.before_integral(self.V.value) + + def compute_derivative(self, x=0. * u.nA / u.cm ** 2): + # [ Compute the derivative of membrane potential ] + # 1. inputs + 2. synapses + x = self.sum_current_inputs(x, self.V.value) + + # 3. channels + for ch in self.nodes(allowed_hierarchy=(1, 1)).filter(IonChannel).values(): + x = x + ch.current(self.V.value) + + # 4. derivatives + self.V.derivative = x / self.C + + # [ integrate dynamics of ion and ion channels ] + # check whether the children channels have the correct parents. + for node in self.nodes(allowed_hierarchy=(1, 1)).filter(IonChannel).values(): + node.compute_derivative(self.V.value) + + def post_derivative(self, *args): + self.V.value = self.sum_delta_inputs(init=self.V.value) + for node in self.nodes(allowed_hierarchy=(1, 1)).filter(IonChannel).values(): + node.post_derivative(self.V.value) + + def update(self, *args): + return self.get_spike() + + def get_spike(self): + if not hasattr(self, '_v_last_time'): + raise ValueError("The membrane potential is not initialized.") + if self._v_last_time is None: + raise ValueError("The membrane potential is not initialized.") + return self.spk_fun(self.V.value - self.V_th) * self.spk_fun(self.V_th - self._v_last_time) diff --git a/examples/fitting_a_hh_neuron/main.py b/examples/fitting_a_hh_neuron/main.py index 220723f..17f9c04 100644 --- a/examples/fitting_a_hh_neuron/main.py +++ b/examples/fitting_a_hh_neuron/main.py @@ -38,148 +38,148 @@ class INa(dx.Channel): - root_type = dx.HHTypedNeuron - - def __init__( - self, - size: bst.typing.Size, - ENa: Union[bst.typing.ArrayLike, Callable] = 50. * u.mV, - gNa: Union[bst.typing.ArrayLike, Callable] = 120. * u.mS, - vth: Union[bst.typing.ArrayLike, Callable] = -63 * u.mV, - ): - super().__init__(size) - self.ENa = bst.init.param(ENa, self.varshape) - self.gNa = bst.init.param(gNa, self.varshape) - self.V_th = bst.init.param(vth, self.varshape) - - def init_state(self, V, batch_size=None): - self.m = dx.State4Integral(bst.init.param(u.math.zeros, self.varshape)) - self.h = dx.State4Integral(bst.init.param(u.math.zeros, self.varshape)) - - # m channel - m_alpha = lambda self, V: 0.32 * 4 / u.math.exprel((13. * u.mV - V + self.V_th).to_decimal(u.mV) / 4.) - m_beta = lambda self, V: 0.28 * 5 / u.math.exprel((V - self.V_th - 40. * u.mV).to_decimal(u.mV) / 5.) - m_inf = lambda self, V: self.m_alpha(V) / (self.m_alpha(V) + self.m_beta(V)) - - # h channel - h_alpha = lambda self, V: 0.128 * u.math.exprel((17. * u.mV - V + self.V_th).to_decimal(u.mV) / 18.) - h_beta = lambda self, V: 4. / (1 + u.math.exp((40. * u.mV - V + self.V_th).to_decimal(u.mV) / 5.)) - h_inf = lambda self, V: self.h_alpha(V) / (self.h_alpha(V) + self.h_beta(V)) - - def compute_derivative(self, V, *args, **kwargs): - m = self.m.value - h = self.h.value - self.m.derivative = (self.m_alpha(V) * (1 - m) - self.m_beta(V) * m) / u.ms - self.h.derivative = (self.h_alpha(V) * (1 - h) - self.h_beta(V) * h) / u.ms - - def current(self, V, *args, **kwargs): - m = self.m.value - h = self.h.value - return (self.gNa * m * m * m * h) * (self.ENa - V) + root_type = dx.HHTypedNeuron + + def __init__( + self, + size: bst.typing.Size, + ENa: Union[bst.typing.ArrayLike, Callable] = 50. * u.mV, + gNa: Union[bst.typing.ArrayLike, Callable] = 120. * u.mS, + vth: Union[bst.typing.ArrayLike, Callable] = -63 * u.mV, + ): + super().__init__(size) + self.ENa = bst.init.param(ENa, self.varshape) + self.gNa = bst.init.param(gNa, self.varshape) + self.V_th = bst.init.param(vth, self.varshape) + + def init_state(self, V, batch_size=None): + self.m = dx.State4Integral(bst.init.param(u.math.zeros, self.varshape)) + self.h = dx.State4Integral(bst.init.param(u.math.zeros, self.varshape)) + + # m channel + m_alpha = lambda self, V: 0.32 * 4 / u.math.exprel((13. * u.mV - V + self.V_th).to_decimal(u.mV) / 4.) + m_beta = lambda self, V: 0.28 * 5 / u.math.exprel((V - self.V_th - 40. * u.mV).to_decimal(u.mV) / 5.) + m_inf = lambda self, V: self.m_alpha(V) / (self.m_alpha(V) + self.m_beta(V)) + + # h channel + h_alpha = lambda self, V: 0.128 * u.math.exprel((17. * u.mV - V + self.V_th).to_decimal(u.mV) / 18.) + h_beta = lambda self, V: 4. / (1 + u.math.exp((40. * u.mV - V + self.V_th).to_decimal(u.mV) / 5.)) + h_inf = lambda self, V: self.h_alpha(V) / (self.h_alpha(V) + self.h_beta(V)) + + def compute_derivative(self, V, *args, **kwargs): + m = self.m.value + h = self.h.value + self.m.derivative = (self.m_alpha(V) * (1 - m) - self.m_beta(V) * m) / u.ms + self.h.derivative = (self.h_alpha(V) * (1 - h) - self.h_beta(V) * h) / u.ms + + def current(self, V, *args, **kwargs): + m = self.m.value + h = self.h.value + return (self.gNa * m * m * m * h) * (self.ENa - V) class IK(dx.Channel): - root_type = dx.HHTypedNeuron + root_type = dx.HHTypedNeuron - def __init__( - self, - size: bst.typing.Size, - EK: Union[bst.typing.ArrayLike, Callable] = -90. * u.mV, - gK: Union[bst.typing.ArrayLike, Callable] = 36. * u.mS, - vth: Union[bst.typing.ArrayLike, Callable] = -63 * u.mV, - ): - super().__init__(size) - self.EK = bst.init.param(EK, self.varshape) - self.gK = bst.init.param(gK, self.varshape) - self.V_th = bst.init.param(vth, self.varshape) + def __init__( + self, + size: bst.typing.Size, + EK: Union[bst.typing.ArrayLike, Callable] = -90. * u.mV, + gK: Union[bst.typing.ArrayLike, Callable] = 36. * u.mS, + vth: Union[bst.typing.ArrayLike, Callable] = -63 * u.mV, + ): + super().__init__(size) + self.EK = bst.init.param(EK, self.varshape) + self.gK = bst.init.param(gK, self.varshape) + self.V_th = bst.init.param(vth, self.varshape) - def init_state(self, V, batch_size=None): - self.n = dx.State4Integral(bst.init.param(u.math.zeros, self.varshape)) + def init_state(self, V, batch_size=None): + self.n = dx.State4Integral(bst.init.param(u.math.zeros, self.varshape)) - # n channel - n_alpha = lambda self, V: 0.032 * 5 / u.math.exprel((15. * u.mV - V + self.V_th).to_decimal(u.mV) / 5.) - n_beta = lambda self, V: .5 * u.math.exp((10. * u.mV - V + self.V_th).to_decimal(u.mV) / 40.) - n_inf = lambda self, V: self.n_alpha(V) / (self.n_alpha(V) + self.n_beta(V)) + # n channel + n_alpha = lambda self, V: 0.032 * 5 / u.math.exprel((15. * u.mV - V + self.V_th).to_decimal(u.mV) / 5.) + n_beta = lambda self, V: .5 * u.math.exp((10. * u.mV - V + self.V_th).to_decimal(u.mV) / 40.) + n_inf = lambda self, V: self.n_alpha(V) / (self.n_alpha(V) + self.n_beta(V)) - def compute_derivative(self, V, *args, **kwargs): - n = self.n.value - self.n.derivative = (self.n_alpha(V) * (1 - n) - self.n_beta(V) * n) / u.ms + def compute_derivative(self, V, *args, **kwargs): + n = self.n.value + self.n.derivative = (self.n_alpha(V) * (1 - n) - self.n_beta(V) * n) / u.ms - def current(self, V, *args, **kwargs): - n2 = self.n.value ** 2 - return (self.gK * n2 * n2) * (self.EK - V) + def current(self, V, *args, **kwargs): + n2 = self.n.value ** 2 + return (self.gK * n2 * n2) * (self.EK - V) class HH(dx.neurons.SingleCompartment): - def __init__( - self, - size, - v_initializer: Callable = bst.init.Uniform(-70 * u.mV, -60. * u.mV), - gL: Union[bst.typing.ArrayLike, Callable] = 0.003 * u.mS, - gNa: Union[bst.typing.ArrayLike, Callable] = 120. * u.mS, - gK: Union[bst.typing.ArrayLike, Callable] = 36. * u.mS, - C: Union[bst.typing.ArrayLike, Callable] = 1. * (u.uF / u.cm ** 2) - ): - super().__init__(size, V_initializer=v_initializer, C=C) - self.ina = INa(size, gNa=gNa) - self.ik = IK(size, gK=gK) - self.il = dx.channels.IL(size, g_max=gL, E=-65. * u.mV) + def __init__( + self, + size, + v_initializer: Callable = bst.init.Uniform(-70 * u.mV, -60. * u.mV), + gL: Union[bst.typing.ArrayLike, Callable] = 0.003 * u.mS, + gNa: Union[bst.typing.ArrayLike, Callable] = 120. * u.mS, + gK: Union[bst.typing.ArrayLike, Callable] = 36. * u.mS, + C: Union[bst.typing.ArrayLike, Callable] = 1. * (u.uF / u.cm ** 2) + ): + super().__init__(size, V_initializer=v_initializer, C=C) + self.ina = INa(size, gNa=gNa) + self.ik = IK(size, gK=gK) + self.il = dx.channels.IL(size, g_max=gL, E=-65. * u.mV) def visualize_target(voltages): - fig, gs = bts.visualize.get_figure(2, voltages.shape[1], 3, 4.5) - times = np.arange(voltages.shape[0]) * 0.01 - for i in range(voltages.shape[1]): - ax = fig.add_subplot(gs[0, i]) - ax.plot(times, voltages.mantissa[:, i], label='target') - plt.xlabel('Time [ms]') - plt.legend() - ax = plt.subplot(gs[1, i]) - ax.plot(times, inp_traces[i].mantissa) - plt.xlabel('Time [ms]') - plt.show() + fig, gs = bts.visualize.get_figure(2, voltages.shape[1], 3, 4.5) + times = np.arange(voltages.shape[0]) * 0.01 + for i in range(voltages.shape[1]): + ax = fig.add_subplot(gs[0, i]) + ax.plot(times, voltages.mantissa[:, i], label='target') + plt.xlabel('Time [ms]') + plt.legend() + ax = plt.subplot(gs[1, i]) + ax.plot(times, inp_traces[i].mantissa) + plt.xlabel('Time [ms]') + plt.show() def visualize(voltages, gl, g_na, g_kd, C): - # currents: [T, B] - # voltages: [T, B] - simulated_vs = simulate_model(gl, g_na, g_kd, C) - voltages = voltages.mantissa - simulated_vs = simulated_vs.mantissa - - fig, gs = bts.visualize.get_figure(2, simulated_vs.shape[1], 3, 4.5) - for i in range(simulated_vs.shape[1]): - ax = fig.add_subplot(gs[0, i]) - ax.plot(voltages[:, i], label='target') - ax.plot(simulated_vs[:, i], label='simulated') - plt.legend() - ax = plt.subplot(gs[1, i]) - ax.plot(inp_traces[i].mantissa) - plt.show() + # currents: [T, B] + # voltages: [T, B] + simulated_vs = simulate_model(gl, g_na, g_kd, C) + voltages = voltages.mantissa + simulated_vs = simulated_vs.mantissa + + fig, gs = bts.visualize.get_figure(2, simulated_vs.shape[1], 3, 4.5) + for i in range(simulated_vs.shape[1]): + ax = fig.add_subplot(gs[0, i]) + ax.plot(voltages[:, i], label='target') + ax.plot(simulated_vs[:, i], label='simulated') + plt.legend() + ax = plt.subplot(gs[1, i]) + ax.plot(inp_traces[i].mantissa) + plt.show() def simulate_model(gl, g_na, g_kd, C): - current = inp_traces.T - assert current.ndim == 2 # [T, B] - n_input = current.shape[1] - hh = HH((n_input, 1), gL=gl, gNa=g_na, gK=g_kd, C=C, v_initializer=bst.init.Constant(-65. * u.mV), ) - hh.init_state() + current = inp_traces.T + assert current.ndim == 2 # [T, B] + n_input = current.shape[1] + hh = HH((n_input, 1), gL=gl, gNa=g_na, gK=g_kd, C=C, v_initializer=bst.init.Constant(-65. * u.mV), ) + hh.init_state() - def step_fun(i, inp): - with bst.environ.context(i=i, t=bst.environ.get_dt() * i): - dx.rk4_step(hh, bst.environ.get('t'), inp) - return hh.V.value + def step_fun(i, inp): + with bst.environ.context(i=i, t=bst.environ.get_dt() * i): + dx.rk4_step(hh, bst.environ.get('t'), inp) + return hh.V.value - indices = np.arange(current.shape[0]) - current = u.math.expand_dims(current, axis=-1) # [T, B, 1] - return bst.transform.for_loop(step_fun, indices, current) # (T, B) + indices = np.arange(current.shape[0]) + current = u.math.expand_dims(current, axis=-1) # [T, B, 1] + return bst.compile.for_loop(step_fun, indices, current) # (T, B) -@bst.transform.jit +@bst.compile.jit def compare_potentials(param): - vs = simulate_model(param['gl'], param['g_na'], param['g_kd'], param['C']) # (T, B) - losses = bts.metric.squared_error(vs.mantissa, target_vs.mantissa) - return losses.mean() + vs = simulate_model(param['gl'], param['g_na'], param['g_kd'], param['C']) # (T, B) + losses = bts.metric.squared_error(vs.mantissa, target_vs.mantissa) + return losses.mean() # inp_traces: [B, T] @@ -195,61 +195,60 @@ def compare_potentials(param): def visualize_hh_input_and_output(): - # Load Input and Output Data - inp_traces = df_inp_traces.to_numpy() - inp_traces = inp_traces[:, 1:] * 1e9 + # Load Input and Output Data + inp_traces = df_inp_traces.to_numpy() + inp_traces = inp_traces[:, 1:] * 1e9 - out_traces = df_out_traces.to_numpy() - out_traces = out_traces[:, 1:] + out_traces = df_out_traces.to_numpy() + out_traces = out_traces[:, 1:] - indices = np.arange(inp_traces.shape[1]) * 0.01 + indices = np.arange(inp_traces.shape[1]) * 0.01 - fig, gs = bts.visualize.get_figure(3, 1, 1.2, 6.0) - ax = fig.add_subplot(gs[0, 0]) - ax.plot(indices, inp_traces.T) - plt.xticks([]) - plt.ylabel('Current [nA]', fontsize=13) + fig, gs = bts.visualize.get_figure(3, 1, 1.2, 6.0) + ax = fig.add_subplot(gs[0, 0]) + ax.plot(indices, inp_traces.T) + plt.xticks([]) + plt.ylabel('Current [nA]', fontsize=13) - ax2 = fig.add_subplot(gs[1:, 0]) - ax2.plot(indices, out_traces.T) - plt.ylabel('Potential [mV]', fontsize=13) - plt.xlabel('Time [ms]') + ax2 = fig.add_subplot(gs[1:, 0]) + ax2.plot(indices, out_traces.T) + plt.ylabel('Potential [mV]', fontsize=13) + plt.xlabel('Time [ms]') - fig.align_ylabels([ax, ax2]) - plt.show() + fig.align_ylabels([ax, ax2]) + plt.show() bounds = { - 'gl': [1e0, 1e2] * u.nS, - 'g_na': [1e0, 2e2] * u.uS, - 'g_kd': [1e0, 1e2] * u.uS, - 'C': [0.1, 2] * u.uF * u.cm ** -2 * area, + 'gl': [1e0, 1e2] * u.nS, + 'g_na': [1e0, 2e2] * u.uS, + 'g_kd': [1e0, 1e2] * u.uS, + 'C': [0.1, 2] * u.uF * u.cm ** -2 * area, } def fitting_by_others(method='DE', n_sample=200, n_iter=20): - print(f"Method: {method}, n_sample: {n_sample}") - - @jax.jit - @jax.vmap - @jax.jit - def loss_with_multiple_run(**params): - return compare_potentials(params) - - opt = bts.optim.NevergradOptimizer( - loss_with_multiple_run, - n_sample=n_sample, - bounds=bounds, - method=method, - ) - opt.initialize() - param = opt.minimize(n_iter) - loss = compare_potentials(param) - print(param) - print(loss) - visualize(target_vs, **param) - return param, loss + print(f"Method: {method}, n_sample: {n_sample}") + + @jax.jit + @jax.vmap + def loss_with_multiple_run(**params): + return compare_potentials(params) + + opt = bts.optim.NevergradOptimizer( + loss_with_multiple_run, + n_sample=n_sample, + bounds=bounds, + method=method, + ) + opt.initialize() + param = opt.minimize(n_iter) + loss = compare_potentials(param) + print(param) + print(loss) + visualize(target_vs, **param) + return param, loss if __name__ == '__main__': - fitting_by_others(n_sample=100) + fitting_by_others(n_sample=100) diff --git a/examples/golgi_model/golgi.ipynb b/examples/golgi_model/golgi.ipynb index cd6008e..a107c39 100644 --- a/examples/golgi_model/golgi.ipynb +++ b/examples/golgi_model/golgi.ipynb @@ -6,8 +6,8 @@ "metadata": {}, "outputs": [], "source": [ - "import sys\n", "import os\n", + "import sys\n", "\n", "current_dir = os.path.dirname(os.path.abspath('.'))\n", "project_root = os.path.abspath(os.path.join(current_dir, '..', '..'))\n", @@ -38,14 +38,14 @@ "loaded_params = np.load('golgi_morphology.npz')\n", "\n", "connection = loaded_params['connection']\n", - "L = loaded_params['L'] # um\n", - "diam = loaded_params['diam'] # um\n", - "Ra = loaded_params['Ra'] # ohm * cm\n", - "cm = loaded_params['cm'] # uF / cm ** 2\n", + "L = loaded_params['L'] # um\n", + "diam = loaded_params['diam'] # um\n", + "Ra = loaded_params['Ra'] # ohm * cm\n", + "cm = loaded_params['cm'] # uF / cm ** 2\n", "\n", "n_neuron = 1\n", "n_compartments = len(L)\n", - "size = (n_neuron,n_compartments)\n", + "size = (n_neuron, n_compartments)\n", "\n", "index_soma = loaded_params['index_soma']\n", "index_axon = loaded_params['index_axon']\n", @@ -55,12 +55,12 @@ "## conductvalues \n", "conductvalues = 1e3 * np.array([\n", "\n", - " 0.00499506303209, 0.01016375552607, 0.00247172479141, 0.00128859564935,\n", - " 3.690771983E-05, 0.0080938853146, 0.01226052748146, 0.01650689958385,\n", - " 0.00139885617712, 0.14927733727426, 0.00549507510519, 0.14910988921938,\n", - " 0.00406420380423, 0.01764345789036, 0.10177335775222, 0.0087689418803,\n", - " 3.407734319E-05, 0.0003371456442, 0.00030643090764, 0.17233663543619,\n", - " 0.00024381226198, 0.10008178886943, 0.00595046001148, 0.0115, 0.0091\n", + " 0.00499506303209, 0.01016375552607, 0.00247172479141, 0.00128859564935,\n", + " 3.690771983E-05, 0.0080938853146, 0.01226052748146, 0.01650689958385,\n", + " 0.00139885617712, 0.14927733727426, 0.00549507510519, 0.14910988921938,\n", + " 0.00406420380423, 0.01764345789036, 0.10177335775222, 0.0087689418803,\n", + " 3.407734319E-05, 0.0003371456442, 0.00030643090764, 0.17233663543619,\n", + " 0.00024381226198, 0.10008178886943, 0.00595046001148, 0.0115, 0.0091\n", "])\n", "\n", "## IL \n", @@ -91,19 +91,19 @@ "gcagrc[index_axon[0]] = conductvalues[22]\n", "\n", "## ICav23_Ma2020\n", - "gcav23 = np.zeros(n_compartments)\n", + "gcav23 = np.zeros(n_compartments)\n", "gcav23[index_dend_apical] = conductvalues[3]\n", "\n", "## ICav31_Ma2020 \n", - "gcav31 = np.zeros(n_compartments)\n", + "gcav31 = np.zeros(n_compartments)\n", "gcav31[index_soma] = conductvalues[16]\n", "gcav31[index_dend_apical] = conductvalues[4]\n", "\n", "## INa_Rsg\n", "gnarsg = np.zeros(n_compartments)\n", - "gnarsg [index_soma] = conductvalues[9]\n", + "gnarsg[index_soma] = conductvalues[9]\n", "gnarsg[index_dend_apical] = conductvalues[0]\n", - "gnarsg[index_dend_basal] = conductvalues[5]\n", + "gnarsg[index_dend_basal] = conductvalues[5]\n", "gnarsg[index_axon[0]] = conductvalues[19]\n", "gnarsg[index_axon[1:]] = 11.5\n", "\n", @@ -117,7 +117,7 @@ "\n", "## IKca3_1_Ma2020 \n", "gkca31 = np.zeros(n_compartments)\n", - "gkca31[index_soma] = conductvalues[14]" + "gkca31[index_soma] = conductvalues[14]" ] }, { @@ -127,9 +127,9 @@ "outputs": [], "source": [ "# single ion test\n", - "connection = ((1,2),(2,3))\n", - "gl = np.zeros(n_compartments) \n", - "g_test = np.zeros(n_compartments) \n", + "connection = ((1, 2), (2, 3))\n", + "gl = np.zeros(n_compartments)\n", + "g_test = np.zeros(n_compartments)\n", "g_test[index_soma] = 2.5e-4" ] }, @@ -140,27 +140,27 @@ "outputs": [], "source": [ "class Golgi(dx.neurons.MultiCompartment):\n", - " def __init__(self, size, connection, Ra, cm, diam, L, gl, g_test):\n", - " super().__init__(\n", - " size=size,\n", - " connection=connection,\n", - " Ra=Ra * u.ohm * u.cm,\n", - " cm=cm * u.uF / u.cm ** 2,\n", - " diam=diam * u.um,\n", - " L=L * u.um,\n", - " V_th=20. * u.mV,\n", - " V_initializer=bst.init.Constant(-55 * u.mV),\n", - " spk_fun=bst.surrogate.ReluGrad(),\n", - " )\n", - "\n", - " self.IL = dx.channels.IL(self.size, E=-55. * u.mV, g_max=gl * u.mS / (u.cm ** 2))\n", - " #self.Ih1 = dx.channels.Ih1_Ma2020(self.size, E=-20. * u.mV, g_max=g_test * u.mS / (u.cm ** 2))\n", - " self.k = dx.ions.PotassiumFixed(self.size, E=-80. * u.mV)\n", - " self.k.add_elem(dx.channels.IKM_Grc_Ma2020(self.size, g_max=g_test * u.mS / (u.cm ** 2)))\n", - " #self.ca = dx.ions.CalciumFixed(self.size, E=137.* u.mV, C =5e-5 * u.mM)\n", - " #self.ca.add_elem(dx.channels.ICav31_Ma2020(self.size, g_max=g_test * (u.cm / u.second)))\n", - " #self.kca = dx.MixIons(self.k, self.ca)\n", - " #self.kca.add_elem(dx.channels.IKca1_1_Ma2020(self.size, g_max=g_test * u.mS / (u.cm ** 2)))" + " def __init__(self, size, connection, Ra, cm, diam, L, gl, g_test):\n", + " super().__init__(\n", + " size=size,\n", + " connection=connection,\n", + " Ra=Ra * u.ohm * u.cm,\n", + " cm=cm * u.uF / u.cm ** 2,\n", + " diam=diam * u.um,\n", + " L=L * u.um,\n", + " V_th=20. * u.mV,\n", + " V_initializer=bst.init.Constant(-55 * u.mV),\n", + " spk_fun=bst.surrogate.ReluGrad(),\n", + " )\n", + "\n", + " self.IL = dx.channels.IL(self.size, E=-55. * u.mV, g_max=gl * u.mS / (u.cm ** 2))\n", + " #self.Ih1 = dx.channels.Ih1_Ma2020(self.size, E=-20. * u.mV, g_max=g_test * u.mS / (u.cm ** 2))\n", + " self.k = dx.ions.PotassiumFixed(self.size, E=-80. * u.mV)\n", + " self.k.add_elem(dx.channels.IKM_Grc_Ma2020(self.size, g_max=g_test * u.mS / (u.cm ** 2)))\n", + " #self.ca = dx.ions.CalciumFixed(self.size, E=137.* u.mV, C =5e-5 * u.mM)\n", + " #self.ca.add_elem(dx.channels.ICav31_Ma2020(self.size, g_max=g_test * (u.cm / u.second)))\n", + " #self.kca = dx.MixIons(self.k, self.ca)\n", + " #self.kca.add_elem(dx.channels.IKca1_1_Ma2020(self.size, g_max=g_test * u.mS / (u.cm ** 2)))" ] }, { @@ -194,94 +194,96 @@ "source": [ "@bst.transform.jit(static_argnums=6)\n", "def simulate(Ra, cm, diam, L, gl, gkv11, method='ieuler'):\n", - " cell = Golgi(size, connection, Ra, cm, diam, L, gl, gkv11)\n", - " cell.init_state()\n", - " cell.reset_state()\n", - "\n", - " def step(t, *args):\n", - " inp_a = np.full((n_neuron, n_compartments), 0.) * u.nA\n", - " inp_a[..., 0] = 0.002 * u.nA\n", - " inp_b = np.full((n_neuron, n_compartments), 0.) * u.nA\n", - " cell.compute_derivative(u.math.where(t < 100 * u.ms, inp_a, inp_b))\n", - "\n", - " def save(t, *args):\n", - " return cell.V.value\n", - "\n", - " with jax.ensure_compile_time_eval():\n", - " dt = 0.01 * u.ms\n", - " ts = u.math.arange(0. * u.ms, 200. * u.ms, dt)\n", - " ts, ys, steps = dx.diffrax_solve(\n", - " step, method, 0. * u.ms, 200. * u.ms, dt, ts,\n", - " savefn=save, atol=1e-5, # max_steps=200000,\n", - " )\n", + " cell = Golgi(size, connection, Ra, cm, diam, L, gl, gkv11)\n", + " cell.init_state()\n", + " cell.reset_state()\n", + "\n", + " def step(t, *args):\n", + " inp_a = np.full((n_neuron, n_compartments), 0.) * u.nA\n", + " inp_a[..., 0] = 0.002 * u.nA\n", + " inp_b = np.full((n_neuron, n_compartments), 0.) * u.nA\n", + " cell.compute_derivative(u.math.where(t < 100 * u.ms, inp_a, inp_b))\n", + "\n", + " def save(t, *args):\n", + " return cell.V.value\n", + "\n", + " with jax.ensure_compile_time_eval():\n", + " dt = 0.01 * u.ms\n", + " ts = u.math.arange(0. * u.ms, 200. * u.ms, dt)\n", + " ts, ys, steps = dx.diffrax_solve(\n", + " step, method, 0. * u.ms, 200. * u.ms, dt, ts,\n", + " savefn=save, atol=1e-5, # max_steps=200000,\n", + " )\n", "\n", - " return ts, ys, steps\n", + " return ts, ys, steps\n", "\n", "\n", "@bst.transform.jit\n", "def simulate2(Ra, cm, diam, L, gl, gkv11):\n", - " cell = Golgi(size, connection, Ra, cm, diam, L, gl, gkv11)\n", - " cell.init_state()\n", - " cell.reset_state()\n", - "\n", - " def step_run(t):\n", - " inp_a = np.full((n_neuron, n_compartments), 0.) * u.nA\n", - " inp_a[..., 0] = 0.002 * u.nA\n", - " inp_b = np.full((n_neuron, n_compartments), 0.) * u.nA\n", - " inp = u.math.where(t < 100 * u.ms, inp_a, inp_b)\n", - " dx.rk4_step(cell, t, inp)\n", - " return cell.V.value\n", - "\n", - " with jax.ensure_compile_time_eval():\n", - " dt = 0.001 * u.ms\n", - " ts = u.math.arange(0. * u.ms, 200. * u.ms, dt)\n", - " with bst.environ.context(dt=dt):\n", - " ys = bst.transform.for_loop(step_run, ts)\n", - " return ts, ys[::10], ts.size\n", + " cell = Golgi(size, connection, Ra, cm, diam, L, gl, gkv11)\n", + " cell.init_state()\n", + " cell.reset_state()\n", + "\n", + " def step_run(t):\n", + " inp_a = np.full((n_neuron, n_compartments), 0.) * u.nA\n", + " inp_a[..., 0] = 0.002 * u.nA\n", + " inp_b = np.full((n_neuron, n_compartments), 0.) * u.nA\n", + " inp = u.math.where(t < 100 * u.ms, inp_a, inp_b)\n", + " dx.rk4_step(cell, t, inp)\n", + " return cell.V.value\n", + "\n", + " with jax.ensure_compile_time_eval():\n", + " dt = 0.001 * u.ms\n", + " ts = u.math.arange(0. * u.ms, 200. * u.ms, dt)\n", + " with bst.environ.context(dt=dt):\n", + " ys = bst.transform.for_loop(step_run, ts)\n", + " return ts, ys[::10], ts.size\n", + "\n", "\n", "def visualize_a_simulate(Ra, cm, diam, L, gl, gkv11):\n", - " t0 = time.time()\n", - " ts, ys_kvaerno5, steps = simulate(Ra, cm, diam, L, gl, gkv11, 'kvaerno5')\n", - " print('kvaerno5', steps)\n", - " print('time', time.time() - t0)\n", - "\n", - " t0 = time.time()\n", - " ts, ys_ieuler, steps = simulate(Ra, cm, diam, L, gl, gkv11, 'ieuler')\n", - " print('ieuler', steps)\n", - " print('time', time.time() - t0)\n", - "\n", - " '''\n", - " t0 = time.time()\n", - " ts, ys_tsit5, steps = simulate(Ra, cm, diam, L, gl, gkv11, 'tsit5')\n", - " print('tsit5', steps)\n", - " print('time', time.time() - t0)\n", - "\n", - " t0 = time.time()\n", - " ts, ys_dopri5, steps = simulate(Ra, cm, diam, L, gl, gkv11, 'dopri5')\n", - " print('dopri5', steps)\n", - " print('time', time.time() - t0)\n", - " '''\n", - " t0 = time.time()\n", - " ts2, ys2_rk4, steps = simulate2(Ra, cm, diam, L, gl, gkv11)\n", - " print('rk4', steps)\n", - " print('time', time.time() - t0)\n", - " \n", - " def plot(data, ax, title):\n", - " ax.plot(ts.to_decimal(u.ms), u.math.squeeze(data.to_decimal(u.mV)))\n", - " plt.xlabel('Time [ms]')\n", - " plt.ylabel('Potential [mV]')\n", - " plt.title(title)\n", - "\n", - " fig, gs = bts.visualize.get_figure(4, 1, 6.0, 8.0)\n", - " plot(ys_kvaerno5 - ys_ieuler, fig.add_subplot(gs[0, 0]), 'kvaerno5 - ieuler')\n", - " #plot(ys_ieuler - ys_tsit5, fig.add_subplot(gs[0, 1]), 'ieuler - tsit5')\n", - " #plot(ys_tsit5 - ys_dopri5, fig.add_subplot(gs[0, 2]), 'tsit5 - dopri5')\n", - " #plot(ys_dopri5 - ys2_rk4, fig.add_subplot(gs[0, 3]), 'dopri5 - rk4')\n", - " plot(ys2_rk4 - ys_ieuler, fig.add_subplot(gs[1, 0]), 'rk4 - ieuler')\n", - " plot(ys_ieuler, fig.add_subplot(gs[2, 0]), 'ieuler')\n", - " plot(ys_kvaerno5, fig.add_subplot(gs[3, 0]), 'kvaerno5')\n", - "\n", - " plt.show()\n", + " t0 = time.time()\n", + " ts, ys_kvaerno5, steps = simulate(Ra, cm, diam, L, gl, gkv11, 'kvaerno5')\n", + " print('kvaerno5', steps)\n", + " print('time', time.time() - t0)\n", + "\n", + " t0 = time.time()\n", + " ts, ys_ieuler, steps = simulate(Ra, cm, diam, L, gl, gkv11, 'ieuler')\n", + " print('ieuler', steps)\n", + " print('time', time.time() - t0)\n", + "\n", + " '''\n", + " t0 = time.time()\n", + " ts, ys_tsit5, steps = simulate(Ra, cm, diam, L, gl, gkv11, 'tsit5')\n", + " print('tsit5', steps)\n", + " print('time', time.time() - t0)\n", + "\n", + " t0 = time.time()\n", + " ts, ys_dopri5, steps = simulate(Ra, cm, diam, L, gl, gkv11, 'dopri5')\n", + " print('dopri5', steps)\n", + " print('time', time.time() - t0)\n", + " '''\n", + " t0 = time.time()\n", + " ts2, ys2_rk4, steps = simulate2(Ra, cm, diam, L, gl, gkv11)\n", + " print('rk4', steps)\n", + " print('time', time.time() - t0)\n", + "\n", + " def plot(data, ax, title):\n", + " ax.plot(ts.to_decimal(u.ms), u.math.squeeze(data.to_decimal(u.mV)))\n", + " plt.xlabel('Time [ms]')\n", + " plt.ylabel('Potential [mV]')\n", + " plt.title(title)\n", + "\n", + " fig, gs = bts.visualize.get_figure(4, 1, 6.0, 8.0)\n", + " plot(ys_kvaerno5 - ys_ieuler, fig.add_subplot(gs[0, 0]), 'kvaerno5 - ieuler')\n", + " #plot(ys_ieuler - ys_tsit5, fig.add_subplot(gs[0, 1]), 'ieuler - tsit5')\n", + " #plot(ys_tsit5 - ys_dopri5, fig.add_subplot(gs[0, 2]), 'tsit5 - dopri5')\n", + " #plot(ys_dopri5 - ys2_rk4, fig.add_subplot(gs[0, 3]), 'dopri5 - rk4')\n", + " plot(ys2_rk4 - ys_ieuler, fig.add_subplot(gs[1, 0]), 'rk4 - ieuler')\n", + " plot(ys_ieuler, fig.add_subplot(gs[2, 0]), 'ieuler')\n", + " plot(ys_kvaerno5, fig.add_subplot(gs[3, 0]), 'kvaerno5')\n", + "\n", + " plt.show()\n", + "\n", "\n", "visualize_a_simulate(Ra, cm, diam, L, gl, g_test)" ] diff --git a/examples/golgi_model/golgi.py b/examples/golgi_model/golgi.py index 27ec1f5..6986798 100644 --- a/examples/golgi_model/golgi.py +++ b/examples/golgi_model/golgi.py @@ -14,6 +14,7 @@ # ============================================================================== import time + import brainstate as bst import braintools as bts import brainunit as u @@ -46,12 +47,12 @@ ## conductvalues conductvalues = 1e3 * np.array([ - 0.00499506303209, 0.01016375552607, 0.00247172479141, 0.00128859564935, - 3.690771983E-05, 0.0080938853146, 0.01226052748146, 0.01650689958385, - 0.00139885617712, 0.14927733727426, 0.00549507510519, 0.14910988921938, - 0.00406420380423, 0.01764345789036, 0.10177335775222, 0.0087689418803, - 3.407734319E-05, 0.0003371456442, 0.00030643090764, 0.17233663543619, - 0.00024381226198, 0.10008178886943, 0.00595046001148, 0.0115, 0.0091 + 0.00499506303209, 0.01016375552607, 0.00247172479141, 0.00128859564935, + 3.690771983E-05, 0.0080938853146, 0.01226052748146, 0.01650689958385, + 0.00139885617712, 0.14927733727426, 0.00549507510519, 0.14910988921938, + 0.00406420380423, 0.01764345789036, 0.10177335775222, 0.0087689418803, + 3.407734319E-05, 0.0003371456442, 0.00030643090764, 0.17233663543619, + 0.00024381226198, 0.10008178886943, 0.00595046001148, 0.0115, 0.0091 ]) ## IL @@ -82,19 +83,19 @@ gcagrc[index_axon[0]] = conductvalues[22] ## ICav23_Ma2020 -gcav23 = np.zeros(n_compartments) +gcav23 = np.zeros(n_compartments) gcav23[index_dend_apical] = conductvalues[3] ## ICav31_Ma2020 -gcav31 = np.zeros(n_compartments) +gcav31 = np.zeros(n_compartments) gcav31[index_soma] = conductvalues[16] gcav31[index_dend_apical] = conductvalues[4] ## INa_Rsg gnarsg = np.zeros(n_compartments) -gnarsg [index_soma] = conductvalues[9] +gnarsg[index_soma] = conductvalues[9] gnarsg[index_dend_apical] = conductvalues[0] -gnarsg[index_dend_basal] = conductvalues[5] +gnarsg[index_dend_basal] = conductvalues[5] gnarsg[index_axon[0]] = conductvalues[19] gnarsg[index_axon[1:]] = 11.5 @@ -108,131 +109,132 @@ ## IKca3_1_Ma2020 gkca31 = np.zeros(n_compartments) -gkca31[index_soma] = conductvalues[14] +gkca31[index_soma] = conductvalues[14] ## IKca2_2_Ma2020 gkca22 = np.zeros(n_compartments) -gkca31[index_dend_apical] = conductvalues[2] -gkca31[index_dend_basal] = conductvalues[7] +gkca31[index_dend_apical] = conductvalues[2] +gkca31[index_dend_basal] = conductvalues[7] ## IKca1_1_Ma2020 gkca11 = np.zeros(n_compartments) -gkca11[index_soma]= conductvalues[13] -gkca11[index_dend_apical]= conductvalues[1] -gkca11[index_dend_basal]= conductvalues[6] +gkca11[index_soma] = conductvalues[13] +gkca11[index_dend_apical] = conductvalues[1] +gkca11[index_dend_basal] = conductvalues[6] gkca11[index_axon[0]] = conductvalues[21] gkca11[index_axon[1:]] = conductvalues[13] ## IKM_Grc_Ma2020 -gkmgrc = np.zeros(n_compartments) +gkmgrc = np.zeros(n_compartments) gkmgrc[index_axon[0]] = conductvalues[20] -class Golgi(dx.neurons.MultiCompartment): - def __init__(self, size, connection, Ra, cm, diam, L, gl, gkv11): - super().__init__( - size=size, - connection=connection, - Ra=Ra * u.ohm * u.cm, - cm=cm * u.uF / u.cm ** 2, - diam=diam * u.um, - L=L * u.um, - V_th=20. * u.mV, - V_initializer=bst.init.Constant(-55 * u.mV), - spk_fun=bst.surrogate.ReluGrad(), - ) - - self.IL = dx.channels.IL(self.size, E=-55. * u.mV, g_max=gl * u.mS / (u.cm ** 2)) - self.k = dx.ions.PotassiumFixed(self.size, E=-80. * u.mV) - self.k.add_elem(dx.channels.IKv11_Ak2007(self.size, g_max=gkv11 * u.mS / (u.cm ** 2))) - -@bst.transform.jit(static_argnums=6) +class Golgi(dx.neurons.MultiCompartment): + def __init__(self, size, connection, Ra, cm, diam, L, gl, gkv11): + super().__init__( + size=size, + connection=connection, + Ra=Ra * u.ohm * u.cm, + cm=cm * u.uF / u.cm ** 2, + diam=diam * u.um, + L=L * u.um, + V_th=20. * u.mV, + V_initializer=bst.init.Constant(-55 * u.mV), + spk_fun=bst.surrogate.ReluGrad(), + ) + + self.IL = dx.channels.IL(self.size, E=-55. * u.mV, g_max=gl * u.mS / (u.cm ** 2)) + self.k = dx.ions.PotassiumFixed(self.size, E=-80. * u.mV) + self.k.add_elem(k=dx.channels.IKv11_Ak2007(self.size, g_max=gkv11 * u.mS / (u.cm ** 2))) + + +@bst.compile.jit(static_argnums=6) def simulate(Ra, cm, diam, L, gl, gkv11, method='ieuler'): - cell = Golgi(size, connection, Ra, cm, diam, L, gl, gkv11) - cell.init_state() - cell.reset_state() - - def step(t, *args): - inp_a = np.full((n_neuron, n_compartments), 0.) * u.nA - inp_a[..., 30] = 0.02 * u.nA - inp_b = np.full((n_neuron, n_compartments), 0.) * u.nA - cell.compute_derivative(u.math.where(t < 100 * u.ms, inp_a, inp_b)) - - def save(t, *args): - return cell.V.value - - with jax.ensure_compile_time_eval(): - dt = 0.01 * u.ms - ts = u.math.arange(0. * u.ms, 200. * u.ms, dt) - ts, ys, steps = dx.diffrax_solve( - step, method, 0. * u.ms, 200. * u.ms, dt, ts, - savefn=save, atol=1e-5, # max_steps=200000, - ) + cell = Golgi(size, connection, Ra, cm, diam, L, gl, gkv11) + cell.init_state() + cell.reset_state() + + def step(t, *args): + inp_a = np.full((n_neuron, n_compartments), 0.) * u.nA + inp_a[..., 30] = 0.02 * u.nA + inp_b = np.full((n_neuron, n_compartments), 0.) * u.nA + cell.compute_derivative(u.math.where(t < 100 * u.ms, inp_a, inp_b)) + + def save(t, *args): + return cell.V.value + + with jax.ensure_compile_time_eval(): + dt = 0.01 * u.ms + ts = u.math.arange(0. * u.ms, 200. * u.ms, dt) + ts, ys, steps = dx.diffrax_solve( + step, method, 0. * u.ms, 200. * u.ms, dt, ts, + savefn=save, atol=1e-5, # max_steps=200000, + ) - return ts, ys, steps + return ts, ys, steps -@bst.transform.jit +@bst.compile.jit def simulate2(Ra, cm, diam, L, gl, gkv11): - cell = Golgi(size, connection, Ra, cm, diam, L, gl, gkv11) - cell.init_state() - cell.reset_state() - - def step_run(t): - inp_a = np.full((n_neuron, n_compartments), 0.) * u.nA - inp_a[..., 30] = 0.02 * u.nA - inp_b = np.full((n_neuron, n_compartments), 0.) * u.nA - inp = u.math.where(t < 100 * u.ms, inp_a, inp_b) - dx.rk4_step(cell, t, inp) - return cell.V.value - - with jax.ensure_compile_time_eval(): - dt = 0.001 * u.ms - ts = u.math.arange(0. * u.ms, 200. * u.ms, dt) - with bst.environ.context(dt=dt): - ys = bst.transform.for_loop(step_run, ts) - return ts, ys[::10], ts.size + cell = Golgi(size, connection, Ra, cm, diam, L, gl, gkv11) + cell.init_state() + cell.reset_state() + + def step_run(t): + inp_a = np.full((n_neuron, n_compartments), 0.) * u.nA + inp_a[..., 30] = 0.02 * u.nA + inp_b = np.full((n_neuron, n_compartments), 0.) * u.nA + inp = u.math.where(t < 100 * u.ms, inp_a, inp_b) + dx.rk4_step(cell, t, inp) + return cell.V.value + + with jax.ensure_compile_time_eval(): + dt = 0.001 * u.ms + ts = u.math.arange(0. * u.ms, 200. * u.ms, dt) + with bst.environ.context(dt=dt): + ys = bst.compile.for_loop(step_run, ts) + return ts, ys[::10], ts.size + def visualize_a_simulate(Ra, cm, diam, L, gl, gkv11): - t0 = time.time() - ts, ys_kvaerno5, steps = simulate(Ra, cm, diam, L, gl, gkv11, 'kvaerno5') - print('kvaerno5', steps) - print('time', time.time() - t0) - - t0 = time.time() - ts, ys_ieuler, steps = simulate(Ra, cm, diam, L, gl, gkv11, 'ieuler') - print('ieuler', steps) - print('time', time.time() - t0) - - t0 = time.time() - ts, ys_tsit5, steps = simulate(Ra, cm, diam, L, gl, gkv11, 'tsit5') - print('tsit5', steps) - print('time', time.time() - t0) - - t0 = time.time() - ts, ys_dopri5, steps = simulate(Ra, cm, diam, L, gl, gkv11, 'dopri5') - print('dopri5', steps) - print('time', time.time() - t0) - - t0 = time.time() - ts2, ys2_rk4, steps = simulate2(Ra, cm, diam, L, gl, gkv11) - print('rk4', steps) - print('time', time.time() - t0) - - def plot(data, ax, title): - ax.plot(ts.to_decimal(u.ms), u.math.squeeze(data.to_decimal(u.mV))) - plt.xlabel('Time [ms]') - plt.ylabel('Potential [mV]') - plt.title(title) - - fig, gs = bts.visualize.get_figure(1, 5, 3.0, 4.0) - plot(ys_kvaerno5 - ys_ieuler, fig.add_subplot(gs[0, 0]), 'kvaerno5 - ieuler') - plot(ys_ieuler - ys_tsit5, fig.add_subplot(gs[0, 1]), 'ieuler - tsit5') - plot(ys_tsit5 - ys_dopri5, fig.add_subplot(gs[0, 2]), 'tsit5 - dopri5') - plot(ys_dopri5 - ys2_rk4, fig.add_subplot(gs[0, 3]), 'dopri5 - rk4') - plot(ys_ieuler, fig.add_subplot(gs[0, 4]), 'ieuler') - plt.show() + t0 = time.time() + ts, ys_kvaerno5, steps = simulate(Ra, cm, diam, L, gl, gkv11, 'kvaerno5') + print('kvaerno5', steps) + print('time', time.time() - t0) + + t0 = time.time() + ts, ys_ieuler, steps = simulate(Ra, cm, diam, L, gl, gkv11, 'ieuler') + print('ieuler', steps) + print('time', time.time() - t0) + + t0 = time.time() + ts, ys_tsit5, steps = simulate(Ra, cm, diam, L, gl, gkv11, 'tsit5') + print('tsit5', steps) + print('time', time.time() - t0) + + t0 = time.time() + ts, ys_dopri5, steps = simulate(Ra, cm, diam, L, gl, gkv11, 'dopri5') + print('dopri5', steps) + print('time', time.time() - t0) + + t0 = time.time() + ts2, ys2_rk4, steps = simulate2(Ra, cm, diam, L, gl, gkv11) + print('rk4', steps) + print('time', time.time() - t0) + + def plot(data, ax, title): + ax.plot(ts.to_decimal(u.ms), u.math.squeeze(data.to_decimal(u.mV))) + plt.xlabel('Time [ms]') + plt.ylabel('Potential [mV]') + plt.title(title) + + fig, gs = bts.visualize.get_figure(1, 5, 3.0, 4.0) + plot(ys_kvaerno5 - ys_ieuler, fig.add_subplot(gs[0, 0]), 'kvaerno5 - ieuler') + plot(ys_ieuler - ys_tsit5, fig.add_subplot(gs[0, 1]), 'ieuler - tsit5') + plot(ys_tsit5 - ys_dopri5, fig.add_subplot(gs[0, 2]), 'tsit5 - dopri5') + plot(ys_dopri5 - ys2_rk4, fig.add_subplot(gs[0, 3]), 'dopri5 - rk4') + plot(ys_ieuler, fig.add_subplot(gs[0, 4]), 'ieuler') + plt.show() visualize_a_simulate(Ra, cm, diam, L, gl, gkv11) - diff --git a/examples/golgi_model/golgi_mophology.py b/examples/golgi_model/golgi_mophology.py index 6ffd5c3..34c8ef0 100644 --- a/examples/golgi_model/golgi_mophology.py +++ b/examples/golgi_model/golgi_mophology.py @@ -14,14 +14,14 @@ # ============================================================================== import numpy as np + loaded_params = np.load('golgi_morphology.npz') connection = loaded_params['connection'] -L = loaded_params['L'] # um -diam = loaded_params['diam'] # um -Ra = loaded_params['Ra'] # ohm * cm -cm = loaded_params['cm'] # uF / cm ** 2 - +L = loaded_params['L'] # um +diam = loaded_params['diam'] # um +Ra = loaded_params['Ra'] # ohm * cm +cm = loaded_params['cm'] # uF / cm ** 2 index_soma = loaded_params['index_soma'] index_axon = loaded_params['index_axon'] diff --git a/examples/hh_neuron.py b/examples/hh_neuron.py index ae0826e..180d0df 100644 --- a/examples/hh_neuron.py +++ b/examples/hh_neuron.py @@ -14,39 +14,39 @@ # ============================================================================== import brainstate as bst -import brainunit as bu +import brainunit as u import matplotlib.pyplot as plt import dendritex as dx -bst.environ.set(dt=0.01 * bu.ms) +bst.environ.set(dt=0.01 * u.ms) class HH(dx.neurons.SingleCompartment): - def __init__(self, size): - super().__init__(size) + def __init__(self, size): + super().__init__(size) - self.na = dx.ions.SodiumFixed(size, E=50. * bu.mV) - self.na.add_elem(dx.channels.INa_HH1952(size)) + self.na = dx.ions.SodiumFixed(size, E=50. * u.mV) + self.na.add_elem(INa=dx.channels.INa_HH1952(size)) - self.k = dx.ions.PotassiumFixed(size, E=-77. * bu.mV) - self.k.add_elem(dx.channels.IK_HH1952(size)) + self.k = dx.ions.PotassiumFixed(size, E=-77. * u.mV) + self.k.add_elem(IK=dx.channels.IK_HH1952(size)) - self.IL = dx.channels.IL(size, E=-54.387 * bu.mV, g_max=0.03 * (bu.mS / bu.cm ** 2)) + self.IL = dx.channels.IL(size, E=-54.387 * u.mV, g_max=0.03 * (u.mS / u.cm ** 2)) - def step_fun(self, t): - # dx.euler_step(hh, t, 10 * u.nA) - # dx.rk2_step(hh, t, 10 * u.nA) - # dx.rk3_step(hh, t, 10 * u.nA) - dx.rk4_step(self, t, 10 * bu.nA) - return self.V.value + def step_fun(self, t): + # dx.euler_step(hh, t, 10 * u.nA) + # dx.rk2_step(hh, t, 10 * u.nA) + # dx.rk3_step(hh, t, 10 * u.nA) + dx.rk2_step(self, t, 10 * u.nA / u.cm ** 2) + return self.V.value hh = HH([1, 1]) hh.init_state() -times = bu.math.arange(10000) * bst.environ.get_dt() -vs = bst.transform.for_loop(hh.step_fun, times) +times = u.math.arange(10000) * bst.environ.get_dt() +vs = bst.compile.for_loop(hh.step_fun, times) -plt.plot(times.to_decimal(bu.ms), bu.math.squeeze(vs.to_decimal(bu.mV))) +plt.plot(times, u.math.squeeze(vs)) plt.show() diff --git a/examples/simple_dendrite_model.py b/examples/simple_dendrite_model.py index 76d8a1c..39b4de9 100644 --- a/examples/simple_dendrite_model.py +++ b/examples/simple_dendrite_model.py @@ -22,75 +22,74 @@ class INa(dx.channels.SodiumChannel): - def __init__(self, size, g_max): - super().__init__(size) + def __init__(self, size, g_max): + super().__init__(size) - self.g_max = bst.init.param(g_max, self.varshape) + self.g_max = bst.init.param(g_max, self.varshape) - def init_state(self, V, Na: dx.IonInfo, batch_size: int = None): - self.m = dx.State4Integral(self.m_inf(V)) - self.h = dx.State4Integral(self.h_inf(V)) + def init_state(self, V, Na: dx.IonInfo, batch_size: int = None): + self.m = dx.State4Integral(self.m_inf(V)) + self.h = dx.State4Integral(self.h_inf(V)) - def compute_derivative(self, V, Na: dx.IonInfo): - self.m.derivative = (self.m_alpha(V) * (1 - self.m.value) - self.m_beta(V) * self.m.value) / u.ms - self.h.derivative = (self.h_alpha(V) * (1 - self.h.value) - self.h_beta(V) * self.h.value) / u.ms + def compute_derivative(self, V, Na: dx.IonInfo): + self.m.derivative = (self.m_alpha(V) * (1 - self.m.value) - self.m_beta(V) * self.m.value) / u.ms + self.h.derivative = (self.h_alpha(V) * (1 - self.h.value) - self.h_beta(V) * self.h.value) / u.ms - def current(self, V, Na: dx.IonInfo): - return self.g_max * self.m.value ** 3 * self.h.value * (Na.E - V) + def current(self, V, Na: dx.IonInfo): + return self.g_max * self.m.value ** 3 * self.h.value * (Na.E - V) - # m channel - m_alpha = lambda self, V: 1. / u.math.exprel(-(V / u.mV + 40.) / 10.) # nan - m_beta = lambda self, V: 4. * u.math.exp(-(V / u.mV + 65.) / 18.) - m_inf = lambda self, V: self.m_alpha(V) / (self.m_alpha(V) + self.m_beta(V)) + # m channel + m_alpha = lambda self, V: 1. / u.math.exprel(-(V / u.mV + 40.) / 10.) # nan + m_beta = lambda self, V: 4. * u.math.exp(-(V / u.mV + 65.) / 18.) + m_inf = lambda self, V: self.m_alpha(V) / (self.m_alpha(V) + self.m_beta(V)) - # h channel - h_alpha = lambda self, V: 0.07 * u.math.exp(-(V / u.mV + 65.) / 20.) - h_beta = lambda self, V: 1. / (1. + u.math.exp(-(V / u.mV + 35.) / 10.)) - h_inf = lambda self, V: self.h_alpha(V) / (self.h_alpha(V) + self.h_beta(V)) + # h channel + h_alpha = lambda self, V: 0.07 * u.math.exp(-(V / u.mV + 65.) / 20.) + h_beta = lambda self, V: 1. / (1. + u.math.exp(-(V / u.mV + 35.) / 10.)) + h_inf = lambda self, V: self.h_alpha(V) / (self.h_alpha(V) + self.h_beta(V)) class IK(dx.channels.PotassiumChannel): - def __init__(self, size, g_max): - super().__init__(size) - self.g_max = bst.init.param(g_max, self.varshape) + def __init__(self, size, g_max): + super().__init__(size) + self.g_max = bst.init.param(g_max, self.varshape) - def init_state(self, V, K: dx.IonInfo, batch_size: int = None): - self.n = dx.State4Integral(self.n_inf(V)) + def init_state(self, V, K: dx.IonInfo, batch_size: int = None): + self.n = dx.State4Integral(self.n_inf(V)) - def compute_derivative(self, V, K: dx.IonInfo): - self.n.derivative = (self.n_alpha(V) * (1 - self.n.value) - self.n_beta(V) * self.n.value) / u.ms + def compute_derivative(self, V, K: dx.IonInfo): + self.n.derivative = (self.n_alpha(V) * (1 - self.n.value) - self.n_beta(V) * self.n.value) / u.ms - def current(self, V, K: dx.IonInfo): - return self.g_max * self.n.value ** 4 * (K.E - V) + def current(self, V, K: dx.IonInfo): + return self.g_max * self.n.value ** 4 * (K.E - V) - n_alpha = lambda self, V: 0.1 / u.math.exprel(-(V / u.mV + 55.) / 10.) - n_beta = lambda self, V: 0.125 * u.math.exp(-(V / u.mV + 65.) / 80.) - n_inf = lambda self, V: self.n_alpha(V) / (self.n_alpha(V) + self.n_beta(V)) + n_alpha = lambda self, V: 0.1 / u.math.exprel(-(V / u.mV + 55.) / 10.) + n_beta = lambda self, V: 0.125 * u.math.exp(-(V / u.mV + 65.) / 80.) + n_inf = lambda self, V: self.n_alpha(V) / (self.n_alpha(V) + self.n_beta(V)) class ThreeCompartmentHH(dx.neurons.MultiCompartment): - def __init__(self, n_neuron: int, g_na, g_k): - super().__init__( - size=(n_neuron, 3), - connection=((0, 1), (1, 2)), - Ra=100. * u.ohm * u.cm, - cm=1.0 * u.uF / u.cm ** 2, - diam=(12.6157, 1., 1.) * u.um, - L=(12.6157, 200., 400.) * u.um, - V_th=20. * u.mV, - V_initializer=bst.init.Constant(-65 * u.mV), - spk_fun=bst.surrogate.ReluGrad(), - ) - - self.IL = dx.channels.IL(self.size, E=(-54.3, -65., -65.) * u.mV, g_max=[0.0003, 0.001, 0.001] * s) - - self.na = dx.ions.SodiumFixed(self.size, E=50. * u.mV) - self.na.add_elem(INa(self.size, g_max=(g_na, 0., 0.) * s)) - - self.k = dx.ions.PotassiumFixed(self.size, E=-77. * u.mV) - self.k.add_elem(IK(self.size, g_max=(g_k, 0., 0.) * s)) - - def step_run(self, t, inp): - dx.rk4_step(self, t, inp) - return self.V.value, self.spike.value - + def __init__(self, n_neuron: int, g_na, g_k): + super().__init__( + size=(n_neuron, 3), + connection=((0, 1), (1, 2)), + Ra=100. * u.ohm * u.cm, + cm=1.0 * u.uF / u.cm ** 2, + diam=(12.6157, 1., 1.) * u.um, + L=(12.6157, 200., 400.) * u.um, + V_th=20. * u.mV, + V_initializer=bst.init.Constant(-65 * u.mV), + spk_fun=bst.surrogate.ReluGrad(), + ) + + self.IL = dx.channels.IL(self.size, E=(-54.3, -65., -65.) * u.mV, g_max=[0.0003, 0.001, 0.001] * s) + + self.na = dx.ions.SodiumFixed(self.size, E=50. * u.mV) + self.na.add_elem(INa=INa(self.size, g_max=(g_na, 0., 0.) * s)) + + self.k = dx.ions.PotassiumFixed(self.size, E=-77. * u.mV) + self.k.add_elem(IK=IK(self.size, g_max=(g_k, 0., 0.) * s)) + + def step_run(self, t, inp): + dx.rk4_step(self, t, inp) + return self.V.value, self.spike.value diff --git a/examples/simple_dendrite_model_fitting_by_adam.py b/examples/simple_dendrite_model_fitting_by_adam.py index 76861d4..7ad2dc7 100644 --- a/examples/simple_dendrite_model_fitting_by_adam.py +++ b/examples/simple_dendrite_model_fitting_by_adam.py @@ -27,114 +27,114 @@ def visualize_a_simulate(params, f_current, show=True, title=''): - saveat = u.math.arange(0., 100., 0.1) * u.ms - ts, vs, _ = solve_explicit_solver(params, f_current, saveat) + saveat = u.math.arange(0., 100., 0.1) * u.ms + ts, vs, _ = solve_explicit_solver(params, f_current, saveat) - fig, gs = bts.visualize.get_figure(1, 1, 3.0, 4.0) - ax = fig.add_subplot(gs[0, 0]) - plt.plot(ts.to_decimal(u.ms), u.math.squeeze(vs.to_decimal(u.mV))) - plt.xlabel('Time [ms]') - plt.ylabel('Potential [mV]') - if title: - plt.title(title) - if show: - plt.show() + fig, gs = bts.visualize.get_figure(1, 1, 3.0, 4.0) + ax = fig.add_subplot(gs[0, 0]) + plt.plot(ts.to_decimal(u.ms), u.math.squeeze(vs.to_decimal(u.mV))) + plt.xlabel('Time [ms]') + plt.ylabel('Potential [mV]') + if title: + plt.title(title) + if show: + plt.show() def fitting_example(): - t1 = 200 * u.ms - - # Step 1: generating input currents - saveat = u.math.arange(0., t1 / u.ms, 0.2) * u.ms - - def f_current(t, i_current, *args): - return jax.lax.switch( - i_current, - [ - lambda t: u.math.where(t < 50. * u.ms, - 0. * u.nA, - u.math.where(t < 100. * u.ms, 0.5 * u.nA, 0. * u.nA)), - lambda t: u.math.where(t < 60. * u.ms, - 0. * u.nA, - u.math.where(t < 160. * u.ms, 0.2 * u.nA, 0. * u.nA)), - lambda t: u.math.where(t < 80. * u.ms, - 0. * u.nA, - u.math.where(t < 160. * u.ms, 1.0 * u.nA, 0. * u.nA)), - lambda t: u.math.where(t < 100. * u.ms, - 0.2 * u.nA, - u.math.where(t < 150. * u.ms, 0.1 * u.nA, 0.3 * u.nA)), - ], # suppose there are 4 input currents - t - ) - - # Step 2: generating the target neuronal parameters - target_params = np.asarray([0.12, 0.036]) - - # Step 3: generating the target potentials - - def simulate_per_param(param): - indices = np.arange(4) - fun = functools.partial(solve_explicit_solver, param, f_current, saveat, t1) - _, simulated_vs, _ = jax.vmap(fun)((indices,)) # [n_input, T, n_compartment] - return simulated_vs - - target_vs = simulate_per_param(target_params) - - # Step 4: initialize a batch of parameters to optimize, - # these parameters are candidates to be optimized - bounds = [ - np.asarray([0.05, 0.01]), - np.asarray([0.2, 0.1]) - ] - n_batch = 8 - param_to_optimize = bst.ParamState(bst.random.uniform(bounds[0], bounds[1], (n_batch, bounds[0].size))) - - # Step 5: define the loss function and optimizers - - # calculate the loss for each parameter based on - # the mismatch between the 4 simulated and target potentials - def loss_per_param(param, step=10): - simulated_vs = simulate_per_param(param) # [n_input, T, n_compartment] - losses = bts.metric.squared_error(simulated_vs.mantissa[..., ::step, 0], target_vs.mantissa[..., ::step, 0]) - return losses.mean() - - # calculate the gradients and loss for each parameter - @jax.vmap - @jax.jit - def compute_grad(param): - grads, loss = bst.transform.grad(loss_per_param, argnums=0, return_value=True)(param) - return grads, loss - - # find the best loss and parameter in the batch - @bst.transform.jit - def best_loss_and_param(params, losses): - i_best = u.math.argmin(losses) - return losses[i_best], params[i_best] - - # define the optimizer - optimizer = bst.optim.Adam(lr=1e-3) - optimizer.register_trainable_weights({'param': param_to_optimize}) - - # Step 6: training - @bst.transform.jit - def train_step_per_epoch(): - grads, losses = compute_grad(param_to_optimize.value) - optimizer.update({'param': grads}) - return losses - - for i_epoch in range(1000): - losses = train_step_per_epoch() - best_loss, best_param = best_loss_and_param(param_to_optimize.value, losses) - if best_loss < 1e-5: - print(f'Epoch {i_epoch}, loss={losses.mean()}, best loss={best_loss}, best param={best_param}') - break - if i_epoch % 10 == 0: - print(f'Epoch {i_epoch}, loss={losses.mean()}, best loss={best_loss}, best param={best_param}') - - # Step 7: visualize the results - visualize_a_simulate(target_params, functools.partial(f_current, i_current=0), title='Target', show=False) - visualize_a_simulate(best_param, functools.partial(f_current, i_current=0), title='Fitted', show=True) + t1 = 200 * u.ms + + # Step 1: generating input currents + saveat = u.math.arange(0., t1 / u.ms, 0.2) * u.ms + + def f_current(t, i_current, *args): + return jax.lax.switch( + i_current, + [ + lambda t: u.math.where(t < 50. * u.ms, + 0. * u.nA, + u.math.where(t < 100. * u.ms, 0.5 * u.nA, 0. * u.nA)), + lambda t: u.math.where(t < 60. * u.ms, + 0. * u.nA, + u.math.where(t < 160. * u.ms, 0.2 * u.nA, 0. * u.nA)), + lambda t: u.math.where(t < 80. * u.ms, + 0. * u.nA, + u.math.where(t < 160. * u.ms, 1.0 * u.nA, 0. * u.nA)), + lambda t: u.math.where(t < 100. * u.ms, + 0.2 * u.nA, + u.math.where(t < 150. * u.ms, 0.1 * u.nA, 0.3 * u.nA)), + ], # suppose there are 4 input currents + t + ) + + # Step 2: generating the target neuronal parameters + target_params = np.asarray([0.12, 0.036]) + + # Step 3: generating the target potentials + + def simulate_per_param(param): + indices = np.arange(4) + fun = functools.partial(solve_explicit_solver, param, f_current, saveat, t1) + _, simulated_vs, _ = jax.vmap(fun)((indices,)) # [n_input, T, n_compartment] + return simulated_vs + + target_vs = simulate_per_param(target_params) + + # Step 4: initialize a batch of parameters to optimize, + # these parameters are candidates to be optimized + bounds = [ + np.asarray([0.05, 0.01]), + np.asarray([0.2, 0.1]) + ] + n_batch = 8 + param_to_optimize = bst.ParamState(bst.random.uniform(bounds[0], bounds[1], (n_batch, bounds[0].size))) + + # Step 5: define the loss function and optimizers + + # calculate the loss for each parameter based on + # the mismatch between the 4 simulated and target potentials + def loss_per_param(param, step=10): + simulated_vs = simulate_per_param(param) # [n_input, T, n_compartment] + losses = bts.metric.squared_error(simulated_vs.mantissa[..., ::step, 0], target_vs.mantissa[..., ::step, 0]) + return losses.mean() + + # calculate the gradients and loss for each parameter + @jax.vmap + @jax.jit + def compute_grad(param): + grads, loss = bst.augment.grad(loss_per_param, argnums=0, return_value=True)(param) + return grads, loss + + # find the best loss and parameter in the batch + @bst.compile.jit + def best_loss_and_param(params, losses): + i_best = u.math.argmin(losses) + return losses[i_best], params[i_best] + + # define the optimizer + optimizer = bst.optim.Adam(lr=1e-3) + optimizer.register_trainable_weights({'param': param_to_optimize}) + + # Step 6: training + @bst.compile.jit + def train_step_per_epoch(): + grads, losses = compute_grad(param_to_optimize.value) + optimizer.update({'param': grads}) + return losses + + for i_epoch in range(1000): + losses = train_step_per_epoch() + best_loss, best_param = best_loss_and_param(param_to_optimize.value, losses) + if best_loss < 1e-5: + print(f'Epoch {i_epoch}, loss={losses.mean()}, best loss={best_loss}, best param={best_param}') + break + if i_epoch % 10 == 0: + print(f'Epoch {i_epoch}, loss={losses.mean()}, best loss={best_loss}, best param={best_param}') + + # Step 7: visualize the results + visualize_a_simulate(target_params, functools.partial(f_current, i_current=0), title='Target', show=False) + visualize_a_simulate(best_param, functools.partial(f_current, i_current=0), title='Fitted', show=True) if __name__ == '__main__': - fitting_example() + fitting_example() diff --git a/examples/simple_dendrite_model_simulation.py b/examples/simple_dendrite_model_simulation.py index ab2aa3a..751a21c 100644 --- a/examples/simple_dendrite_model_simulation.py +++ b/examples/simple_dendrite_model_simulation.py @@ -34,22 +34,22 @@ def solve_explicit_solver( dt: u.Quantity = 0.01 * u.ms, method: str = 'tsit5' ): - hh = ThreeCompartmentHH(n_neuron=1, g_na=params[0], g_k=params[1]) - hh.init_state() + hh = ThreeCompartmentHH(n_neuron=1, g_na=params[0], g_k=params[1]) + hh.init_state() - def step(t, *args): - currents = f_current(t, *args) - hh.compute_derivative(currents) + def step(t, *args): + currents = f_current(t, *args) + hh.compute_derivative(currents) - def save(t, *args): - return hh.V.value + def save(t, *args): + return hh.V.value - ts, ys, steps = dx.diffrax_solve( - step, method, 0. * u.ms, t1, dt, - savefn=save, saveat=saveat, args=args, - adjoint=dfx.RecursiveCheckpointAdjoint(1000) - ) - return ts, ys, steps + ts, ys, steps = dx.diffrax_solve( + step, method, 0. * u.ms, t1, dt, + savefn=save, saveat=saveat, args=args, + adjoint=dfx.RecursiveCheckpointAdjoint(1000) + ) + return ts, ys, steps def adjoint_solve_explicit_solver( @@ -62,30 +62,30 @@ def adjoint_solve_explicit_solver( method: str = 'tsit5', max_steps: int = 100000 ): - hh = ThreeCompartmentHH(n_neuron=1, g_na=params[0], g_k=params[1]) - hh.init_state() + hh = ThreeCompartmentHH(n_neuron=1, g_na=params[0], g_k=params[1]) + hh.init_state() - def step(t, *args): - currents = f_current(t, *args) - hh.compute_derivative(currents) + def step(t, *args): + currents = f_current(t, *args) + hh.compute_derivative(currents) - ts, ys, steps = dx.diffrax_solve_adjoint( - step, method, 0. * u.ms, t1, dt, saveat=saveat, args=args, max_steps=max_steps - ) - return ts, ys[0], steps + ts, ys, steps = dx.diffrax_solve_adjoint( + step, method, 0. * u.ms, t1, dt, saveat=saveat, args=args, max_steps=max_steps + ) + return ts, ys[0], steps def simulate(): - g = np.asarray([0.12, 0.036, 0.0003, 0.001, 0.001]) - f_current = lambda t: np.asarray([0.2, 0., 0.]) * u.nA - saveat = u.math.arange(0., 100., 0.1) * u.ms - ts, ys, steps = solve_explicit_solver(g, f_current, saveat) - print(steps) - plt.plot(ts.to_decimal(u.ms), u.math.squeeze(ys).to_decimal(u.mV)) - plt.xlabel('Time [ms]') - plt.ylabel('Potential [mV]') - plt.show() + g = np.asarray([0.12, 0.036, 0.0003, 0.001, 0.001]) + f_current = lambda t: np.asarray([0.2, 0., 0.]) * u.nA + saveat = u.math.arange(0., 100., 0.1) * u.ms + ts, ys, steps = solve_explicit_solver(g, f_current, saveat) + print(steps) + plt.plot(ts.to_decimal(u.ms), u.math.squeeze(ys).to_decimal(u.mV)) + plt.xlabel('Time [ms]') + plt.ylabel('Potential [mV]') + plt.show() if __name__ == '__main__': - simulate() + simulate() diff --git a/examples/thalamus_single_compartment_neurons.py b/examples/thalamus_single_compartment_neurons.py index 3fdc271..efab159 100644 --- a/examples/thalamus_single_compartment_neurons.py +++ b/examples/thalamus_single_compartment_neurons.py @@ -31,174 +31,174 @@ class HTC(dx.neurons.SingleCompartment): - def __init__(self, size, gKL=0.01 * (u.mS / u.cm ** 2), V_initializer=bst.init.Constant(-65. * u.mV)): - super().__init__(size, V_initializer=V_initializer, V_th=20. * u.mV) + def __init__(self, size, gKL=0.01 * (u.mS / u.cm ** 2), V_initializer=bst.init.Constant(-65. * u.mV)): + super().__init__(size, V_initializer=V_initializer, V_th=20. * u.mV) - self.na = dx.ions.SodiumFixed(size, E=50. * u.mV) - self.na.add_elem(INa=dx.channels.INa_Ba2002(size, V_sh=-30 * u.mV)) + self.na = dx.ions.SodiumFixed(size, E=50. * u.mV) + self.na.add_elem(INa=dx.channels.INa_Ba2002(size, V_sh=-30 * u.mV)) - self.k = dx.ions.PotassiumFixed(size, E=-90. * u.mV) - self.k.add_elem(IKL=dx.channels.IK_Leak(size, g_max=gKL)) - self.k.add_elem(IDR=dx.channels.IKDR_Ba2002(size, V_sh=-30. * u.mV, phi=0.25)) + self.k = dx.ions.PotassiumFixed(size, E=-90. * u.mV) + self.k.add_elem(IKL=dx.channels.IK_Leak(size, g_max=gKL)) + self.k.add_elem(IDR=dx.channels.IKDR_Ba2002(size, V_sh=-30. * u.mV, phi=0.25)) - self.ca = dx.ions.CalciumDetailed(size, C_rest=5e-5 * u.mM, tau=10. * u.ms, d=0.5 * u.um) - self.ca.add_elem(ICaL=dx.channels.ICaL_IS2008(size, g_max=0.5 * (u.mS / u.cm ** 2))) - self.ca.add_elem(ICaN=dx.channels.ICaN_IS2008(size, g_max=0.5 * (u.mS / u.cm ** 2))) - self.ca.add_elem(ICaT=dx.channels.ICaT_HM1992(size, g_max=2.1 * (u.mS / u.cm ** 2))) - self.ca.add_elem(ICaHT=dx.channels.ICaHT_HM1992(size, g_max=3.0 * (u.mS / u.cm ** 2))) + self.ca = dx.ions.CalciumDetailed(size, C_rest=5e-5 * u.mM, tau=10. * u.ms, d=0.5 * u.um) + self.ca.add_elem(ICaL=dx.channels.ICaL_IS2008(size, g_max=0.5 * (u.mS / u.cm ** 2))) + self.ca.add_elem(ICaN=dx.channels.ICaN_IS2008(size, g_max=0.5 * (u.mS / u.cm ** 2))) + self.ca.add_elem(ICaT=dx.channels.ICaT_HM1992(size, g_max=2.1 * (u.mS / u.cm ** 2))) + self.ca.add_elem(ICaHT=dx.channels.ICaHT_HM1992(size, g_max=3.0 * (u.mS / u.cm ** 2))) - self.kca = dx.MixIons(self.k, self.ca) - self.kca.add_elem(IAHP=dx.channels.IAHP_De1994(size, g_max=0.3 * (u.mS / u.cm ** 2))) + self.kca = dx.MixIons(self.k, self.ca) + self.kca.add_elem(IAHP=dx.channels.IAHP_De1994(size, g_max=0.3 * (u.mS / u.cm ** 2))) - self.Ih = dx.channels.Ih_HM1992(size, g_max=0.01 * (u.mS / u.cm ** 2), E=-43 * u.mV) - self.IL = dx.channels.IL(size, g_max=0.0075 * (u.mS / u.cm ** 2), E=-70 * u.mV) + self.Ih = dx.channels.Ih_HM1992(size, g_max=0.01 * (u.mS / u.cm ** 2), E=-43 * u.mV) + self.IL = dx.channels.IL(size, g_max=0.0075 * (u.mS / u.cm ** 2), E=-70 * u.mV) - def compute_derivative(self, x=0. * u.nA): - return super().compute_derivative(x * (1e-3 / (2.9e-4 * u.cm ** 2))) + def compute_derivative(self, x=0. * u.nA): + return super().compute_derivative(x * (1e-3 / (2.9e-4 * u.cm ** 2))) class RTC(dx.neurons.SingleCompartment): - def __init__(self, size, gKL=0.01 * (u.mS / u.cm ** 2), V_initializer=bst.init.Constant(-65. * u.mV)): - super().__init__(size, V_initializer=V_initializer, V_th=20 * u.mV) + def __init__(self, size, gKL=0.01 * (u.mS / u.cm ** 2), V_initializer=bst.init.Constant(-65. * u.mV)): + super().__init__(size, V_initializer=V_initializer, V_th=20 * u.mV) - self.na = dx.ions.SodiumFixed(size) - self.na.add_elem(INa=dx.channels.INa_Ba2002(size, V_sh=-40 * u.mV)) + self.na = dx.ions.SodiumFixed(size) + self.na.add_elem(INa=dx.channels.INa_Ba2002(size, V_sh=-40 * u.mV)) - self.k = dx.ions.PotassiumFixed(size, E=-90. * u.mV) - self.k.add_elem(IDR=dx.channels.IKDR_Ba2002(size, V_sh=-40 * u.mV, phi=0.25)) - self.k.add_elem(IKL=dx.channels.IK_Leak(size, g_max=gKL)) + self.k = dx.ions.PotassiumFixed(size, E=-90. * u.mV) + self.k.add_elem(IDR=dx.channels.IKDR_Ba2002(size, V_sh=-40 * u.mV, phi=0.25)) + self.k.add_elem(IKL=dx.channels.IK_Leak(size, g_max=gKL)) - self.ca = dx.ions.CalciumDetailed(size, C_rest=5e-5 * u.mM, tau=10. * u.ms, d=0.5 * u.um) - self.ca.add_elem(ICaL=dx.channels.ICaL_IS2008(size, g_max=0.3 * (u.mS / u.cm ** 2))) - self.ca.add_elem(ICaN=dx.channels.ICaN_IS2008(size, g_max=0.6 * (u.mS / u.cm ** 2))) - self.ca.add_elem(ICaT=dx.channels.ICaT_HM1992(size, g_max=2.1 * (u.mS / u.cm ** 2))) - self.ca.add_elem(ICaHT=dx.channels.ICaHT_HM1992(size, g_max=0.6 * (u.mS / u.cm ** 2))) + self.ca = dx.ions.CalciumDetailed(size, C_rest=5e-5 * u.mM, tau=10. * u.ms, d=0.5 * u.um) + self.ca.add_elem(ICaL=dx.channels.ICaL_IS2008(size, g_max=0.3 * (u.mS / u.cm ** 2))) + self.ca.add_elem(ICaN=dx.channels.ICaN_IS2008(size, g_max=0.6 * (u.mS / u.cm ** 2))) + self.ca.add_elem(ICaT=dx.channels.ICaT_HM1992(size, g_max=2.1 * (u.mS / u.cm ** 2))) + self.ca.add_elem(ICaHT=dx.channels.ICaHT_HM1992(size, g_max=0.6 * (u.mS / u.cm ** 2))) - self.kca = dx.MixIons(self.k, self.ca) - self.kca.add_elem(IAHP=dx.channels.IAHP_De1994(size, g_max=0.1 * (u.mS / u.cm ** 2))) + self.kca = dx.MixIons(self.k, self.ca) + self.kca.add_elem(IAHP=dx.channels.IAHP_De1994(size, g_max=0.1 * (u.mS / u.cm ** 2))) - self.Ih = dx.channels.Ih_HM1992(size, g_max=0.01 * (u.mS / u.cm ** 2), E=-43 * u.mV) - self.IL = dx.channels.IL(size, g_max=0.0075 * (u.mS / u.cm ** 2), E=-70 * u.mV) + self.Ih = dx.channels.Ih_HM1992(size, g_max=0.01 * (u.mS / u.cm ** 2), E=-43 * u.mV) + self.IL = dx.channels.IL(size, g_max=0.0075 * (u.mS / u.cm ** 2), E=-70 * u.mV) - def compute_derivative(self, x=0. * u.nA): - return super().compute_derivative(x * (1e-3 / (2.9e-4 * u.cm ** 2))) + def compute_derivative(self, x=0. * u.nA): + return super().compute_derivative(x * (1e-3 / (2.9e-4 * u.cm ** 2))) class IN(dx.neurons.SingleCompartment): - def __init__(self, size, V_initializer=bst.init.Constant(-70. * u.mV)): - super().__init__(size, V_initializer=V_initializer, V_th=20. * u.mV) + def __init__(self, size, V_initializer=bst.init.Constant(-70. * u.mV)): + super().__init__(size, V_initializer=V_initializer, V_th=20. * u.mV) - self.na = dx.ions.SodiumFixed(size) - self.na.add_elem(INa=dx.channels.INa_Ba2002(size, V_sh=-30 * u.mV)) + self.na = dx.ions.SodiumFixed(size) + self.na.add_elem(INa=dx.channels.INa_Ba2002(size, V_sh=-30 * u.mV)) - self.k = dx.ions.PotassiumFixed(size, E=-90. * u.mV) - self.k.add_elem(IDR=dx.channels.IKDR_Ba2002(size, V_sh=-30 * u.mV, phi=0.25)) - self.k.add_elem(IKL=dx.channels.IK_Leak(size, g_max=0.01 * (u.mS / u.cm ** 2))) + self.k = dx.ions.PotassiumFixed(size, E=-90. * u.mV) + self.k.add_elem(IDR=dx.channels.IKDR_Ba2002(size, V_sh=-30 * u.mV, phi=0.25)) + self.k.add_elem(IKL=dx.channels.IK_Leak(size, g_max=0.01 * (u.mS / u.cm ** 2))) - self.ca = dx.ions.CalciumDetailed(size, C_rest=5e-5 * u.mM, tau=10. * u.ms, d=0.5 * u.um) - self.ca.add_elem(ICaN=dx.channels.ICaN_IS2008(size, g_max=0.1 * (u.mS / u.cm ** 2))) - self.ca.add_elem(ICaHT=dx.channels.ICaHT_HM1992(size, g_max=2.5 * (u.mS / u.cm ** 2))) + self.ca = dx.ions.CalciumDetailed(size, C_rest=5e-5 * u.mM, tau=10. * u.ms, d=0.5 * u.um) + self.ca.add_elem(ICaN=dx.channels.ICaN_IS2008(size, g_max=0.1 * (u.mS / u.cm ** 2))) + self.ca.add_elem(ICaHT=dx.channels.ICaHT_HM1992(size, g_max=2.5 * (u.mS / u.cm ** 2))) - self.kca = dx.MixIons(self.k, self.ca) - self.kca.add_elem(IAHP=dx.channels.IAHP_De1994(size, g_max=0.2 * (u.mS / u.cm ** 2))) + self.kca = dx.MixIons(self.k, self.ca) + self.kca.add_elem(IAHP=dx.channels.IAHP_De1994(size, g_max=0.2 * (u.mS / u.cm ** 2))) - self.IL = dx.channels.IL(size, g_max=0.0075 * (u.mS / u.cm ** 2), E=-60 * u.mV) - self.Ih = dx.channels.Ih_HM1992(size, g_max=0.05 * (u.mS / u.cm ** 2), E=-43 * u.mV) + self.IL = dx.channels.IL(size, g_max=0.0075 * (u.mS / u.cm ** 2), E=-60 * u.mV) + self.Ih = dx.channels.Ih_HM1992(size, g_max=0.05 * (u.mS / u.cm ** 2), E=-43 * u.mV) - def compute_derivative(self, x=0. * u.nA): - return super().compute_derivative(x * (1e-3 / (1.7e-4 * u.cm ** 2))) + def compute_derivative(self, x=0. * u.nA): + return super().compute_derivative(x * (1e-3 / (1.7e-4 * u.cm ** 2))) class TRN(dx.neurons.SingleCompartment): - def __init__(self, size, V_initializer=bst.init.Constant(-70. * u.mV), gl=0.0075): - super().__init__(size, V_initializer=V_initializer, V_th=20. * u.mV) + def __init__(self, size, V_initializer=bst.init.Constant(-70. * u.mV), gl=0.0075): + super().__init__(size, V_initializer=V_initializer, V_th=20. * u.mV) - self.na = dx.ions.SodiumFixed(size) - self.na.add_elem(INa=dx.channels.INa_Ba2002(size, V_sh=-40 * u.mV)) + self.na = dx.ions.SodiumFixed(size) + self.na.add_elem(INa=dx.channels.INa_Ba2002(size, V_sh=-40 * u.mV)) - self.k = dx.ions.PotassiumFixed(size, E=-90. * u.mV) - self.k.add_elem(IDR=dx.channels.IKDR_Ba2002(size, V_sh=-40 * u.mV)) - self.k.add_elem(IKL=dx.channels.IK_Leak(size, g_max=0.01 * (u.mS / u.cm ** 2))) + self.k = dx.ions.PotassiumFixed(size, E=-90. * u.mV) + self.k.add_elem(IDR=dx.channels.IKDR_Ba2002(size, V_sh=-40 * u.mV)) + self.k.add_elem(IKL=dx.channels.IK_Leak(size, g_max=0.01 * (u.mS / u.cm ** 2))) - self.ca = dx.ions.CalciumDetailed(size, C_rest=5e-5 * u.mM, tau=100. * u.ms, d=0.5 * u.um) - self.ca.add_elem(ICaN=dx.channels.ICaN_IS2008(size, g_max=0.2 * (u.mS / u.cm ** 2))) - self.ca.add_elem(ICaT=dx.channels.ICaT_HP1992(size, g_max=1.3 * (u.mS / u.cm ** 2))) + self.ca = dx.ions.CalciumDetailed(size, C_rest=5e-5 * u.mM, tau=100. * u.ms, d=0.5 * u.um) + self.ca.add_elem(ICaN=dx.channels.ICaN_IS2008(size, g_max=0.2 * (u.mS / u.cm ** 2))) + self.ca.add_elem(ICaT=dx.channels.ICaT_HP1992(size, g_max=1.3 * (u.mS / u.cm ** 2))) - self.kca = dx.MixIons(self.k, self.ca) - self.kca.add_elem(IAHP=dx.channels.IAHP_De1994(size, g_max=0.2 * (u.mS / u.cm ** 2))) + self.kca = dx.MixIons(self.k, self.ca) + self.kca.add_elem(IAHP=dx.channels.IAHP_De1994(size, g_max=0.2 * (u.mS / u.cm ** 2))) - # self.IL = dx.channels.IL(size, g_max=0.01 * (u.mS / u.cm ** 2), E=-60 * u.mV) - self.IL = dx.channels.IL(size, g_max=gl * (u.mS / u.cm ** 2), E=-60 * u.mV) + # self.IL = dx.channels.IL(size, g_max=0.01 * (u.mS / u.cm ** 2), E=-60 * u.mV) + self.IL = dx.channels.IL(size, g_max=gl * (u.mS / u.cm ** 2), E=-60 * u.mV) - def compute_derivative(self, x=0. * u.nA): - return super().compute_derivative(x * (1e-3 / (1.43e-4 * u.cm ** 2))) + def compute_derivative(self, x=0. * u.nA): + return super().compute_derivative(x * (1e-3 / (1.43e-4 * u.cm ** 2))) - def step_run(self, t, inp): - # dx.rk4_step(neu, t, inp) - dx.rk2_step(self, t, inp) - # dx.euler_step(neu, t, inp) - return self.V.value + def step_run(self, t, inp): + # dx.rk4_step(neu, t, inp) + dx.rk2_step(self, t, inp) + # dx.euler_step(neu, t, inp) + return self.V.value def try_trn_neuron(): - bst.environ.set(dt=0.01 * u.ms) + bst.environ.set(dt=0.01 * u.ms) - I = bts.input.section_input(values=[0, -0.05, 0], durations=[500, 200, 1000], dt=0.01) * u.uA - times = u.math.arange(I.shape[0]) * bst.environ.get_dt() + I = bts.input.section_input(values=[0, -0.05, 0], durations=[500, 200, 1000], dt=0.01) * u.uA + times = u.math.arange(I.shape[0]) * bst.environ.get_dt() - # neu = HTC([1, 1]) # [n_neuron, n_compartment] - # neu = IN([1, 1]) # [n_neuron, n_compartment] - # neu = RTC(1) # [n_neuron, n_compartment] - neu = TRN([1, 1], gl=0.0075) # [n_neuron, n_compartment] - neu.init_state() + # neu = HTC([1, 1]) # [n_neuron, n_compartment] + # neu = IN([1, 1]) # [n_neuron, n_compartment] + # neu = RTC(1) # [n_neuron, n_compartment] + neu = TRN([1, 1], gl=0.0075) # [n_neuron, n_compartment] + neu.init_state() - t0 = time.time() - vs = bst.transform.for_loop(neu.step_run, times, I) - t1 = time.time() - print(f"Elapsed time: {t1 - t0:.4f} s") + t0 = time.time() + vs = bst.compile.for_loop(neu.step_run, times, I) + t1 = time.time() + print(f"Elapsed time: {t1 - t0:.4f} s") - neu = TRN([1, 1], gl=0.00751) # [n_neuron, n_compartment] - neu.init_state() - vs2 = bst.transform.for_loop(neu.step_run, times, I) + neu = TRN([1, 1], gl=0.00751) # [n_neuron, n_compartment] + neu.init_state() + vs2 = bst.compile.for_loop(neu.step_run, times, I) - plt.plot(times.to_decimal(u.ms), u.math.squeeze(vs.to_decimal(u.mV))) - plt.plot(times.to_decimal(u.ms), u.math.squeeze(vs2.to_decimal(u.mV))) - plt.show() + plt.plot(times.to_decimal(u.ms), u.math.squeeze(vs.to_decimal(u.mV))) + plt.plot(times.to_decimal(u.ms), u.math.squeeze(vs2.to_decimal(u.mV))) + plt.show() def try_trn_neuron2(): - bst.environ.set(dt=0.01 * u.ms) + bst.environ.set(dt=0.01 * u.ms) - I = bts.input.section_input(values=[0, -0.05, 0], durations=[500, 200, 1000], dt=0.01) * u.uA - all_times = u.math.arange(I.shape[0]) * bst.environ.get_dt() + I = bts.input.section_input(values=[0, -0.05, 0], durations=[500, 200, 1000], dt=0.01) * u.uA + all_times = u.math.arange(I.shape[0]) * bst.environ.get_dt() - neu = TRN([1, 1], gl=0.0075) # [n_neuron, n_compartment] + neu = TRN([1, 1], gl=0.0075) # [n_neuron, n_compartment] - @bst.transform.jit - def run(): - neu.init_state() - vs = bst.transform.for_loop(neu.step_run, all_times, I) - return vs - - times = [] - t0 = time.time() - vs = run() - t1 = time.time() - print(f"Compilation time: {t1 - t0:.4f} s") - times.append(t1 - t0) + @bst.compile.jit + def run(): + neu.init_state() + vs = bst.compile.for_loop(neu.step_run, all_times, I) + return vs - for _ in range(5): + times = [] t0 = time.time() vs = run() t1 = time.time() - print(f"Running Time: {t1 - t0}") + print(f"Compilation time: {t1 - t0:.4f} s") times.append(t1 - t0) - print(times) + for _ in range(5): + t0 = time.time() + vs = run() + t1 = time.time() + print(f"Running Time: {t1 - t0}") + times.append(t1 - t0) + + print(times) - plt.plot(all_times.to_decimal(u.ms), u.math.squeeze(vs.to_decimal(u.mV))) - plt.show() + plt.plot(all_times.to_decimal(u.ms), u.math.squeeze(vs.to_decimal(u.mV))) + plt.show() if __name__ == '__main__': - try_trn_neuron2() + try_trn_neuron2() diff --git a/pyproject.toml b/pyproject.toml index 12c9107..a53a674 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta" [tool.setuptools.packages.find] exclude = [ - "docs*", "build*", "dist*", + "docs*", "build*", "dist*", "examples*", "dendritex.egg-info*", "dendritex/__pycache__*", "dendritex/__init__.py" ] @@ -48,7 +48,7 @@ dependencies = [ 'jax', 'jaxlib', 'numpy', - 'brainstate>=0.0.2', + 'brainstate>=0.1.0', 'brainunit>=0.0.2.post20240903', 'diffrax', ] diff --git a/setup.py b/setup.py index c182964..3c8e535 100644 --- a/setup.py +++ b/setup.py @@ -27,70 +27,70 @@ # version here = os.path.abspath(os.path.dirname(__file__)) with open(os.path.join(here, 'dendritex/', '__init__.py'), 'r') as f: - init_py = f.read() + init_py = f.read() version = re.search('__version__ = "(.*)"', init_py).groups()[0] print(version) if len(sys.argv) > 2 and sys.argv[2] == '--python-tag=py3': - version = version + version = version else: - version += '.post{}'.format(time.strftime("%Y%m%d", time.localtime())) + version += '.post{}'.format(time.strftime("%Y%m%d", time.localtime())) # obtain long description from README with io.open(os.path.join(here, 'README.md'), 'r', encoding='utf-8') as f: - README = f.read() + README = f.read() # installation packages packages = find_packages( - exclude=[ - "docs*", "build*", - "dist*", "dendritex.egg-info*", "dendritex/__pycache__*" - ] + exclude=[ + "docs*", "build*", "examples*", + "dist*", "dendritex.egg-info*", "dendritex/__pycache__*" + ] ) # setup setup( - name='dendritex', - version=version, - description='Dendrite Modeling in Python', - long_description=README, - long_description_content_type="text/markdown", - author='Dendritex Developers', - author_email='chao.brain@qq.com', - packages=packages, - python_requires='>=3.9', - install_requires=['numpy>=1.15', 'jax', 'brainunit>=0.0.2.post20240903', 'brainstate>=0.0.2', 'diffrax'], - url='https://github.com/chaoming0625/dendritex', - project_urls={ - "Bug Tracker": "https://github.com/chaoming0625/dendritex/issues", - "Documentation": "https://dendrite.readthedocs.io/", - "Source Code": "https://github.com/chaoming0625/dendritex", - }, - extras_require={ - 'cpu': ['jaxlib'], - 'cuda12': ['jaxlib[cuda12]', ], - 'tpu': ['jaxlib[tpu]'], - }, - keywords=( - 'dendritic computation, ' - 'dendritic modeling, ' - 'brain modeling, ' - 'neuron simulation' - ), - classifiers=[ - 'Natural Language :: English', - 'Operating System :: OS Independent', - 'Programming Language :: Python', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', - 'Programming Language :: Python :: 3.12', - 'Intended Audience :: Science/Research', - 'License :: OSI Approved :: Apache Software License', - 'Topic :: Scientific/Engineering :: Bio-Informatics', - 'Topic :: Scientific/Engineering :: Mathematics', - 'Topic :: Scientific/Engineering :: Artificial Intelligence', - 'Topic :: Software Development :: Libraries', - ], - license='Apache-2.0 license', + name='dendritex', + version=version, + description='Dendrite Modeling in Python', + long_description=README, + long_description_content_type="text/markdown", + author='Dendritex Developers', + author_email='chao.brain@qq.com', + packages=packages, + python_requires='>=3.9', + install_requires=['numpy>=1.15', 'jax', 'brainunit>=0.0.2.post20240903', 'brainstate>=0.1.0', 'diffrax'], + url='https://github.com/chaoming0625/dendritex', + project_urls={ + "Bug Tracker": "https://github.com/chaoming0625/dendritex/issues", + "Documentation": "https://dendrite.readthedocs.io/", + "Source Code": "https://github.com/chaoming0625/dendritex", + }, + extras_require={ + 'cpu': ['jaxlib'], + 'cuda12': ['jaxlib[cuda12]', ], + 'tpu': ['jaxlib[tpu]'], + }, + keywords=( + 'dendritic computation, ' + 'dendritic modeling, ' + 'brain modeling, ' + 'neuron simulation' + ), + classifiers=[ + 'Natural Language :: English', + 'Operating System :: OS Independent', + 'Programming Language :: Python', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: Apache Software License', + 'Topic :: Scientific/Engineering :: Bio-Informatics', + 'Topic :: Scientific/Engineering :: Mathematics', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + 'Topic :: Software Development :: Libraries', + ], + license='Apache-2.0 license', )