为什么这个深度学习卷积模型不能泛化?

huangapple go评论55阅读模式
英文:

Why is this deep learning convolutional model not generalizing?

问题

我正在使用PyTorch训练一个卷积神经网络,用于处理3D医学光栅图像(.nrrd文件),以从非常嘈杂的超声图像中获得估计的体积测量值。

我拥有大约30位患者的200个个体光栅图像,并对它们进行了各种变换和在所有3轴上的噪声,总共扩充到了5000个图像。所有光栅图像在使用前都被调整为128x128x128。

我正在进行6折交叉验证,确保验证集由训练集中完全不同的患者组成。我认为这有助于检查模型是否真的具有泛化能力,能够估计未见过的患者的光栅图像。

问题是,模型无法泛化或学习。请看我进行的两次测试运行的结果(每次运行耗时10小时):

第一次训练失败

第二次训练失败

所使用的架构只包括6个卷积层,后面跟着2个全连接层,没有太复杂的结构。这可能是什么原因?是不是我的模型没有足够的数据来学习?

我尝试降低学习率和增加权重衰减,但没有成功。我还没有尝试使用其他损失函数和优化器(当前使用的是均方误差损失和Adam优化器)。

以下是代码部分,供您参考:

class RasterNet(nn.Module):
    def __init__(self):
        super(RasterNet, self).__init__()

        self.conv0 = nn.Sequential( # 128x128x128 -> 256x32x32
            nn.Conv2d(128, 256, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        # ...(以下是其他卷积层的定义)

        self.linear = nn.Sequential(
            nn.Linear(8192, 4096),
            nn.ReLU(),
            nn.Linear(4096, 1)
        )

    def forward(self, base):
        # ...(前向传播的定义)
英文:

I am training a convolutional network using pytorch that works on 3D medical raster images (.nrrd files) to get estimated volume measurements from very noisy ultrasound images.

I have around 200 individual raster images of 30 patients, and have augmented them to over 5000 applying all kind of transforms and noise in all 3 axis (chosen randomly). All the rasters are resized to 128x128x128 before being used.

I am doing 6-fold cross validation, where I make sure that the validation set is composed of entirely different patients from those in the training set. I think this helps see if the model is actually generalizing and is capable of estimating rasters of unseen patients.

Problem is, the model is failing to generalize or learn at all. See the results I get for 2 test runs I have made (10 hours processing each):

First Training Failure

Second Training Failure

The architecture used is just 6 convolutional layers followed by 2 densely connected ones, nothing too fancy. What could be causing this? Could it be I don't have enough data for my model to learn?

I tried lowering the learning rate and raising weight decay, no luck. I haven't tried using other criterions and optimizers (currently using MSE Loss and Adam).

*Edit: Added code:

class RasterNet(nn.Module):
    def __init__(self):
        super(RasterNet, self).__init__()

        self.conv0 = nn.Sequential( # 128x128x128 -> 256x32x32
            nn.Conv2d(128, 256, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.conv1 = nn.Sequential( # 256x32x32 -> 512x16x16
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.conv2 = nn.Sequential( # 512x16x16 -> 1024x8x8
            nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.conv3 = nn.Sequential( # 1024x8x8 -> 2048x4x4
            nn.Conv2d(1024, 2048, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(2048),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.conv4 = nn.Sequential( # 2048x4x4 -> 4096x2x2
            nn.Conv2d(2048, 4096, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(4096),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.conv5 = nn.Sequential( # 4096x2x2 -> 8192x1x1
            nn.Conv2d(4096, 8192, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(8192),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.linear = nn.Sequential(
            nn.Linear(8192, 4096),
            nn.ReLU(),
            nn.Linear(4096, 1)
        )

    def forward(self, base):
        base = base.squeeze().float().to(dml)

        # View from y axis (Coronal, as this is the clearest view)
        base = torch.transpose(base, 2, 1)

        x = self.conv0(base)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = x.view(x.size(0), -1)
        return self.linear(x)

答案1

得分: 0

以下是翻译好的部分:

首先,您的训练损失收敛到一个较低的值,但验证损失较高,意味着您的模型过度拟合了训练分布。这可能意味着:

  1. 您的模型架构不足以从低级(像素/体素)信息中有意义地提取高级信息,而是学到了广泛的训练集偏差项,从而使损失相对较低。这可能表明您的验证和训练拆分来自不同的分布,或者您的损失函数选择不当。
  2. 您的模型太富有表现力(高方差),以至于可以学到确切的训练示例(经典的过拟合)

其次,神经网络训练的一个几乎普遍的技巧是在运行时使用数据增强。这意味着,与其在训练之前生成一组增强的图像,您可以生成一组应用数据变换的增强函数,随机应用这些变换。这组函数用于在每个训练周期中转换数据批次,使模型永远不会看到完全相同的数据示例两次。

第三,这种模型架构相对简单(比第一个现代深度卷积神经网络AlexNet简单)。通过制定更深层次的架构并使用残差层(请参阅ResNet)来处理梯度消失问题,可以获得更高性能。如果您使用这种架构来执行此任务,我会感到有些惊讶。

验证损失平均而言可能会高于训练损失是正常的。可能您的模型在某种程度上有所学习,但与(可能过度拟合的)训练曲线相比,损失曲线相对平缓。我建议还计算每个时期的验证准确性并跨时期报告此值。您应该看到训练准确性增加,可能还会看到验证准确性增加。

请注意,交叉验证并不完全用于确定模型是否可以泛化到未见过的患者。这是验证集的目的。相反,交叉验证确保训练-验证性能在多个数据分区上有效,不仅仅是选择“容易”的验证集的结果。

纯粹出于速度/简单起见,我建议首先在没有交叉验证的情况下训练模型(即使用单一的训练-测试分区)。一旦在整个数据集上实现了良好的性能,您可以使用k折交叉验证重新训练,以确保上述情况,但这应该加快您的调试周期。

英文:

Ok a few notes which are not an "answer" per se but are too extended for comments:

First, the fact that your training loss converges to a low value, but your validation loss is high, means that your model is overfit to the training distribution. This could mean:

  1. Your model architecture is not expressive enough to meaningfully distill high-level information from low-level (pixel/voxel) information so instead learns training-set wide bias terms that bring the loss relatively low. This could indicate that your validation and training split are from different distributions, or else that your loss function is not well-chosen for the task.
  2. Your model is too expressive (high variance) such that it can learn the exact training examples (classic overfitting)

Second, an almost-ubiquitous trick for NN training is to use at-runtime data augmentation. This means that, rather then generating a set of augmented images before training, you instead generate a set of augmenting functions which apply data transformations randomly. This set of functions is used to transform the data batch at each training epoch, such that the model never sees exactly the same data example twice.

Third, this model architecture is relatively simplistic (simpler than AlexNet, the first modern deep CNN.) Far greater performance has been achieved by making much deeper architectures and using residual layers to (see ResNet) to deal with the vanishing gradient problem. I'd be somewhat surprised if you could achieve good performance on this task with this architecture.

It is normal for the validation loss to be higher on average than the training loss. It is possible that your model is learning to some extent but the loss curve is relatively shallow when compared to the (likely overfit) training curve. I suggest also computing epoch-wide validation accuracy and reporting this value across epochs. You should see training accuracy increase, and possibly validation accuracy as well.

Do note that cross-validation is not quite exactly meant to determine whether the model generalizes to unseen patients. That is the purpose of the validation set. Instead, cross-validation ensures that the training - validation performance is valid across multiple data partitions, and isn't simply the result of selecting an "easy" validation set.

Purely for speed/simplicity, I recommend training the model first without cross-validation (i.e. use a single training-testing partition. Once you achieve good performance on the whole dataset, you can retrain with k-fold to ensure the above, but this should make your debug cycles a bit faster.

huangapple
  • 本文由 发表于 2023年2月7日 04:48:35
  • 转载请务必保留本文链接:https://go.coder-hub.com/75366429.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定