Skip to content

Commit

Permalink
add a neuron fitting example
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Sep 8, 2024
1 parent d44bd67 commit 06ed3cb
Show file tree
Hide file tree
Showing 6 changed files with 292 additions and 18 deletions.
8 changes: 4 additions & 4 deletions dendritex/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,19 +310,19 @@ def current(self, *args, **kwargs):
raise NotImplementedError

def before_integral(self, *args, **kwargs):
raise NotImplementedError
pass

def compute_derivative(self, *args, **kwargs):
raise NotImplementedError

def after_integral(self, *args, **kwargs):
raise NotImplementedError
pass

def reset_state(self, *args, **kwargs):
raise NotImplementedError
pass

def init_state(self, *args, **kwargs):
raise NotImplementedError
pass


class IonInfo(NamedTuple):
Expand Down
16 changes: 6 additions & 10 deletions dendritex/neurons/single_compartment.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing import Union, Optional, Callable

import brainstate as bst
import brainunit as bu
import brainunit as u

from .._base import HHTypedNeuron, IonChannel, State4Integral

Expand Down Expand Up @@ -70,10 +70,9 @@ class SingleCompartment(HHTypedNeuron):
def __init__(
self,
size: bst.typing.Size,
C: Union[bst.typing.ArrayLike, Callable] = 1. * bu.uF / bu.cm ** 2,
A: Union[bst.typing.ArrayLike, Callable] = 1e-3 * bu.cm ** 2,
V_th: Union[bst.typing.ArrayLike, Callable] = 0. * bu.mV,
V_initializer: Union[bst.typing.ArrayLike, Callable] = bst.init.Uniform(-70 * bu.mV, -60. * bu.mV),
C: Union[bst.typing.ArrayLike, Callable] = 1. * u.uF / u.cm ** 2,
V_th: Union[bst.typing.ArrayLike, Callable] = 0. * u.mV,
V_initializer: Union[bst.typing.ArrayLike, Callable] = bst.init.Uniform(-70 * u.mV, -60. * u.mV),
spk_fun: Callable = bst.surrogate.ReluGrad(),
name: Optional[str] = None,
mode: Optional[bst.mixin.Mode] = None,
Expand All @@ -85,7 +84,6 @@ def __init__(
assert self.n_compartment == 1, (f'Point-based neuron only supports single compartment. '
f'But got {self.n_compartment} compartments.')
self.C = C
self.A = A
self.V_th = V_th
self._V_initializer = V_initializer
self.spk_fun = spk_fun
Expand All @@ -103,11 +101,9 @@ def before_integral(self, *args):
for node in channels.values():
node.before_integral(self.V.value)

def compute_derivative(self, x=0.):
def compute_derivative(self, x=0. * u.nA / u.cm ** 2):
# [ Compute the derivative of membrane potential ]
# 1. inputs
x = x * (1e-3 / self.A)
# 2. synapses
# 1. inputs + 2. synapses
x = self.sum_current_inputs(self.V.value, init=x)
# 3. channels
for ch in self.nodes(level=1, include_self=False).subset(IonChannel).values():
Expand Down
254 changes: 254 additions & 0 deletions examples/fitting_a_hh_neuron/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================


from typing import Union, Callable

import brainstate as bst
import braintools as bts
import brainunit as u
import dendritex as dx
import jax
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

bst.environ.set(dt=0.01 * u.ms)

# Load Input and Output Data
df_inp_traces = pd.read_csv('neuron_data/input_traces_hh.csv')
df_out_traces = pd.read_csv('neuron_data/output_traces_hh.csv')

area = 20000 * u.um ** 2 # neuronal area
inp_traces = df_inp_traces.to_numpy()[:, 1:] * 1e9 * u.nA # input currents
mem_traces = df_out_traces.to_numpy()[:, 1:] * u.mV # membrane potentials to record


class INa(dx.Channel):
root_type = dx.HHTypedNeuron

def __init__(
self,
size: bst.typing.Size,
ENa: Union[bst.typing.ArrayLike, Callable] = 50. * u.mV,
gNa: Union[bst.typing.ArrayLike, Callable] = 120. * u.mS,
vth: Union[bst.typing.ArrayLike, Callable] = -63 * u.mV,
):
super().__init__(size)
self.ENa = bst.init.param(ENa, self.varshape)
self.gNa = bst.init.param(gNa, self.varshape)
self.V_th = bst.init.param(vth, self.varshape)

def init_state(self, V, batch_size=None):
self.m = dx.State4Integral(bst.init.param(u.math.zeros, self.varshape))
self.h = dx.State4Integral(bst.init.param(u.math.zeros, self.varshape))

# m channel
m_alpha = lambda self, V: 0.32 * 4 / u.math.exprel((13. * u.mV - V + self.V_th).to_decimal(u.mV) / 4.)
m_beta = lambda self, V: 0.28 * 5 / u.math.exprel((V - self.V_th - 40. * u.mV).to_decimal(u.mV) / 5.)
m_inf = lambda self, V: self.m_alpha(V) / (self.m_alpha(V) + self.m_beta(V))

# h channel
h_alpha = lambda self, V: 0.128 * u.math.exprel((17. * u.mV - V + self.V_th).to_decimal(u.mV) / 18.)
h_beta = lambda self, V: 4. / (1 + u.math.exp((40. * u.mV - V + self.V_th).to_decimal(u.mV) / 5.))
h_inf = lambda self, V: self.h_alpha(V) / (self.h_alpha(V) + self.h_beta(V))

def compute_derivative(self, V, *args, **kwargs):
m = self.m.value
h = self.h.value
self.m.derivative = (self.m_alpha(V) * (1 - m) - self.m_beta(V) * m) / u.ms
self.h.derivative = (self.h_alpha(V) * (1 - h) - self.h_beta(V) * h) / u.ms

def current(self, V, *args, **kwargs):
m = self.m.value
h = self.h.value
return (self.gNa * m * m * m * h) * (self.ENa - V)


class IK(dx.Channel):
root_type = dx.HHTypedNeuron

def __init__(
self,
size: bst.typing.Size,
EK: Union[bst.typing.ArrayLike, Callable] = -90. * u.mV,
gK: Union[bst.typing.ArrayLike, Callable] = 36. * u.mS,
vth: Union[bst.typing.ArrayLike, Callable] = -63 * u.mV,
):
super().__init__(size)
self.EK = bst.init.param(EK, self.varshape)
self.gK = bst.init.param(gK, self.varshape)
self.V_th = bst.init.param(vth, self.varshape)

def init_state(self, V, batch_size=None):
self.n = dx.State4Integral(bst.init.param(u.math.zeros, self.varshape))

# n channel
n_alpha = lambda self, V: 0.032 * 5 / u.math.exprel((15. * u.mV - V + self.V_th).to_decimal(u.mV) / 5.)
n_beta = lambda self, V: .5 * u.math.exp((10. * u.mV - V + self.V_th).to_decimal(u.mV) / 40.)
n_inf = lambda self, V: self.n_alpha(V) / (self.n_alpha(V) + self.n_beta(V))

def compute_derivative(self, V, *args, **kwargs):
n = self.n.value
self.n.derivative = (self.n_alpha(V) * (1 - n) - self.n_beta(V) * n) / u.ms

def current(self, V, *args, **kwargs):
n2 = self.n.value ** 2
return (self.gK * n2 * n2) * (self.EK - V)


class HH(dx.neurons.SingleCompartment):
def __init__(
self,
size,
v_initializer: Callable = bst.init.Uniform(-70 * u.mV, -60. * u.mV),
gL: Union[bst.typing.ArrayLike, Callable] = 0.003 * u.mS,
gNa: Union[bst.typing.ArrayLike, Callable] = 120. * u.mS,
gK: Union[bst.typing.ArrayLike, Callable] = 36. * u.mS,
C: Union[bst.typing.ArrayLike, Callable] = 1. * (u.uF / u.cm ** 2)
):
super().__init__(size, V_initializer=v_initializer, C=C)
self.ina = INa(size, gNa=gNa)
self.ik = IK(size, gK=gK)
self.il = dx.channels.IL(size, g_max=gL, E=-65. * u.mV)


def visualize_target(voltages):
fig, gs = bts.visualize.get_figure(2, voltages.shape[1], 3, 4.5)
times = np.arange(voltages.shape[0]) * 0.01
for i in range(voltages.shape[1]):
ax = fig.add_subplot(gs[0, i])
ax.plot(times, voltages.mantissa[:, i], label='target')
plt.xlabel('Time [ms]')
plt.legend()
ax = plt.subplot(gs[1, i])
ax.plot(times, inp_traces[i].mantissa)
plt.xlabel('Time [ms]')
plt.show()


def visualize(voltages, gl, g_na, g_kd, C):
# currents: [T, B]
# voltages: [T, B]
simulated_vs = simulate_model(gl, g_na, g_kd, C)
voltages = voltages.mantissa
simulated_vs = simulated_vs.mantissa

fig, gs = bts.visualize.get_figure(2, simulated_vs.shape[1], 3, 4.5)
for i in range(simulated_vs.shape[1]):
ax = fig.add_subplot(gs[0, i])
ax.plot(voltages[:, i], label='target')
ax.plot(simulated_vs[:, i], label='simulated')
plt.legend()
ax = plt.subplot(gs[1, i])
ax.plot(inp_traces[i].mantissa)
plt.show()


def simulate_model(gl, g_na, g_kd, C):
current = inp_traces.T
assert current.ndim == 2 # [T, B]
n_input = current.shape[1]
hh = HH((n_input, 1), gL=gl, gNa=g_na, gK=g_kd, C=C, v_initializer=bst.init.Constant(-65. * u.mV), )
hh.init_state()

def step_fun(i, inp):
with bst.environ.context(i=i, t=bst.environ.get_dt() * i):
dx.rk4_step(hh, bst.environ.get('t'), inp)
return hh.V.value

indices = np.arange(current.shape[0])
current = u.math.expand_dims(current, axis=-1) # [T, B, 1]
return bst.transform.for_loop(step_fun, indices, current) # (T, B)


@bst.transform.jit
def compare_potentials(param):
vs = simulate_model(param['gl'], param['g_na'], param['g_kd'], param['C']) # (T, B)
losses = bts.metric.squared_error(vs.mantissa, target_vs.mantissa)
return losses.mean()


# inp_traces: [B, T]
# target_vs: [B, T, 1]
# target_vs = simulate_model(50.37195496 * u.nsiemens,
# 28.01317016 * u.usiemens,
# 8.15968937 * u.usiemens,
# 195.2173985 * u.pfarad)
target_vs = u.math.expand_dims(mem_traces.T, axis=-1) # [T, B, 1]


# visualize_target(target_vs)


def visualize_hh_input_and_output():
# Load Input and Output Data
inp_traces = df_inp_traces.to_numpy()
inp_traces = inp_traces[:, 1:] * 1e9

out_traces = df_out_traces.to_numpy()
out_traces = out_traces[:, 1:]

indices = np.arange(inp_traces.shape[1]) * 0.01

fig, gs = bts.visualize.get_figure(3, 1, 1.2, 6.0)
ax = fig.add_subplot(gs[0, 0])
ax.plot(indices, inp_traces.T)
plt.xticks([])
plt.ylabel('Current [nA]', fontsize=13)

ax2 = fig.add_subplot(gs[1:, 0])
ax2.plot(indices, out_traces.T)
plt.ylabel('Potential [mV]', fontsize=13)
plt.xlabel('Time [ms]')

fig.align_ylabels([ax, ax2])
plt.show()


bounds = {
'gl': [1e0, 1e2] * u.nS,
'g_na': [1e1, 2e2] * u.uS,
'g_kd': [1e1, 1e2] * u.uS,
'C': [0.1, 2] * u.uF * u.cm ** -2 * area,
}


def fitting_by_others(method='DE', n_sample=200):
print(f"Method: {method}, n_sample: {n_sample}")

@jax.jit
@jax.vmap
@jax.jit
def loss_with_multiple_run(**params):
return compare_potentials(params)

opt = bts.optim.NevergradOptimizer(
loss_with_multiple_run,
n_sample=n_sample,
bounds=bounds,
method=method,
)
opt.initialize()
param = opt.minimize(10)
loss = compare_potentials(param)
print(param)
print(loss)
visualize(target_vs, **param)
return param, loss


if __name__ == '__main__':
fitting_by_others(n_sample=100)
Loading

0 comments on commit 06ed3cb

Please sign in to comment.