CNN model accuracy is oscillating strangely
I have a CNN that is designed to be trained on the MNIST data set using federated averaging (the model is trained locally on a number of clients for a number of local epochs on local client's data and then averaged). The accuracy of the model is changing very strangely with the number of clients, for example, on two clients I have a global model accuracy of 97% but on 3 clients only 11%.
Because my data is distributed across the clients in an IID way, I would expect the accuracy to be low on 1 clients, but then goes up with the number of clients, here is my implementation:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
# Define the CNN architecture
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
self.relu2 = nn.ReLU()
self.pool = nn.MaxPool2d(2)
self.fc1 = nn.Linear(64 * 12 * 12, 128)
self.relu3 = nn.ReLU()
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.relu1(self.conv1(x))
x = self.pool(self.relu2(self.conv2(x)))
x = x.view(-1, 64 * 12 * 12)
x = self.relu3(self.fc1(x))
x = self.fc2(x)
return x
# Define the dataset class
class MNISTDataset(Dataset):
def __init__(self, data, targets, transform=None):
self.data = data
self.targets = targets
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, index):
x = self.data[index]
y = self.targets[index]
if self.transform:
x = transforms.ToPILImage()(x) # Convert tensor to PIL Image
x = self.transform(x)
return x, y
# Load MNIST dataset
train_dataset = datasets.MNIST(
test_dataset = datasets.MNIST(
# Define the number of clients
num_clients = 3
# Shuffle and distribute data between clients
data_per_client = len(train_dataset) // num_clients
client_datasets = []
for i in range(num_clients):
start_index = i * data_per_client
end_index = (i + 1) * data_per_client
data = train_dataset.data[start_index:end_index]
targets = train_dataset.targets[start_index:end_index]
client_dataset = MNISTDataset(data, targets, transform=transforms.Compose([
# Define the federated learning parameters
num_epochs = 3
learning_rate = 0.01
# Initialize the global model
global_model = CNN()
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(global_model.parameters(), lr=learning_rate)
# Train the global model using federated averaging
for epoch in range(num_epochs):
for client_dataset in client_datasets:
# Create data loader for each client
client_loader = DataLoader(client_dataset, batch_size=64, shuffle=True)
# Initialize the local model
local_model = CNN()
# Define the optimizer for the local model
local_optimizer = optim.Adam(local_model.parameters(), lr=learning_rate)
# Train the local model
for inputs, labels in client_loader:
outputs = local_model(inputs)
loss = criterion(outputs, labels)
# Update the global model using federated averaging
for global_param, local_param in zip(global_model.parameters(), local_model.parameters()):
global_param.data += local_param.data
# Average the global model's parameters
for global_param in global_model.parameters():
global_param.data /= num_clients
# Evaluate the global model on testing data
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
total_correct = 0
total_samples = 0
with torch.no_grad():
for inputs, labels in test_loader:
outputs = global_model(inputs)
_, predicted = torch.max(outputs, 1)
total_samples += labels.size(0)
total_correct += (predicted == labels).sum().item()
accuracy = 100.0 * total_correct / total_samples
print(f"Global Model Accuracy: {accuracy}%")
得分: 1
for global_param in global_model.parameters():
global_param.data /= num_clients
The issue with your implementation lies in the averaging step of the global model's parameters after training each client's local model. Instead of directly summing the local parameters to the global parameters, you need to divide the sum by the number of clients.
To fix the issue, modify the averaging step as follows:
for global_param in global_model.parameters():
global_param.data /= num_clients
With this modification, the global model's accuracy should be more consistent across different numbers of clients, and the strange fluctuations you observed should be resolved.