Machine Learning how to Industrial AI & Operational Intelligence Spiking Neural Networks: A Snntorch Simulation Example

Spiking Neural Networks: A Snntorch Simulation Example

Spiking Neural Networks (SNNs) represent a fundamentally different approach to neural computation compared to traditional artificial neural networks. Instead of processing continuous values, SNNs communicate through discrete spikes—binary events that occur at specific moments in time.

Key Differences from Traditional Neural Networks

Aspect Artificial Neural Networks (ANNs) Spiking Neural Networks (SNNs)
Information Encoding Continuous activation values (0-1 or unbounded) Binary spikes over time (discrete events)
Time Dimension Feedforward processing; no temporal dimension Temporal dynamics; information encoded in spike timing
Neuronal State No internal state; stateless computation Membrane potential; stateful recurrent dynamics
Energy Efficiency High (matrix multiplications every cycle) Very high (only spikes trigger computation)
Biological Plausibility Low (continuous activations) High (mimics actual neuron behavior)
💡 Key Insight: SNNs encode information in the timing and frequency of spikes, not in the magnitude of activations. This mirrors how real brains process information through spike timing and neural synchronization.

The Spike: A Brief Pulse

In SNNs, a spike is a discrete binary event (1 or 0) that occurs at a specific time step. Unlike traditional neurons with continuous outputs, a spiking neuron:

  • Integrates incoming inputs over time into a membrane potential
  • Compares this potential against a threshold
  • Fires a spike (emits a 1) when the threshold is exceeded
  • Resets its potential after firing to prepare for the next input

Membrane Potential Update (Leaky Integrate-and-Fire):

U(t) = β·U(t-1) + I(t)
S(t) = 1 if U(t) > Θ, else 0
U(t) → 0 if S(t) = 1

Where: U = membrane potential, β = decay factor, I = input current, Θ = threshold, S = spike output

Why Spiking Neural Networks Matter

⚡ Energy Efficiency

SNNs use 100-1000× less energy than ANNs on neuromorphic hardware. Computation only occurs when spikes are generated, avoiding wasteful operations.

🧠 Biological Realism

SNNs mirror actual brain architecture with spiking neurons, synaptic plasticity, and temporal dynamics. Ideal for brain-inspired computing.

⏱️ Temporal Processing

Natural handling of temporal data streams (audio, video, sensor signals). Spike timing preserves fine-grained temporal structure.

🛡️ Robustness

Recent research (2025) shows SNNs achieve 2× the adversarial robustness of ANNs on CIFAR-10, resisting adversarial attacks better.

Current State of SNNs (2025)

Recent Advances:

  • SpiNNaker-2: Neuromorphic hardware achieving 50× capacity increase with 22nm process and 3D integration
  • Training Methods: Surrogate gradient descent enables backpropagation through spike events, narrowing ANN-SNN accuracy gap
  • Accuracy Parity: SNNs now achieve comparable accuracy to ANNs on complex tasks (CIFAR-10, ImageNet) while consuming 100× less energy
  • Temporal Encoding: Better understanding of spike timing’s role in information encoding and robustness

ANNs vs SNNs: Performance Comparison

ANN vs SNN comparison chart

SNNs excel in energy efficiency (95/100) and temporal resolution (98/100) but trade some processing speed (40/100) vs ANNs (95/100). However, SNNs achieve superior robustness (88/100 vs 55/100).

⚠️ Current Limitations: While SNNs show promise, they still face challenges in general-purpose computing. Accuracy lags behind ANNs on static datasets without temporal structure, training is more complex, and neuromorphic hardware remains expensive and limited.

Introduction to SNNtorch

SNNtorch is a Python library built on PyTorch that simplifies building, simulating, and training spiking neural networks. It provides pre-designed neuron models and utilities for working with temporal dynamics in neural computation.

Why SNNtorch?

🔌 PyTorch Integration

