将图像分类模型转化为分层模型

huangapple go评论61阅读模式
英文:

Transforming an image classification model into a hierarchical model

问题

我正在使用ResNet50来从图像中提取特征。我如何移除网络的分类头部,并且因为我需要拆分网络以获得中间特征,我如何将网络转化为这样一个分层形式,以便我可以像下面这样访问中间网络架构:

import torch
from torchvision.models import resnet50, ResNet50_Weights
import torch.nn as nn

class model(nn.Module):
    def __init__(self, pretrained=True):
        super(model, self).__init__()

        self.featureExtractor = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

    def forward(self, x):   

        x1 = self.featureExtractor.layer1(x)     #(number of feature maps:256)
        x2 = self.featureExtractor.layer2(x1)    # (number of feature maps:512)
        x3 = self.featureExtractor.layer3(x2)    # (number of feature maps:1024)
        x4 = self.featureExtractor.layer4(x3)    # (number of feature maps:2048)

        return x1, x2, x3, x4

虽然我知道如何使用hook方法提取网络的中间特征,但我不知道如何将网络分割成这样的分层结构。有什么想法吗?

英文:

I am using ResNet50 for extracting features from images. How I can remove the classification head of the network and also because I need to split the network to have intermediate features how I can transform the network into such a hierarchical form that I can access to the intermediate network architecture as follows:

import torch
from torchvision.models import resnet50, ResNet50_Weights
import torch.nn as nn

class model(nn.Module):
    def __init__(self, pretrained=True):
        super(model, self).__init__()

        self.featureExtractor = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

def forward(self, x):   
        
      x1= self.featureExtractor_1(x)     #(number of feature maps:256)
      x2= self.featureExtractor_2(x1)    # (number of feature maps:512)
      x3= self.featureExtractor_3(x2)    # (number of feature maps:1024)
      x4= self.featureExtractor_4(x3)    # (number of feature maps:2048)

 return x1, x2, x3, x4

Although I know how to extract the intermediate features of the network by using hook method I do not know to to split the network into such a hierarchy?

Any idea?

答案1

得分: 1

以下是翻译好的部分:

您可以按照以下方式操作

    import torch
    from torchvision.models import resnet50, ResNet50_Weights
    import torch.nn as nn

    class model(nn.Module):
        def __init__(self):
            super(model, self).__init__()

            self.model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
            self.block1 = nn.Sequential(*list(self.model.children())[:5])
            self.block2 = nn.Sequential(*list(this.model.children())[5:6])
            self.block3 = nn.Sequential(*list(this.model.children())[6:7])
            self.block4 = nn.Sequential(*list(this.model.children())[7:8])


        def forward(self, x):   
            
            x1 = self.block1(x)    
            x2 = self.block2(x1)
            x3 = self.block3(x2)
            x4 = self.block4(x3)

            return(x1, x2, x3, x4)

    x = torch.randn(1, 3, 256, 256)
    model = model()
    x1, x2, x3, x4 = model(x)

    print(x1.shape, x2.shape, x3.shape, x4.shape)

这将产生

    torch.Size([1, 256, 64, 64]) torch.Size([1, 512, 32, 32]) torch.Size([1, 1024, 16, 16]) torch.Size([1, 2048, 8, 8])
英文:

You can do as follows:

import torch
from torchvision.models import resnet50, ResNet50_Weights
import torch.nn as nn

class model(nn.Module):
    def __init__(self):
        super(model, self).__init__()

        self.model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        self.block1 = nn.Sequential(*list(self.model.children())[:5])
        self.block2 = nn.Sequential(*list(self.model.children())[5:6])
        self.block3 = nn.Sequential(*list(self.model.children())[6:7])
        self.block4 = nn.Sequential(*list(self.model.children())[7:8])


    def forward(self, x):   
        
        x1 = self.block1(x)    
        x2 = self.block2(x1)
        x3 = self.block3(x2)
        x4 = self.block4(x3)

        return(x1, x2, x3, x4)

x = torch.randn(1, 3, 256, 256)
model = model()
x1, x2, x3, x4 = model(x)

print(x1.shape, x2.shape, x3.shape, x4.shape)

Which gives:

torch.Size([1, 256, 64, 64]) torch.Size([1, 512, 32, 32]) torch.Size([1, 1024, 16, 16]) torch.Size([1, 2048, 8, 8])

huangapple
  • 本文由 发表于 2023年6月27日 20:13:13
  • 转载请务必保留本文链接:https://go.coder-hub.com/76564741.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定