Examples -------- You can build a simple convolutional spiking neural network, trained on the MNIST dataset, as follows:: import torch import torch.nn as nn import torch.optim as optim import torchvision import torchvision.transforms as transforms from snngrow.base.neuron.LIFNode import LIFNode from tqdm import tqdm # Define the CSNN model class CNN(nn.Module): def __init__(self, T): super(CNN, self).__init__() self.T = T self.csnn = nn.Sequential( nn.Conv2d(1, 32, kernel_size=3), LIFNode(), nn.MaxPool2d(kernel_size=2), nn.Conv2d(32, 64, kernel_size=3), LIFNode(), nn.Flatten(), nn.Linear(36864, 128), nn.Linear(128, 10) ) def forward(self, x): x_seq = [] for _ in range(self.T): x_seq.append(self.csnn(x)) out = torch.stack(x_seq).mean(0) return out def main(): # Load the MNIST dataset transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) trainset = torchvision.datasets.MNIST(root='./datas', train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True) testset = torchvision.datasets.MNIST(root='./datas', train=False, download=True, transform=transform) testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False) # Create an instance of the CNN model model = CNN(T=4) # model to GPU device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model.to(device) # Define the loss function and optimizer criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) # Train the model for epoch in range(10): running_loss = 0.0 model.train() for data in tqdm(trainloader): inputs, labels = data optimizer.zero_grad() outputs = model(inputs.to(device)) loss = criterion(outputs.cpu(), labels) loss.backward() optimizer.step() running_loss += loss.item() for m in model.modules(): if hasattr(m, 'reset'): m.reset() # Test the model correct = 0 total = 0 with torch.no_grad(): model.eval() for data in testloader: images, labels = data outputs = model(images.to(device)) _, predicted = torch.max(outputs.detach().cpu().data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() for m in model.modules(): if hasattr(m, 'reset'): m.reset() print(f'Epoch: {epoch + 1}, Loss: {running_loss / 100}, Accuracy on the test set: {(correct / total) * 100}%') running_loss = 0.0 print('Training finished.')