Built on PyTorch, enabling GPU acceleration, automatic differentiation, and use of standard deep learning tools

🧬 Multiple Neuron Models

Lapicque, Leaky, Synaptic, Alpha neurons—choose the model matching your problem’s biology and complexity

📚 Tutorials & Docs

Comprehensive documentation with 5+ tutorials covering fundamentals to advanced training techniques

⚙️ Easy to Use

Simple API that feels natural to PyTorch users. Initialize neurons, run forward pass, extract spikes

Installation

$ pip install snntorch
$ pip install torch torchvision
$ pip install matplotlib numpy
$ pip install jupyter # for interactive notebooks

Core Concepts in SNNtorch

🧠 Neuron Object: Represents a single spiking neuron layer. Maintains membrane potential and fires spikes based on inputs.
⏱️ Time Steps: SNNs process data sequentially over T time steps. Each step receives input, updates state, and produces a spike.
💾 Hidden State: Unlike ANNs, neurons have internal state (membrane potential) that persists across time steps and must be initialized/reset.
📊 Spike Recording: Collect spike outputs over time steps in a tensor. Spikes are binary: 1 (fired) or 0 (did not fire).
🎯 Surrogate Gradients: Technique to backpropagate through spikes using a smooth approximation. Enables gradient-based training of SNNs.

Neuron Models Available

Model Complexity Use Case Key Feature
Lapicque High Biophysical accuracy RC circuit-based; parameters R & C directly mapped
Leaky Low Fast training, general tasks Simplified 1st-order model; easy to tune
Synaptic Medium Realistic synaptic dynamics 2nd-order model accounting for gradual synaptic current
Alpha Medium Smooth membrane tracking 2nd-order model with alpha function dynamics

The Leaky Integrate-and-Fire Neuron

The Leaky Integrate-and-Fire (LIF) neuron is the most widely used model in SNNtorch. It balances biological plausibility with computational efficiency.

How LIF Works

  1. Integrate: Sum weighted inputs and add to membrane potential
  2. Leak: Exponentially decay the membrane potential (mimics biological leakage)
  3. Fire: If membrane potential exceeds threshold, emit a spike
  4. Reset: Reset potential to resting value after spike

The LIF Equation

Leaky Integrate-and-Fire Dynamics:

U(t) = β·U(t-1) + W·x(t)
S(t) = 1 if U(t) ≥ Θ else 0
U(t) ← U_reset if S(t) = 1

Parameters:
β ∈ [0,1] = decay factor (controls “leakiness”)
W = synaptic weight matrix
x(t) = input at time t
Θ = firing threshold
U_reset = resting potential (usually 0)

Intuition

The parameter beta is crucial:

  • β close to 1: Strong memory, neuron “remembers” past inputs longer, easier to reach threshold
  • β close to 0: Fast decay, neuron forgets inputs quickly, requires stronger recent inputs to fire
  • β = 0.5: Common middle ground offering good balance
💡 Time Constant: The membrane potential decay is exponential: U(t) decays as β^t. The effective “time constant” is roughly τ = -1/ln(β). For example, β=0.5 means the potential halves every time step.

SNN Simulation Fundamentals

The Temporal Loop

Unlike feedforward ANNs, SNNs require a loop over time steps. Each iteration processes one time step of data and updates neuron states:

# Pseudocode structure
for time_step in range(num_steps):
    # Get input for this time step
    input_t = input_data[time_step]
    
    # Forward pass through neuron
    spike, membrane = neuron(input_t, membrane)
    
    # Record the spike
    spike_history.append(spike)

# Use accumulated spikes for classification/analysis

State Initialization

Before simulation, neuron membrane potentials must be initialized:

import snntorch as snn
import torch

# Create a Leaky neuron
neuron = snn.Leaky(beta=0.9)

# Initialize membrane potential (method 1: zeros)
mem = torch.zeros(batch_size, num_neurons)

# OR use built-in initialization
mem = neuron.init_leaky()

