将图像分类模型转化为分层模型

huangapple go评论97阅读模式
英文:

Transforming an image classification model into a hierarchical model

问题

我正在使用ResNet50来从图像中提取特征。我如何移除网络的分类头部,并且因为我需要拆分网络以获得中间特征,我如何将网络转化为这样一个分层形式,以便我可以像下面这样访问中间网络架构:

  1. import torch
  2. from torchvision.models import resnet50, ResNet50_Weights
  3. import torch.nn as nn
  4. class model(nn.Module):
  5. def __init__(self, pretrained=True):
  6. super(model, self).__init__()
  7. self.featureExtractor = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
  8. def forward(self, x):
  9. x1 = self.featureExtractor.layer1(x) #(number of feature maps:256)
  10. x2 = self.featureExtractor.layer2(x1) # (number of feature maps:512)
  11. x3 = self.featureExtractor.layer3(x2) # (number of feature maps:1024)
  12. x4 = self.featureExtractor.layer4(x3) # (number of feature maps:2048)
  13. return x1, x2, x3, x4

虽然我知道如何使用hook方法提取网络的中间特征,但我不知道如何将网络分割成这样的分层结构。有什么想法吗?

英文:

I am using ResNet50 for extracting features from images. How I can remove the classification head of the network and also because I need to split the network to have intermediate features how I can transform the network into such a hierarchical form that I can access to the intermediate network architecture as follows:

  1. import torch
  2. from torchvision.models import resnet50, ResNet50_Weights
  3. import torch.nn as nn
  4. class model(nn.Module):
  5. def __init__(self, pretrained=True):
  6. super(model, self).__init__()
  7. self.featureExtractor = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
  8. def forward(self, x):
  9. x1= self.featureExtractor_1(x) #(number of feature maps:256)
  10. x2= self.featureExtractor_2(x1) # (number of feature maps:512)
  11. x3= self.featureExtractor_3(x2) # (number of feature maps:1024)
  12. x4= self.featureExtractor_4(x3) # (number of feature maps:2048)
  13. return x1, x2, x3, x4

Although I know how to extract the intermediate features of the network by using hook method I do not know to to split the network into such a hierarchy?

Any idea?

答案1

得分: 1

以下是翻译好的部分:

  1. 您可以按照以下方式操作
  2. import torch
  3. from torchvision.models import resnet50, ResNet50_Weights
  4. import torch.nn as nn
  5. class model(nn.Module):
  6. def __init__(self):
  7. super(model, self).__init__()
  8. self.model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
  9. self.block1 = nn.Sequential(*list(self.model.children())[:5])
  10. self.block2 = nn.Sequential(*list(this.model.children())[5:6])
  11. self.block3 = nn.Sequential(*list(this.model.children())[6:7])
  12. self.block4 = nn.Sequential(*list(this.model.children())[7:8])
  13. def forward(self, x):
  14. x1 = self.block1(x)
  15. x2 = self.block2(x1)
  16. x3 = self.block3(x2)
  17. x4 = self.block4(x3)
  18. return(x1, x2, x3, x4)
  19. x = torch.randn(1, 3, 256, 256)
  20. model = model()
  21. x1, x2, x3, x4 = model(x)
  22. print(x1.shape, x2.shape, x3.shape, x4.shape)
  23. 这将产生
  24. torch.Size([1, 256, 64, 64]) torch.Size([1, 512, 32, 32]) torch.Size([1, 1024, 16, 16]) torch.Size([1, 2048, 8, 8])
英文:

You can do as follows:

  1. import torch
  2. from torchvision.models import resnet50, ResNet50_Weights
  3. import torch.nn as nn
  4. class model(nn.Module):
  5. def __init__(self):
  6. super(model, self).__init__()
  7. self.model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
  8. self.block1 = nn.Sequential(*list(self.model.children())[:5])
  9. self.block2 = nn.Sequential(*list(self.model.children())[5:6])
  10. self.block3 = nn.Sequential(*list(self.model.children())[6:7])
  11. self.block4 = nn.Sequential(*list(self.model.children())[7:8])
  12. def forward(self, x):
  13. x1 = self.block1(x)
  14. x2 = self.block2(x1)
  15. x3 = self.block3(x2)
  16. x4 = self.block4(x3)
  17. return(x1, x2, x3, x4)
  18. x = torch.randn(1, 3, 256, 256)
  19. model = model()
  20. x1, x2, x3, x4 = model(x)
  21. print(x1.shape, x2.shape, x3.shape, x4.shape)

Which gives:

  1. torch.Size([1, 256, 64, 64]) torch.Size([1, 512, 32, 32]) torch.Size([1, 1024, 16, 16]) torch.Size([1, 2048, 8, 8])

huangapple
  • 本文由 发表于 2023年6月27日 20:13:13
  • 转载请务必保留本文链接:https://go.coder-hub.com/76564741.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定