英文:
Store intermediate values of pytorch module
问题
I try to plot attention maps for ViT. I know that I can do something like
h_attn = model.blocks[-1].attn.register_forward_hook(get_activations('attention'))
to register a hook that captures the output of the ViT's attention layer. The ViT's attention layer has the following forward structure:
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
Can I somehow attach the hook such that I get the attn
value and not the return value of forward (e.g., by using some kind of dummy module)?
英文:
I try to plot attention maps for ViT. I know that I can do something like
h_attn = model.blocks[-1].attn.register_forward_hook(get_activations('attention'))
to register a hook that camputres output of some nn.module
in my model.
The ViT's attention layer has the following forward structure:
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
Can I somehow attach the hook such that i get the attn
value and not the return value of forward (e.g. by using some kind of dummy-module)?
答案1
得分: 0
在这种情况下,您想要做的是在模块的forward方法中捕获中间输出,具体来说是attn
张量。当调用模块的前向或后向方法时,钩子函数会被调用,但它们不直接允许您捕获中间值。
但是,您可以创建一个注意力模块的新子类,在其中修改forward方法以将attn值存储为属性,稍后可以访问。
以下是如何创建这样一个子类的示例:
class AttentionWithStoredAttn(Attention):
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # 使torchscript正常工作(不能使用张量作为元组)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
self.stored_attn = attn.detach().cpu() # 存储attn张量
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
然后,您可以用这个新模块替换您模型中的原始注意力模块。在每次前向传递之后,您可以通过模块的stored_attn
属性访问存储的attn
张量。
这将增加您的内存使用,因为您正在存储attn
张量。此外,记得在推理之前调用model.eval()
以禁用dropout和其他训练特定的操作。如果在一个批次中使用多个输入,存储的注意力映射将在每次前向传递中被覆盖,因此根据需要检索或保存它们。
英文:
In this case, what you want to do is capture intermediate outputs within the forward method of a module, specifically the attn
tensor. Hooks are called when a module's forward or backward method is called, but they do not directly allow you to capture intermediate values.
However, you can create a new subclass of the attention module where you modify the forward method to store the attn value as an attribute, which you can access later.
Here's an example of how you can create such a subclass:
class AttentionWithStoredAttn(Attention):
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
self.stored_attn = attn.detach().cpu() # Store the attn tensor
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
Then you can replace the original attention module in your model with this new one. After each forward pass, you can access the stored attn
tensor via the stored_attn
attribute of the module.
This will increase your memory usage because you are storing the attn
tensor. Also, remember to call model.eval()
before inference to disable dropout and other training-specific operations. If you are working with multiple inputs in a batch, the stored attention maps will be overwritten at each forward pass, so retrieve or save them as needed.
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论