# Now ready to simulate

Spike Recording Convention

Throughout a simulation, collect spike outputs in a list and stack them:

spike_rec = []  # Record spikes at each time step

for step in range(num_steps):
    spk_out, mem = neuron(input_t[step], mem)
    spike_rec.append(spk_out)

# Stack into tensor: [num_steps, batch_size, num_neurons]
spikes_tensor = torch.stack(spike_rec, dim=0)

# Total spikes per neuron across all time steps
total_spikes = spikes_tensor.sum(dim=0)

Decoding Spikes for Classification

SNNs encode information in spike timing/frequency. To extract class predictions, sum spikes over the entire simulation:

# Sum spikes over time for each neuron (rate-coding)
spike_sums = spikes_tensor.sum(dim=0)  # [batch_size, num_output_neurons]

# Predicted class is the neuron that fired most
predicted_class = spike_sums.argmax(dim=1)

# Compare with targets for accuracy
accuracy = (predicted_class == targets).float().mean()
🎯 Rate Coding: In rate coding (common in SNNtorch), information is encoded in the average firing frequency (number of spikes over time). The neuron corresponding to the correct class should fire most frequently.

Simple SNN Example: Single Neuron Simulation

Let’s start with a basic example: simulate a single LIF neuron responding to step input.

#!/usr/bin/env python3
"""
Simple SNN Simulation with SNNtorch
Single Leaky Integrate-and-Fire neuron responding to current injection
"""

import snntorch as snn
import torch
import matplotlib.pyplot as plt

# Parameters
num_steps = 100
num_neurons = 1
threshold = 1.0
beta = 0.95  # High beta = strong memory

# Create neuron
lif = snn.Leaky(beta=beta, threshold=threshold)

# Initialize membrane potential
mem = torch.zeros(num_neurons)

# Create input: zeros for 30 steps, then constant 1.0 for 70 steps
input_current = torch.cat([
    torch.zeros(30, num_neurons),
    torch.ones(70, num_neurons) * 0.5
], dim=0)

# Records
mem_rec = []
spk_rec = []

# Simulate over time
print(f"Simulating {num_steps} time steps...")
print(f"Neuron threshold: {threshold}, decay beta: {beta}\n")

for step in range(num_steps):
    # Pass input current to neuron
    spk, mem = lif(input_current[step], mem)
    
    # Record membrane potential and spike
    mem_rec.append(mem.detach().clone())
    spk_rec.append(spk.detach().clone())
    
    if step < 35 or step > 68:  # Print first few and last few steps
        print(f"Step {step:3d}: I={input_current[step,0]:.2f}, "
              f"U={mem.item():.3f}, Spike={spk.item():.0f}")
    elif step == 35:
        print("...")

# Convert lists to tensors
mem_rec = torch.stack(mem_rec)
spk_rec = torch.stack(spk_rec)

# Print summary
total_spikes = spk_rec.sum().item()
print(f"\nTotal spikes fired: {int(total_spikes)}/{num_steps}")
print(f"Average firing rate: {100*total_spikes/num_steps:.1f}%")

# Visualization
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(10, 8))

# Plot 1: Input current
ax1.plot(input_current.numpy(), linewidth=2, color='blue')
ax1.set_ylabel('Input Current')
ax1.set_title('Input Current Over Time')
ax1.grid(True, alpha=0.3)
ax1.set_ylim(-0.1, 0.6)

# Plot 2: Membrane potential
ax2.plot(mem_rec.numpy(), linewidth=2, color='orange')
ax2.axhline(y=threshold, color='red', linestyle='--', label=f'Threshold={threshold}')
ax2.set_ylabel('Membrane Potential (U)')
ax2.set_title('Neuron Membrane Potential Over Time')
ax2.legend()
ax2.grid(True, alpha=0.3)

