Skip to content

Neurons

btorch.models.neurons

Attributes

__all__ = ['LIF', 'ALIF', 'ELIF', 'GLIF3', 'Izhikevich'] module-attribute

Classes

ALIF

Bases: BaseNode

Adaptive leaky integrate-and-fire neuron with conductance adaptation.

The ALIF model extends standard LIF by adding a voltage-dependent potassium conductance (g_k) that creates spike-frequency adaptation. Each spike increases g_k by dg_k, which then decays exponentially.

Dynamics

dv/dt = (-g_leak * (v - E_leak) - g_k * (v - E_k) + x) / c_m dg_k/dt = -g_k / tau_adapt

At spike: g_k += dg_k

Parameters:

Name Type Description Default
n_neuron int | Sequence[int]

Number of neurons (int or tuple of dimensions).

required
v_threshold float | Float[TensorLike, ' n_neuron']

Firing threshold (mV). Default: 1.0.

1.0
v_reset float | Float[TensorLike, ' n_neuron']

Reset voltage after spike (mV). Default: 0.0.

0.0
c_m float | Float[TensorLike, ' n_neuron']

Membrane capacitance (pF). Default: 1.0.

1.0
g_leak float | Float[TensorLike, ' n_neuron']

Leak conductance (nS). Default: 1.0.

1.0
E_leak float | Float[TensorLike, ' n_neuron']

Leak reversal potential (mV). Default: 0.0.

0.0
E_k float | Float[TensorLike, ' n_neuron']

Potassium reversal potential (mV). Default: -70.0.

-70.0
g_k_init float | Float[TensorLike, ' n_neuron']

Initial adaptation conductance (nS). Default: 0.0.

0.0
tau_adapt float | Float[TensorLike, ' n_neuron']

Adaptation time constant (ms). Default: 20.0.

20.0
dg_k float | Float[TensorLike, ' n_neuron']

Adaptation increment per spike (nS). Default: 0.0.

0.0
tau_ref float | Float[TensorLike, ' n_neuron'] | None

Refractory period (ms). None disables refractory. Default: None.

None
trainable_param set[str]

Set of parameter names to make trainable.

set()
surrogate_function Callable

Surrogate gradient function. Default: Sigmoid().

Sigmoid()
detach_reset bool

If True, detach reset signal. Default: False.

False
hard_reset bool

If True, use hard reset. Default: False.

False
pre_spike_v bool

If True, store pre-spike voltage. Default: False.

False
step_mode Literal['s']

Step mode. Default: "s".

's'
backend Literal['torch']

Backend implementation. Default: "torch".

'torch'
device

Device for tensors. Default: None.

None
dtype

Data type for tensors. Default: None.

None

Attributes:

Name Type Description
v Tensor

Membrane potential, shape (*batch, n_neuron).

g_k Tensor

Adaptation conductance, shape (*batch, n_neuron).

refractory Tensor | None

Refractory counter (if tau_ref specified).

c_m, (g_leak, E_leak, E_k)

Neuron parameters.

tau_adapt Tensor | Parameter

Adaptation time constant.

dg_k Tensor | Parameter

Per-spike adaptation increment.

Source code in btorch/models/neurons/alif.py
class ALIF(BaseNode):
    """Adaptive leaky integrate-and-fire neuron with conductance adaptation.

    The ALIF model extends standard LIF by adding a voltage-dependent
    potassium conductance (g_k) that creates spike-frequency adaptation.
    Each spike increases g_k by dg_k, which then decays exponentially.

    Dynamics:
        dv/dt = (-g_leak * (v - E_leak) - g_k * (v - E_k) + x) / c_m
        dg_k/dt = -g_k / tau_adapt

        At spike: g_k += dg_k

    Args:
        n_neuron: Number of neurons (int or tuple of dimensions).
        v_threshold: Firing threshold (mV). Default: 1.0.
        v_reset: Reset voltage after spike (mV). Default: 0.0.
        c_m: Membrane capacitance (pF). Default: 1.0.
        g_leak: Leak conductance (nS). Default: 1.0.
        E_leak: Leak reversal potential (mV). Default: 0.0.
        E_k: Potassium reversal potential (mV). Default: -70.0.
        g_k_init: Initial adaptation conductance (nS). Default: 0.0.
        tau_adapt: Adaptation time constant (ms). Default: 20.0.
        dg_k: Adaptation increment per spike (nS). Default: 0.0.
        tau_ref: Refractory period (ms). None disables refractory.
            Default: None.
        trainable_param: Set of parameter names to make trainable.
        surrogate_function: Surrogate gradient function. Default: Sigmoid().
        detach_reset: If True, detach reset signal. Default: False.
        hard_reset: If True, use hard reset. Default: False.
        pre_spike_v: If True, store pre-spike voltage. Default: False.
        step_mode: Step mode. Default: "s".
        backend: Backend implementation. Default: "torch".
        device: Device for tensors. Default: None.
        dtype: Data type for tensors. Default: None.

    Attributes:
        v: Membrane potential, shape (*batch, n_neuron).
        g_k: Adaptation conductance, shape (*batch, n_neuron).
        refractory: Refractory counter (if tau_ref specified).
        c_m, g_leak, E_leak, E_k: Neuron parameters.
        tau_adapt: Adaptation time constant.
        dg_k: Per-spike adaptation increment.
    """

    g_k: torch.Tensor
    refractory: torch.Tensor | None

    c_m: torch.Tensor | torch.nn.Parameter
    g_leak: torch.Tensor | torch.nn.Parameter
    E_leak: torch.Tensor | torch.nn.Parameter
    E_k: torch.Tensor | torch.nn.Parameter
    tau_adapt: torch.Tensor | torch.nn.Parameter
    dg_k: torch.Tensor | torch.nn.Parameter
    tau_ref: torch.Tensor | torch.nn.Parameter | None

    def __init__(
        self,
        n_neuron: int | Sequence[int],
        v_threshold: float | Float[TensorLike, " n_neuron"] = 1.0,
        v_reset: float | Float[TensorLike, " n_neuron"] = 0.0,
        c_m: float | Float[TensorLike, " n_neuron"] = 1.0,
        g_leak: float | Float[TensorLike, " n_neuron"] = 1.0,
        E_leak: float | Float[TensorLike, " n_neuron"] = 0.0,
        E_k: float | Float[TensorLike, " n_neuron"] = -70.0,
        g_k_init: float | Float[TensorLike, " n_neuron"] = 0.0,
        tau_adapt: float | Float[TensorLike, " n_neuron"] = 20.0,
        dg_k: float | Float[TensorLike, " n_neuron"] = 0.0,
        tau_ref: float | Float[TensorLike, " n_neuron"] | None = None,
        trainable_param: set[str] = set(),
        surrogate_function: Callable = Sigmoid(),
        detach_reset: bool = False,
        hard_reset: bool = False,
        pre_spike_v: bool = False,
        step_mode: Literal["s"] = "s",
        backend: Literal["torch"] = "torch",
        device=None,
        dtype=None,
    ):
        super().__init__(
            n_neuron=n_neuron,
            v_threshold=v_threshold,
            v_reset=v_reset,
            trainable_param=trainable_param,
            surrogate_function=surrogate_function,
            detach_reset=detach_reset,
            hard_reset=hard_reset,
            pre_spike_v=pre_spike_v,
            step_mode=step_mode,
            backend=backend,
            device=device,
            dtype=dtype,
        )
        _factory_kwargs: dict[str, Any] = {"device": device, "dtype": dtype}
        self.def_param(
            "c_m",
            c_m,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )
        self.def_param(
            "g_leak",
            g_leak,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )
        self.def_param(
            "E_leak",
            E_leak,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )
        self.def_param(
            "E_k",
            E_k,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )
        self.def_param(
            "tau_adapt",
            tau_adapt,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )
        self.def_param(
            "dg_k",
            dg_k,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )
        self._use_refractory = tau_ref is not None
        if self._use_refractory:
            self.def_param(
                "tau_ref",
                tau_ref,
                trainable_param=self.trainable_param,
                **_factory_kwargs,
            )
            self.register_memory("refractory", 0.0, self.n_neuron)
        else:
            self.tau_ref = None

        self.register_memory("g_k", g_k_init, self.n_neuron)

    @property
    def v_rest(self):
        if self._v_rest is None:
            return self.v_reset
        return self._v_rest

    @v_rest.setter
    def v_rest(self, v_rest):
        if self._v_rest is not None:
            self._v_rest = v_rest

    @property
    def v_peak(self):
        return self.v_threshold

    @v_peak.setter
    def v_peak(self, value):
        self.v_threshold = value

    def dV(
        self,
        v: Float[Tensor, "*batch n_neuron"],
        g_k: Float[Tensor, "*batch n_neuron"],
        x: Float[Tensor, "*batch n_neuron"],
    ):
        leak_term = -self.g_leak * (v - self.E_leak)
        adapt_term = -g_k * (v - self.E_k)
        derivative = (leak_term + adapt_term + x) / self.c_m
        linear = (-self.g_leak - g_k) / self.c_m
        return derivative, linear

    def dgk(
        self, g_k: Float[Tensor, "*batch n_neuron"]
    ) -> Float[Tensor, "*batch n_neuron"]:
        derivative = -g_k / self.tau_adapt
        linear = -1.0 / self.tau_adapt
        return derivative, linear

    def neuronal_charge(self, x: Float[Tensor, "*batch n_neuron"]):
        dt = environ.get("dt")
        self.v = exp_euler_step(self.dV, self.v, self.g_k, x, dt=dt)

    def neuronal_adaptation(self):
        dt = environ.get("dt")
        self.g_k = exp_euler_step(self.dgk, self.g_k, dt=dt)

    def neuronal_fire(self):
        spike = self.surrogate_function(
            (self.v - self.v_threshold) / (self.v_threshold - self.v_reset)
        )
        if not self._use_refractory:
            return spike
        not_in_refractory = self.refractory == 0
        spike = spike * not_in_refractory.detach().to(self.v.dtype)
        return spike

    def neuronal_reset(self, spike: Float[Tensor, "*batch n"]):
        if self.detach_reset:
            spike_d = spike.detach()
        else:
            spike_d = spike

        if self.pre_spike_v:
            self.v_pre_spike = self.v.clone()

        if self.hard_reset:
            self.v = self.v - (self.v - self.v_reset) * spike_d
        else:
            self.v = self.v - (self.v_threshold - self.v_reset) * spike_d

        self.g_k = self.g_k + self.dg_k * spike_d

        if self._use_refractory:
            self.refractory = torch.relu(
                self.refractory + spike_d * self.tau_ref - environ.get("dt")
            )

    def extra_repr(self):
        g_k_init = self._memories_rv["g_k"].value
        parts = [
            f"c_m={self._format_repr_value(self.c_m)}",
            f"g_leak={self._format_repr_value(self.g_leak)}",
            f"E_leak={self._format_repr_value(self.E_leak)}",
            f"E_k={self._format_repr_value(self.E_k)}",
            f"tau_adapt={self._format_repr_value(self.tau_adapt)}",
            f"dg_k={self._format_repr_value(self.dg_k)}",
            f"g_k_init={self._format_repr_value(g_k_init)}",
            f"tau_ref={self._format_repr_value(self.tau_ref)}"
            if self._use_refractory
            else "tau_ref=None",
        ]
        base = super().extra_repr()
        if base:
            parts.append(base)
        return ", ".join(parts)

ELIF

Bases: ALIF

Exponential integrate-and-fire neuron with adaptation.

The ELIF model extends ALIF by adding an exponential term to the voltage dynamics, creating a sharp upswing when approaching threshold (the "initiation zone"). This captures the rapid depolarization seen in real neurons.

Dynamics

dv/dt = (g_leak * delta_T * exp((v - v_T) / delta_T) - g_leak * (v - E_leak) - g_k * (v - E_k) + x) / c_m dg_k/dt = -g_k / tau_adapt

The exponential term creates a soft threshold effect where membrane potential accelerates as it approaches v_T.

Parameters:

Name Type Description Default
n_neuron int | Sequence[int]

Number of neurons (int or tuple of dimensions).

required
v_threshold float | Float[TensorLike, ' n_neuron']

Firing threshold (mV). Default: 1.0.

1.0
v_reset float | Float[TensorLike, ' n_neuron']

Reset voltage after spike (mV). Default: 0.0.

0.0
c_m float | Float[TensorLike, ' n_neuron']

Membrane capacitance (pF). Default: 1.0.

1.0
g_leak float | Float[TensorLike, ' n_neuron']

Leak conductance (nS). Default: 1.0.

1.0
E_leak float | Float[TensorLike, ' n_neuron']

Leak reversal potential (mV). Default: 0.0.

0.0
E_k float | Float[TensorLike, ' n_neuron']

Potassium reversal potential (mV). Default: -70.0.

-70.0
g_k_init float | Float[TensorLike, ' n_neuron']

Initial adaptation conductance (nS). Default: 0.0.

0.0
tau_adapt float | Float[TensorLike, ' n_neuron']

Adaptation time constant (ms). Default: 20.0.

20.0
dg_k float | Float[TensorLike, ' n_neuron']

Adaptation increment per spike (nS). Default: 0.0.

0.0
tau_ref float | Float[TensorLike, ' n_neuron'] | None

Refractory period (ms). Default: 0.0.

0.0
delta_T float | Float[TensorLike, ' n_neuron']

Slope factor for exponential term (mV). Default: 1.0.

1.0
v_T float | Float[TensorLike, ' n_neuron']

Soft threshold potential (mV). Default: 0.0.

0.0
trainable_param set[str]

Set of parameter names to make trainable.

set()
surrogate_function Callable

Surrogate gradient function. Default: Sigmoid().

Sigmoid()
detach_reset bool

If True, detach reset signal. Default: False.

False
hard_reset bool

If True, use hard reset. Default: False.

False
pre_spike_v bool

If True, store pre-spike voltage. Default: False.

False
step_mode Literal['s']

Step mode. Default: "s".

's'
backend Literal['torch']

Backend implementation. Default: "torch".

'torch'
device

Device for tensors. Default: None.

None
dtype

Data type for tensors. Default: None.

None

Attributes:

Name Type Description
delta_T Tensor | Parameter

Slope factor for exponential term.

v_T Tensor | Parameter

Soft threshold potential.

Source code in btorch/models/neurons/alif.py
class ELIF(ALIF):
    """Exponential integrate-and-fire neuron with adaptation.

    The ELIF model extends ALIF by adding an exponential term to the
    voltage dynamics, creating a sharp upswing when approaching threshold
    (the "initiation zone"). This captures the rapid depolarization seen
    in real neurons.

    Dynamics:
        dv/dt = (g_leak * delta_T * exp((v - v_T) / delta_T)
                 - g_leak * (v - E_leak) - g_k * (v - E_k) + x) / c_m
        dg_k/dt = -g_k / tau_adapt

    The exponential term creates a soft threshold effect where membrane
    potential accelerates as it approaches v_T.

    Args:
        n_neuron: Number of neurons (int or tuple of dimensions).
        v_threshold: Firing threshold (mV). Default: 1.0.
        v_reset: Reset voltage after spike (mV). Default: 0.0.
        c_m: Membrane capacitance (pF). Default: 1.0.
        g_leak: Leak conductance (nS). Default: 1.0.
        E_leak: Leak reversal potential (mV). Default: 0.0.
        E_k: Potassium reversal potential (mV). Default: -70.0.
        g_k_init: Initial adaptation conductance (nS). Default: 0.0.
        tau_adapt: Adaptation time constant (ms). Default: 20.0.
        dg_k: Adaptation increment per spike (nS). Default: 0.0.
        tau_ref: Refractory period (ms). Default: 0.0.
        delta_T: Slope factor for exponential term (mV). Default: 1.0.
        v_T: Soft threshold potential (mV). Default: 0.0.
        trainable_param: Set of parameter names to make trainable.
        surrogate_function: Surrogate gradient function. Default: Sigmoid().
        detach_reset: If True, detach reset signal. Default: False.
        hard_reset: If True, use hard reset. Default: False.
        pre_spike_v: If True, store pre-spike voltage. Default: False.
        step_mode: Step mode. Default: "s".
        backend: Backend implementation. Default: "torch".
        device: Device for tensors. Default: None.
        dtype: Data type for tensors. Default: None.

    Attributes:
        delta_T: Slope factor for exponential term.
        v_T: Soft threshold potential.
    """

    delta_T: torch.Tensor | torch.nn.Parameter
    v_T: torch.Tensor | torch.nn.Parameter

    def __init__(
        self,
        n_neuron: int | Sequence[int],
        v_threshold: float | Float[TensorLike, " n_neuron"] = 1.0,
        v_reset: float | Float[TensorLike, " n_neuron"] = 0.0,
        c_m: float | Float[TensorLike, " n_neuron"] = 1.0,
        g_leak: float | Float[TensorLike, " n_neuron"] = 1.0,
        E_leak: float | Float[TensorLike, " n_neuron"] = 0.0,
        E_k: float | Float[TensorLike, " n_neuron"] = -70.0,
        g_k_init: float | Float[TensorLike, " n_neuron"] = 0.0,
        tau_adapt: float | Float[TensorLike, " n_neuron"] = 20.0,
        dg_k: float | Float[TensorLike, " n_neuron"] = 0.0,
        tau_ref: float | Float[TensorLike, " n_neuron"] | None = 0.0,
        delta_T: float | Float[TensorLike, " n_neuron"] = 1.0,
        v_T: float | Float[TensorLike, " n_neuron"] = 0.0,
        trainable_param: set[str] = set(),
        surrogate_function: Callable = Sigmoid(),
        detach_reset: bool = False,
        hard_reset: bool = False,
        pre_spike_v: bool = False,
        step_mode: Literal["s"] = "s",
        backend: Literal["torch"] = "torch",
        device=None,
        dtype=None,
    ):
        super().__init__(
            n_neuron=n_neuron,
            v_threshold=v_threshold,
            v_reset=v_reset,
            c_m=c_m,
            g_leak=g_leak,
            E_leak=E_leak,
            E_k=E_k,
            g_k_init=g_k_init,
            tau_adapt=tau_adapt,
            dg_k=dg_k,
            tau_ref=tau_ref,
            trainable_param=trainable_param,
            surrogate_function=surrogate_function,
            detach_reset=detach_reset,
            hard_reset=hard_reset,
            pre_spike_v=pre_spike_v,
            step_mode=step_mode,
            backend=backend,
            device=device,
            dtype=dtype,
        )
        _factory_kwargs: dict[str, Any] = {"device": device, "dtype": dtype}
        self.def_param(
            "delta_T",
            delta_T,
            sizes=self.n_neuron,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )
        self.def_param(
            "v_T",
            v_T,
            sizes=self.n_neuron,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )

    def dV(
        self,
        v: Float[Tensor, "*batch n_neuron"],
        g_k: Float[Tensor, "*batch n_neuron"],
        x: Float[Tensor, "*batch n_neuron"],
    ):
        leak_term = -self.g_leak * (v - self.E_leak)
        adapt_term = -g_k * (v - self.E_k)
        exp_term = self.g_leak * self.delta_T * torch.exp((v - self.v_T) / self.delta_T)
        derivative = (leak_term + adapt_term + exp_term + x) / self.c_m
        linear = (-self.g_leak - g_k + exp_term / self.delta_T) / self.c_m
        return derivative, linear

    def neuronal_charge(self, x: Float[Tensor, "*batch n_neuron"]):
        dt = environ.get("dt")
        self.v = exp_euler_step(self.dV, self.v, self.g_k, x, dt=dt)

    def extra_repr(self):
        parts = [
            f"delta_T={self._format_repr_value(self.delta_T)}",
            f"v_T={self._format_repr_value(self.v_T)}",
        ]
        base = super().extra_repr()
        if base:
            parts.append(base)
        return ", ".join(parts)

GLIF3

Bases: BaseNode

GLIF3 model with after-spike currents and refractory period.

The GLIF3 model extends standard LIF by adding after-spike currents (ASC) that capture spike-frequency adaptation. Each spike adds asc_amps to the ASC vector, which then decays exponentially with time constants 1/k.

Dynamics

dV/dt = -(V - V_rest) / tau + (I_in + sum(I_asc)) / c_m dI_asc/dt = -k * I_asc

At spike: I_asc += asc_amps

Parameters:

Name Type Description Default
n_neuron int | Sequence[int]

Number of neurons (int or tuple of dimensions).

required
v_threshold float | Float[TensorLike, ' n_neuron']

Firing threshold (mV). Default: -50.0.

-50.0
v_reset float | Float[TensorLike, ' n_neuron']

Reset voltage after spike (mV). Default: -70.0.

-70.0
v_rest None | float | Float[TensorLike, ' n_neuron']

Resting potential (mV). Defaults to v_reset if None.

None
c_m float | Float[TensorLike, ' n_neuron']

Membrane capacitance (pF). Default: 0.05.

0.05
tau float | Float[TensorLike, ' n_neuron']

Membrane time constant (ms). Default: 20.0.

20.0
k float | Sequence[float] | Float[TensorLike, 'n_neuron {self.n_Iasc}']

ASC decay rates (ms^-1), can be list for multiple ASC components. Default: [0.2].

[0.2]
asc_amps float | Sequence[float] | Float[TensorLike, 'n_neuron {self.n_Iasc}']

ASC amplitudes (pA) added at each spike. Default: [0.0].

[0.0]
tau_ref float | Float[TensorLike, ' n_neuron'] | None

Refractory period (ms). Default: 0.0.

0.0
trainable_param set[str]

Set of parameter names to make trainable.

set()
surrogate_function Callable

Surrogate gradient function. Default: ATan().

ATan()
detach_reset bool

If True, detach reset signal. Default: False.

False
hard_reset bool

If True, use hard reset. Default: False.

False
pre_spike_v bool

If True, store pre-spike voltage. Default: False.

False
step_mode Literal['s']

Step mode. Default: "s".

's'
backend Literal['torch']

Backend implementation. Default: "torch".

'torch'
device

Device for tensors. Default: None.

None
dtype

Data type for tensors. Default: None.

None

Attributes:

Name Type Description
v Tensor

Membrane potential, shape (*batch, n_neuron).

Iasc Tensor

After-spike currents, shape (*batch, n_neuron, n_Iasc).

refractory Tensor | None

Refractory counter (if tau_ref > 0).

c_m, (tau, tau_ref)

Neuron parameters.

k Tensor | Parameter

ASC decay rates, shape (n_neuron, n_Iasc) or (n_Iasc,).

asc_amps Tensor | Parameter

ASC amplitudes, shape (n_neuron, n_Iasc) or (n_Iasc,).

n_Iasc int

Number of ASC components.

References

Teeter et al., "Generalized leaky integrate-and-fire models classify multiple neuron types," Nature Communications, 2018.

