Tutorial 1: Building an RSNN¶
Author: btorch contributors
Based on: examples/rsnn.py
This tutorial walks through building a minimal Recurrent Spiking Neural Network (RSNN) with btorch.
What You Will Learn¶
- How to compose a neuron (
GLIF3), a synapse (AlphaPSC), and a recurrent wrapper (RecurrentNN). - How to initialize and reset network state.
- How to record internal state variables (voltage, current) during a forward pass.
- How
dtcontext governs ODE integration.
The Building Blocks¶
A minimal RSNN needs three things:
- Neuron model — defines how membrane voltage evolves and when spikes fire.
- Synapse model — transforms spikes into post-synaptic currents.
- Connection layer — defines the weight matrix between neurons.
import torch
import torch.nn as nn
from btorch.models import environ, functional, glif, rnn, synapse
from btorch.models.linear import DenseConn
class MinimalRSNN(nn.Module):
def __init__(self, num_input, num_hidden, num_output, device="cpu"):
super().__init__()
# 1. Input projection
self.fc_in = nn.Linear(num_input, num_hidden, bias=False, device=device)
# 2. Spiking neuron
neuron_module = glif.GLIF3(
n_neuron=num_hidden,
v_threshold=-45.0,
v_reset=-60.0,
c_m=2.0,
tau=20.0,
tau_ref=2.0,
k=[0.1, 0.2],
asc_amps=[1.0, -2.0],
step_mode="s", # single-step definition
backend="torch",
device=device,
)
# 3. Recurrent connection
conn = DenseConn(num_hidden, num_hidden, bias=None, device=device)
# 4. Synapse
psc_module = synapse.AlphaPSC(
n_neuron=num_hidden,
tau_syn=5.0,
linear=conn,
step_mode="s",
)
# 5. Recurrent wrapper (multi-step)
self.brain = rnn.RecurrentNN(
neuron=neuron_module,
synapse=psc_module,
step_mode="m",
update_state_names=("neuron.v", "synapse.psc"),
)
# 6. Output readout
self.fc_out = nn.Linear(num_hidden, num_output, bias=False, device=device)
def forward(self, x):
x = self.fc_in(x) # (T, Batch, num_input) -> (T, Batch, N)
spike, states = self.brain(x) # spike: (T, Batch, N)
rate = spike.mean(dim=0) # (Batch, N)
out = self.fc_out(rate) # (Batch, num_output)
return out
Initializing State¶
Before the first forward pass, call init_net_state to register and initialize memory buffers:
model = MinimalRSNN(num_input=20, num_hidden=64, num_output=5)
functional.init_net_state(model, batch_size=4, device="cpu")
Running a Forward Pass¶
Wrap the forward pass in an environ.context(dt=...) block:
environ.set(dt=1.0)
inputs = torch.rand((100, 4, 20)) # (T, Batch, input_dim)
functional.reset_net(model, batch_size=4)
with environ.context(dt=1.0):
out = model(inputs) # (Batch, num_output)
Inspecting Recorded States¶
update_state_names tells RecurrentNN which variables to save. The returned states dict uses dot notation:
with environ.context(dt=1.0):
spike, states = model.brain(inputs)
print(states["neuron.v"].shape) # (T, Batch, N)
print(states["synapse.psc"].shape) # (T, Batch, N)
You can unflatten the dict for easier access:
from btorch.utils.dict_utils import unflatten_dict
nested = unflatten_dict(states, dot=True)
nested["neuron"]["v"][:, 0, :] # voltage of batch 0
What Comes Next¶
In Tutorial 2: Training an SNN, we add a loss function, a training loop, and checkpointing.