# Plot 3: Spikes (raster plot)
spike_times = torch.nonzero(spk_rec).numpy()
if len(spike_times) > 0:
    ax3.scatter(spike_times[:, 0], spike_times[:, 1], color='red', s=100, marker='|')
ax3.set_xlabel('Time Step')
ax3.set_ylabel('Neuron')
ax3.set_title('Spike Raster Plot')
ax3.set_ylim(-0.5, 0.5)
ax3.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('simple_snn_simulation.png', dpi=150)
print("\n✓ Plot saved as 'simple_snn_simulation.png'")

plt.show()

✅ Key Insights from This Example:

  • The neuron integrates input over time; it doesn’t fire immediately when stimulus arrives
  • The membrane potential gradually rises as current is injected
  • Once threshold is exceeded, the neuron fires a spike
  • After firing, membrane potential resets, and the neuron can fire again
  • The beta parameter determines how quickly the membrane decays

Building a Network for MNIST Classification

Let’s scale up to a practical machine learning task: classifying handwritten digits with an SNN.

#!/usr/bin/env python3
"""
SNN for MNIST Classification using SNNtorch
3-layer network: input -> hidden -> output
"""

import snntorch as snn
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Hyperparameters
batch_size = 128
learning_rate = 1e-3
num_epochs = 1  # For demo; increase for better accuracy
num_steps = 25  # Simulation time steps
num_inputs = 784  # 28x28 pixels flattened
num_hidden = 256
num_outputs = 10

beta = 0.9  # Membrane decay factor

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load MNIST data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root='./data', train=True, 
                               transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, 
                              transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Define SNN architecture
class SNNNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta)
        
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta)
    
    def forward(self, x):
        # x shape: [batch_size, 1, 28, 28]
        batch_size = x.size(0)
        
        # Flatten image
        x = x.view(batch_size, -1)  # [batch_size, 784]
        
        # Initialize membrane potentials
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        
        # Records
        spk1_rec = []
        mem2_rec = []
        
        # Simulate over time steps
        for step in range(num_steps):
            # Layer 1: input -> hidden
            cur1 = self.fc1(x)
            spk1, mem1 = self.lif1(cur1, mem1)
            
            # Layer 2: hidden -> output
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            
            # Record output layer
            spk1_rec.append(spk1)
            mem2_rec.append(mem2)
        
        # Stack over time dimension
        spk1_rec = torch.stack(spk1_rec, dim=0)
        mem2_rec = torch.stack(mem2_rec, dim=0)
        
        return spk1_rec, mem2_rec

# Initialize network
net = SNNNet().to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

# Training loop
def train_epoch():
    net.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for batch_idx, (data, targets) in enumerate(train_loader):
        data, targets = data.to(device), targets.to(device)
        
        # Forward pass
        spk_rec, mem_rec = net(data)
        
        # Use membrane potential at each time step for loss
        loss = 0
        for step in range(num_steps):
            loss += loss_fn(mem_rec[step], targets)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        # Accuracy using spike counts
        spk_count = spk_rec.sum(dim=0)  # Sum over time
        _, predicted = spk_count.max(1)
        total += targets.size(0)
        correct += (predicted == targets).sum().item()
        
        if batch_idx % 50 == 0:
            print(f"Batch {batch_idx}, Loss: {loss.item():.4f}, "
                  f"Accuracy: {100*correct/total:.1f}%")
    
    return total_loss / len(train_loader), 100 * correct / total

# Test loop
def test_epoch():
    net.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data, targets in test_loader:
            data, targets = data.to(device), targets.to(device)
            
            spk_rec, mem_rec = net(data)
            spk_count = spk_rec.sum(dim=0)
            _, predicted = spk_count.max(1)
            
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
    
    return 100 * correct / total

# Run training
print(f"Training SNN on MNIST for {num_epochs} epoch(s)...")
print(f"Network: {num_inputs} -> {num_hidden} -> {num_outputs}")
print(f"Simulation time: {num_steps} steps\n")