Source code in btorch/models/neurons/glif.py
class GLIF3(BaseNode):
    """GLIF3 model with after-spike currents and refractory period.

    The GLIF3 model extends standard LIF by adding after-spike currents
    (ASC) that capture spike-frequency adaptation. Each spike adds
    asc_amps to the ASC vector, which then decays exponentially with
    time constants 1/k.

    Dynamics:
        dV/dt = -(V - V_rest) / tau + (I_in + sum(I_asc)) / c_m
        dI_asc/dt = -k * I_asc

        At spike: I_asc += asc_amps

    Args:
        n_neuron: Number of neurons (int or tuple of dimensions).
        v_threshold: Firing threshold (mV). Default: -50.0.
        v_reset: Reset voltage after spike (mV). Default: -70.0.
        v_rest: Resting potential (mV). Defaults to v_reset if None.
        c_m: Membrane capacitance (pF). Default: 0.05.
        tau: Membrane time constant (ms). Default: 20.0.
        k: ASC decay rates (ms^-1), can be list for multiple ASC components.
            Default: [0.2].
        asc_amps: ASC amplitudes (pA) added at each spike.
            Default: [0.0].
        tau_ref: Refractory period (ms). Default: 0.0.
        trainable_param: Set of parameter names to make trainable.
        surrogate_function: Surrogate gradient function. Default: ATan().
        detach_reset: If True, detach reset signal. Default: False.
        hard_reset: If True, use hard reset. Default: False.
        pre_spike_v: If True, store pre-spike voltage. Default: False.
        step_mode: Step mode. Default: "s".
        backend: Backend implementation. Default: "torch".
        device: Device for tensors. Default: None.
        dtype: Data type for tensors. Default: None.

    Attributes:
        v: Membrane potential, shape (*batch, n_neuron).
        Iasc: After-spike currents, shape (*batch, n_neuron, n_Iasc).
        refractory: Refractory counter (if tau_ref > 0).
        c_m, tau, tau_ref: Neuron parameters.
        k: ASC decay rates, shape (n_neuron, n_Iasc) or (n_Iasc,).
        asc_amps: ASC amplitudes, shape (n_neuron, n_Iasc) or (n_Iasc,).
        n_Iasc: Number of ASC components.

    References:
        Teeter et al., "Generalized leaky integrate-and-fire models
        classify multiple neuron types," Nature Communications, 2018.
    """

    # make mypy typing and autocompletion easier
    Iasc: torch.Tensor
    refractory: torch.Tensor | None

    c_m: torch.Tensor | torch.nn.Parameter
    tau: torch.Tensor | torch.nn.Parameter
    tau_ref: torch.Tensor | torch.nn.Parameter | None
    k: torch.Tensor | torch.nn.Parameter
    asc_amps: torch.Tensor | torch.nn.Parameter

    def __init__(
        self,
        n_neuron: int | Sequence[int],
        v_threshold: float | Float[TensorLike, " n_neuron"] = -50.0,  # mV
        v_reset: float | Float[TensorLike, " n_neuron"] = -70.0,  # mV
        v_rest: None | float | Float[TensorLike, " n_neuron"] = None,
        c_m: float | Float[TensorLike, " n_neuron"] = 0.05,  # 1/20 pfarad
        tau: float | Float[TensorLike, " n_neuron"] = 20.0,  # ms
        k: float | Sequence[float] | Float[TensorLike, "n_neuron {self.n_Iasc}"] = [
            0.2
        ],  # ms^-1
        asc_amps: float
        | Sequence[float]
        | Float[TensorLike, "n_neuron {self.n_Iasc}"] = [0.0],  # pA
        tau_ref: float | Float[TensorLike, " n_neuron"] | None = 0.0,  # ms
        trainable_param: set[str] = set(),
        surrogate_function: Callable = ATan(),
        detach_reset: bool = False,
        hard_reset: bool = False,
        pre_spike_v: bool = False,
        step_mode: Literal["s"] = "s",
        backend: Literal["torch"] = "torch",
        device=None,
        dtype=None,
    ):
        super().__init__(
            n_neuron=n_neuron,
            v_threshold=v_threshold,
            v_reset=v_reset,
            trainable_param=trainable_param,
            surrogate_function=surrogate_function,
            detach_reset=detach_reset,
            step_mode=step_mode,
            backend=backend,
            pre_spike_v=pre_spike_v,
        )
        _factory_kwargs: dict[str, Any] = {"device": device, "dtype": dtype}
        self.hard_reset = hard_reset
        self.def_param(
            "c_m",
            c_m,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )
        self.def_param(
            "tau",
            tau,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )
        self._use_refractory = tau_ref is not None
        if self._use_refractory:
            self.def_param(
                "tau_ref",
                tau_ref,
                trainable_param=self.trainable_param,
                **_factory_kwargs,
            )
            self.register_memory("refractory", 0.0, self.n_neuron)
        else:
            self.tau_ref = None

        # for compat
        if v_rest is not None:
            self.def_param(
                "_v_rest",
                v_rest,
                trainable_param=self.trainable_param,
                **_factory_kwargs,
            )
        else:
            self._v_rest = None

        # Handle after-spike currents.
        if isinstance(asc_amps, Number):
            asc_amps = [asc_amps]
        if isinstance(k, Number):
            k = [k]

        resolved_asc_sizes = self.def_param_resolve_sizes(
            k,
            asc_amps,
            sizes=self.n_neuron + (None,),
        )
        self.n_Iasc: int = resolved_asc_sizes[-1]

        self.def_param(
            "k",
            k,
            sizes=resolved_asc_sizes,
            trainable_param=self.trainable_param,
            normalize_to_sizes=True,
            **_factory_kwargs,
        )
        self.def_param(
            "asc_amps",
            asc_amps,
            sizes=resolved_asc_sizes,
            trainable_param=self.trainable_param,
            normalize_to_sizes=True,
            **_factory_kwargs,
        )

        self.register_memory(
            "Iasc",
            [
                0.0,
            ]
            * self.n_Iasc,
            self.n_neuron + (self.n_Iasc,),
        )

    @property
    def v_rest(self) -> torch.Tensor:
        """Resting potential (mV).

        For compatibility with GLIF4/GLIF5, falls back to v_reset if
        not explicitly set during initialization.

        Returns:
            Resting potential tensor.
        """
        if self._v_rest is None:
            return self.v_reset
        return self._v_rest

    @v_rest.setter
    def v_rest(self, v_rest: float | torch.Tensor):
        """Set resting potential.

        Args:
            v_rest: New resting potential value (mV).
        """
        if self._v_rest is not None:
            self._v_rest = v_rest

    def dIasc(self, Iasc: Float[Tensor, "*batch n_neuron {self.n_Iasc}"]):
        """Compute ASC derivative for exponential Euler integration.

        Args:
            Iasc: After-spike currents, shape (*batch, n_neuron, n_Iasc).

        Returns:
            Tuple of (derivative, linear_coefficient) for exp_euler_step.
        """
        return -self.k * Iasc, -self.k

    def dV(
        self,
        v: Float[Tensor, "*batch n_neuron"],
        Iasc: Float[Tensor, "*batch n_neuron {self.n_Iasc}"],
        x: Float[Tensor, "*batch n_neuron"],
    ):
        """Compute membrane potential derivative for exp Euler integration.

        Args:
            v: Membrane potential, shape (*batch, n_neuron).
            Iasc: After-spike currents, shape (*batch, n_neuron, n_Iasc).
            x: Input current, shape (*batch, n_neuron).

        Returns:
            Tuple of (derivative, linear_coefficient) for exp_euler_step.
        """
        Isum = x
        # torch.autocast will cast half to float32 for sum op
        # see https://docs.pytorch.org/docs/stable/amp.html#ops-that-can-autocast-to-float32
        # here Iasc generally only have <4 modes, so no overflow guaranteed
        return (
            -(v - self.v_rest) / self.tau
            + (Isum + Iasc.sum(-1, dtype=Iasc.dtype)) / self.c_m,
            -1.0 / self.tau,
        )

    def neuronal_charge(self, x: Float[Tensor, "*batch n_neuron"]):
        v = exp_euler_step(self.dV, self.v, self.Iasc, x, dt=environ.get("dt"))
        self.v = v

    def neuronal_adaptation(self):
        self.Iasc = exp_euler_step(self.dIasc, self.Iasc, dt=environ.get("dt"))

    def neuronal_fire(self):
        # Check if voltage exceeds threshold and not in refractory period
        spike = self.surrogate_function(
            (self.v - self.v_threshold) / (self.v_threshold - self.v_reset)
        )
        if not self._use_refractory:
            return spike
        not_in_refractory = self.refractory == 0
        spike = spike * not_in_refractory.detach().to(self.v.dtype)
        return spike

    def neuronal_reset(self, spike: Float[Tensor, "*batch n"]):
        if self.detach_reset:
            spike_d = spike.detach()
        else:
            spike_d = spike

        if self.pre_spike_v:
            self.v_pre_spike = self.v.clone()

        if self.hard_reset:
            # hard reset
            self.v = self.v - (self.v - self.v_reset) * spike_d
        else:
            # soft reset
            self.v = self.v - (self.v_threshold - self.v_reset) * spike_d

        # Add after-spike currents
        self.Iasc = self.Iasc + self.asc_amps * spike_d[..., None]

        if self._use_refractory:
            # Set refractory period
            self.refractory = torch.relu(
                self.refractory + spike_d * self.tau_ref - environ.get("dt")
            )

    def get_rheobase(self):
        """Calculate rheobase current, the minimum constant input current
        required to make the neuron fire."""
        return get_rheobase(self.v_threshold, self.v_rest, self.c_m, self.tau)

    def extra_repr(self):
        parts = [
            f"c_m={self._format_repr_value(self.c_m)}",
            f"tau={self._format_repr_value(self.tau)}",
            f"tau_ref={self._format_repr_value(self.tau_ref)}"
            if self._use_refractory
            else "tau_ref=None",
            f"n_Iasc={self.n_Iasc}",
            f"k={self._format_repr_value(self.k)}",
            f"asc_amps={self._format_repr_value(self.asc_amps)}",
            "v_rest=auto"
            if self._v_rest is None
            else f"v_rest={self._format_repr_value(self._v_rest)}",
        ]
        base = super().extra_repr()
        if base:
            parts.append(base)
        return ", ".join(parts)

    # TODO: headache to define precise input-output shapes
    # TODO: shape handling not torch.compile friendly
    def _normalize_state_shapes(
        self,
        x: TensorLike | float,
        v0: TensorLike | float,
        Iasc0: TensorLike | float,
        dt: TensorLike | float,
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        device = device or self.v_reset.device
        dtype = dtype or self.v_reset.dtype
        x, v0, Iasc0 = (
            torch.as_tensor(x, device=device, dtype=dtype),
            torch.as_tensor(v0, device=device, dtype=dtype),
            torch.as_tensor(Iasc0, device=device, dtype=dtype),
        )
        if isinstance(dt, float):
            dt = torch.tensor([dt], device=device, dtype=dtype)
        else:
            dt = torch.as_tensor(dt, device=device, dtype=dtype)

        shapes = (x.shape, v0.shape, Iasc0.shape[:-1])
        longest_shape = max(shapes, key=len)
        if dt.shape[0] != longest_shape[0]:
            dt = expand_trailing_dims(dt, longest_shape, broadcast_only=True)

        return x, v0, Iasc0, dt

    def forward_exact_no_spike(
        self,
        x: Float[Tensor, "*batch #neuron"] | Float[Tensor, "*batch"],
        v0: Float[Tensor, "*batch neuron"] | None = None,
        Iasc0: Float[Tensor, "*batch neuron {self.n_Iasc}"] | None = None,
        dt: float
        | Float[TensorLike, "#time *batch neuron"]
        | Float[TensorLike, "#*batch neuron"]
        | Float[TensorLike, "#time *batch"]
        | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        if dt is None:
            dt = environ.get("dt")

        update = (v0 is None) and (Iasc0 is None)
        if v0 is None:
            v0 = self.v
        if Iasc0 is None:
            Iasc0 = self.Iasc

        x, v0, Iasc0, dt = self._normalize_state_shapes(x, v0, Iasc0, dt)

        v_inf = self.v_reset + x * self.tau / self.c_m

        exp_m = torch.exp(-dt / self.tau)
        # (time, batch, neuron, n_Iasc)
        exp_asc = torch.exp(-dt[..., None] * self.k)

        Iasc = Iasc0 * exp_asc

        # degenerate case if tau=tau_asc=1/k
        Iasc_contrib = torch.where(
            torch.abs(self.k - 1 / self.tau[..., None]) > 1e-12,
            (Iasc0 / self.c_m[..., None])
            * (exp_asc - exp_m[..., None])
            / (1.0 / self.tau[..., None] - self.k),
            (Iasc0 / self.c_m[..., None]) * (dt * exp_m)[..., None],
        )
        v = v_inf + (v0 - v_inf) * exp_m + Iasc_contrib.sum(dim=-1)

        if update:
            self.v = v
            self.Iasc = Iasc
        return v, Iasc
Attributes
v_rest property writable

Resting potential (mV).

For compatibility with GLIF4/GLIF5, falls back to v_reset if not explicitly set during initialization.

Returns:

Type Description
Tensor

Resting potential tensor.

Functions
dIasc(Iasc)

Compute ASC derivative for exponential Euler integration.

Parameters:

Name Type Description Default
Iasc Float[Tensor, '*batch n_neuron {self.n_Iasc}']

After-spike currents, shape (*batch, n_neuron, n_Iasc).

required

Returns:

Type Description

Tuple of (derivative, linear_coefficient) for exp_euler_step.

Source code in btorch/models/neurons/glif.py
def dIasc(self, Iasc: Float[Tensor, "*batch n_neuron {self.n_Iasc}"]):
    """Compute ASC derivative for exponential Euler integration.

    Args:
        Iasc: After-spike currents, shape (*batch, n_neuron, n_Iasc).

    Returns:
        Tuple of (derivative, linear_coefficient) for exp_euler_step.
    """
    return -self.k * Iasc, -self.k
dV(v, Iasc, x)

Compute membrane potential derivative for exp Euler integration.

Parameters:

Name Type Description Default
v Float[Tensor, '*batch n_neuron']

Membrane potential, shape (*batch, n_neuron).

required
Iasc Float[Tensor, '*batch n_neuron {self.n_Iasc}']

After-spike currents, shape (*batch, n_neuron, n_Iasc).

required
x Float[Tensor, '*batch n_neuron']

Input current, shape (*batch, n_neuron).

required

Returns:

Type Description

Tuple of (derivative, linear_coefficient) for exp_euler_step.

Source code in btorch/models/neurons/glif.py
def dV(
    self,
    v: Float[Tensor, "*batch n_neuron"],
    Iasc: Float[Tensor, "*batch n_neuron {self.n_Iasc}"],
    x: Float[Tensor, "*batch n_neuron"],
):
    """Compute membrane potential derivative for exp Euler integration.

    Args:
        v: Membrane potential, shape (*batch, n_neuron).
        Iasc: After-spike currents, shape (*batch, n_neuron, n_Iasc).
        x: Input current, shape (*batch, n_neuron).

    Returns:
        Tuple of (derivative, linear_coefficient) for exp_euler_step.
    """
    Isum = x
    # torch.autocast will cast half to float32 for sum op
    # see https://docs.pytorch.org/docs/stable/amp.html#ops-that-can-autocast-to-float32
    # here Iasc generally only have <4 modes, so no overflow guaranteed
    return (
        -(v - self.v_rest) / self.tau
        + (Isum + Iasc.sum(-1, dtype=Iasc.dtype)) / self.c_m,
        -1.0 / self.tau,
    )
get_rheobase()

Calculate rheobase current, the minimum constant input current required to make the neuron fire.

Source code in btorch/models/neurons/glif.py
def get_rheobase(self):
    """Calculate rheobase current, the minimum constant input current
    required to make the neuron fire."""
    return get_rheobase(self.v_threshold, self.v_rest, self.c_m, self.tau)

Izhikevich

Bases: BaseNode

Izhikevich neuron with quadratic dynamics and recovery variable.

Efficient model reproducing diverse spiking patterns (tonic, bursting, etc.) via a 2D ODE system with quadratic nonlinearity.

Dynamics

dv/dt = (k(v-v_rest)(v-v_threshold) - u + I) / c_m du/dt = a * (b*(v-v_rest) - u)

At spike: v=v_reset, u=u+d

Parameters:

Name Type Description Default
n_neuron int | Sequence[int]

Number of neurons.

required
v_threshold float | Float[TensorLike, ' n_neuron']

Threshold (mV). Default: 30.0.

30.0
v_reset float | Float[TensorLike, ' n_neuron']

Reset voltage (mV). Default: -65.0.

-65.0
v_rest float | Float[TensorLike, ' n_neuron']

Resting potential (mV). Default: -65.0.

-65.0
v_peak float | Float[TensorLike, ' n_neuron']

Spike cutoff (mV). Default: -40.0.

-40.0
c_m float | Float[TensorLike, ' n_neuron']

Capacitance (pF). Default: 100.0.

100.0
k float | Float[TensorLike, ' n_neuron']

Scaling factor (nS/mV). Default: 0.7.

0.7
a float | Float[TensorLike, ' n_neuron']

Recovery timescale (ms^-1). Default: 0.03.

0.03
b float | Float[TensorLike, ' n_neuron']

Recovery coupling (nS). Default: -2.0.

-2.0
d float | Float[TensorLike, ' n_neuron']

Recovery jump (pA). Default: 100.0.

100.0
trainable_param set[str]

Trainable parameters. Default: ().

set()
surrogate_function Callable

Surrogate for backprop. Default: Sigmoid().

Sigmoid()
detach_reset bool

Detach reset signal. Default: False.

False
hard_reset bool

Hard vs soft reset. Default: False.

False
pre_spike bool

Store pre-spike values. Default: False.

False
step_mode Literal['s']

Step mode. Default: "s".

's'
backend Literal['torch']

Backend. Default: "torch".

'torch'
device

Device. Default: None.

None
dtype

Dtype. Default: None.

None

Attributes:

Name Type Description
v Tensor

Membrane potential (*batch, n_neuron).

u Tensor

Recovery variable (*batch, n_neuron).

Reference

Izhikevich, IEEE Trans. Neural Networks, 2003.

Source code in btorch/models/neurons/izhikevich.py
class Izhikevich(BaseNode):
    """Izhikevich neuron with quadratic dynamics and recovery variable.

    Efficient model reproducing diverse spiking patterns (tonic, bursting,
    etc.) via a 2D ODE system with quadratic nonlinearity.

    Dynamics:
        dv/dt = (k*(v-v_rest)*(v-v_threshold) - u + I) / c_m
        du/dt = a * (b*(v-v_rest) - u)

    At spike: v=v_reset, u=u+d

    Args:
        n_neuron: Number of neurons.
        v_threshold: Threshold (mV). Default: 30.0.
        v_reset: Reset voltage (mV). Default: -65.0.
        v_rest: Resting potential (mV). Default: -65.0.
        v_peak: Spike cutoff (mV). Default: -40.0.
        c_m: Capacitance (pF). Default: 100.0.
        k: Scaling factor (nS/mV). Default: 0.7.
        a: Recovery timescale (ms^-1). Default: 0.03.
        b: Recovery coupling (nS). Default: -2.0.
        d: Recovery jump (pA). Default: 100.0.
        trainable_param: Trainable parameters. Default: ().
        surrogate_function: Surrogate for backprop. Default: Sigmoid().
        detach_reset: Detach reset signal. Default: False.
        hard_reset: Hard vs soft reset. Default: False.
        pre_spike: Store pre-spike values. Default: False.
        step_mode: Step mode. Default: "s".
        backend: Backend. Default: "torch".
        device: Device. Default: None.
        dtype: Dtype. Default: None.

    Attributes:
        v: Membrane potential (*batch, n_neuron).
        u: Recovery variable (*batch, n_neuron).

    Reference:
        Izhikevich, IEEE Trans. Neural Networks, 2003.
    """

    HIPPOCAMPOME_TO_ARGS = {
        "k": "k",
        "a": "a",
        "b": "b",
        "d": "d",
        "C": "c_m",
        "vr": "v_rest",
        "vt": "v_threshold",
        "vpeak": "v_peak",
        "vmin": "v_reset",
    }

    u: torch.Tensor
    u_pre_spike: torch.Tensor

    v_reset: torch.Tensor | torch.nn.Parameter
    v_rest: torch.Tensor | torch.nn.Parameter
    v_peak: torch.Tensor | torch.nn.Parameter
    c_m: torch.Tensor | torch.nn.Parameter
    k: torch.Tensor | torch.nn.Parameter
    a: torch.Tensor | torch.nn.Parameter
    b: torch.Tensor | torch.nn.Parameter
    d: torch.Tensor | torch.nn.Parameter

    def __init__(
        self,
        n_neuron: int | Sequence[int],
        v_threshold: float | Float[TensorLike, " n_neuron"] = 30.0,
        v_reset: float | Float[TensorLike, " n_neuron"] = -65.0,
        v_rest: float | Float[TensorLike, " n_neuron"] = -65.0,
        v_peak: float | Float[TensorLike, " n_neuron"] = -40.0,
        c_m: float | Float[TensorLike, " n_neuron"] = 100.0,
        k: float | Float[TensorLike, " n_neuron"] = 0.7,
        a: float | Float[TensorLike, " n_neuron"] = 0.03,
        b: float | Float[TensorLike, " n_neuron"] = -2.0,
        d: float | Float[TensorLike, " n_neuron"] = 100.0,
        trainable_param: set[str] = set(),
        surrogate_function: Callable = Sigmoid(),
        detach_reset: bool = False,
        hard_reset: bool = False,
        pre_spike: bool = False,
        step_mode: Literal["s"] = "s",
        backend: Literal["torch"] = "torch",
        device=None,
        dtype=None,
    ):
        super().__init__(
            n_neuron=n_neuron,
            v_threshold=v_threshold,
            v_reset=v_reset,
            trainable_param=trainable_param,
            surrogate_function=surrogate_function,
            detach_reset=detach_reset,
            hard_reset=hard_reset,
            pre_spike_v=pre_spike,
            step_mode=step_mode,
            backend=backend,
            device=device,
            dtype=dtype,
        )
        _factory_kwargs: dict[str, Any] = {"device": device, "dtype": dtype}
        self.def_param(
            "c_m",
            c_m,
            sizes=self.n_neuron,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )
        self.def_param(
            "v_rest",
            v_rest,
            sizes=self.n_neuron,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )
        self.def_param(
            "v_peak",
            v_peak,
            sizes=self.n_neuron,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )
        self.def_param(
            "k",
            k,
            sizes=self.n_neuron,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )
        self.def_param(
            "a",
            a,
            sizes=self.n_neuron,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )
        self.def_param(
            "b",
            b,
            sizes=self.n_neuron,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )
        self.def_param(
            "d",
            d,
            sizes=self.n_neuron,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )

        self.register_memory("u", 0, self.n_neuron)
        if pre_spike:
            self.register_memory("u_pre_spike", None, self.n_neuron)

    @classmethod
    def from_hippocampome(
        cls,
        n_neuron: int | Sequence[int],
        k,
        a,
        b,
        d,
        C,
        vr,
        vt,
        vpeak,
        vmin,
        **kwargs,
    ):
        """
        Build an :class:`Izhikevich` neuron using parameter names from
        https://hippocampome.org.

        Parameter mapping (HippoCampome -> Izhikevich args):
        - k -> k (scaling factor)
        - a -> a (recovery time constant)
        - b -> b (recovery sensitivity)
        - d -> d (reset current)
        - C -> c_m (capacitance)
        - vr -> v_rest (resting potential)
        - vt -> v_threshold (instantaneous threshold)
        - vpeak -> v_peak (spike cutoff)
        - vmin -> v_reset (post-spike reset voltage)

        All values are expected in the same units as the canonical
        Izhikevich model (mV, pF, pA).
        """
        kwargs.setdefault("pre_spike", True)
        return cls(
            n_neuron,
            v_threshold=vt,
            v_reset=vmin,
            v_rest=vr,
            v_peak=vpeak,
            c_m=C,
            k=k,
            a=a,
            b=b,
            d=d,
            **kwargs,
        )

    @classmethod
    def from_canonical_quadratic(
        cls,
        n_neuron: int | Sequence[int],
        p1: float = 0.04,
        p2: float = 5.0,
        # TODO: p3: float = 0.0, adjust equation
        v_rest: float = -65.0,
        c_m: float = 1.0,
        v_peak: float = 30.0,
        **kwargs,
    ):
        """
        Instantiate using the canonical quadratic form
        ``dV/dt = p1*v^2 + p2*v + p3 - u + I``.

        The mapping assumes ``c_m`` acts as the membrane capacitance and that
        ``k/c_m`` equals ``p1``. The linear term enforces
        ``v_threshold = -p2/p1 - v_rest``. Remaining
        keyword arguments are passed directly to :class:`Izhikevich`.
        """
        k = p1 * c_m
        v_threshold = -p2 / p1 - v_rest
        # i_bias = p3 - p1 * v_rest * v_threshold

        return cls(
            n_neuron,
            v_threshold=v_threshold,
            v_reset=kwargs.pop("v_reset", v_rest),
            v_rest=v_rest,
            v_peak=v_peak,
            c_m=c_m,
            k=k,
            **kwargs,
        )

    def dV(
        self,
        v: Float[Tensor, "*batch n_neuron"],
        u: Float[Tensor, "*batch n_neuron"],
        x: Float[Tensor, "*batch n_neuron"],
    ):
        quadratic = self.k * (v - self.v_rest) * (v - self.v_threshold)
        return (x + quadratic - u) / self.c_m

    def dU(
        self,
        u: Float[Tensor, "*batch n_neuron"],
        v: Float[Tensor, "*batch n_neuron"],
    ):
        return self.a * (self.b * (v - self.v_rest) - u)

    def neuronal_charge(self, x: Float[Tensor, "*batch n_neuron"]):
        dt = environ.get("dt")
        self.v = euler_step(self.dV, self.v, self.u, x, dt=dt)

    def neuronal_adaptation(self):
        dt = environ.get("dt")
        self.u = euler_step(self.dU, self.u, self.v, dt=dt)

    def neuronal_fire(self):
        # TODO: confirm scaling with (self.v_threshold - self.v_reset)
        # or (self.v_peak - self.v_reset)
        spike = self.surrogate_function(
            (self.v - self.v_peak) / (self.v_threshold - self.v_reset)
        )
        return spike

    def neuronal_reset(self, spike: Float[Tensor, "*batch n"]):
        if self.detach_reset:
            spike_d = spike.detach()
        else:
            spike_d = spike

        if self.pre_spike_v:
            self.v_pre_spike = self.v.clone()
            self.u_pre_spike = self.u.clone()

        if self.hard_reset:
            self.v = self.v - (self.v - self.v_reset) * spike_d
        else:
            self.v = self.v - (self.v_peak - self.v_reset) * spike_d

        self.u = self.u + self.d * spike_d

    def extra_repr(self):
        parts = [
            f"c_m={self._format_repr_value(self.c_m)}",
            f"k={self._format_repr_value(self.k)}",
            f"a={self._format_repr_value(self.a)}",
            f"b={self._format_repr_value(self.b)}",
            f"d={self._format_repr_value(self.d)}",
            f"v_rest={self._format_repr_value(self.v_rest)}",
            f"v_peak={self._format_repr_value(self.v_peak)}",
        ]
        base = super().extra_repr()
        if base:
            parts.append(base)
        return ", ".join(parts)
Functions
from_canonical_quadratic(n_neuron, p1=0.04, p2=5.0, v_rest=-65.0, c_m=1.0, v_peak=30.0, **kwargs) classmethod

Instantiate using the canonical quadratic form dV/dt = p1*v^2 + p2*v + p3 - u + I.

The mapping assumes c_m acts as the membrane capacitance and that k/c_m equals p1. The linear term enforces v_threshold = -p2/p1 - v_rest. Remaining keyword arguments are passed directly to :class:Izhikevich.

Source code in btorch/models/neurons/izhikevich.py
@classmethod
def from_canonical_quadratic(
    cls,
    n_neuron: int | Sequence[int],
    p1: float = 0.04,
    p2: float = 5.0,
    # TODO: p3: float = 0.0, adjust equation
    v_rest: float = -65.0,
    c_m: float = 1.0,
    v_peak: float = 30.0,
    **kwargs,
):
    """
    Instantiate using the canonical quadratic form
    ``dV/dt = p1*v^2 + p2*v + p3 - u + I``.

    The mapping assumes ``c_m`` acts as the membrane capacitance and that
    ``k/c_m`` equals ``p1``. The linear term enforces
    ``v_threshold = -p2/p1 - v_rest``. Remaining
    keyword arguments are passed directly to :class:`Izhikevich`.
    """
    k = p1 * c_m
    v_threshold = -p2 / p1 - v_rest
    # i_bias = p3 - p1 * v_rest * v_threshold

    return cls(
        n_neuron,
        v_threshold=v_threshold,
        v_reset=kwargs.pop("v_reset", v_rest),
        v_rest=v_rest,
        v_peak=v_peak,
        c_m=c_m,
        k=k,
        **kwargs,
    )
from_hippocampome(n_neuron, k, a, b, d, C, vr, vt, vpeak, vmin, **kwargs) classmethod

Build an :class:Izhikevich neuron using parameter names from https://hippocampome.org.

Parameter mapping (HippoCampome -> Izhikevich args): - k -> k (scaling factor) - a -> a (recovery time constant) - b -> b (recovery sensitivity) - d -> d (reset current) - C -> c_m (capacitance) - vr -> v_rest (resting potential) - vt -> v_threshold (instantaneous threshold) - vpeak -> v_peak (spike cutoff) - vmin -> v_reset (post-spike reset voltage)

All values are expected in the same units as the canonical Izhikevich model (mV, pF, pA).

Source code in btorch/models/neurons/izhikevich.py
@classmethod
def from_hippocampome(
    cls,
    n_neuron: int | Sequence[int],
    k,
    a,
    b,
    d,
    C,
    vr,
    vt,
    vpeak,
    vmin,
    **kwargs,
):
    """
    Build an :class:`Izhikevich` neuron using parameter names from
    https://hippocampome.org.

    Parameter mapping (HippoCampome -> Izhikevich args):
    - k -> k (scaling factor)
    - a -> a (recovery time constant)
    - b -> b (recovery sensitivity)
    - d -> d (reset current)
    - C -> c_m (capacitance)
    - vr -> v_rest (resting potential)
    - vt -> v_threshold (instantaneous threshold)
    - vpeak -> v_peak (spike cutoff)
    - vmin -> v_reset (post-spike reset voltage)

    All values are expected in the same units as the canonical
    Izhikevich model (mV, pF, pA).
    """
    kwargs.setdefault("pre_spike", True)
    return cls(
        n_neuron,
        v_threshold=vt,
        v_reset=vmin,
        v_rest=vr,
        v_peak=vpeak,
        c_m=C,
        k=k,
        a=a,
        b=b,
        d=d,
        **kwargs,
    )

LIF

Bases: BaseNode

Leaky integrate-and-fire neuron with optional refractory period.

The LIF neuron integrates input current while leaking towards a resting potential. When the membrane potential exceeds a threshold, a spike is emitted and the potential is reset.

Dynamics

dV/dt = -(V - V_reset) / tau + I / c_m

If tau_ref is specified, a refractory period prevents spiking for tau_ref milliseconds after each spike.

Parameters:

Name Type Description Default
n_neuron int | Sequence[int]

Number of neurons (int or tuple of dimensions).

required
v_threshold float | Float[TensorLike, ' n_neuron']

Firing threshold (mV). Default: 1.0.

1.0
v_reset float | Float[TensorLike, ' n_neuron']

Reset voltage after spike (mV). Default: 0.0.

0.0
c_m float | Float[TensorLike, ' n_neuron']

Membrane capacitance. Default: 1.0.

1.0
tau float | Float[TensorLike, ' n_neuron']

Membrane time constant (ms). Default: 20.0.

20.0
tau_ref float | Float[TensorLike, ' n_neuron'] | None

Refractory period duration (ms). None disables refractory behavior. Default: None.

None
trainable_param set[str]

Set of parameter names to make trainable. Default: empty set.

set()
surrogate_function Callable

Surrogate gradient function for backpropagation. Default: Sigmoid().

Sigmoid()
detach_reset bool

If True, detach reset signal from computation graph. Default: False.

False
hard_reset bool

If True, reset to v_reset directly. If False, subtract (v_threshold - v_reset) from membrane potential (soft reset). Default: False.

False
pre_spike_v bool

If True, store pre-spike voltage in v_pre_spike buffer. Default: False.

False
step_mode Literal['s']

Step mode, currently only "s" (single step) supported. Default: "s".

's'
backend Literal['torch']

Backend implementation. Default: "torch".

'torch'
device

Device for tensors. Default: None.

None
dtype

Data type for tensors. Default: None.

None

Attributes:

Name Type Description
v Tensor

Membrane potential tensor, shape (*batch, n_neuron).

refractory Tensor | None

Refractory counter (if tau_ref specified).

c_m Tensor | Parameter

Membrane capacitance (parameter or buffer).

tau Tensor | Parameter

Time constant (parameter or buffer).

tau_ref Tensor | Parameter | None

Refractory period (parameter or buffer, or None).

Shape
  • Input: (*batch, n_neuron)
  • Output: (*batch, n_neuron) spike tensor (0 or 1)
Source code in btorch/models/neurons/lif.py
class LIF(BaseNode):
    """Leaky integrate-and-fire neuron with optional refractory period.

    The LIF neuron integrates input current while leaking towards a resting
    potential. When the membrane potential exceeds a threshold, a spike is
    emitted and the potential is reset.

    Dynamics:
        dV/dt = -(V - V_reset) / tau + I / c_m

        If tau_ref is specified, a refractory period prevents spiking for
        tau_ref milliseconds after each spike.

    Args:
        n_neuron: Number of neurons (int or tuple of dimensions).
        v_threshold: Firing threshold (mV). Default: 1.0.
        v_reset: Reset voltage after spike (mV). Default: 0.0.
        c_m: Membrane capacitance. Default: 1.0.
        tau: Membrane time constant (ms). Default: 20.0.
        tau_ref: Refractory period duration (ms). None disables refractory
            behavior. Default: None.
        trainable_param: Set of parameter names to make trainable.
            Default: empty set.
        surrogate_function: Surrogate gradient function for backpropagation.
            Default: Sigmoid().
        detach_reset: If True, detach reset signal from computation graph.
            Default: False.
        hard_reset: If True, reset to v_reset directly. If False, subtract
            (v_threshold - v_reset) from membrane potential (soft reset).
            Default: False.
        pre_spike_v: If True, store pre-spike voltage in v_pre_spike buffer.
            Default: False.
        step_mode: Step mode, currently only "s" (single step) supported.
            Default: "s".
        backend: Backend implementation. Default: "torch".
        device: Device for tensors. Default: None.
        dtype: Data type for tensors. Default: None.

    Attributes:
        v: Membrane potential tensor, shape (*batch, n_neuron).
        refractory: Refractory counter (if tau_ref specified).
        c_m: Membrane capacitance (parameter or buffer).
        tau: Time constant (parameter or buffer).
        tau_ref: Refractory period (parameter or buffer, or None).

    Shape:
        - Input: (*batch, n_neuron)
        - Output: (*batch, n_neuron) spike tensor (0 or 1)
    """

    refractory: torch.Tensor | None
    c_m: torch.Tensor | torch.nn.Parameter
    tau: torch.Tensor | torch.nn.Parameter
    tau_ref: torch.Tensor | torch.nn.Parameter | None

    def __init__(
        self,
        n_neuron: int | Sequence[int],
        v_threshold: float | Float[TensorLike, " n_neuron"] = 1.0,
        v_reset: float | Float[TensorLike, " n_neuron"] = 0.0,
        c_m: float | Float[TensorLike, " n_neuron"] = 1.0,
        tau: float | Float[TensorLike, " n_neuron"] = 20.0,
        tau_ref: float | Float[TensorLike, " n_neuron"] | None = None,
        trainable_param: set[str] = set(),
        surrogate_function: Callable = Sigmoid(),
        detach_reset: bool = False,
        hard_reset: bool = False,
        pre_spike_v: bool = False,
        step_mode: Literal["s"] = "s",
        backend: Literal["torch"] = "torch",
        device=None,
        dtype=None,
    ):
        super().__init__(
            n_neuron=n_neuron,
            v_threshold=v_threshold,
            v_reset=v_reset,
            trainable_param=trainable_param,
            surrogate_function=surrogate_function,
            detach_reset=detach_reset,
            hard_reset=hard_reset,
            pre_spike_v=pre_spike_v,
            step_mode=step_mode,
            backend=backend,
            device=device,
            dtype=dtype,
        )
        _factory_kwargs: dict[str, Any] = {"device": device, "dtype": dtype}
        self.def_param(
            "c_m",
            c_m,
            sizes=self.n_neuron,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )
        self.def_param(
            "tau",
            tau,
            sizes=self.n_neuron,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )
        self._use_refractory = tau_ref is not None
        if self._use_refractory:
            self.def_param(
                "tau_ref",
                tau_ref,
                sizes=self.n_neuron,
                trainable_param=self.trainable_param,
                **_factory_kwargs,
            )
            self.register_memory("refractory", 0.0, self.n_neuron)
        else:
            self.tau_ref = None

    def dV(
        self,
        v: Float[Tensor, "*batch n_neuron"],
        x: Float[Tensor, "*batch n_neuron"],
    ):
        derivative = -(v - self.v_reset) / self.tau + x / self.c_m
        return derivative

    def neuronal_charge(self, x: Float[Tensor, "*batch n_neuron"]):
        v = euler_step(self.dV, self.v, x, dt=environ.get("dt"))
        self.v = v

    def neuronal_adaptation(self):
        # LIF has no intrinsic adaptation other than the refractory counter.
        return None

    def neuronal_fire(self):
        spike = self.surrogate_function(
            (self.v - self.v_threshold) / (self.v_threshold - self.v_reset)
        )
        if not self._use_refractory:
            return spike
        not_in_refractory = self.refractory == 0
        spike = spike * not_in_refractory.detach().to(self.v.dtype)
        return spike

    def neuronal_reset(self, spike: Float[Tensor, "*batch n"]):
        if self.detach_reset:
            spike_d = spike.detach()
        else:
            spike_d = spike

        if self.pre_spike_v:
            self.v_pre_spike = self.v.clone()

        if self.hard_reset:
            self.v = self.v - (self.v - self.v_reset) * spike_d
        else:
            self.v = self.v - (self.v_threshold - self.v_reset) * spike_d

        if self._use_refractory:
            self.refractory = torch.relu(
                self.refractory + spike_d * self.tau_ref - environ.get("dt")
            )

    def extra_repr(self):
        parts = [
            f"c_m={self._format_repr_value(self.c_m)}",
            f"tau={self._format_repr_value(self.tau)}",
            f"tau_ref={self._format_repr_value(self.tau_ref)}"
            if self._use_refractory
            else "tau_ref=None",
        ]
        base = super().extra_repr()
        if base:
            parts.append(base)
        return ", ".join(parts)

btorch.models.neurons.alif

Adaptive leaky integrate-and-fire (ALIF) neuron models.

This module provides ALIF and ELIF (exponential LIF) neuron implementations with conductance-based adaptation mechanisms.

The ALIF neuron extends LIF by adding a potassium conductance (g_k) that increases with each spike, creating spike-frequency adaptation:

dv/dt = (-g_leak * (v - E_leak) - g_k * (v - E_k) + x) / c_m
dg_k/dt = -g_k / tau_adapt

where g_k increments by dg_k at each spike and decays exponentially, creating negative feedback that slows firing rate over time.

Attributes

TensorLike = np.ndarray | torch.Tensor module-attribute

Classes

ALIF

Bases: BaseNode

Adaptive leaky integrate-and-fire neuron with conductance adaptation.

The ALIF model extends standard LIF by adding a voltage-dependent potassium conductance (g_k) that creates spike-frequency adaptation. Each spike increases g_k by dg_k, which then decays exponentially.

Dynamics

dv/dt = (-g_leak * (v - E_leak) - g_k * (v - E_k) + x) / c_m dg_k/dt = -g_k / tau_adapt

At spike: g_k += dg_k

Parameters:

Name Type Description Default
n_neuron int | Sequence[int]

Number of neurons (int or tuple of dimensions).

required
v_threshold float | Float[TensorLike, ' n_neuron']

Firing threshold (mV). Default: 1.0.

1.0
v_reset float | Float[TensorLike, ' n_neuron']

Reset voltage after spike (mV). Default: 0.0.

0.0
c_m float | Float[TensorLike, ' n_neuron']

Membrane capacitance (pF). Default: 1.0.

1.0
g_leak float | Float[TensorLike, ' n_neuron']

Leak conductance (nS). Default: 1.0.

1.0
E_leak float | Float[TensorLike, ' n_neuron']

Leak reversal potential (mV). Default: 0.0.

0.0
E_k float | Float[TensorLike, ' n_neuron']

Potassium reversal potential (mV). Default: -70.0.

-70.0
g_k_init float | Float[TensorLike, ' n_neuron']

Initial adaptation conductance (nS). Default: 0.0.

0.0
tau_adapt float | Float[TensorLike, ' n_neuron']

Adaptation time constant (ms). Default: 20.0.

20.0
dg_k float | Float[TensorLike, ' n_neuron']

Adaptation increment per spike (nS). Default: 0.0.

0.0
tau_ref float | Float[TensorLike, ' n_neuron'] | None

Refractory period (ms). None disables refractory. Default: None.

None
trainable_param set[str]

Set of parameter names to make trainable.

set()
surrogate_function Callable

Surrogate gradient function. Default: Sigmoid().

Sigmoid()
detach_reset bool

If True, detach reset signal. Default: False.

False
hard_reset bool

If True, use hard reset. Default: False.

False
pre_spike_v bool

If True, store pre-spike voltage. Default: False.

False
step_mode Literal['s']

Step mode. Default: "s".

's'
backend Literal['torch']

Backend implementation. Default: "torch".

'torch'
device

Device for tensors. Default: None.

None
dtype

Data type for tensors. Default: None.

None

Attributes:

Name Type Description
v Tensor

Membrane potential, shape (*batch, n_neuron).

g_k Tensor

Adaptation conductance, shape (*batch, n_neuron).

refractory Tensor | None

Refractory counter (if tau_ref specified).

c_m, (g_leak, E_leak, E_k)

Neuron parameters.

tau_adapt Tensor | Parameter

Adaptation time constant.

dg_k Tensor | Parameter

Per-spike adaptation increment.

Source code in btorch/models/neurons/alif.py
class ALIF(BaseNode):
    """Adaptive leaky integrate-and-fire neuron with conductance adaptation.

    The ALIF model extends standard LIF by adding a voltage-dependent
    potassium conductance (g_k) that creates spike-frequency adaptation.
    Each spike increases g_k by dg_k, which then decays exponentially.

    Dynamics:
        dv/dt = (-g_leak * (v - E_leak) - g_k * (v - E_k) + x) / c_m
        dg_k/dt = -g_k / tau_adapt

        At spike: g_k += dg_k

    Args:
        n_neuron: Number of neurons (int or tuple of dimensions).
        v_threshold: Firing threshold (mV). Default: 1.0.
        v_reset: Reset voltage after spike (mV). Default: 0.0.
        c_m: Membrane capacitance (pF). Default: 1.0.
        g_leak: Leak conductance (nS). Default: 1.0.
        E_leak: Leak reversal potential (mV). Default: 0.0.
        E_k: Potassium reversal potential (mV). Default: -70.0.
        g_k_init: Initial adaptation conductance (nS). Default: 0.0.
        tau_adapt: Adaptation time constant (ms). Default: 20.0.
        dg_k: Adaptation increment per spike (nS). Default: 0.0.
        tau_ref: Refractory period (ms). None disables refractory.
            Default: None.
        trainable_param: Set of parameter names to make trainable.
        surrogate_function: Surrogate gradient function. Default: Sigmoid().
        detach_reset: If True, detach reset signal. Default: False.
        hard_reset: If True, use hard reset. Default: False.
        pre_spike_v: If True, store pre-spike voltage. Default: False.
        step_mode: Step mode. Default: "s".
        backend: Backend implementation. Default: "torch".
        device: Device for tensors. Default: None.
        dtype: Data type for tensors. Default: None.

    Attributes:
        v: Membrane potential, shape (*batch, n_neuron).
        g_k: Adaptation conductance, shape (*batch, n_neuron).
        refractory: Refractory counter (if tau_ref specified).
        c_m, g_leak, E_leak, E_k: Neuron parameters.
        tau_adapt: Adaptation time constant.
        dg_k: Per-spike adaptation increment.
    """

    g_k: torch.Tensor
    refractory: torch.Tensor | None

    c_m: torch.Tensor | torch.nn.Parameter
    g_leak: torch.Tensor | torch.nn.Parameter
    E_leak: torch.Tensor | torch.nn.Parameter
    E_k: torch.Tensor | torch.nn.Parameter
    tau_adapt: torch.Tensor | torch.nn.Parameter
    dg_k: torch.Tensor | torch.nn.Parameter
    tau_ref: torch.Tensor | torch.nn.Parameter | None

    def __init__(
        self,
        n_neuron: int | Sequence[int],
        v_threshold: float | Float[TensorLike, " n_neuron"] = 1.0,
        v_reset: float | Float[TensorLike, " n_neuron"] = 0.0,
        c_m: float | Float[TensorLike, " n_neuron"] = 1.0,
        g_leak: float | Float[TensorLike, " n_neuron"] = 1.0,
        E_leak: float | Float[TensorLike, " n_neuron"] = 0.0,
        E_k: float | Float[TensorLike, " n_neuron"] = -70.0,
        g_k_init: float | Float[TensorLike, " n_neuron"] = 0.0,
        tau_adapt: float | Float[TensorLike, " n_neuron"] = 20.0,
        dg_k: float | Float[TensorLike, " n_neuron"] = 0.0,
        tau_ref: float | Float[TensorLike, " n_neuron"] | None = None,
        trainable_param: set[str] = set(),
        surrogate_function: Callable = Sigmoid(),
        detach_reset: bool = False,
        hard_reset: bool = False,
        pre_spike_v: bool = False,
        step_mode: Literal["s"] = "s",
        backend: Literal["torch"] = "torch",
        device=None,
        dtype=None,
    ):
        super().__init__(
            n_neuron=n_neuron,
            v_threshold=v_threshold,
            v_reset=v_reset,
            trainable_param=trainable_param,
            surrogate_function=surrogate_function,
            detach_reset=detach_reset,
            hard_reset=hard_reset,
            pre_spike_v=pre_spike_v,
            step_mode=step_mode,
            backend=backend,
            device=device,
            dtype=dtype,
        )
        _factory_kwargs: dict[str, Any] = {"device": device, "dtype": dtype}
        self.def_param(
            "c_m",
            c_m,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )
        self.def_param(
            "g_leak",
            g_leak,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )
        self.def_param(
            "E_leak",
            E_leak,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )
        self.def_param(
            "E_k",
            E_k,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )
        self.def_param(
            "tau_adapt",
            tau_adapt,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )
        self.def_param(
            "dg_k",
            dg_k,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )
        self._use_refractory = tau_ref is not None
        if self._use_refractory:
            self.def_param(
                "tau_ref",
                tau_ref,
                trainable_param=self.trainable_param,
                **_factory_kwargs,
            )
            self.register_memory("refractory", 0.0, self.n_neuron)
        else:
            self.tau_ref = None

        self.register_memory("g_k", g_k_init, self.n_neuron)

    @property
    def v_rest(self):
        if self._v_rest is None:
            return self.v_reset
        return self._v_rest

    @v_rest.setter
    def v_rest(self, v_rest):
        if self._v_rest is not None:
            self._v_rest = v_rest

    @property
    def v_peak(self):
        return self.v_threshold

    @v_peak.setter
    def v_peak(self, value):
        self.v_threshold = value

    def dV(
        self,
        v: Float[Tensor, "*batch n_neuron"],
        g_k: Float[Tensor, "*batch n_neuron"],
        x: Float[Tensor, "*batch n_neuron"],
    ):
        leak_term = -self.g_leak * (v - self.E_leak)
        adapt_term = -g_k * (v - self.E_k)
        derivative = (leak_term + adapt_term + x) / self.c_m
        linear = (-self.g_leak - g_k) / self.c_m
        return derivative, linear

    def dgk(
        self, g_k: Float[Tensor, "*batch n_neuron"]
    ) -> Float[Tensor, "*batch n_neuron"]:
        derivative = -g_k / self.tau_adapt
        linear = -1.0 / self.tau_adapt
        return derivative, linear

    def neuronal_charge(self, x: Float[Tensor, "*batch n_neuron"]):
        dt = environ.get("dt")
        self.v = exp_euler_step(self.dV, self.v, self.g_k, x, dt=dt)

    def neuronal_adaptation(self):
        dt = environ.get("dt")
        self.g_k = exp_euler_step(self.dgk, self.g_k, dt=dt)

    def neuronal_fire(self):
        spike = self.surrogate_function(
            (self.v - self.v_threshold) / (self.v_threshold - self.v_reset)
        )
        if not self._use_refractory:
            return spike
        not_in_refractory = self.refractory == 0
        spike = spike * not_in_refractory.detach().to(self.v.dtype)
        return spike

    def neuronal_reset(self, spike: Float[Tensor, "*batch n"]):
        if self.detach_reset:
            spike_d = spike.detach()
        else:
            spike_d = spike

        if self.pre_spike_v:
            self.v_pre_spike = self.v.clone()

        if self.hard_reset:
            self.v = self.v - (self.v - self.v_reset) * spike_d
        else:
            self.v = self.v - (self.v_threshold - self.v_reset) * spike_d

        self.g_k = self.g_k + self.dg_k * spike_d

        if self._use_refractory:
            self.refractory = torch.relu(
                self.refractory + spike_d * self.tau_ref - environ.get("dt")
            )

    def extra_repr(self):
        g_k_init = self._memories_rv["g_k"].value
        parts = [
            f"c_m={self._format_repr_value(self.c_m)}",
            f"g_leak={self._format_repr_value(self.g_leak)}",
            f"E_leak={self._format_repr_value(self.E_leak)}",
            f"E_k={self._format_repr_value(self.E_k)}",
            f"tau_adapt={self._format_repr_value(self.tau_adapt)}",
            f"dg_k={self._format_repr_value(self.dg_k)}",
            f"g_k_init={self._format_repr_value(g_k_init)}",
            f"tau_ref={self._format_repr_value(self.tau_ref)}"
            if self._use_refractory
            else "tau_ref=None",
        ]
        base = super().extra_repr()
        if base:
            parts.append(base)
        return ", ".join(parts)

BaseNode

Bases: ParamBufferMixin, MemoryModule

Base class for differentiable spiking neurons.

Implements the spiking neuron lifecycle: charge -> adapt -> fire -> reset. Subclasses implement neuronal_charge() and neuronal_adaptation().

Parameters:

Name Type Description Default
n_neuron int | Sequence[int]

Number of neurons (int or tuple).

required
v_threshold float | Float[TensorLike, ' n_neuron']

Firing threshold. Default: 1.0.

1.0
v_reset float | Float[TensorLike, ' n_neuron']

Reset voltage. Default: 0.0.

0.0
trainable_param set[str]

Trainable parameter names. Default: ().

set()
surrogate_function Callable

Surrogate for backprop. Default: Sigmoid().

Sigmoid()
detach_reset bool

Detach reset signal. Default: False.

False
hard_reset bool

Hard vs soft reset. Default: False.

False
pre_spike_v bool

Store pre-spike voltage. Default: False.

False
step_mode

"s" or "m". Default: "s".

's'
backend

Compute backend. Default: "torch".

'torch'
device

Tensor device. Default: None.

None
dtype

Tensor dtype. Default: None.

None
Source code in btorch/models/base.py
class BaseNode(ParamBufferMixin, MemoryModule):
    """Base class for differentiable spiking neurons.

    Implements the spiking neuron lifecycle: charge -> adapt -> fire -> reset.
    Subclasses implement neuronal_charge() and neuronal_adaptation().

    Args:
        n_neuron: Number of neurons (int or tuple).
        v_threshold: Firing threshold. Default: 1.0.
        v_reset: Reset voltage. Default: 0.0.
        trainable_param: Trainable parameter names. Default: ().
        surrogate_function: Surrogate for backprop. Default: Sigmoid().
        detach_reset: Detach reset signal. Default: False.
        hard_reset: Hard vs soft reset. Default: False.
        pre_spike_v: Store pre-spike voltage. Default: False.
        step_mode: "s" or "m". Default: "s".
        backend: Compute backend. Default: "torch".
        device: Tensor device. Default: None.
        dtype: Tensor dtype. Default: None.
    """

    n_neuron: tuple[int, ...]
    size: int
    v: torch.Tensor
    v_pre_spike: torch.Tensor
    v_threshold: torch.Tensor | torch.nn.Parameter
    v_reset: torch.Tensor | torch.nn.Parameter

    def __init__(
        self,
        n_neuron: int | Sequence[int],
        v_threshold: float | Float[TensorLike, " n_neuron"] = 1.0,
        v_reset: float | Float[TensorLike, " n_neuron"] = 0.0,
        trainable_param: set[str] = set(),
        surrogate_function: Callable = Sigmoid(),
        detach_reset: bool = False,
        hard_reset: bool = False,
        pre_spike_v: bool = False,
        step_mode="s",
        backend="torch",
        device=None,
        dtype=None,
    ):
        """Modified spikingjelly BaseNode.

        * :ref:`API in English <BaseNode.__init__-en>`

        This class is the base class of differentiable spiking neurons.
        """

        # override neuron.BaseNode's __init__ method to remove unnecessary checks
        # call neuron.BaseNode's parent MemoryModule directly
        super().__init__()

        self.n_neuron, self.size = normalize_n_neuron(n_neuron)
        self.register_memory("v", v_reset, self.n_neuron)
        self.pre_spike_v = pre_spike_v

        _factory_kwargs: dict[str, Any] = {"device": device, "dtype": dtype}
        if pre_spike_v:
            self.register_memory(
                "v_pre_spike", v_reset, self.n_neuron, persistent=False
            )

        self.trainable_param = set(trainable_param)
        self.def_param(
            "v_threshold",
            v_threshold,
            sizes=self.n_neuron,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )
        self.def_param(
            "v_reset",
            v_reset,
            sizes=self.n_neuron,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )

        self.detach_reset = detach_reset
        self.surrogate_function = surrogate_function
        self.hard_reset = hard_reset

        self.step_mode = step_mode
        self.backend = backend

    def extra_repr(self):
        parts = [
            f"n_neuron={self.n_neuron}",
            f"v_threshold={self._format_repr_value(self.v_threshold)}",
            f"v_reset={self._format_repr_value(self.v_reset)}",
            f"step_mode={self.step_mode}",
            f"backend={self.backend}",
            f"surrogate={self.surrogate_function.__class__.__name__}",
        ]
        if self.detach_reset:
            parts.append("detach_reset=True")
        if self.hard_reset:
            parts.append("hard_reset=True")
        if self.pre_spike_v:
            parts.append("pre_spike_v=True")
        mem_repr = super().extra_repr()
        if mem_repr:
            parts.append(mem_repr)
        return ", ".join(parts)

    @abstractmethod
    def neuronal_charge(self, x: torch.Tensor):
        """
         * :ref:`API in English <BaseNode.neuronal_charge-en>`

        .. _BaseNode.neuronal_charge-cn:

        定义神经元的充电差分方程。子类必须实现这个函数。

        * :ref:`中文API <BaseNode.neuronal_charge-cn>`

        .. _BaseNode.neuronal_charge-en:


        Define the charge difference equation.
        The sub-class must implement this function.
        """
        raise NotImplementedError

    def neuronal_fire(self):
        """
        * :ref:`API in English <BaseNode.neuronal_fire-en>`

        .. _BaseNode.neuronal_fire-cn:

        根据当前神经元的电压、阈值,计算输出脉冲。

        * :ref:`中文API <BaseNode.neuronal_fire-cn>`

        .. _BaseNode.neuronal_fire-en:


        Calculate out spikes of neurons by their current membrane potential
        and threshold voltage.
        """

        return self.surrogate_function(self.v - self.v_threshold)

    def neuronal_reset(self, spike):
        """
        * :ref:`API in English <BaseNode.neuronal_reset-en>`

        .. _BaseNode.neuronal_reset-cn:

        根据当前神经元释放的脉冲,对膜电位进行重置。

        * :ref:`中文API <BaseNode.neuronal_reset-cn>`

        .. _BaseNode.neuronal_reset-en:


        Reset the membrane potential according to neurons' output spikes.
        """
        if self.detach_reset:
            spike_d = spike.detach()
        else:
            spike_d = spike

        if self.pre_spike_v:
            self.v_pre_spike = self.v.clone()

        if self.hard_reset:
            # hard reset
            self.v = self.v - (self.v - self.v_reset) * spike_d
        else:
            # soft reset
            self.v = self.v - (self.v_threshold - self.v_reset) * spike_d

    def neuronal_adaptation(self):
        raise NotImplementedError()

    def single_step_forward(self, x: Float[Tensor, "*batch n_neuron"]):
        """
        * :ref:`API in English <BaseNode.single_step_forward-en>`
        """
        self.neuronal_charge(x)
        self.neuronal_adaptation()
        spike = self.neuronal_fire()
        self.neuronal_reset(spike)
        return spike

    def multi_step_forward(self, x_seq: Float[Tensor, "T *batch n_neuron"]):
        s_seq = []
        for t, x in enumerate(x_seq):
            s = self.single_step_forward(x)
            s_seq.append(s)

        return torch.stack(s_seq)
Functions
__init__(n_neuron, v_threshold=1.0, v_reset=0.0, trainable_param=set(), surrogate_function=Sigmoid(), detach_reset=False, hard_reset=False, pre_spike_v=False, step_mode='s', backend='torch', device=None, dtype=None)

Modified spikingjelly BaseNode.

  • :ref:API in English <BaseNode.__init__-en>

This class is the base class of differentiable spiking neurons.

Source code in btorch/models/base.py
def __init__(
    self,
    n_neuron: int | Sequence[int],
    v_threshold: float | Float[TensorLike, " n_neuron"] = 1.0,
    v_reset: float | Float[TensorLike, " n_neuron"] = 0.0,
    trainable_param: set[str] = set(),
    surrogate_function: Callable = Sigmoid(),
    detach_reset: bool = False,
    hard_reset: bool = False,
    pre_spike_v: bool = False,
    step_mode="s",
    backend="torch",
    device=None,
    dtype=None,
):
    """Modified spikingjelly BaseNode.

    * :ref:`API in English <BaseNode.__init__-en>`

    This class is the base class of differentiable spiking neurons.
    """

    # override neuron.BaseNode's __init__ method to remove unnecessary checks
    # call neuron.BaseNode's parent MemoryModule directly
    super().__init__()

    self.n_neuron, self.size = normalize_n_neuron(n_neuron)
    self.register_memory("v", v_reset, self.n_neuron)
    self.pre_spike_v = pre_spike_v

    _factory_kwargs: dict[str, Any] = {"device": device, "dtype": dtype}
    if pre_spike_v:
        self.register_memory(
            "v_pre_spike", v_reset, self.n_neuron, persistent=False
        )

    self.trainable_param = set(trainable_param)
    self.def_param(
        "v_threshold",
        v_threshold,
        sizes=self.n_neuron,
        trainable_param=self.trainable_param,
        **_factory_kwargs,
    )
    self.def_param(
        "v_reset",
        v_reset,
        sizes=self.n_neuron,
        trainable_param=self.trainable_param,
        **_factory_kwargs,
    )

    self.detach_reset = detach_reset
    self.surrogate_function = surrogate_function
    self.hard_reset = hard_reset

    self.step_mode = step_mode
    self.backend = backend
neuronal_charge(x) abstractmethod
  • :ref:API in English <BaseNode.neuronal_charge-en>

.. _BaseNode.neuronal_charge-cn:

定义神经元的充电差分方程。子类必须实现这个函数。

  • :ref:中文API <BaseNode.neuronal_charge-cn>

.. _BaseNode.neuronal_charge-en:

Define the charge difference equation. The sub-class must implement this function.

Source code in btorch/models/base.py
@abstractmethod
def neuronal_charge(self, x: torch.Tensor):
    """
     * :ref:`API in English <BaseNode.neuronal_charge-en>`

    .. _BaseNode.neuronal_charge-cn:

    定义神经元的充电差分方程。子类必须实现这个函数。

    * :ref:`中文API <BaseNode.neuronal_charge-cn>`

    .. _BaseNode.neuronal_charge-en:


    Define the charge difference equation.
    The sub-class must implement this function.
    """
    raise NotImplementedError
neuronal_fire()
  • :ref:API in English <BaseNode.neuronal_fire-en>

.. _BaseNode.neuronal_fire-cn:

根据当前神经元的电压、阈值,计算输出脉冲。

  • :ref:中文API <BaseNode.neuronal_fire-cn>

.. _BaseNode.neuronal_fire-en:

Calculate out spikes of neurons by their current membrane potential and threshold voltage.

Source code in btorch/models/base.py
def neuronal_fire(self):
    """
    * :ref:`API in English <BaseNode.neuronal_fire-en>`

    .. _BaseNode.neuronal_fire-cn:

    根据当前神经元的电压、阈值,计算输出脉冲。

    * :ref:`中文API <BaseNode.neuronal_fire-cn>`

    .. _BaseNode.neuronal_fire-en:


    Calculate out spikes of neurons by their current membrane potential
    and threshold voltage.
    """

    return self.surrogate_function(self.v - self.v_threshold)
neuronal_reset(spike)
  • :ref:API in English <BaseNode.neuronal_reset-en>

.. _BaseNode.neuronal_reset-cn:

根据当前神经元释放的脉冲,对膜电位进行重置。

  • :ref:中文API <BaseNode.neuronal_reset-cn>

.. _BaseNode.neuronal_reset-en:

Reset the membrane potential according to neurons' output spikes.

Source code in btorch/models/base.py
def neuronal_reset(self, spike):
    """
    * :ref:`API in English <BaseNode.neuronal_reset-en>`

    .. _BaseNode.neuronal_reset-cn:

    根据当前神经元释放的脉冲,对膜电位进行重置。

    * :ref:`中文API <BaseNode.neuronal_reset-cn>`

    .. _BaseNode.neuronal_reset-en:


    Reset the membrane potential according to neurons' output spikes.
    """
    if self.detach_reset:
        spike_d = spike.detach()
    else:
        spike_d = spike

    if self.pre_spike_v:
        self.v_pre_spike = self.v.clone()

    if self.hard_reset:
        # hard reset
        self.v = self.v - (self.v - self.v_reset) * spike_d
    else:
        # soft reset
        self.v = self.v - (self.v_threshold - self.v_reset) * spike_d
single_step_forward(x)
  • :ref:API in English <BaseNode.single_step_forward-en>
Source code in btorch/models/base.py
def single_step_forward(self, x: Float[Tensor, "*batch n_neuron"]):
    """
    * :ref:`API in English <BaseNode.single_step_forward-en>`
    """
    self.neuronal_charge(x)
    self.neuronal_adaptation()
    spike = self.neuronal_fire()
    self.neuronal_reset(spike)
    return spike

ELIF

Bases: ALIF

Exponential integrate-and-fire neuron with adaptation.

The ELIF model extends ALIF by adding an exponential term to the voltage dynamics, creating a sharp upswing when approaching threshold (the "initiation zone"). This captures the rapid depolarization seen in real neurons.

Dynamics

dv/dt = (g_leak * delta_T * exp((v - v_T) / delta_T) - g_leak * (v - E_leak) - g_k * (v - E_k) + x) / c_m dg_k/dt = -g_k / tau_adapt

The exponential term creates a soft threshold effect where membrane potential accelerates as it approaches v_T.

Parameters:

Name Type Description Default
n_neuron int | Sequence[int]

Number of neurons (int or tuple of dimensions).

required
v_threshold float | Float[TensorLike, ' n_neuron']

Firing threshold (mV). Default: 1.0.

1.0
v_reset float | Float[TensorLike, ' n_neuron']

Reset voltage after spike (mV). Default: 0.0.

0.0
c_m float | Float[TensorLike, ' n_neuron']

Membrane capacitance (pF). Default: 1.0.

1.0
g_leak float | Float[TensorLike, ' n_neuron']

Leak conductance (nS). Default: 1.0.

1.0
E_leak float | Float[TensorLike, ' n_neuron']

Leak reversal potential (mV). Default: 0.0.

0.0
E_k float | Float[TensorLike, ' n_neuron']

Potassium reversal potential (mV). Default: -70.0.

-70.0
g_k_init float | Float[TensorLike, ' n_neuron']

Initial adaptation conductance (nS). Default: 0.0.

0.0
tau_adapt float | Float[TensorLike, ' n_neuron']

Adaptation time constant (ms). Default: 20.0.

20.0
dg_k float | Float[TensorLike, ' n_neuron']

Adaptation increment per spike (nS). Default: 0.0.

0.0
tau_ref float | Float[TensorLike, ' n_neuron'] | None

Refractory period (ms). Default: 0.0.

0.0
delta_T float | Float[TensorLike, ' n_neuron']

Slope factor for exponential term (mV). Default: 1.0.

1.0
v_T float | Float[TensorLike, ' n_neuron']

Soft threshold potential (mV). Default: 0.0.

0.0
trainable_param set[str]

Set of parameter names to make trainable.

set()
surrogate_function Callable

Surrogate gradient function. Default: Sigmoid().

Sigmoid()
detach_reset bool

If True, detach reset signal. Default: False.

False
hard_reset bool

If True, use hard reset. Default: False.

False
pre_spike_v bool

If True, store pre-spike voltage. Default: False.

False
step_mode Literal['s']

Step mode. Default: "s".

's'
backend Literal['torch']

Backend implementation. Default: "torch".

'torch'
device

Device for tensors. Default: None.

None
dtype

Data type for tensors. Default: None.

None

Attributes:

Name Type Description
delta_T Tensor | Parameter

Slope factor for exponential term.

v_T Tensor | Parameter

Soft threshold potential.

Source code in btorch/models/neurons/alif.py
class ELIF(ALIF):
    """Exponential integrate-and-fire neuron with adaptation.

    The ELIF model extends ALIF by adding an exponential term to the
    voltage dynamics, creating a sharp upswing when approaching threshold
    (the "initiation zone"). This captures the rapid depolarization seen
    in real neurons.

    Dynamics:
        dv/dt = (g_leak * delta_T * exp((v - v_T) / delta_T)
                 - g_leak * (v - E_leak) - g_k * (v - E_k) + x) / c_m
        dg_k/dt = -g_k / tau_adapt

    The exponential term creates a soft threshold effect where membrane
    potential accelerates as it approaches v_T.

    Args:
        n_neuron: Number of neurons (int or tuple of dimensions).
        v_threshold: Firing threshold (mV). Default: 1.0.
        v_reset: Reset voltage after spike (mV). Default: 0.0.
        c_m: Membrane capacitance (pF). Default: 1.0.
        g_leak: Leak conductance (nS). Default: 1.0.
        E_leak: Leak reversal potential (mV). Default: 0.0.
        E_k: Potassium reversal potential (mV). Default: -70.0.
        g_k_init: Initial adaptation conductance (nS). Default: 0.0.
        tau_adapt: Adaptation time constant (ms). Default: 20.0.
        dg_k: Adaptation increment per spike (nS). Default: 0.0.
        tau_ref: Refractory period (ms). Default: 0.0.
        delta_T: Slope factor for exponential term (mV). Default: 1.0.
        v_T: Soft threshold potential (mV). Default: 0.0.
        trainable_param: Set of parameter names to make trainable.
        surrogate_function: Surrogate gradient function. Default: Sigmoid().
        detach_reset: If True, detach reset signal. Default: False.
        hard_reset: If True, use hard reset. Default: False.
        pre_spike_v: If True, store pre-spike voltage. Default: False.
        step_mode: Step mode. Default: "s".
        backend: Backend implementation. Default: "torch".
        device: Device for tensors. Default: None.
        dtype: Data type for tensors. Default: None.

    Attributes:
        delta_T: Slope factor for exponential term.
        v_T: Soft threshold potential.
    """

    delta_T: torch.Tensor | torch.nn.Parameter
    v_T: torch.Tensor | torch.nn.Parameter

    def __init__(
        self,
        n_neuron: int | Sequence[int],
        v_threshold: float | Float[TensorLike, " n_neuron"] = 1.0,
        v_reset: float | Float[TensorLike, " n_neuron"] = 0.0,
        c_m: float | Float[TensorLike, " n_neuron"] = 1.0,
        g_leak: float | Float[TensorLike, " n_neuron"] = 1.0,
        E_leak: float | Float[TensorLike, " n_neuron"] = 0.0,
        E_k: float | Float[TensorLike, " n_neuron"] = -70.0,
        g_k_init: float | Float[TensorLike, " n_neuron"] = 0.0,
        tau_adapt: float | Float[TensorLike, " n_neuron"] = 20.0,
        dg_k: float | Float[TensorLike, " n_neuron"] = 0.0,
        tau_ref: float | Float[TensorLike, " n_neuron"] | None = 0.0,
        delta_T: float | Float[TensorLike, " n_neuron"] = 1.0,
        v_T: float | Float[TensorLike, " n_neuron"] = 0.0,
        trainable_param: set[str] = set(),
        surrogate_function: Callable = Sigmoid(),
        detach_reset: bool = False,
        hard_reset: bool = False,
        pre_spike_v: bool = False,
        step_mode: Literal["s"] = "s",
        backend: Literal["torch"] = "torch",
        device=None,
        dtype=None,
    ):
        super().__init__(
            n_neuron=n_neuron,
            v_threshold=v_threshold,
            v_reset=v_reset,
            c_m=c_m,
            g_leak=g_leak,
            E_leak=E_leak,
            E_k=E_k,
            g_k_init=g_k_init,
            tau_adapt=tau_adapt,
            dg_k=dg_k,
            tau_ref=tau_ref,
            trainable_param=trainable_param,
            surrogate_function=surrogate_function,
            detach_reset=detach_reset,
            hard_reset=hard_reset,
            pre_spike_v=pre_spike_v,
            step_mode=step_mode,
            backend=backend,
            device=device,
            dtype=dtype,
        )
        _factory_kwargs: dict[str, Any] = {"device": device, "dtype": dtype}
        self.def_param(
            "delta_T",
            delta_T,
            sizes=self.n_neuron,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )
        self.def_param(
            "v_T",
            v_T,
            sizes=self.n_neuron,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )

    def dV(
        self,
        v: Float[Tensor, "*batch n_neuron"],
        g_k: Float[Tensor, "*batch n_neuron"],
        x: Float[Tensor, "*batch n_neuron"],
    ):
        leak_term = -self.g_leak * (v - self.E_leak)
        adapt_term = -g_k * (v - self.E_k)
        exp_term = self.g_leak * self.delta_T * torch.exp((v - self.v_T) / self.delta_T)
        derivative = (leak_term + adapt_term + exp_term + x) / self.c_m
        linear = (-self.g_leak - g_k + exp_term / self.delta_T) / self.c_m
        return derivative, linear

    def neuronal_charge(self, x: Float[Tensor, "*batch n_neuron"]):
        dt = environ.get("dt")
        self.v = exp_euler_step(self.dV, self.v, self.g_k, x, dt=dt)

    def extra_repr(self):
        parts = [
            f"delta_T={self._format_repr_value(self.delta_T)}",
            f"v_T={self._format_repr_value(self.v_T)}",
        ]
        base = super().extra_repr()
        if base:
            parts.append(base)
        return ", ".join(parts)

Sigmoid

Bases: SurrogateFunctionBase

Logistic surrogate derivative.

Source code in btorch/models/surrogate/sigmoid.py
class Sigmoid(SurrogateFunctionBase):
    """Logistic surrogate derivative."""

    def primitive(self, x: torch.Tensor) -> torch.Tensor:
        return _sigmoid_primitive(x, self.alpha)

    def derivative(
        self,
        x: torch.Tensor,
        grad_output: torch.Tensor,
        damping_factor: float = 1.0,
    ) -> torch.Tensor:
        return _sigmoid_derivative(x, grad_output, self.alpha, damping_factor)

Functions

exp_euler_step(f, *args, dt=1.0, linear=None)

One integration step applying the exponential Euler method.

.. math:: rac{dx}{dt} = f(x) = Ax + B

where :math:A is the linear term and :math:f(x) is the derivative.

The update rule is:

.. math:: x_{n+1} = x_n + rac{e^{dt A} - 1}{A} f(x_n) \ &= e^{dt A}x_n + rac{e^{dt A} - 1}{A} B

Source code in btorch/models/ode.py
def exp_euler_step(f: Callable, *args, dt=1.0, linear: Tensor | None = None):
    """One integration step applying the exponential Euler method.

    .. math::
        \frac{dx}{dt} = f(x) = Ax + B

    where :math:`A` is the linear term and :math:`f(x)` is the derivative.

    The update rule is:

    .. math::
        x_{n+1} = x_n + \frac{e^{dt A} - 1}{A} f(x_n) \\
        &= e^{dt A}x_n + \frac{e^{dt A} - 1}{A} B
    """
    out = f(*args)
    derivative, linear_from_f = _split_derivative_linear(out)
    if linear is None:
        linear = linear_from_f
    if linear is None:
        if torch.compiler.is_compiling():
            raise RuntimeError(
                "torch.compile cannot use vjp fallback here; return "
                "(derivative, linear) from f(*args) or pass linear=."
            )
        if len(args) > 1:
            _f = lambda x: _derivative_only(f(x, *args[1:]))
        else:
            _f = lambda x: _derivative_only(f(x))
        derivative, linear_f = vjp(_f, args[0])
        linear = linear_f(torch.ones_like(derivative))[0]
    return args[0] + torch.expm1(dt * linear) / linear * derivative

btorch.models.neurons.glif

Generalized leaky integrate-and-fire (GLIF) neuron models.

This module implements the GLIF3 model from the Allen Institute [1], which extends standard LIF with after-spike currents (ASC) that capture spike-frequency adaptation and other slow currents.

The GLIF3 neuron follows

dV/dt = -(V - V_rest) / tau + (I_in + sum(I_asc)) / c_m dI_asc/dt = -k * I_asc

where I_asc are after-spike currents that increment by asc_amps at each spike.

References

[1] Teeter et al., "Generalized leaky integrate-and-fire models classify multiple neuron types," Nat. Commun., 2018.

Attributes

TensorLike = np.ndarray | torch.Tensor module-attribute

Classes

ATan

Bases: SurrogateFunctionBase

Arctan surrogate matching SpikingJelly's alpha scaling.

Source code in btorch/models/surrogate/atan.py
class ATan(SurrogateFunctionBase):
    """Arctan surrogate matching SpikingJelly's alpha scaling."""

    def __init__(
        self, alpha: float = 2.0, damping_factor: float = 1.0, spiking: bool = True
    ):
        super().__init__(alpha=alpha, damping_factor=damping_factor, spiking=spiking)

    def primitive(self, x: torch.Tensor) -> torch.Tensor:
        return _atan_primitive(x, self.alpha)

    def derivative(
        self,
        x: torch.Tensor,
        grad_output: torch.Tensor,
        damping_factor: float = 1.0,
    ) -> torch.Tensor:
        return _atan_derivative(x, grad_output, self.alpha, damping_factor)

BaseNode

Bases: ParamBufferMixin, MemoryModule

Base class for differentiable spiking neurons.

Implements the spiking neuron lifecycle: charge -> adapt -> fire -> reset. Subclasses implement neuronal_charge() and neuronal_adaptation().

Parameters:

Name Type Description Default
n_neuron int | Sequence[int]

Number of neurons (int or tuple).

required
v_threshold float | Float[TensorLike, ' n_neuron']

Firing threshold. Default: 1.0.

1.0
v_reset float | Float[TensorLike, ' n_neuron']

Reset voltage. Default: 0.0.

0.0
trainable_param set[str]

Trainable parameter names. Default: ().

set()
surrogate_function Callable

Surrogate for backprop. Default: Sigmoid().

Sigmoid()
detach_reset bool

Detach reset signal. Default: False.

False
hard_reset bool

Hard vs soft reset. Default: False.

False
pre_spike_v bool

Store pre-spike voltage. Default: False.

False
step_mode

"s" or "m". Default: "s".

's'
backend

Compute backend. Default: "torch".

'torch'
device

Tensor device. Default: None.

None
dtype

Tensor dtype. Default: None.

None
Source code in btorch/models/base.py
class BaseNode(ParamBufferMixin, MemoryModule):
    """Base class for differentiable spiking neurons.

    Implements the spiking neuron lifecycle: charge -> adapt -> fire -> reset.
    Subclasses implement neuronal_charge() and neuronal_adaptation().

    Args:
        n_neuron: Number of neurons (int or tuple).
        v_threshold: Firing threshold. Default: 1.0.
        v_reset: Reset voltage. Default: 0.0.
        trainable_param: Trainable parameter names. Default: ().
        surrogate_function: Surrogate for backprop. Default: Sigmoid().
        detach_reset: Detach reset signal. Default: False.
        hard_reset: Hard vs soft reset. Default: False.
        pre_spike_v: Store pre-spike voltage. Default: False.
        step_mode: "s" or "m". Default: "s".
        backend: Compute backend. Default: "torch".
        device: Tensor device. Default: None.
        dtype: Tensor dtype. Default: None.
    """

    n_neuron: tuple[int, ...]
    size: int
    v: torch.Tensor
    v_pre_spike: torch.Tensor
    v_threshold: torch.Tensor | torch.nn.Parameter
    v_reset: torch.Tensor | torch.nn.Parameter

    def __init__(
        self,
        n_neuron: int | Sequence[int],
        v_threshold: float | Float[TensorLike, " n_neuron"] = 1.0,
        v_reset: float | Float[TensorLike, " n_neuron"] = 0.0,
        trainable_param: set[str] = set(),
        surrogate_function: Callable = Sigmoid(),
        detach_reset: bool = False,
        hard_reset: bool = False,
        pre_spike_v: bool = False,
        step_mode="s",
        backend="torch",
        device=None,
        dtype=None,
    ):
        """Modified spikingjelly BaseNode.

        * :ref:`API in English <BaseNode.__init__-en>`

        This class is the base class of differentiable spiking neurons.
        """

        # override neuron.BaseNode's __init__ method to remove unnecessary checks
        # call neuron.BaseNode's parent MemoryModule directly
        super().__init__()

        self.n_neuron, self.size = normalize_n_neuron(n_neuron)
        self.register_memory("v", v_reset, self.n_neuron)
        self.pre_spike_v = pre_spike_v

        _factory_kwargs: dict[str, Any] = {"device": device, "dtype": dtype}
        if pre_spike_v:
            self.register_memory(
                "v_pre_spike", v_reset, self.n_neuron, persistent=False
            )

        self.trainable_param = set(trainable_param)
        self.def_param(
            "v_threshold",
            v_threshold,
            sizes=self.n_neuron,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )
        self.def_param(
            "v_reset",
            v_reset,
            sizes=self.n_neuron,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )

        self.detach_reset = detach_reset
        self.surrogate_function = surrogate_function
        self.hard_reset = hard_reset

        self.step_mode = step_mode
        self.backend = backend

    def extra_repr(self):
        parts = [
            f"n_neuron={self.n_neuron}",
            f"v_threshold={self._format_repr_value(self.v_threshold)}",
            f"v_reset={self._format_repr_value(self.v_reset)}",
            f"step_mode={self.step_mode}",
            f"backend={self.backend}",
            f"surrogate={self.surrogate_function.__class__.__name__}",
        ]
        if self.detach_reset:
            parts.append("detach_reset=True")
        if self.hard_reset:
            parts.append("hard_reset=True")
        if self.pre_spike_v:
            parts.append("pre_spike_v=True")
        mem_repr = super().extra_repr()
        if mem_repr:
            parts.append(mem_repr)
        return ", ".join(parts)

    @abstractmethod
    def neuronal_charge(self, x: torch.Tensor):
        """
         * :ref:`API in English <BaseNode.neuronal_charge-en>`

        .. _BaseNode.neuronal_charge-cn:

        定义神经元的充电差分方程。子类必须实现这个函数。

        * :ref:`中文API <BaseNode.neuronal_charge-cn>`

        .. _BaseNode.neuronal_charge-en:


        Define the charge difference equation.
        The sub-class must implement this function.
        """
        raise NotImplementedError

    def neuronal_fire(self):
        """
        * :ref:`API in English <BaseNode.neuronal_fire-en>`

        .. _BaseNode.neuronal_fire-cn:

        根据当前神经元的电压、阈值,计算输出脉冲。

        * :ref:`中文API <BaseNode.neuronal_fire-cn>`

        .. _BaseNode.neuronal_fire-en:


        Calculate out spikes of neurons by their current membrane potential
        and threshold voltage.
        """

        return self.surrogate_function(self.v - self.v_threshold)

    def neuronal_reset(self, spike):
        """
        * :ref:`API in English <BaseNode.neuronal_reset-en>`

        .. _BaseNode.neuronal_reset-cn:

        根据当前神经元释放的脉冲,对膜电位进行重置。

        * :ref:`中文API <BaseNode.neuronal_reset-cn>`

        .. _BaseNode.neuronal_reset-en:


        Reset the membrane potential according to neurons' output spikes.
        """
        if self.detach_reset:
            spike_d = spike.detach()
        else:
            spike_d = spike

        if self.pre_spike_v:
            self.v_pre_spike = self.v.clone()

        if self.hard_reset:
            # hard reset
            self.v = self.v - (self.v - self.v_reset) * spike_d
        else:
            # soft reset
            self.v = self.v - (self.v_threshold - self.v_reset) * spike_d

    def neuronal_adaptation(self):
        raise NotImplementedError()

    def single_step_forward(self, x: Float[Tensor, "*batch n_neuron"]):
        """
        * :ref:`API in English <BaseNode.single_step_forward-en>`
        """
        self.neuronal_charge(x)
        self.neuronal_adaptation()
        spike = self.neuronal_fire()
        self.neuronal_reset(spike)
        return spike

    def multi_step_forward(self, x_seq: Float[Tensor, "T *batch n_neuron"]):
        s_seq = []
        for t, x in enumerate(x_seq):
            s = self.single_step_forward(x)
            s_seq.append(s)

        return torch.stack(s_seq)
Functions
__init__(n_neuron, v_threshold=1.0, v_reset=0.0, trainable_param=set(), surrogate_function=Sigmoid(), detach_reset=False, hard_reset=False, pre_spike_v=False, step_mode='s', backend='torch', device=None, dtype=None)

Modified spikingjelly BaseNode.

  • :ref:API in English <BaseNode.__init__-en>

This class is the base class of differentiable spiking neurons.

Source code in btorch/models/base.py
def __init__(
    self,
    n_neuron: int | Sequence[int],
    v_threshold: float | Float[TensorLike, " n_neuron"] = 1.0,
    v_reset: float | Float[TensorLike, " n_neuron"] = 0.0,
    trainable_param: set[str] = set(),
    surrogate_function: Callable = Sigmoid(),
    detach_reset: bool = False,
    hard_reset: bool = False,
    pre_spike_v: bool = False,
    step_mode="s",
    backend="torch",
    device=None,
    dtype=None,
):
    """Modified spikingjelly BaseNode.

    * :ref:`API in English <BaseNode.__init__-en>`

    This class is the base class of differentiable spiking neurons.
    """

    # override neuron.BaseNode's __init__ method to remove unnecessary checks
    # call neuron.BaseNode's parent MemoryModule directly
    super().__init__()

    self.n_neuron, self.size = normalize_n_neuron(n_neuron)
    self.register_memory("v", v_reset, self.n_neuron)
    self.pre_spike_v = pre_spike_v

    _factory_kwargs: dict[str, Any] = {"device": device, "dtype": dtype}
    if pre_spike_v:
        self.register_memory(
            "v_pre_spike", v_reset, self.n_neuron, persistent=False
        )

    self.trainable_param = set(trainable_param)
    self.def_param(
        "v_threshold",
        v_threshold,
        sizes=self.n_neuron,
        trainable_param=self.trainable_param,
        **_factory_kwargs,
    )
    self.def_param(
        "v_reset",
        v_reset,
        sizes=self.n_neuron,
        trainable_param=self.trainable_param,
        **_factory_kwargs,
    )

    self.detach_reset = detach_reset
    self.surrogate_function = surrogate_function
    self.hard_reset = hard_reset

    self.step_mode = step_mode
    self.backend = backend
neuronal_charge(x) abstractmethod
  • :ref:API in English <BaseNode.neuronal_charge-en>

.. _BaseNode.neuronal_charge-cn:

定义神经元的充电差分方程。子类必须实现这个函数。

  • :ref:中文API <BaseNode.neuronal_charge-cn>

.. _BaseNode.neuronal_charge-en:

Define the charge difference equation. The sub-class must implement this function.

Source code in btorch/models/base.py
@abstractmethod
def neuronal_charge(self, x: torch.Tensor):
    """
     * :ref:`API in English <BaseNode.neuronal_charge-en>`

    .. _BaseNode.neuronal_charge-cn:

    定义神经元的充电差分方程。子类必须实现这个函数。

    * :ref:`中文API <BaseNode.neuronal_charge-cn>`

    .. _BaseNode.neuronal_charge-en:


    Define the charge difference equation.
    The sub-class must implement this function.
    """
    raise NotImplementedError
neuronal_fire()
  • :ref:API in English <BaseNode.neuronal_fire-en>

.. _BaseNode.neuronal_fire-cn:

根据当前神经元的电压、阈值,计算输出脉冲。

  • :ref:中文API <BaseNode.neuronal_fire-cn>

.. _BaseNode.neuronal_fire-en:

Calculate out spikes of neurons by their current membrane potential and threshold voltage.

Source code in btorch/models/base.py
def neuronal_fire(self):
    """
    * :ref:`API in English <BaseNode.neuronal_fire-en>`

    .. _BaseNode.neuronal_fire-cn:

    根据当前神经元的电压、阈值,计算输出脉冲。

    * :ref:`中文API <BaseNode.neuronal_fire-cn>`

    .. _BaseNode.neuronal_fire-en:


    Calculate out spikes of neurons by their current membrane potential
    and threshold voltage.
    """

    return self.surrogate_function(self.v - self.v_threshold)
neuronal_reset(spike)
  • :ref:API in English <BaseNode.neuronal_reset-en>

.. _BaseNode.neuronal_reset-cn:

根据当前神经元释放的脉冲,对膜电位进行重置。

  • :ref:中文API <BaseNode.neuronal_reset-cn>

.. _BaseNode.neuronal_reset-en:

Reset the membrane potential according to neurons' output spikes.

Source code in btorch/models/base.py
def neuronal_reset(self, spike):
    """
    * :ref:`API in English <BaseNode.neuronal_reset-en>`

    .. _BaseNode.neuronal_reset-cn:

    根据当前神经元释放的脉冲,对膜电位进行重置。

    * :ref:`中文API <BaseNode.neuronal_reset-cn>`

    .. _BaseNode.neuronal_reset-en:


    Reset the membrane potential according to neurons' output spikes.
    """
    if self.detach_reset:
        spike_d = spike.detach()
    else:
        spike_d = spike

    if self.pre_spike_v:
        self.v_pre_spike = self.v.clone()

    if self.hard_reset:
        # hard reset
        self.v = self.v - (self.v - self.v_reset) * spike_d
    else:
        # soft reset
        self.v = self.v - (self.v_threshold - self.v_reset) * spike_d
single_step_forward(x)
  • :ref:API in English <BaseNode.single_step_forward-en>
Source code in btorch/models/base.py
def single_step_forward(self, x: Float[Tensor, "*batch n_neuron"]):
    """
    * :ref:`API in English <BaseNode.single_step_forward-en>`
    """
    self.neuronal_charge(x)
    self.neuronal_adaptation()
    spike = self.neuronal_fire()
    self.neuronal_reset(spike)
    return spike

GLIF3

Bases: BaseNode

GLIF3 model with after-spike currents and refractory period.

The GLIF3 model extends standard LIF by adding after-spike currents (ASC) that capture spike-frequency adaptation. Each spike adds asc_amps to the ASC vector, which then decays exponentially with time constants 1/k.

Dynamics

dV/dt = -(V - V_rest) / tau + (I_in + sum(I_asc)) / c_m dI_asc/dt = -k * I_asc

At spike: I_asc += asc_amps

Parameters:

Name Type Description Default
n_neuron int | Sequence[int]

Number of neurons (int or tuple of dimensions).

required
v_threshold float | Float[TensorLike, ' n_neuron']

Firing threshold (mV). Default: -50.0.

-50.0
v_reset float | Float[TensorLike, ' n_neuron']

Reset voltage after spike (mV). Default: -70.0.

-70.0
v_rest None | float | Float[TensorLike, ' n_neuron']

Resting potential (mV). Defaults to v_reset if None.

None
c_m float | Float[TensorLike, ' n_neuron']

Membrane capacitance (pF). Default: 0.05.

0.05
tau float | Float[TensorLike, ' n_neuron']

Membrane time constant (ms). Default: 20.0.

20.0
k float | Sequence[float] | Float[TensorLike, 'n_neuron {self.n_Iasc}']

ASC decay rates (ms^-1), can be list for multiple ASC components. Default: [0.2].

[0.2]
asc_amps float | Sequence[float] | Float[TensorLike, 'n_neuron {self.n_Iasc}']

ASC amplitudes (pA) added at each spike. Default: [0.0].

[0.0]
tau_ref float | Float[TensorLike, ' n_neuron'] | None

Refractory period (ms). Default: 0.0.

0.0
trainable_param set[str]

Set of parameter names to make trainable.

set()
surrogate_function Callable

Surrogate gradient function. Default: ATan().

ATan()
detach_reset bool

If True, detach reset signal. Default: False.

False
hard_reset bool

If True, use hard reset. Default: False.

False
pre_spike_v bool

If True, store pre-spike voltage. Default: False.

False
step_mode Literal['s']

Step mode. Default: "s".

's'
backend Literal['torch']

Backend implementation. Default: "torch".

'torch'
device

Device for tensors. Default: None.

None
dtype

Data type for tensors. Default: None.

None

Attributes:

Name Type Description
v Tensor

Membrane potential, shape (*batch, n_neuron).

Iasc Tensor

After-spike currents, shape (*batch, n_neuron, n_Iasc).

refractory Tensor | None

Refractory counter (if tau_ref > 0).

c_m, (tau, tau_ref)

Neuron parameters.

k Tensor | Parameter

ASC decay rates, shape (n_neuron, n_Iasc) or (n_Iasc,).

asc_amps Tensor | Parameter

ASC amplitudes, shape (n_neuron, n_Iasc) or (n_Iasc,).

n_Iasc int

Number of ASC components.

References

Teeter et al., "Generalized leaky integrate-and-fire models classify multiple neuron types," Nature Communications, 2018.

Source code in btorch/models/neurons/glif.py
class GLIF3(BaseNode):
    """GLIF3 model with after-spike currents and refractory period.

    The GLIF3 model extends standard LIF by adding after-spike currents
    (ASC) that capture spike-frequency adaptation. Each spike adds
    asc_amps to the ASC vector, which then decays exponentially with
    time constants 1/k.

    Dynamics:
        dV/dt = -(V - V_rest) / tau + (I_in + sum(I_asc)) / c_m
        dI_asc/dt = -k * I_asc

        At spike: I_asc += asc_amps

    Args:
        n_neuron: Number of neurons (int or tuple of dimensions).
        v_threshold: Firing threshold (mV). Default: -50.0.
        v_reset: Reset voltage after spike (mV). Default: -70.0.
        v_rest: Resting potential (mV). Defaults to v_reset if None.
        c_m: Membrane capacitance (pF). Default: 0.05.
        tau: Membrane time constant (ms). Default: 20.0.
        k: ASC decay rates (ms^-1), can be list for multiple ASC components.
            Default: [0.2].
        asc_amps: ASC amplitudes (pA) added at each spike.
            Default: [0.0].
        tau_ref: Refractory period (ms). Default: 0.0.
        trainable_param: Set of parameter names to make trainable.
        surrogate_function: Surrogate gradient function. Default: ATan().
        detach_reset: If True, detach reset signal. Default: False.
        hard_reset: If True, use hard reset. Default: False.
        pre_spike_v: If True, store pre-spike voltage. Default: False.
        step_mode: Step mode. Default: "s".
        backend: Backend implementation. Default: "torch".
        device: Device for tensors. Default: None.
        dtype: Data type for tensors. Default: None.

    Attributes:
        v: Membrane potential, shape (*batch, n_neuron).
        Iasc: After-spike currents, shape (*batch, n_neuron, n_Iasc).
        refractory: Refractory counter (if tau_ref > 0).
        c_m, tau, tau_ref: Neuron parameters.
        k: ASC decay rates, shape (n_neuron, n_Iasc) or (n_Iasc,).
        asc_amps: ASC amplitudes, shape (n_neuron, n_Iasc) or (n_Iasc,).
        n_Iasc: Number of ASC components.

    References:
        Teeter et al., "Generalized leaky integrate-and-fire models
        classify multiple neuron types," Nature Communications, 2018.
    """

    # make mypy typing and autocompletion easier
    Iasc: torch.Tensor
    refractory: torch.Tensor | None

    c_m: torch.Tensor | torch.nn.Parameter
    tau: torch.Tensor | torch.nn.Parameter
    tau_ref: torch.Tensor | torch.nn.Parameter | None
    k: torch.Tensor | torch.nn.Parameter
    asc_amps: torch.Tensor | torch.nn.Parameter

    def __init__(
        self,
        n_neuron: int | Sequence[int],
        v_threshold: float | Float[TensorLike, " n_neuron"] = -50.0,  # mV
        v_reset: float | Float[TensorLike, " n_neuron"] = -70.0,  # mV
        v_rest: None | float | Float[TensorLike, " n_neuron"] = None,
        c_m: float | Float[TensorLike, " n_neuron"] = 0.05,  # 1/20 pfarad
        tau: float | Float[TensorLike, " n_neuron"] = 20.0,  # ms
        k: float | Sequence[float] | Float[TensorLike, "n_neuron {self.n_Iasc}"] = [
            0.2
        ],  # ms^-1
        asc_amps: float
        | Sequence[float]
        | Float[TensorLike, "n_neuron {self.n_Iasc}"] = [0.0],  # pA
        tau_ref: float | Float[TensorLike, " n_neuron"] | None = 0.0,  # ms
        trainable_param: set[str] = set(),
        surrogate_function: Callable = ATan(),
        detach_reset: bool = False,
        hard_reset: bool = False,
        pre_spike_v: bool = False,
        step_mode: Literal["s"] = "s",
        backend: Literal["torch"] = "torch",
        device=None,
        dtype=None,
    ):
        super().__init__(
            n_neuron=n_neuron,
            v_threshold=v_threshold,
            v_reset=v_reset,
            trainable_param=trainable_param,
            surrogate_function=surrogate_function,
            detach_reset=detach_reset,
            step_mode=step_mode,
            backend=backend,
            pre_spike_v=pre_spike_v,
        )
        _factory_kwargs: dict[str, Any] = {"device": device, "dtype": dtype}
        self.hard_reset = hard_reset
        self.def_param(
            "c_m",
            c_m,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )
        self.def_param(
            "tau",
            tau,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )
        self._use_refractory = tau_ref is not None
        if self._use_refractory:
            self.def_param(
                "tau_ref",
                tau_ref,
                trainable_param=self.trainable_param,
                **_factory_kwargs,
            )
            self.register_memory("refractory", 0.0, self.n_neuron)
        else:
            self.tau_ref = None

        # for compat
        if v_rest is not None:
            self.def_param(
                "_v_rest",
                v_rest,
                trainable_param=self.trainable_param,
                **_factory_kwargs,
            )
        else:
            self._v_rest = None

        # Handle after-spike currents.
        if isinstance(asc_amps, Number):
            asc_amps = [asc_amps]
        if isinstance(k, Number):
            k = [k]

        resolved_asc_sizes = self.def_param_resolve_sizes(
            k,
            asc_amps,
            sizes=self.n_neuron + (None,),
        )
        self.n_Iasc: int = resolved_asc_sizes[-1]

        self.def_param(
            "k",
            k,
            sizes=resolved_asc_sizes,
            trainable_param=self.trainable_param,
            normalize_to_sizes=True,
            **_factory_kwargs,
        )
        self.def_param(
            "asc_amps",
            asc_amps,
            sizes=resolved_asc_sizes,
            trainable_param=self.trainable_param,
            normalize_to_sizes=True,
            **_factory_kwargs,
        )

        self.register_memory(
            "Iasc",
            [
                0.0,
            ]
            * self.n_Iasc,
            self.n_neuron + (self.n_Iasc,),
        )

    @property
    def v_rest(self) -> torch.Tensor:
        """Resting potential (mV).

        For compatibility with GLIF4/GLIF5, falls back to v_reset if
        not explicitly set during initialization.

        Returns:
            Resting potential tensor.
        """
        if self._v_rest is None:
            return self.v_reset
        return self._v_rest

    @v_rest.setter
    def v_rest(self, v_rest: float | torch.Tensor):
        """Set resting potential.

        Args:
            v_rest: New resting potential value (mV).
        """
        if self._v_rest is not None:
            self._v_rest = v_rest

    def dIasc(self, Iasc: Float[Tensor, "*batch n_neuron {self.n_Iasc}"]):
        """Compute ASC derivative for exponential Euler integration.

        Args:
            Iasc: After-spike currents, shape (*batch, n_neuron, n_Iasc).

        Returns:
            Tuple of (derivative, linear_coefficient) for exp_euler_step.
        """
        return -self.k * Iasc, -self.k

    def dV(
        self,
        v: Float[Tensor, "*batch n_neuron"],
        Iasc: Float[Tensor, "*batch n_neuron {self.n_Iasc}"],
        x: Float[Tensor, "*batch n_neuron"],
    ):
        """Compute membrane potential derivative for exp Euler integration.

        Args:
            v: Membrane potential, shape (*batch, n_neuron).
            Iasc: After-spike currents, shape (*batch, n_neuron, n_Iasc).
            x: Input current, shape (*batch, n_neuron).

        Returns:
            Tuple of (derivative, linear_coefficient) for exp_euler_step.
        """
        Isum = x
        # torch.autocast will cast half to float32 for sum op
        # see https://docs.pytorch.org/docs/stable/amp.html#ops-that-can-autocast-to-float32
        # here Iasc generally only have <4 modes, so no overflow guaranteed
        return (
            -(v - self.v_rest) / self.tau
            + (Isum + Iasc.sum(-1, dtype=Iasc.dtype)) / self.c_m,
            -1.0 / self.tau,
        )

    def neuronal_charge(self, x: Float[Tensor, "*batch n_neuron"]):
        v = exp_euler_step(self.dV, self.v, self.Iasc, x, dt=environ.get("dt"))
        self.v = v

    def neuronal_adaptation(self):
        self.Iasc = exp_euler_step(self.dIasc, self.Iasc, dt=environ.get("dt"))

    def neuronal_fire(self):
        # Check if voltage exceeds threshold and not in refractory period
        spike = self.surrogate_function(
            (self.v - self.v_threshold) / (self.v_threshold - self.v_reset)
        )
        if not self._use_refractory:
            return spike
        not_in_refractory = self.refractory == 0
        spike = spike * not_in_refractory.detach().to(self.v.dtype)
        return spike

    def neuronal_reset(self, spike: Float[Tensor, "*batch n"]):
        if self.detach_reset:
            spike_d = spike.detach()
        else:
            spike_d = spike

        if self.pre_spike_v:
            self.v_pre_spike = self.v.clone()

        if self.hard_reset:
            # hard reset
            self.v = self.v - (self.v - self.v_reset) * spike_d
        else:
            # soft reset
            self.v = self.v - (self.v_threshold - self.v_reset) * spike_d

        # Add after-spike currents
        self.Iasc = self.Iasc + self.asc_amps * spike_d[..., None]

        if self._use_refractory:
            # Set refractory period
            self.refractory = torch.relu(
                self.refractory + spike_d * self.tau_ref - environ.get("dt")
            )

    def get_rheobase(self):
        """Calculate rheobase current, the minimum constant input current
        required to make the neuron fire."""
        return get_rheobase(self.v_threshold, self.v_rest, self.c_m, self.tau)

    def extra_repr(self):
        parts = [
            f"c_m={self._format_repr_value(self.c_m)}",
            f"tau={self._format_repr_value(self.tau)}",
            f"tau_ref={self._format_repr_value(self.tau_ref)}"
            if self._use_refractory
            else "tau_ref=None",
            f"n_Iasc={self.n_Iasc}",
            f"k={self._format_repr_value(self.k)}",
            f"asc_amps={self._format_repr_value(self.asc_amps)}",
            "v_rest=auto"
            if self._v_rest is None
            else f"v_rest={self._format_repr_value(self._v_rest)}",
        ]
        base = super().extra_repr()
        if base:
            parts.append(base)
        return ", ".join(parts)

    # TODO: headache to define precise input-output shapes
    # TODO: shape handling not torch.compile friendly
    def _normalize_state_shapes(
        self,
        x: TensorLike | float,
        v0: TensorLike | float,
        Iasc0: TensorLike | float,
        dt: TensorLike | float,
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        device = device or self.v_reset.device
        dtype = dtype or self.v_reset.dtype
        x, v0, Iasc0 = (
            torch.as_tensor(x, device=device, dtype=dtype),
            torch.as_tensor(v0, device=device, dtype=dtype),
            torch.as_tensor(Iasc0, device=device, dtype=dtype),
        )
        if isinstance(dt, float):
            dt = torch.tensor([dt], device=device, dtype=dtype)
        else:
            dt = torch.as_tensor(dt, device=device, dtype=dtype)

        shapes = (x.shape, v0.shape, Iasc0.shape[:-1])
        longest_shape = max(shapes, key=len)
        if dt.shape[0] != longest_shape[0]:
            dt = expand_trailing_dims(dt, longest_shape, broadcast_only=True)

        return x, v0, Iasc0, dt

    def forward_exact_no_spike(
        self,
        x: Float[Tensor, "*batch #neuron"] | Float[Tensor, "*batch"],
        v0: Float[Tensor, "*batch neuron"] | None = None,
        Iasc0: Float[Tensor, "*batch neuron {self.n_Iasc}"] | None = None,
        dt: float
        | Float[TensorLike, "#time *batch neuron"]
        | Float[TensorLike, "#*batch neuron"]
        | Float[TensorLike, "#time *batch"]
        | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        if dt is None:
            dt = environ.get("dt")

        update = (v0 is None) and (Iasc0 is None)
        if v0 is None:
            v0 = self.v
        if Iasc0 is None:
            Iasc0 = self.Iasc

        x, v0, Iasc0, dt = self._normalize_state_shapes(x, v0, Iasc0, dt)

        v_inf = self.v_reset + x * self.tau / self.c_m

        exp_m = torch.exp(-dt / self.tau)
        # (time, batch, neuron, n_Iasc)
        exp_asc = torch.exp(-dt[..., None] * self.k)

        Iasc = Iasc0 * exp_asc

        # degenerate case if tau=tau_asc=1/k
        Iasc_contrib = torch.where(
            torch.abs(self.k - 1 / self.tau[..., None]) > 1e-12,
            (Iasc0 / self.c_m[..., None])
            * (exp_asc - exp_m[..., None])
            / (1.0 / self.tau[..., None] - self.k),
            (Iasc0 / self.c_m[..., None]) * (dt * exp_m)[..., None],
        )
        v = v_inf + (v0 - v_inf) * exp_m + Iasc_contrib.sum(dim=-1)

        if update:
            self.v = v
            self.Iasc = Iasc
        return v, Iasc
Attributes
v_rest property writable

Resting potential (mV).

For compatibility with GLIF4/GLIF5, falls back to v_reset if not explicitly set during initialization.

Returns:

Type Description
Tensor

Resting potential tensor.

Functions
dIasc(Iasc)

Compute ASC derivative for exponential Euler integration.

Parameters:

Name Type Description Default
Iasc Float[Tensor, '*batch n_neuron {self.n_Iasc}']

After-spike currents, shape (*batch, n_neuron, n_Iasc).

required

Returns:

Type Description

Tuple of (derivative, linear_coefficient) for exp_euler_step.

Source code in btorch/models/neurons/glif.py
def dIasc(self, Iasc: Float[Tensor, "*batch n_neuron {self.n_Iasc}"]):
    """Compute ASC derivative for exponential Euler integration.

    Args:
        Iasc: After-spike currents, shape (*batch, n_neuron, n_Iasc).

    Returns:
        Tuple of (derivative, linear_coefficient) for exp_euler_step.
    """
    return -self.k * Iasc, -self.k
dV(v, Iasc, x)

Compute membrane potential derivative for exp Euler integration.

Parameters:

Name Type Description Default
v Float[Tensor, '*batch n_neuron']

Membrane potential, shape (*batch, n_neuron).

required
Iasc Float[Tensor, '*batch n_neuron {self.n_Iasc}']

After-spike currents, shape (*batch, n_neuron, n_Iasc).

required
x Float[Tensor, '*batch n_neuron']

Input current, shape (*batch, n_neuron).

required

Returns:

Type Description

Tuple of (derivative, linear_coefficient) for exp_euler_step.

Source code in btorch/models/neurons/glif.py
def dV(
    self,
    v: Float[Tensor, "*batch n_neuron"],
    Iasc: Float[Tensor, "*batch n_neuron {self.n_Iasc}"],
    x: Float[Tensor, "*batch n_neuron"],
):
    """Compute membrane potential derivative for exp Euler integration.

    Args:
        v: Membrane potential, shape (*batch, n_neuron).
        Iasc: After-spike currents, shape (*batch, n_neuron, n_Iasc).
        x: Input current, shape (*batch, n_neuron).

    Returns:
        Tuple of (derivative, linear_coefficient) for exp_euler_step.
    """
    Isum = x
    # torch.autocast will cast half to float32 for sum op
    # see https://docs.pytorch.org/docs/stable/amp.html#ops-that-can-autocast-to-float32
    # here Iasc generally only have <4 modes, so no overflow guaranteed
    return (
        -(v - self.v_rest) / self.tau
        + (Isum + Iasc.sum(-1, dtype=Iasc.dtype)) / self.c_m,
        -1.0 / self.tau,
    )
get_rheobase()

Calculate rheobase current, the minimum constant input current required to make the neuron fire.

Source code in btorch/models/neurons/glif.py
def get_rheobase(self):
    """Calculate rheobase current, the minimum constant input current
    required to make the neuron fire."""
    return get_rheobase(self.v_threshold, self.v_rest, self.c_m, self.tau)

Functions

exp_euler_step(f, *args, dt=1.0, linear=None)

One integration step applying the exponential Euler method.

.. math:: rac{dx}{dt} = f(x) = Ax + B

where :math:A is the linear term and :math:f(x) is the derivative.

The update rule is:

.. math:: x_{n+1} = x_n + rac{e^{dt A} - 1}{A} f(x_n) \ &= e^{dt A}x_n + rac{e^{dt A} - 1}{A} B

Source code in btorch/models/ode.py
def exp_euler_step(f: Callable, *args, dt=1.0, linear: Tensor | None = None):
    """One integration step applying the exponential Euler method.

    .. math::
        \frac{dx}{dt} = f(x) = Ax + B

    where :math:`A` is the linear term and :math:`f(x)` is the derivative.

    The update rule is:

    .. math::
        x_{n+1} = x_n + \frac{e^{dt A} - 1}{A} f(x_n) \\
        &= e^{dt A}x_n + \frac{e^{dt A} - 1}{A} B
    """
    out = f(*args)
    derivative, linear_from_f = _split_derivative_linear(out)
    if linear is None:
        linear = linear_from_f
    if linear is None:
        if torch.compiler.is_compiling():
            raise RuntimeError(
                "torch.compile cannot use vjp fallback here; return "
                "(derivative, linear) from f(*args) or pass linear=."
            )
        if len(args) > 1:
            _f = lambda x: _derivative_only(f(x, *args[1:]))
        else:
            _f = lambda x: _derivative_only(f(x))
        derivative, linear_f = vjp(_f, args[0])
        linear = linear_f(torch.ones_like(derivative))[0]
    return args[0] + torch.expm1(dt * linear) / linear * derivative

expand_trailing_dims(tensor, target_trailing_shape, match_full_shape=False, broadcast_only=False, view=True)

Source code in btorch/models/shape.py
def expand_trailing_dims(
    tensor: torch.Tensor,
    target_trailing_shape: int | tuple[int, ...],
    match_full_shape: bool = False,
    broadcast_only: bool = False,
    view=True,
) -> torch.Tensor:
    return expand_dims(
        tensor,
        target_trailing_shape,
        match_full_shape,
        position="trailing",
        view=view,
        broadcast_only=broadcast_only,
    )

get_rheobase(v_threshold, v_rest, c_m, tau)

Calculate rheobase current.

The rheobase is the minimum constant input current required to make the neuron fire. For GLIF models: I_rheobase = (v_threshold - v_rest) * c_m / tau

Parameters:

Name Type Description Default
v_threshold float | Tensor

Firing threshold (mV).

required
v_rest float | Tensor

Resting potential (mV).

required
c_m float | Tensor

Membrane capacitance (pF).

required
tau float | Tensor

Membrane time constant (ms).

required

Returns:

Type Description
float | Tensor

Rheobase current (pA).

Source code in btorch/models/neurons/glif.py
def get_rheobase(
    v_threshold: float | torch.Tensor,
    v_rest: float | torch.Tensor,
    c_m: float | torch.Tensor,
    tau: float | torch.Tensor,
) -> float | torch.Tensor:
    """Calculate rheobase current.

    The rheobase is the minimum constant input current required to make
    the neuron fire. For GLIF models:
        I_rheobase = (v_threshold - v_rest) * c_m / tau

    Args:
        v_threshold: Firing threshold (mV).
        v_rest: Resting potential (mV).
        c_m: Membrane capacitance (pF).
        tau: Membrane time constant (ms).

    Returns:
        Rheobase current (pA).
    """
    # For GLIF3, rheobase can be calculated as:
    # I_rheobase = (v_threshold - v_rest) * c_m / tau
    I_rheobase = (v_threshold - v_rest) * c_m / tau
    return I_rheobase

btorch.models.neurons.izhikevich

Izhikevich neuron model.

Efficient 2D model reproducing diverse cortical spiking patterns.

Attributes

TensorLike = np.ndarray | torch.Tensor module-attribute

Classes

BaseNode

Bases: ParamBufferMixin, MemoryModule

Base class for differentiable spiking neurons.

Implements the spiking neuron lifecycle: charge -> adapt -> fire -> reset. Subclasses implement neuronal_charge() and neuronal_adaptation().

Parameters:

Name Type Description Default
n_neuron int | Sequence[int]

Number of neurons (int or tuple).

required
v_threshold float | Float[TensorLike, ' n_neuron']

Firing threshold. Default: 1.0.

1.0
v_reset float | Float[TensorLike, ' n_neuron']

Reset voltage. Default: 0.0.

0.0
trainable_param set[str]

Trainable parameter names. Default: ().

set()
surrogate_function Callable

Surrogate for backprop. Default: Sigmoid().

Sigmoid()
detach_reset bool

Detach reset signal. Default: False.

False
hard_reset bool

Hard vs soft reset. Default: False.

False
pre_spike_v bool

Store pre-spike voltage. Default: False.

False
step_mode

"s" or "m". Default: "s".

's'
backend

Compute backend. Default: "torch".

'torch'
device

Tensor device. Default: None.

None
dtype

Tensor dtype. Default: None.

None
Source code in btorch/models/base.py
class BaseNode(ParamBufferMixin, MemoryModule):
    """Base class for differentiable spiking neurons.

    Implements the spiking neuron lifecycle: charge -> adapt -> fire -> reset.
    Subclasses implement neuronal_charge() and neuronal_adaptation().

    Args:
        n_neuron: Number of neurons (int or tuple).
        v_threshold: Firing threshold. Default: 1.0.
        v_reset: Reset voltage. Default: 0.0.
        trainable_param: Trainable parameter names. Default: ().
        surrogate_function: Surrogate for backprop. Default: Sigmoid().
        detach_reset: Detach reset signal. Default: False.
        hard_reset: Hard vs soft reset. Default: False.
        pre_spike_v: Store pre-spike voltage. Default: False.
        step_mode: "s" or "m". Default: "s".
        backend: Compute backend. Default: "torch".
        device: Tensor device. Default: None.
        dtype: Tensor dtype. Default: None.
    """

    n_neuron: tuple[int, ...]
    size: int
    v: torch.Tensor
    v_pre_spike: torch.Tensor
    v_threshold: torch.Tensor | torch.nn.Parameter
    v_reset: torch.Tensor | torch.nn.Parameter

    def __init__(
        self,
        n_neuron: int | Sequence[int],
        v_threshold: float | Float[TensorLike, " n_neuron"] = 1.0,
        v_reset: float | Float[TensorLike, " n_neuron"] = 0.0,
        trainable_param: set[str] = set(),
        surrogate_function: Callable = Sigmoid(),
        detach_reset: bool = False,
        hard_reset: bool = False,
        pre_spike_v: bool = False,
        step_mode="s",
        backend="torch",
        device=None,
        dtype=None,
    ):
        """Modified spikingjelly BaseNode.

        * :ref:`API in English <BaseNode.__init__-en>`

        This class is the base class of differentiable spiking neurons.
        """

        # override neuron.BaseNode's __init__ method to remove unnecessary checks
        # call neuron.BaseNode's parent MemoryModule directly
        super().__init__()

        self.n_neuron, self.size = normalize_n_neuron(n_neuron)
        self.register_memory("v", v_reset, self.n_neuron)
        self.pre_spike_v = pre_spike_v

        _factory_kwargs: dict[str, Any] = {"device": device, "dtype": dtype}
        if pre_spike_v:
            self.register_memory(
                "v_pre_spike", v_reset, self.n_neuron, persistent=False
            )

        self.trainable_param = set(trainable_param)
        self.def_param(
            "v_threshold",
            v_threshold,
            sizes=self.n_neuron,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )
        self.def_param(
            "v_reset",
            v_reset,
            sizes=self.n_neuron,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )

        self.detach_reset = detach_reset
        self.surrogate_function = surrogate_function
        self.hard_reset = hard_reset

        self.step_mode = step_mode
        self.backend = backend

    def extra_repr(self):
        parts = [
            f"n_neuron={self.n_neuron}",
            f"v_threshold={self._format_repr_value(self.v_threshold)}",
            f"v_reset={self._format_repr_value(self.v_reset)}",
            f"step_mode={self.step_mode}",
            f"backend={self.backend}",
            f"surrogate={self.surrogate_function.__class__.__name__}",
        ]
        if self.detach_reset:
            parts.append("detach_reset=True")
        if self.hard_reset:
            parts.append("hard_reset=True")
        if self.pre_spike_v:
            parts.append("pre_spike_v=True")
        mem_repr = super().extra_repr()
        if mem_repr:
            parts.append(mem_repr)
        return ", ".join(parts)

    @abstractmethod
    def neuronal_charge(self, x: torch.Tensor):
        """
         * :ref:`API in English <BaseNode.neuronal_charge-en>`

        .. _BaseNode.neuronal_charge-cn:

        定义神经元的充电差分方程。子类必须实现这个函数。

        * :ref:`中文API <BaseNode.neuronal_charge-cn>`

        .. _BaseNode.neuronal_charge-en:


        Define the charge difference equation.
        The sub-class must implement this function.
        """
        raise NotImplementedError

    def neuronal_fire(self):
        """
        * :ref:`API in English <BaseNode.neuronal_fire-en>`

        .. _BaseNode.neuronal_fire-cn:

        根据当前神经元的电压、阈值,计算输出脉冲。

        * :ref:`中文API <BaseNode.neuronal_fire-cn>`

        .. _BaseNode.neuronal_fire-en:


        Calculate out spikes of neurons by their current membrane potential
        and threshold voltage.
        """

        return self.surrogate_function(self.v - self.v_threshold)

    def neuronal_reset(self, spike):
        """
        * :ref:`API in English <BaseNode.neuronal_reset-en>`

        .. _BaseNode.neuronal_reset-cn:

        根据当前神经元释放的脉冲,对膜电位进行重置。

        * :ref:`中文API <BaseNode.neuronal_reset-cn>`

        .. _BaseNode.neuronal_reset-en:


        Reset the membrane potential according to neurons' output spikes.
        """
        if self.detach_reset:
            spike_d = spike.detach()
        else:
            spike_d = spike

        if self.pre_spike_v:
            self.v_pre_spike = self.v.clone()

        if self.hard_reset:
            # hard reset
            self.v = self.v - (self.v - self.v_reset) * spike_d
        else:
            # soft reset
            self.v = self.v - (self.v_threshold - self.v_reset) * spike_d

    def neuronal_adaptation(self):
        raise NotImplementedError()

    def single_step_forward(self, x: Float[Tensor, "*batch n_neuron"]):
        """
        * :ref:`API in English <BaseNode.single_step_forward-en>`
        """
        self.neuronal_charge(x)
        self.neuronal_adaptation()
        spike = self.neuronal_fire()
        self.neuronal_reset(spike)
        return spike

    def multi_step_forward(self, x_seq: Float[Tensor, "T *batch n_neuron"]):
        s_seq = []
        for t, x in enumerate(x_seq):
            s = self.single_step_forward(x)
            s_seq.append(s)

        return torch.stack(s_seq)
Functions
__init__(n_neuron, v_threshold=1.0, v_reset=0.0, trainable_param=set(), surrogate_function=Sigmoid(), detach_reset=False, hard_reset=False, pre_spike_v=False, step_mode='s', backend='torch', device=None, dtype=None)

Modified spikingjelly BaseNode.

  • :ref:API in English <BaseNode.__init__-en>

This class is the base class of differentiable spiking neurons.

Source code in btorch/models/base.py
def __init__(
    self,
    n_neuron: int | Sequence[int],
    v_threshold: float | Float[TensorLike, " n_neuron"] = 1.0,
    v_reset: float | Float[TensorLike, " n_neuron"] = 0.0,
    trainable_param: set[str] = set(),
    surrogate_function: Callable = Sigmoid(),
    detach_reset: bool = False,
    hard_reset: bool = False,
    pre_spike_v: bool = False,
    step_mode="s",
    backend="torch",
    device=None,
    dtype=None,
):
    """Modified spikingjelly BaseNode.

    * :ref:`API in English <BaseNode.__init__-en>`

    This class is the base class of differentiable spiking neurons.
    """

    # override neuron.BaseNode's __init__ method to remove unnecessary checks
    # call neuron.BaseNode's parent MemoryModule directly
    super().__init__()

    self.n_neuron, self.size = normalize_n_neuron(n_neuron)
    self.register_memory("v", v_reset, self.n_neuron)
    self.pre_spike_v = pre_spike_v

    _factory_kwargs: dict[str, Any] = {"device": device, "dtype": dtype}
    if pre_spike_v:
        self.register_memory(
            "v_pre_spike", v_reset, self.n_neuron, persistent=False
        )

    self.trainable_param = set(trainable_param)
    self.def_param(
        "v_threshold",
        v_threshold,
        sizes=self.n_neuron,
        trainable_param=self.trainable_param,
        **_factory_kwargs,
    )
    self.def_param(
        "v_reset",
        v_reset,
        sizes=self.n_neuron,
        trainable_param=self.trainable_param,
        **_factory_kwargs,
    )

    self.detach_reset = detach_reset
    self.surrogate_function = surrogate_function
    self.hard_reset = hard_reset

    self.step_mode = step_mode
    self.backend = backend
neuronal_charge(x) abstractmethod
  • :ref:API in English <BaseNode.neuronal_charge-en>

.. _BaseNode.neuronal_charge-cn:

定义神经元的充电差分方程。子类必须实现这个函数。

  • :ref:中文API <BaseNode.neuronal_charge-cn>

.. _BaseNode.neuronal_charge-en:

Define the charge difference equation. The sub-class must implement this function.

Source code in btorch/models/base.py
@abstractmethod
def neuronal_charge(self, x: torch.Tensor):
    """
     * :ref:`API in English <BaseNode.neuronal_charge-en>`

    .. _BaseNode.neuronal_charge-cn:

    定义神经元的充电差分方程。子类必须实现这个函数。

    * :ref:`中文API <BaseNode.neuronal_charge-cn>`

    .. _BaseNode.neuronal_charge-en:


    Define the charge difference equation.
    The sub-class must implement this function.
    """
    raise NotImplementedError
neuronal_fire()
  • :ref:API in English <BaseNode.neuronal_fire-en>

.. _BaseNode.neuronal_fire-cn:

根据当前神经元的电压、阈值,计算输出脉冲。

  • :ref:中文API <BaseNode.neuronal_fire-cn>

.. _BaseNode.neuronal_fire-en:

Calculate out spikes of neurons by their current membrane potential and threshold voltage.

Source code in btorch/models/base.py
def neuronal_fire(self):
    """
    * :ref:`API in English <BaseNode.neuronal_fire-en>`

    .. _BaseNode.neuronal_fire-cn:

    根据当前神经元的电压、阈值,计算输出脉冲。

    * :ref:`中文API <BaseNode.neuronal_fire-cn>`

    .. _BaseNode.neuronal_fire-en:


    Calculate out spikes of neurons by their current membrane potential
    and threshold voltage.
    """

    return self.surrogate_function(self.v - self.v_threshold)
neuronal_reset(spike)
  • :ref:API in English <BaseNode.neuronal_reset-en>

.. _BaseNode.neuronal_reset-cn:

根据当前神经元释放的脉冲,对膜电位进行重置。

  • :ref:中文API <BaseNode.neuronal_reset-cn>

.. _BaseNode.neuronal_reset-en:

Reset the membrane potential according to neurons' output spikes.

Source code in btorch/models/base.py
def neuronal_reset(self, spike):
    """
    * :ref:`API in English <BaseNode.neuronal_reset-en>`

    .. _BaseNode.neuronal_reset-cn:

    根据当前神经元释放的脉冲,对膜电位进行重置。

    * :ref:`中文API <BaseNode.neuronal_reset-cn>`

    .. _BaseNode.neuronal_reset-en:


    Reset the membrane potential according to neurons' output spikes.
    """
    if self.detach_reset:
        spike_d = spike.detach()
    else:
        spike_d = spike

    if self.pre_spike_v:
        self.v_pre_spike = self.v.clone()

    if self.hard_reset:
        # hard reset
        self.v = self.v - (self.v - self.v_reset) * spike_d
    else:
        # soft reset
        self.v = self.v - (self.v_threshold - self.v_reset) * spike_d
single_step_forward(x)
  • :ref:API in English <BaseNode.single_step_forward-en>
Source code in btorch/models/base.py
def single_step_forward(self, x: Float[Tensor, "*batch n_neuron"]):
    """
    * :ref:`API in English <BaseNode.single_step_forward-en>`
    """
    self.neuronal_charge(x)
    self.neuronal_adaptation()
    spike = self.neuronal_fire()
    self.neuronal_reset(spike)
    return spike

Izhikevich

Bases: BaseNode

Izhikevich neuron with quadratic dynamics and recovery variable.

Efficient model reproducing diverse spiking patterns (tonic, bursting, etc.) via a 2D ODE system with quadratic nonlinearity.

Dynamics

dv/dt = (k(v-v_rest)(v-v_threshold) - u + I) / c_m du/dt = a * (b*(v-v_rest) - u)

At spike: v=v_reset, u=u+d

Parameters:

Name Type Description Default
n_neuron int | Sequence[int]

Number of neurons.

required
v_threshold float | Float[TensorLike, ' n_neuron']

Threshold (mV). Default: 30.0.

30.0
v_reset float | Float[TensorLike, ' n_neuron']

Reset voltage (mV). Default: -65.0.

-65.0
v_rest float | Float[TensorLike, ' n_neuron']

Resting potential (mV). Default: -65.0.

-65.0
v_peak float | Float[TensorLike, ' n_neuron']

Spike cutoff (mV). Default: -40.0.

-40.0
c_m float | Float[TensorLike, ' n_neuron']

Capacitance (pF). Default: 100.0.

100.0
k float | Float[TensorLike, ' n_neuron']

Scaling factor (nS/mV). Default: 0.7.

0.7
a float | Float[TensorLike, ' n_neuron']

Recovery timescale (ms^-1). Default: 0.03.

0.03
b float | Float[TensorLike, ' n_neuron']

Recovery coupling (nS). Default: -2.0.

-2.0
d float | Float[TensorLike, ' n_neuron']

Recovery jump (pA). Default: 100.0.

100.0
trainable_param set[str]

Trainable parameters. Default: ().

set()
surrogate_function Callable

Surrogate for backprop. Default: Sigmoid().

Sigmoid()
detach_reset bool

Detach reset signal. Default: False.

False
hard_reset bool

Hard vs soft reset. Default: False.

False
pre_spike bool

Store pre-spike values. Default: False.

False
step_mode Literal['s']

Step mode. Default: "s".

's'
backend Literal['torch']

Backend. Default: "torch".

'torch'
device

Device. Default: None.

None
dtype

Dtype. Default: None.

None

Attributes:

Name Type Description
v Tensor

Membrane potential (*batch, n_neuron).

u Tensor

Recovery variable (*batch, n_neuron).

Reference

Izhikevich, IEEE Trans. Neural Networks, 2003.

Source code in btorch/models/neurons/izhikevich.py
class Izhikevich(BaseNode):
    """Izhikevich neuron with quadratic dynamics and recovery variable.

    Efficient model reproducing diverse spiking patterns (tonic, bursting,
    etc.) via a 2D ODE system with quadratic nonlinearity.

    Dynamics:
        dv/dt = (k*(v-v_rest)*(v-v_threshold) - u + I) / c_m
        du/dt = a * (b*(v-v_rest) - u)

    At spike: v=v_reset, u=u+d

    Args:
        n_neuron: Number of neurons.
        v_threshold: Threshold (mV). Default: 30.0.
        v_reset: Reset voltage (mV). Default: -65.0.
        v_rest: Resting potential (mV). Default: -65.0.
        v_peak: Spike cutoff (mV). Default: -40.0.
        c_m: Capacitance (pF). Default: 100.0.
        k: Scaling factor (nS/mV). Default: 0.7.
        a: Recovery timescale (ms^-1). Default: 0.03.
        b: Recovery coupling (nS). Default: -2.0.
        d: Recovery jump (pA). Default: 100.0.
        trainable_param: Trainable parameters. Default: ().
        surrogate_function: Surrogate for backprop. Default: Sigmoid().
        detach_reset: Detach reset signal. Default: False.
        hard_reset: Hard vs soft reset. Default: False.
        pre_spike: Store pre-spike values. Default: False.
        step_mode: Step mode. Default: "s".
        backend: Backend. Default: "torch".
        device: Device. Default: None.
        dtype: Dtype. Default: None.

    Attributes:
        v: Membrane potential (*batch, n_neuron).
        u: Recovery variable (*batch, n_neuron).

    Reference:
        Izhikevich, IEEE Trans. Neural Networks, 2003.
    """

    HIPPOCAMPOME_TO_ARGS = {
        "k": "k",
        "a": "a",
        "b": "b",
        "d": "d",
        "C": "c_m",
        "vr": "v_rest",
        "vt": "v_threshold",
        "vpeak": "v_peak",
        "vmin": "v_reset",
    }

    u: torch.Tensor
    u_pre_spike: torch.Tensor

    v_reset: torch.Tensor | torch.nn.Parameter
    v_rest: torch.Tensor | torch.nn.Parameter
    v_peak: torch.Tensor | torch.nn.Parameter
    c_m: torch.Tensor | torch.nn.Parameter
    k: torch.Tensor | torch.nn.Parameter
    a: torch.Tensor | torch.nn.Parameter
    b: torch.Tensor | torch.nn.Parameter
    d: torch.Tensor | torch.nn.Parameter

    def __init__(
        self,
        n_neuron: int | Sequence[int],
        v_threshold: float | Float[TensorLike, " n_neuron"] = 30.0,
        v_reset: float | Float[TensorLike, " n_neuron"] = -65.0,
        v_rest: float | Float[TensorLike, " n_neuron"] = -65.0,
        v_peak: float | Float[TensorLike, " n_neuron"] = -40.0,
        c_m: float | Float[TensorLike, " n_neuron"] = 100.0,
        k: float | Float[TensorLike, " n_neuron"] = 0.7,
        a: float | Float[TensorLike, " n_neuron"] = 0.03,
        b: float | Float[TensorLike, " n_neuron"] = -2.0,
        d: float | Float[TensorLike, " n_neuron"] = 100.0,
        trainable_param: set[str] = set(),
        surrogate_function: Callable = Sigmoid(),
        detach_reset: bool = False,
        hard_reset: bool = False,
        pre_spike: bool = False,
        step_mode: Literal["s"] = "s",
        backend: Literal["torch"] = "torch",
        device=None,
        dtype=None,
    ):
        super().__init__(
            n_neuron=n_neuron,
            v_threshold=v_threshold,
            v_reset=v_reset,
            trainable_param=trainable_param,
            surrogate_function=surrogate_function,
            detach_reset=detach_reset,
            hard_reset=hard_reset,
            pre_spike_v=pre_spike,
            step_mode=step_mode,
            backend=backend,
            device=device,
            dtype=dtype,
        )
        _factory_kwargs: dict[str, Any] = {"device": device, "dtype": dtype}
        self.def_param(
            "c_m",
            c_m,
            sizes=self.n_neuron,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )
        self.def_param(
            "v_rest",
            v_rest,
            sizes=self.n_neuron,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )
        self.def_param(
            "v_peak",
            v_peak,
            sizes=self.n_neuron,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )
        self.def_param(
            "k",
            k,
            sizes=self.n_neuron,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )
        self.def_param(
            "a",
            a,
            sizes=self.n_neuron,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )
        self.def_param(
            "b",
            b,
            sizes=self.n_neuron,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )
        self.def_param(
            "d",
            d,
            sizes=self.n_neuron,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )

        self.register_memory("u", 0, self.n_neuron)
        if pre_spike:
            self.register_memory("u_pre_spike", None, self.n_neuron)

    @classmethod
    def from_hippocampome(
        cls,
        n_neuron: int | Sequence[int],
        k,
        a,
        b,
        d,
        C,
        vr,
        vt,
        vpeak,
        vmin,
        **kwargs,
    ):
        """
        Build an :class:`Izhikevich` neuron using parameter names from
        https://hippocampome.org.

        Parameter mapping (HippoCampome -> Izhikevich args):
        - k -> k (scaling factor)
        - a -> a (recovery time constant)
        - b -> b (recovery sensitivity)
        - d -> d (reset current)
        - C -> c_m (capacitance)
        - vr -> v_rest (resting potential)
        - vt -> v_threshold (instantaneous threshold)
        - vpeak -> v_peak (spike cutoff)
        - vmin -> v_reset (post-spike reset voltage)

        All values are expected in the same units as the canonical
        Izhikevich model (mV, pF, pA).
        """
        kwargs.setdefault("pre_spike", True)
        return cls(
            n_neuron,
            v_threshold=vt,
            v_reset=vmin,
            v_rest=vr,
            v_peak=vpeak,
            c_m=C,
            k=k,
            a=a,
            b=b,
            d=d,
            **kwargs,
        )

    @classmethod
    def from_canonical_quadratic(
        cls,
        n_neuron: int | Sequence[int],
        p1: float = 0.04,
        p2: float = 5.0,
        # TODO: p3: float = 0.0, adjust equation
        v_rest: float = -65.0,
        c_m: float = 1.0,
        v_peak: float = 30.0,
        **kwargs,
    ):
        """
        Instantiate using the canonical quadratic form
        ``dV/dt = p1*v^2 + p2*v + p3 - u + I``.

        The mapping assumes ``c_m`` acts as the membrane capacitance and that
        ``k/c_m`` equals ``p1``. The linear term enforces
        ``v_threshold = -p2/p1 - v_rest``. Remaining
        keyword arguments are passed directly to :class:`Izhikevich`.
        """
        k = p1 * c_m
        v_threshold = -p2 / p1 - v_rest
        # i_bias = p3 - p1 * v_rest * v_threshold

        return cls(
            n_neuron,
            v_threshold=v_threshold,
            v_reset=kwargs.pop("v_reset", v_rest),
            v_rest=v_rest,
            v_peak=v_peak,
            c_m=c_m,
            k=k,
            **kwargs,
        )

    def dV(
        self,
        v: Float[Tensor, "*batch n_neuron"],
        u: Float[Tensor, "*batch n_neuron"],
        x: Float[Tensor, "*batch n_neuron"],
    ):
        quadratic = self.k * (v - self.v_rest) * (v - self.v_threshold)
        return (x + quadratic - u) / self.c_m

    def dU(
        self,
        u: Float[Tensor, "*batch n_neuron"],
        v: Float[Tensor, "*batch n_neuron"],
    ):
        return self.a * (self.b * (v - self.v_rest) - u)

    def neuronal_charge(self, x: Float[Tensor, "*batch n_neuron"]):
        dt = environ.get("dt")
        self.v = euler_step(self.dV, self.v, self.u, x, dt=dt)

    def neuronal_adaptation(self):
        dt = environ.get("dt")
        self.u = euler_step(self.dU, self.u, self.v, dt=dt)

    def neuronal_fire(self):
        # TODO: confirm scaling with (self.v_threshold - self.v_reset)
        # or (self.v_peak - self.v_reset)
        spike = self.surrogate_function(
            (self.v - self.v_peak) / (self.v_threshold - self.v_reset)
        )
        return spike

    def neuronal_reset(self, spike: Float[Tensor, "*batch n"]):
        if self.detach_reset:
            spike_d = spike.detach()
        else:
            spike_d = spike

        if self.pre_spike_v:
            self.v_pre_spike = self.v.clone()
            self.u_pre_spike = self.u.clone()

        if self.hard_reset:
            self.v = self.v - (self.v - self.v_reset) * spike_d
        else:
            self.v = self.v - (self.v_peak - self.v_reset) * spike_d

        self.u = self.u + self.d * spike_d

    def extra_repr(self):
        parts = [
            f"c_m={self._format_repr_value(self.c_m)}",
            f"k={self._format_repr_value(self.k)}",
            f"a={self._format_repr_value(self.a)}",
            f"b={self._format_repr_value(self.b)}",
            f"d={self._format_repr_value(self.d)}",
            f"v_rest={self._format_repr_value(self.v_rest)}",
            f"v_peak={self._format_repr_value(self.v_peak)}",
        ]
        base = super().extra_repr()
        if base:
            parts.append(base)
        return ", ".join(parts)
Functions
from_canonical_quadratic(n_neuron, p1=0.04, p2=5.0, v_rest=-65.0, c_m=1.0, v_peak=30.0, **kwargs) classmethod

Instantiate using the canonical quadratic form dV/dt = p1*v^2 + p2*v + p3 - u + I.

The mapping assumes c_m acts as the membrane capacitance and that k/c_m equals p1. The linear term enforces v_threshold = -p2/p1 - v_rest. Remaining keyword arguments are passed directly to :class:Izhikevich.

Source code in btorch/models/neurons/izhikevich.py
@classmethod
def from_canonical_quadratic(
    cls,
    n_neuron: int | Sequence[int],
    p1: float = 0.04,
    p2: float = 5.0,
    # TODO: p3: float = 0.0, adjust equation
    v_rest: float = -65.0,
    c_m: float = 1.0,
    v_peak: float = 30.0,
    **kwargs,
):
    """
    Instantiate using the canonical quadratic form
    ``dV/dt = p1*v^2 + p2*v + p3 - u + I``.

    The mapping assumes ``c_m`` acts as the membrane capacitance and that
    ``k/c_m`` equals ``p1``. The linear term enforces
    ``v_threshold = -p2/p1 - v_rest``. Remaining
    keyword arguments are passed directly to :class:`Izhikevich`.
    """
    k = p1 * c_m
    v_threshold = -p2 / p1 - v_rest
    # i_bias = p3 - p1 * v_rest * v_threshold

    return cls(
        n_neuron,
        v_threshold=v_threshold,
        v_reset=kwargs.pop("v_reset", v_rest),
        v_rest=v_rest,
        v_peak=v_peak,
        c_m=c_m,
        k=k,
        **kwargs,
    )
from_hippocampome(n_neuron, k, a, b, d, C, vr, vt, vpeak, vmin, **kwargs) classmethod

Build an :class:Izhikevich neuron using parameter names from https://hippocampome.org.

Parameter mapping (HippoCampome -> Izhikevich args): - k -> k (scaling factor) - a -> a (recovery time constant) - b -> b (recovery sensitivity) - d -> d (reset current) - C -> c_m (capacitance) - vr -> v_rest (resting potential) - vt -> v_threshold (instantaneous threshold) - vpeak -> v_peak (spike cutoff) - vmin -> v_reset (post-spike reset voltage)

All values are expected in the same units as the canonical Izhikevich model (mV, pF, pA).

Source code in btorch/models/neurons/izhikevich.py
@classmethod
def from_hippocampome(
    cls,
    n_neuron: int | Sequence[int],
    k,
    a,
    b,
    d,
    C,
    vr,
    vt,
    vpeak,
    vmin,
    **kwargs,
):
    """
    Build an :class:`Izhikevich` neuron using parameter names from
    https://hippocampome.org.

    Parameter mapping (HippoCampome -> Izhikevich args):
    - k -> k (scaling factor)
    - a -> a (recovery time constant)
    - b -> b (recovery sensitivity)
    - d -> d (reset current)
    - C -> c_m (capacitance)
    - vr -> v_rest (resting potential)
    - vt -> v_threshold (instantaneous threshold)
    - vpeak -> v_peak (spike cutoff)
    - vmin -> v_reset (post-spike reset voltage)

    All values are expected in the same units as the canonical
    Izhikevich model (mV, pF, pA).
    """
    kwargs.setdefault("pre_spike", True)
    return cls(
        n_neuron,
        v_threshold=vt,
        v_reset=vmin,
        v_rest=vr,
        v_peak=vpeak,
        c_m=C,
        k=k,
        a=a,
        b=b,
        d=d,
        **kwargs,
    )

Sigmoid

Bases: SurrogateFunctionBase

Logistic surrogate derivative.

Source code in btorch/models/surrogate/sigmoid.py
class Sigmoid(SurrogateFunctionBase):
    """Logistic surrogate derivative."""

    def primitive(self, x: torch.Tensor) -> torch.Tensor:
        return _sigmoid_primitive(x, self.alpha)

    def derivative(
        self,
        x: torch.Tensor,
        grad_output: torch.Tensor,
        damping_factor: float = 1.0,
    ) -> torch.Tensor:
        return _sigmoid_derivative(x, grad_output, self.alpha, damping_factor)

Functions

euler_step(f, *args, dt=1.0)

Source code in btorch/models/ode.py
def euler_step(f: Callable, *args, dt=1.0):
    derivative = _derivative_only(f(*args))
    return args[0] + dt * derivative

btorch.models.neurons.lif

Leaky integrate-and-fire (LIF) neuron models.

This module provides LIF and IF (integrate-and-fire) neuron implementations with optional refractory periods. These are the simplest spiking neuron models, suitable for basic neuromorphic computing tasks and as building blocks for more complex networks.

The LIF neuron follows the dynamics

dV/dt = -(V - V_reset) / tau + I / c_m

where V is membrane potential, tau is the time constant, I is input current, and c_m is membrane capacitance.

Attributes

TensorLike = np.ndarray | torch.Tensor module-attribute

Classes

BaseNode

Bases: ParamBufferMixin, MemoryModule

Base class for differentiable spiking neurons.

Implements the spiking neuron lifecycle: charge -> adapt -> fire -> reset. Subclasses implement neuronal_charge() and neuronal_adaptation().

Parameters:

Name Type Description Default
n_neuron int | Sequence[int]

Number of neurons (int or tuple).

required
v_threshold float | Float[TensorLike, ' n_neuron']

Firing threshold. Default: 1.0.

1.0
v_reset float | Float[TensorLike, ' n_neuron']

Reset voltage. Default: 0.0.

0.0
trainable_param set[str]

Trainable parameter names. Default: ().

set()
surrogate_function Callable

Surrogate for backprop. Default: Sigmoid().

Sigmoid()
detach_reset bool

Detach reset signal. Default: False.

False
hard_reset bool

Hard vs soft reset. Default: False.

False
pre_spike_v bool

Store pre-spike voltage. Default: False.

False
step_mode

"s" or "m". Default: "s".

's'
backend

Compute backend. Default: "torch".

'torch'
device

Tensor device. Default: None.

None
dtype

Tensor dtype. Default: None.

None
Source code in btorch/models/base.py
class BaseNode(ParamBufferMixin, MemoryModule):
    """Base class for differentiable spiking neurons.

    Implements the spiking neuron lifecycle: charge -> adapt -> fire -> reset.
    Subclasses implement neuronal_charge() and neuronal_adaptation().

    Args:
        n_neuron: Number of neurons (int or tuple).
        v_threshold: Firing threshold. Default: 1.0.
        v_reset: Reset voltage. Default: 0.0.
        trainable_param: Trainable parameter names. Default: ().
        surrogate_function: Surrogate for backprop. Default: Sigmoid().
        detach_reset: Detach reset signal. Default: False.
        hard_reset: Hard vs soft reset. Default: False.
        pre_spike_v: Store pre-spike voltage. Default: False.
        step_mode: "s" or "m". Default: "s".
        backend: Compute backend. Default: "torch".
        device: Tensor device. Default: None.
        dtype: Tensor dtype. Default: None.
    """

    n_neuron: tuple[int, ...]
    size: int
    v: torch.Tensor
    v_pre_spike: torch.Tensor
    v_threshold: torch.Tensor | torch.nn.Parameter
    v_reset: torch.Tensor | torch.nn.Parameter

    def __init__(
        self,
        n_neuron: int | Sequence[int],
        v_threshold: float | Float[TensorLike, " n_neuron"] = 1.0,
        v_reset: float | Float[TensorLike, " n_neuron"] = 0.0,
        trainable_param: set[str] = set(),
        surrogate_function: Callable = Sigmoid(),
        detach_reset: bool = False,
        hard_reset: bool = False,
        pre_spike_v: bool = False,
        step_mode="s",
        backend="torch",
        device=None,
        dtype=None,
    ):
        """Modified spikingjelly BaseNode.

        * :ref:`API in English <BaseNode.__init__-en>`

        This class is the base class of differentiable spiking neurons.
        """

        # override neuron.BaseNode's __init__ method to remove unnecessary checks
        # call neuron.BaseNode's parent MemoryModule directly
        super().__init__()

        self.n_neuron, self.size = normalize_n_neuron(n_neuron)
        self.register_memory("v", v_reset, self.n_neuron)
        self.pre_spike_v = pre_spike_v

        _factory_kwargs: dict[str, Any] = {"device": device, "dtype": dtype}
        if pre_spike_v:
            self.register_memory(
                "v_pre_spike", v_reset, self.n_neuron, persistent=False
            )

        self.trainable_param = set(trainable_param)
        self.def_param(
            "v_threshold",
            v_threshold,
            sizes=self.n_neuron,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )
        self.def_param(
            "v_reset",
            v_reset,
            sizes=self.n_neuron,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )

        self.detach_reset = detach_reset
        self.surrogate_function = surrogate_function
        self.hard_reset = hard_reset

        self.step_mode = step_mode
        self.backend = backend

    def extra_repr(self):
        parts = [
            f"n_neuron={self.n_neuron}",
            f"v_threshold={self._format_repr_value(self.v_threshold)}",
            f"v_reset={self._format_repr_value(self.v_reset)}",
            f"step_mode={self.step_mode}",
            f"backend={self.backend}",
            f"surrogate={self.surrogate_function.__class__.__name__}",
        ]
        if self.detach_reset:
            parts.append("detach_reset=True")
        if self.hard_reset:
            parts.append("hard_reset=True")
        if self.pre_spike_v:
            parts.append("pre_spike_v=True")
        mem_repr = super().extra_repr()
        if mem_repr:
            parts.append(mem_repr)
        return ", ".join(parts)

    @abstractmethod
    def neuronal_charge(self, x: torch.Tensor):
        """
         * :ref:`API in English <BaseNode.neuronal_charge-en>`

        .. _BaseNode.neuronal_charge-cn:

        定义神经元的充电差分方程。子类必须实现这个函数。

        * :ref:`中文API <BaseNode.neuronal_charge-cn>`

        .. _BaseNode.neuronal_charge-en:


        Define the charge difference equation.
        The sub-class must implement this function.
        """
        raise NotImplementedError

    def neuronal_fire(self):
        """
        * :ref:`API in English <BaseNode.neuronal_fire-en>`

        .. _BaseNode.neuronal_fire-cn:

        根据当前神经元的电压、阈值,计算输出脉冲。

        * :ref:`中文API <BaseNode.neuronal_fire-cn>`

        .. _BaseNode.neuronal_fire-en:


        Calculate out spikes of neurons by their current membrane potential
        and threshold voltage.
        """

        return self.surrogate_function(self.v - self.v_threshold)

    def neuronal_reset(self, spike):
        """
        * :ref:`API in English <BaseNode.neuronal_reset-en>`

        .. _BaseNode.neuronal_reset-cn:

        根据当前神经元释放的脉冲,对膜电位进行重置。

        * :ref:`中文API <BaseNode.neuronal_reset-cn>`

        .. _BaseNode.neuronal_reset-en:


        Reset the membrane potential according to neurons' output spikes.
        """
        if self.detach_reset:
            spike_d = spike.detach()
        else:
            spike_d = spike

        if self.pre_spike_v:
            self.v_pre_spike = self.v.clone()

        if self.hard_reset:
            # hard reset
            self.v = self.v - (self.v - self.v_reset) * spike_d
        else:
            # soft reset
            self.v = self.v - (self.v_threshold - self.v_reset) * spike_d

    def neuronal_adaptation(self):
        raise NotImplementedError()

    def single_step_forward(self, x: Float[Tensor, "*batch n_neuron"]):
        """
        * :ref:`API in English <BaseNode.single_step_forward-en>`
        """
        self.neuronal_charge(x)
        self.neuronal_adaptation()
        spike = self.neuronal_fire()
        self.neuronal_reset(spike)
        return spike

    def multi_step_forward(self, x_seq: Float[Tensor, "T *batch n_neuron"]):
        s_seq = []
        for t, x in enumerate(x_seq):
            s = self.single_step_forward(x)
            s_seq.append(s)

        return torch.stack(s_seq)
Functions
__init__(n_neuron, v_threshold=1.0, v_reset=0.0, trainable_param=set(), surrogate_function=Sigmoid(), detach_reset=False, hard_reset=False, pre_spike_v=False, step_mode='s', backend='torch', device=None, dtype=None)

Modified spikingjelly BaseNode.

  • :ref:API in English <BaseNode.__init__-en>

This class is the base class of differentiable spiking neurons.

Source code in btorch/models/base.py
def __init__(
    self,
    n_neuron: int | Sequence[int],
    v_threshold: float | Float[TensorLike, " n_neuron"] = 1.0,
    v_reset: float | Float[TensorLike, " n_neuron"] = 0.0,
    trainable_param: set[str] = set(),
    surrogate_function: Callable = Sigmoid(),
    detach_reset: bool = False,
    hard_reset: bool = False,
    pre_spike_v: bool = False,
    step_mode="s",
    backend="torch",
    device=None,
    dtype=None,
):
    """Modified spikingjelly BaseNode.

    * :ref:`API in English <BaseNode.__init__-en>`

    This class is the base class of differentiable spiking neurons.
    """

    # override neuron.BaseNode's __init__ method to remove unnecessary checks
    # call neuron.BaseNode's parent MemoryModule directly
    super().__init__()

    self.n_neuron, self.size = normalize_n_neuron(n_neuron)
    self.register_memory("v", v_reset, self.n_neuron)
    self.pre_spike_v = pre_spike_v

    _factory_kwargs: dict[str, Any] = {"device": device, "dtype": dtype}
    if pre_spike_v:
        self.register_memory(
            "v_pre_spike", v_reset, self.n_neuron, persistent=False
        )

    self.trainable_param = set(trainable_param)
    self.def_param(
        "v_threshold",
        v_threshold,
        sizes=self.n_neuron,
        trainable_param=self.trainable_param,
        **_factory_kwargs,
    )
    self.def_param(
        "v_reset",
        v_reset,
        sizes=self.n_neuron,
        trainable_param=self.trainable_param,
        **_factory_kwargs,
    )

    self.detach_reset = detach_reset
    self.surrogate_function = surrogate_function
    self.hard_reset = hard_reset

    self.step_mode = step_mode
    self.backend = backend
neuronal_charge(x) abstractmethod
  • :ref:API in English <BaseNode.neuronal_charge-en>

.. _BaseNode.neuronal_charge-cn:

定义神经元的充电差分方程。子类必须实现这个函数。

  • :ref:中文API <BaseNode.neuronal_charge-cn>

.. _BaseNode.neuronal_charge-en:

Define the charge difference equation. The sub-class must implement this function.

Source code in btorch/models/base.py
@abstractmethod
def neuronal_charge(self, x: torch.Tensor):
    """
     * :ref:`API in English <BaseNode.neuronal_charge-en>`

    .. _BaseNode.neuronal_charge-cn:

    定义神经元的充电差分方程。子类必须实现这个函数。

    * :ref:`中文API <BaseNode.neuronal_charge-cn>`

    .. _BaseNode.neuronal_charge-en:


    Define the charge difference equation.
    The sub-class must implement this function.
    """
    raise NotImplementedError
neuronal_fire()
  • :ref:API in English <BaseNode.neuronal_fire-en>

.. _BaseNode.neuronal_fire-cn:

根据当前神经元的电压、阈值,计算输出脉冲。

  • :ref:中文API <BaseNode.neuronal_fire-cn>

.. _BaseNode.neuronal_fire-en:

Calculate out spikes of neurons by their current membrane potential and threshold voltage.

Source code in btorch/models/base.py
def neuronal_fire(self):
    """
    * :ref:`API in English <BaseNode.neuronal_fire-en>`

    .. _BaseNode.neuronal_fire-cn:

    根据当前神经元的电压、阈值,计算输出脉冲。

    * :ref:`中文API <BaseNode.neuronal_fire-cn>`

    .. _BaseNode.neuronal_fire-en:


    Calculate out spikes of neurons by their current membrane potential
    and threshold voltage.
    """

    return self.surrogate_function(self.v - self.v_threshold)
neuronal_reset(spike)
  • :ref:API in English <BaseNode.neuronal_reset-en>

.. _BaseNode.neuronal_reset-cn:

根据当前神经元释放的脉冲,对膜电位进行重置。

  • :ref:中文API <BaseNode.neuronal_reset-cn>

.. _BaseNode.neuronal_reset-en:

Reset the membrane potential according to neurons' output spikes.

Source code in btorch/models/base.py
def neuronal_reset(self, spike):
    """
    * :ref:`API in English <BaseNode.neuronal_reset-en>`

    .. _BaseNode.neuronal_reset-cn:

    根据当前神经元释放的脉冲,对膜电位进行重置。

    * :ref:`中文API <BaseNode.neuronal_reset-cn>`

    .. _BaseNode.neuronal_reset-en:


    Reset the membrane potential according to neurons' output spikes.
    """
    if self.detach_reset:
        spike_d = spike.detach()
    else:
        spike_d = spike

    if self.pre_spike_v:
        self.v_pre_spike = self.v.clone()

    if self.hard_reset:
        # hard reset
        self.v = self.v - (self.v - self.v_reset) * spike_d
    else:
        # soft reset
        self.v = self.v - (self.v_threshold - self.v_reset) * spike_d
single_step_forward(x)
  • :ref:API in English <BaseNode.single_step_forward-en>
Source code in btorch/models/base.py
def single_step_forward(self, x: Float[Tensor, "*batch n_neuron"]):
    """
    * :ref:`API in English <BaseNode.single_step_forward-en>`
    """
    self.neuronal_charge(x)
    self.neuronal_adaptation()
    spike = self.neuronal_fire()
    self.neuronal_reset(spike)
    return spike

IF

Bases: LIF

Integrate-and-fire neuron without leak.

Simplified variant of LIF that lacks the leak term, meaning the membrane potential integrates input current linearly without decay:

dV/dt = I / c_m

This model is useful for theoretical analysis and as a baseline, though it lacks biological realism due to unbounded integration.

Parameters:

Name Type Description Default
n_neuron int | Sequence[int]

Number of neurons (int or tuple of dimensions).

required
v_threshold float | Float[TensorLike, ' n_neuron']

Firing threshold. Default: 1.0.

1.0
v_reset float | Float[TensorLike, ' n_neuron']

Reset voltage after spike. Default: 0.0.

0.0
c_m float | Float[TensorLike, ' n_neuron']

Membrane capacitance. Default: 1.0.

1.0
tau float | Float[TensorLike, ' n_neuron']

Time constant (inherited from LIF but not used in dynamics).

20.0
tau_ref float | Float[TensorLike, ' n_neuron'] | None

Refractory period duration. Default: None.

None
trainable_param set[str]

Set of parameter names to make trainable.

set()
surrogate_function Callable

Surrogate gradient function. Default: Sigmoid().

Sigmoid()
detach_reset bool

If True, detach reset signal. Default: False.

False
hard_reset bool

If True, use hard reset. Default: False.

False
pre_spike_v bool

If True, store pre-spike voltage. Default: False.

False
step_mode Literal['s']

Step mode. Default: "s".

's'
backend Literal['torch']

Backend implementation. Default: "torch".

'torch'
device

Device for tensors. Default: None.

None
dtype

Data type for tensors. Default: None.

None
Source code in btorch/models/neurons/lif.py
class IF(LIF):
    """Integrate-and-fire neuron without leak.

    Simplified variant of LIF that lacks the leak term, meaning the membrane
    potential integrates input current linearly without decay:

        dV/dt = I / c_m

    This model is useful for theoretical analysis and as a baseline,
    though it lacks biological realism due to unbounded integration.

    Args:
        n_neuron: Number of neurons (int or tuple of dimensions).
        v_threshold: Firing threshold. Default: 1.0.
        v_reset: Reset voltage after spike. Default: 0.0.
        c_m: Membrane capacitance. Default: 1.0.
        tau: Time constant (inherited from LIF but not used in dynamics).
        tau_ref: Refractory period duration. Default: None.
        trainable_param: Set of parameter names to make trainable.
        surrogate_function: Surrogate gradient function. Default: Sigmoid().
        detach_reset: If True, detach reset signal. Default: False.
        hard_reset: If True, use hard reset. Default: False.
        pre_spike_v: If True, store pre-spike voltage. Default: False.
        step_mode: Step mode. Default: "s".
        backend: Backend implementation. Default: "torch".
        device: Device for tensors. Default: None.
        dtype: Data type for tensors. Default: None.
    """

    def dV(
        self,
        x: Float[Tensor, "*batch n_neuron"],
    ) -> Float[Tensor, "*batch n_neuron"]:
        """Compute membrane potential derivative (no leak term).

        Args:
            x: Input current, shape (*batch, n_neuron).

        Returns:
            dV/dt derivative, shape (*batch, n_neuron).
        """
        derivative = x / self.c_m
        return derivative

    def neuronal_charge(self, x: Float[Tensor, "*batch n_neuron"]):
        """Update membrane potential using Euler integration (no leak).

        Args:
            x: Input current, shape (*batch, n_neuron).
        """
        v = euler_step(self.dV, x, dt=environ.get("dt"))
        self.v = v
Functions
dV(x)

Compute membrane potential derivative (no leak term).

Parameters:

Name Type Description Default
x Float[Tensor, '*batch n_neuron']

Input current, shape (*batch, n_neuron).

required

Returns:

Type Description
Float[Tensor, '*batch n_neuron']

dV/dt derivative, shape (*batch, n_neuron).

Source code in btorch/models/neurons/lif.py
def dV(
    self,
    x: Float[Tensor, "*batch n_neuron"],
) -> Float[Tensor, "*batch n_neuron"]:
    """Compute membrane potential derivative (no leak term).

    Args:
        x: Input current, shape (*batch, n_neuron).

    Returns:
        dV/dt derivative, shape (*batch, n_neuron).
    """
    derivative = x / self.c_m
    return derivative
neuronal_charge(x)

Update membrane potential using Euler integration (no leak).

Parameters:

Name Type Description Default
x Float[Tensor, '*batch n_neuron']

Input current, shape (*batch, n_neuron).

required
Source code in btorch/models/neurons/lif.py
def neuronal_charge(self, x: Float[Tensor, "*batch n_neuron"]):
    """Update membrane potential using Euler integration (no leak).

    Args:
        x: Input current, shape (*batch, n_neuron).
    """
    v = euler_step(self.dV, x, dt=environ.get("dt"))
    self.v = v

LIF

Bases: BaseNode

Leaky integrate-and-fire neuron with optional refractory period.

The LIF neuron integrates input current while leaking towards a resting potential. When the membrane potential exceeds a threshold, a spike is emitted and the potential is reset.

Dynamics

dV/dt = -(V - V_reset) / tau + I / c_m

If tau_ref is specified, a refractory period prevents spiking for tau_ref milliseconds after each spike.

Parameters:

Name Type Description Default
n_neuron int | Sequence[int]

Number of neurons (int or tuple of dimensions).

required
v_threshold float | Float[TensorLike, ' n_neuron']

Firing threshold (mV). Default: 1.0.

1.0
v_reset float | Float[TensorLike, ' n_neuron']

Reset voltage after spike (mV). Default: 0.0.

0.0
c_m float | Float[TensorLike, ' n_neuron']

Membrane capacitance. Default: 1.0.

1.0
tau float | Float[TensorLike, ' n_neuron']

Membrane time constant (ms). Default: 20.0.

20.0
tau_ref float | Float[TensorLike, ' n_neuron'] | None

Refractory period duration (ms). None disables refractory behavior. Default: None.

None
trainable_param set[str]

Set of parameter names to make trainable. Default: empty set.

set()
surrogate_function Callable

Surrogate gradient function for backpropagation. Default: Sigmoid().

Sigmoid()
detach_reset bool

If True, detach reset signal from computation graph. Default: False.

False
hard_reset bool

If True, reset to v_reset directly. If False, subtract (v_threshold - v_reset) from membrane potential (soft reset). Default: False.

False
pre_spike_v bool

If True, store pre-spike voltage in v_pre_spike buffer. Default: False.

False
step_mode Literal['s']

Step mode, currently only "s" (single step) supported. Default: "s".

's'
backend Literal['torch']

Backend implementation. Default: "torch".

'torch'
device

Device for tensors. Default: None.

None
dtype

Data type for tensors. Default: None.

None

Attributes:

Name Type Description
v Tensor

Membrane potential tensor, shape (*batch, n_neuron).

refractory Tensor | None

Refractory counter (if tau_ref specified).

c_m Tensor | Parameter

Membrane capacitance (parameter or buffer).

tau Tensor | Parameter

Time constant (parameter or buffer).

tau_ref Tensor | Parameter | None

Refractory period (parameter or buffer, or None).

Shape
  • Input: (*batch, n_neuron)
  • Output: (*batch, n_neuron) spike tensor (0 or 1)
Source code in btorch/models/neurons/lif.py
class LIF(BaseNode):
    """Leaky integrate-and-fire neuron with optional refractory period.

    The LIF neuron integrates input current while leaking towards a resting
    potential. When the membrane potential exceeds a threshold, a spike is
    emitted and the potential is reset.

    Dynamics:
        dV/dt = -(V - V_reset) / tau + I / c_m

        If tau_ref is specified, a refractory period prevents spiking for
        tau_ref milliseconds after each spike.

    Args:
        n_neuron: Number of neurons (int or tuple of dimensions).
        v_threshold: Firing threshold (mV). Default: 1.0.
        v_reset: Reset voltage after spike (mV). Default: 0.0.
        c_m: Membrane capacitance. Default: 1.0.
        tau: Membrane time constant (ms). Default: 20.0.
        tau_ref: Refractory period duration (ms). None disables refractory
            behavior. Default: None.
        trainable_param: Set of parameter names to make trainable.
            Default: empty set.
        surrogate_function: Surrogate gradient function for backpropagation.
            Default: Sigmoid().
        detach_reset: If True, detach reset signal from computation graph.
            Default: False.
        hard_reset: If True, reset to v_reset directly. If False, subtract
            (v_threshold - v_reset) from membrane potential (soft reset).
            Default: False.
        pre_spike_v: If True, store pre-spike voltage in v_pre_spike buffer.
            Default: False.
        step_mode: Step mode, currently only "s" (single step) supported.
            Default: "s".
        backend: Backend implementation. Default: "torch".
        device: Device for tensors. Default: None.
        dtype: Data type for tensors. Default: None.

    Attributes:
        v: Membrane potential tensor, shape (*batch, n_neuron).
        refractory: Refractory counter (if tau_ref specified).
        c_m: Membrane capacitance (parameter or buffer).
        tau: Time constant (parameter or buffer).
        tau_ref: Refractory period (parameter or buffer, or None).

    Shape:
        - Input: (*batch, n_neuron)
        - Output: (*batch, n_neuron) spike tensor (0 or 1)
    """

    refractory: torch.Tensor | None
    c_m: torch.Tensor | torch.nn.Parameter
    tau: torch.Tensor | torch.nn.Parameter
    tau_ref: torch.Tensor | torch.nn.Parameter | None

    def __init__(
        self,
        n_neuron: int | Sequence[int],
        v_threshold: float | Float[TensorLike, " n_neuron"] = 1.0,
        v_reset: float | Float[TensorLike, " n_neuron"] = 0.0,
        c_m: float | Float[TensorLike, " n_neuron"] = 1.0,
        tau: float | Float[TensorLike, " n_neuron"] = 20.0,
        tau_ref: float | Float[TensorLike, " n_neuron"] | None = None,
        trainable_param: set[str] = set(),
        surrogate_function: Callable = Sigmoid(),
        detach_reset: bool = False,
        hard_reset: bool = False,
        pre_spike_v: bool = False,
        step_mode: Literal["s"] = "s",
        backend: Literal["torch"] = "torch",
        device=None,
        dtype=None,
    ):
        super().__init__(
            n_neuron=n_neuron,
            v_threshold=v_threshold,
            v_reset=v_reset,
            trainable_param=trainable_param,
            surrogate_function=surrogate_function,
            detach_reset=detach_reset,
            hard_reset=hard_reset,
            pre_spike_v=pre_spike_v,
            step_mode=step_mode,
            backend=backend,
            device=device,
            dtype=dtype,
        )
        _factory_kwargs: dict[str, Any] = {"device": device, "dtype": dtype}
        self.def_param(
            "c_m",
            c_m,
            sizes=self.n_neuron,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )
        self.def_param(
            "tau",
            tau,
            sizes=self.n_neuron,
            trainable_param=self.trainable_param,
            **_factory_kwargs,
        )
        self._use_refractory = tau_ref is not None
        if self._use_refractory:
            self.def_param(
                "tau_ref",
                tau_ref,
                sizes=self.n_neuron,
                trainable_param=self.trainable_param,
                **_factory_kwargs,
            )
            self.register_memory("refractory", 0.0, self.n_neuron)
        else:
            self.tau_ref = None

    def dV(
        self,
        v: Float[Tensor, "*batch n_neuron"],
        x: Float[Tensor, "*batch n_neuron"],
    ):
        derivative = -(v - self.v_reset) / self.tau + x / self.c_m
        return derivative

    def neuronal_charge(self, x: Float[Tensor, "*batch n_neuron"]):
        v = euler_step(self.dV, self.v, x, dt=environ.get("dt"))
        self.v = v

    def neuronal_adaptation(self):
        # LIF has no intrinsic adaptation other than the refractory counter.
        return None

    def neuronal_fire(self):
        spike = self.surrogate_function(
            (self.v - self.v_threshold) / (self.v_threshold - self.v_reset)
        )
        if not self._use_refractory:
            return spike
        not_in_refractory = self.refractory == 0
        spike = spike * not_in_refractory.detach().to(self.v.dtype)
        return spike

    def neuronal_reset(self, spike: Float[Tensor, "*batch n"]):
        if self.detach_reset:
            spike_d = spike.detach()
        else:
            spike_d = spike

        if self.pre_spike_v:
            self.v_pre_spike = self.v.clone()

        if self.hard_reset:
            self.v = self.v - (self.v - self.v_reset) * spike_d
        else:
            self.v = self.v - (self.v_threshold - self.v_reset) * spike_d

        if self._use_refractory:
            self.refractory = torch.relu(
                self.refractory + spike_d * self.tau_ref - environ.get("dt")
            )

    def extra_repr(self):
        parts = [
            f"c_m={self._format_repr_value(self.c_m)}",
            f"tau={self._format_repr_value(self.tau)}",
            f"tau_ref={self._format_repr_value(self.tau_ref)}"
            if self._use_refractory
            else "tau_ref=None",
        ]
        base = super().extra_repr()
        if base:
            parts.append(base)
        return ", ".join(parts)

Sigmoid

Bases: SurrogateFunctionBase

Logistic surrogate derivative.

Source code in btorch/models/surrogate/sigmoid.py
class Sigmoid(SurrogateFunctionBase):
    """Logistic surrogate derivative."""

    def primitive(self, x: torch.Tensor) -> torch.Tensor:
        return _sigmoid_primitive(x, self.alpha)

    def derivative(
        self,
        x: torch.Tensor,
        grad_output: torch.Tensor,
        damping_factor: float = 1.0,
    ) -> torch.Tensor:
        return _sigmoid_derivative(x, grad_output, self.alpha, damping_factor)

Functions

euler_step(f, *args, dt=1.0)

Source code in btorch/models/ode.py
def euler_step(f: Callable, *args, dt=1.0):
    derivative = _derivative_only(f(*args))
    return args[0] + dt * derivative