储存PyTorch模块的中间值。

huangapple go评论53阅读模式
英文:

Store intermediate values of pytorch module

问题

I try to plot attention maps for ViT. I know that I can do something like
h_attn = model.blocks[-1].attn.register_forward_hook(get_activations('attention'))
to register a hook that captures the output of the ViT's attention layer. The ViT's attention layer has the following forward structure:

def forward(self, x):
    B, N, C = x.shape
    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
    q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

    attn = (q @ k.transpose(-2, -1)) * self.scale
    attn = attn.softmax(dim=-1)
    attn = self.attn_drop(attn)
    
    x = (attn @ v).transpose(1, 2).reshape(B, N, C)
    x = self.proj(x)
    x = self.proj_drop(x)
    
    return x

Can I somehow attach the hook such that I get the attn value and not the return value of forward (e.g., by using some kind of dummy module)?

英文:

I try to plot attention maps for ViT. I know that I can do something like
h_attn = model.blocks[-1].attn.register_forward_hook(get_activations('attention'))
to register a hook that camputres output of some nn.module in my model.
The ViT's attention layer has the following forward structure:

def forward(self, x):
    B, N, C = x.shape
    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
    q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

    attn = (q @ k.transpose(-2, -1)) * self.scale
    attn = attn.softmax(dim=-1)
    attn = self.attn_drop(attn)
    
    x = (attn @ v).transpose(1, 2).reshape(B, N, C)
    x = self.proj(x)
    x = self.proj_drop(x)
    
    return x

Can I somehow attach the hook such that i get the attn value and not the return value of forward (e.g. by using some kind of dummy-module)?

答案1

得分: 0

在这种情况下,您想要做的是在模块的forward方法中捕获中间输出,具体来说是attn张量。当调用模块的前向或后向方法时,钩子函数会被调用,但它们不直接允许您捕获中间值。

但是,您可以创建一个注意力模块的新子类,在其中修改forward方法以将attn值存储为属性,稍后可以访问。

以下是如何创建这样一个子类的示例:

class AttentionWithStoredAttn(Attention):

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]   # 使torchscript正常工作(不能使用张量作为元组)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
    
        self.stored_attn = attn.detach().cpu()  # 存储attn张量
    
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
    
        return x

然后,您可以用这个新模块替换您模型中的原始注意力模块。在每次前向传递之后,您可以通过模块的stored_attn属性访问存储的attn张量。

这将增加您的内存使用,因为您正在存储attn张量。此外,记得在推理之前调用model.eval()以禁用dropout和其他训练特定的操作。如果在一个批次中使用多个输入,存储的注意力映射将在每次前向传递中被覆盖,因此根据需要检索或保存它们。

英文:

In this case, what you want to do is capture intermediate outputs within the forward method of a module, specifically the attn tensor. Hooks are called when a module's forward or backward method is called, but they do not directly allow you to capture intermediate values.

However, you can create a new subclass of the attention module where you modify the forward method to store the attn value as an attribute, which you can access later.

Here's an example of how you can create such a subclass:

class AttentionWithStoredAttn(Attention):

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
    
        self.stored_attn = attn.detach().cpu()  # Store the attn tensor
    
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
    
        return x

Then you can replace the original attention module in your model with this new one. After each forward pass, you can access the stored attn tensor via the stored_attn attribute of the module.

This will increase your memory usage because you are storing the attn tensor. Also, remember to call model.eval() before inference to disable dropout and other training-specific operations. If you are working with multiple inputs in a batch, the stored attention maps will be overwritten at each forward pass, so retrieve or save them as needed.

huangapple
  • 本文由 发表于 2023年5月10日 18:50:13
  • 转载请务必保留本文链接:https://go.coder-hub.com/76217513.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定