for epoch in range(num_epochs):
    train_loss, train_acc = train_epoch()
    test_acc = test_epoch()
    
    print(f"\nEpoch {epoch+1}")
    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Train Accuracy: {train_acc:.2f}%")
    print(f"  Test Accuracy: {test_acc:.2f}%")

print("\n✓ Training complete!")
💡 Expected Results: After 1 epoch, you should see ~70-80% test accuracy. After 5-10 epochs with proper tuning, SNNs can achieve 95%+ accuracy on MNIST while consuming 10-100× less energy on neuromorphic hardware compared to standard ANNs.

Training Spiking Neural Networks

The Challenge: Non-Differentiable Spikes

Spikes are binary outputs (0 or 1), making them non-differentiable—we can’t use standard backpropagation. SNNtorch solves this with surrogate gradients.

Surrogate Gradient Method

During the backward pass, we replace the discontinuous spike function with a smooth approximation (surrogate function). This allows gradients to flow while preserving the discrete nature of spikes in the forward pass:

Surrogate Gradient Trick:

Forward: S(t) = 1 if U(t) ≥ Θ else 0 (discrete, non-differentiable)
Backward: ∂S/∂U ≈ σ'(α(U-Θ)) (smooth surrogate, differentiable)

Common surrogate: sigmoid or fast sigmoid derivative

Loss Functions for SNNs

Two main approaches:

  • Membrane Potential Loss (Recommended): Use membrane potential U at each time step with cross-entropy loss. Encourages correct class to have high membrane potential (high firing rate).
  • Spike Count Loss: Sum spikes over time and compare with target. Simpler but may be less effective.

Key Training Parameters

Parameter Typical Range Impact
num_steps 10-100 Longer = more integration time, slower, potentially better accuracy
beta (decay) 0.5-0.99 Higher = stronger memory, easier to reach threshold
learning_rate 1e-4 to 1e-2 SNNs often train better with lower LR than ANNs
batch_size 32-256 Larger batches = more stable gradients, less noisy

Best Practices

  • Start simple: Use Leaky neurons (easy to tune) before Synaptic/Alpha
  • Monitor both loss and accuracy: SNNs may have high loss but good accuracy due to spike-based metrics
  • Use membrane potential loss: Generally outperforms spike-count loss
  • Longer simulations during test: Use more time steps at test time for better accuracy (trades latency for precision)
  • Early stopping: Monitor test accuracy and stop when it plateaus
⚠️ Common Pitfall: SNNs train slower than ANNs. Don’t expect single-epoch convergence. Be patient and use learning rate scheduling (reduce LR as training progresses).

Real-World Applications of SNNs

Current and Emerging Use Cases

🎥 Event-Based Vision

Neuromorphic cameras (e.g., DVS) output spike events. SNNs naturally process this asynchronous spike data without frames, enabling ultra-low latency and energy perception.

🤖 Robotics & Autonomous Systems

Edge computing on power-constrained robots. SNNs consume 100× less energy, enabling months of deployment vs hours for ANN-based systems.

🎵 Audio & Temporal Signal Processing

Natural temporal structure preserved in spikes. SNNs excel at speech recognition, audio classification, and time-series analysis.

🛡️ Adversarial Robustness

2025 research shows SNNs achieve 2× robustness of ANNs. Spike timing makes them resistant to adversarial attacks.

Neuromorphic Hardware

SNNs are designed to run on specialized neuromorphic chips that dramatically improve energy efficiency:

  • Intel Loihi 2: 128 neuromorphic cores, 80 mW power for complex processing (vs watts for GPUs)
  • SpiNNaker-2: 50× capacity increase, 22nm process, real-time brain simulation
  • IBM TrueNorth: 1 million spiking neurons, 70mW max power consumption
  • BrainScaleS: Accelerated (1000×) analog simulation of spiking networks

Example Application: Gesture Recognition

