英文:
CNN model accuracy is oscillating strangely
问题
我有一个卷积神经网络(CNN),设计用于在MNIST数据集上进行联邦平均训练(模型在多个客户端上本地训练多个本地时期,然后进行平均)。该模型的准确性随客户端数量变化非常奇怪,例如,在两个客户端上,全局模型的准确性为97%,但在3个客户端上只有11%。
因为我的数据以独立同分布(IID)的方式分布在客户端之间,我期望准确性在1个客户端上较低,然后随着客户端数量增加而提高。以下是我的实现:
英文:
I have a CNN that is designed to be trained on the MNIST data set using federated averaging (the model is trained locally on a number of clients for a number of local epochs on local client's data and then averaged). The accuracy of the model is changing very strangely with the number of clients, for example, on two clients I have a global model accuracy of 97% but on 3 clients only 11%.
Because my data is distributed across the clients in an IID way, I would expect the accuracy to be low on 1 clients, but then goes up with the number of clients, here is my implementation:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
# Define the CNN architecture
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
self.relu2 = nn.ReLU()
self.pool = nn.MaxPool2d(2)
self.fc1 = nn.Linear(64 * 12 * 12, 128)
self.relu3 = nn.ReLU()
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.relu1(self.conv1(x))
x = self.pool(self.relu2(self.conv2(x)))
x = x.view(-1, 64 * 12 * 12)
x = self.relu3(self.fc1(x))
x = self.fc2(x)
return x
# Define the dataset class
class MNISTDataset(Dataset):
def __init__(self, data, targets, transform=None):
self.data = data
self.targets = targets
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, index):
x = self.data[index]
y = self.targets[index]
if self.transform:
x = transforms.ToPILImage()(x) # Convert tensor to PIL Image
x = self.transform(x)
return x, y
# Load MNIST dataset
train_dataset = datasets.MNIST(
'./data',
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
])
)
test_dataset = datasets.MNIST(
'./data',
train=False,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
])
)
# Define the number of clients
num_clients = 3
# Shuffle and distribute data between clients
data_per_client = len(train_dataset) // num_clients
client_datasets = []
for i in range(num_clients):
start_index = i * data_per_client
end_index = (i + 1) * data_per_client
data = train_dataset.data[start_index:end_index]
targets = train_dataset.targets[start_index:end_index]
client_dataset = MNISTDataset(data, targets, transform=transforms.Compose([
transforms.ToTensor(),
]))
client_datasets.append(client_dataset)
# Define the federated learning parameters
num_epochs = 3
learning_rate = 0.01
# Initialize the global model
global_model = CNN()
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(global_model.parameters(), lr=learning_rate)
# Train the global model using federated averaging
for epoch in range(num_epochs):
for client_dataset in client_datasets:
# Create data loader for each client
client_loader = DataLoader(client_dataset, batch_size=64, shuffle=True)
# Initialize the local model
local_model = CNN()
local_model.load_state_dict(global_model.state_dict())
# Define the optimizer for the local model
local_optimizer = optim.Adam(local_model.parameters(), lr=learning_rate)
# Train the local model
for inputs, labels in client_loader:
local_optimizer.zero_grad()
outputs = local_model(inputs)
loss = criterion(outputs, labels)
loss.backward()
local_optimizer.step()
# Update the global model using federated averaging
for global_param, local_param in zip(global_model.parameters(), local_model.parameters()):
global_param.data += local_param.data
# Average the global model's parameters
for global_param in global_model.parameters():
global_param.data /= num_clients
# Evaluate the global model on testing data
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
total_correct = 0
total_samples = 0
with torch.no_grad():
for inputs, labels in test_loader:
outputs = global_model(inputs)
_, predicted = torch.max(outputs, 1)
total_samples += labels.size(0)
total_correct += (predicted == labels).sum().item()
accuracy = 100.0 * total_correct / total_samples
print(f"Global Model Accuracy: {accuracy}%")
答案1
得分: 1
你的实现存在问题,问题出在训练每个客户端本地模型后,全局模型参数的平均化步骤上。你需要将本地参数的总和除以客户端的数量,而不是直接将本地参数与全局参数相加。
为了解决这个问题,可以按照以下方式修改平均化步骤:
for global_param in global_model.parameters():
global_param.data /= num_clients
通过这个修改,全局模型的准确性应该在不同客户端数量下更加一致,你观察到的奇怪波动应该得到解决。
英文:
The issue with your implementation lies in the averaging step of the global model's parameters after training each client's local model. Instead of directly summing the local parameters to the global parameters, you need to divide the sum by the number of clients.
To fix the issue, modify the averaging step as follows:
for global_param in global_model.parameters():
global_param.data /= num_clients
With this modification, the global model's accuracy should be more consistent across different numbers of clients, and the strange fluctuations you observed should be resolved.
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论