如何在PyTorch中创建一个高效的条件层?

huangapple go评论58阅读模式
英文:

How to create an efficient conditional layer in pytorch?

问题

我有一个ResNet50模型,它输出一个类别预测(1、2或3)。基于分类器的输出,我想进行另一个预测,选择下一个模型,根据类别预测。

这是我目前的代码。

import torch

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.model1 = torch.nn.Linear(1, 1, bias=False)
        torch.nn.init.ones_(self.model1.weight)

        self.model2 = torch.nn.Linear(1, 1, bias=False)
        torch.nn.init.ones_(self.model2.weight)

        self.model3 = torch.nn.Linear(1, 1, bias=False)
        torch.nn.init.ones_(self.model3.weight)

    def forward(self, x):
        
        # 获取批处理大小
        batch_size = x.size(1)
        output = torch.zeros(batch_size, 1, device=x.device)
        
        # 循环遍历批处理中的每个值
        for i in range(batch_size):
            value = x[:, i]
            if value == 1:
                output[i] = self.model1(value)
            elif value == 2:
                output[i] = self.model2(value)
            else:
                output[i] = self.model3(value)

        return output
model = SimpleModel()

output = model(torch.tensor([[1,2,3]], dtype=torch.float32))
output

我的担忧是,在循环的每次迭代中只进行一次前向传播,这似乎非常低效。如果我将批处理大小增加到64会发生什么?前向传播会并行计算吗?

欢迎任何想法和建议。

英文:

I have a resnet50 model that outputs a class prediction (1, 2 or 3). Based on the output of the classifier, I want to make another prediction that selects the next model based on the class prediction.

This is what I have so far.

import torch

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.model1 = torch.nn.Linear(1, 1, bias=False)
        torch.nn.init.ones_(self.model1.weight)

        self.model2 = torch.nn.Linear(1, 1, bias=False)
        torch.nn.init.ones_(self.model2.weight)

        self.model3 = torch.nn.Linear(1, 1, bias=False)
        torch.nn.init.ones_(self.model3.weight)

    def forward(self, x):
        
        # Get batch_size
        batch_size = x.size(1)
        output = torch.zeros(batch_size, 1, device=x.device)
        
        # Loop over every value in batch
        for i in range(batch_size):
            value = x[:, i]
            if value == 1:
                output[i] = self.model1(value)
            elif value == 2:
                output[i] = self.model2(value)
            else:
                output[i] = self.model3(value)

        return output
model = SimpleModel()

output = model(torch.tensor([[1,2,3]], dtype=torch.float32))
output

My concern is that I am only computing one forward pass on each iteration of the loop which seems very inefficient. What happens if I increase the batch size to 64? Will the forward pass be computed in parallel?

Any thoughts/ideas would be appreciated.

答案1

得分: 1

以下是翻译好的代码部分:

def forward(self, x):
        
        # 获取批量大小
        batch_size = x.size(1)
        output = torch.zeros(batch_size, 1, device=x.device)
        
        # 为每个条件计算一个掩码

        value_mask_1 = (x == 1)
        value_mask_2 = (x == 2)
        value_mask_3 = (x == 3)
        
        # 然后只需在每个条件的掩码所选的项目上运行模型。
        # 然后将模型的输出分配给输出变量中的相应位置。

        output[value_mask_1.view_as(output)] = self.model1(x[value_mask_1])
        output[value_mask_2.view_as(output)] = self.model1(x[value_mask_2])
        output[value_mask_3.view_as(output)] = self.model1(x[value_mask_3])

        return output
英文:

You can do as follows. The code runs each one of the three models just once by using masks as conditions without using any for loop:

def forward(self, x):
        
        # Get batch_size
        batch_size = x.size(1)
        output = torch.zeros(batch_size, 1, device=x.device)
        
        # Compute one mask for each condition

        value_mask_1 = (x == 1)
        value_mask_2 = (x == 2)
        value_mask_3 = (x == 3)
        
        # Then just run the model on the items selected by each condition's mask.
        # And then assign model's outputs to the corresponding positions in the output variable.

        output[value_mask_1.view_as(output)] = self.model1(x[value_mask_1])
        output[value_mask_2.view_as(output)] = self.model1(x[value_mask_2])
        output[value_mask_3.view_as(output)] = self.model1(x[value_mask_3])

        return output

huangapple
  • 本文由 发表于 2023年5月29日 02:29:40
  • 转载请务必保留本文链接:https://go.coder-hub.com/76353017.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定