SNNs are ideal for IMU (inertial measurement unit) sensors in wearables:

# Pseudocode: Gesture recognition on smartwatch
# Input: 9D IMU data (accelerometer, gyroscope, magnetometer)

# Continuous temporal stream from accelerometer
imu_stream = continuous_stream(9)  # [ax, ay, az, gx, gy, gz, mx, my, mz]

# Encode as spikes using Poisson or time-to-first-spike
spike_events = encode_spikes(imu_stream, threshold=0.5)

# Process with SNN (tiny network: 9 -> 64 -> 32 -> 6 classes)
snn_model = SNNNet(input_size=9, hidden=64, output=6)

# Inference on device: <1mW power, <5ms latency
prediction = snn_model(spike_events)

# Continuous gesture recognition
gesture_detected = argmax(prediction)
print(f"Gesture: {GESTURE_NAMES[gesture_detected]}")

# Advantages over ANN:
# - Energy: 100× lower (mW vs W)
# - Latency: <5ms vs 100ms
# - Streaming: Process events as they arrive, not batches

Limitations & Open Challenges

  • Static Data: SNNs underperform on datasets without temporal structure (unless artificially time-encoded)
  • Hardware Availability: Neuromorphic chips are expensive and limited in availability
  • Training Complexity: More difficult to train than ANNs; requires understanding of temporal dynamics
  • Accuracy Gap: On ImageNet-scale datasets, SNNs still lag ~5-10% behind ANNs (though gap is narrowing)

Next Steps

Spiking Neural Networks represent a paradigm shift in neural computation—from continuous activations to discrete, temporal spikes. With SNNtorch, you can now simulate, train, and deploy SNNs on standard hardware while maintaining the option to move to neuromorphic chips for production deployment.

✅ Key Takeaways:

  • SNNs encode information in spike timing/frequency, not activation magnitude
  • SNNtorch makes building SNNs as easy as PyTorch (which it’s built on)
  • Leaky Integrate-and-Fire neurons balance biology and practicality
  • Surrogate gradients enable training via standard backpropagation
  • SNNs achieve 100× energy efficiency and superior robustness
  • Best suited for temporal, event-driven, and edge-computing tasks

Your Next Steps

  1. Install SNNtorch: Run pip install snntorch torch
  2. Run the Simple Example: Copy the single-neuron simulation and visualize spikes
  3. Build MNIST Network: Train a 3-layer SNN on digit classification
  4. Experiment with Parameters: Try different beta, num_steps, neuron models
  5. Explore Advanced Neurons: Try Synaptic or Alpha neurons for better biology
  6. Convert ANN to SNN: Take a trained ANN and convert to SNN using SNNtorch utilities
  7. Apply to Your Data: Build an SNN for your temporal classification problem

Learning Resources

  • SNNtorch Documentation: https://snntorch.readthedocs.io (5+ comprehensive tutorials)
  • GitHub Repository: https://github.com/jeshraghian/snntorch (examples, papers)
  • Research Papers: Follow Jason K. Eshraghian’s work on SNNs and neuromorphic computing
  • Courses: Look for “neuromorphic computing” and “spiking neural networks” MOOCs

Advanced Topics to Explore

  • Neuromorphic hardware deployment (Intel Loihi, SpiNNaker)
  • Temporal attention mechanisms in SNNs
  • Spike-based reinforcement learning for robotics
  • Converting trained ANNs to SNNs (ANN-to-SNN translation)
  • Federated learning with SNNs on edge devices
  • Recurrent SNNs (reservoirs, liquid state machines)
🚀 The Future: As neuromorphic hardware becomes cheaper and more accessible, SNNs will likely dominate edge AI, robotics, and real-time signal processing. The brain processes information with spikes—AI is finally catching up.
Spiking Neural Networks: A SNNtorch Simulation Example

Last Updated: January 2026

Leave a Reply

Your email address will not be published. Required fields are marked *