如何高效计算多个样本的逐样本梯度?

huangapple go评论110阅读模式
英文:

How do I compute multiple per-sample gradients efficiently?

问题

  1. 我正在尝试在PyTorch中高效地计算多个损失梯度不使用for循环)。给定
  2. ```python
  3. import torch
  4. from torch import nn
  5. class NeuralNetwork(nn.Module):
  6. def __init__(self):
  7. super().__init__()
  8. self.linear = nn.Sequential(
  9. nn.Linear(input_size, 16, bias=False),
  10. nn.Linear(16, output_size, bias=False),
  11. )
  12. def forward(self, x):
  13. return self.linear(x)
  14. device = "cpu"
  15. input_size = 2
  16. output_size = 2
  17. x = torch.randn(10, 1, input_size).to(device)
  18. y = torch.randn(10, 1, output_size).to(device)
  19. model = NeuralNetwork().to(device)
  20. loss_fn = nn.MSELoss()
  21. def loss_grad(x, label):
  22. y = model(x)
  23. loss = loss_fn(y, label)
  24. grads = torch.autograd.grad(loss, model.parameters(), retain_graph=True)
  25. return grads

以下部分有效,但使用了for循环:

  1. # 有效但效率低下
  2. def compute_for():
  3. grads = [loss_grad(x[i], y[i]) for i in range(x.shape[0])]
  4. print(grads)
  5. compute_for()

为了提高效率,我尝试使用torch.vmap

  1. # 可能更高效但不起作用
  2. def compute_vmap():
  3. grads = torch.vmap(loss_grad)(x, y)
  4. print(grads)
  5. compute_vmap()

我原本希望它计算每个元素在x, y中关于参数的损失梯度。但是,我收到了错误消息:

  1. RuntimeError: element 0 of tensors does not require grad

据我了解,这意味着来自张量x的元素将被计算,它们不需要独立地梯度。

我该如何修改这段代码以计算所有梯度?或者是否有其他方法可以做到这一点?

  1. <details>
  2. <summary>英文:</summary>
  3. I am trying to compute multiple loss gradients efficiently (without a for loop) in PyTorch. Given:
  4. ```python
  5. import torch
  6. from torch import nn
  7. class NeuralNetwork(nn.Module):
  8. def __init__(self):
  9. super().__init__()
  10. self.linear = nn.Sequential(
  11. nn.Linear(input_size, 16, bias=False),
  12. nn.Linear(16, output_size, bias=False),
  13. )
  14. def forward(self, x):
  15. return self.linear(x)
  16. device = &quot;cpu&quot;
  17. input_size = 2
  18. output_size = 2
  19. x = torch.randn(10, 1, input_size).to(device)
  20. y = torch.randn(10, 1, output_size).to(device)
  21. model = NeuralNetwork().to(device)
  22. loss_fn = nn.MSELoss()
  23. def loss_grad(x, label):
  24. y = model(x)
  25. loss = loss_fn(y, label)
  26. grads = torch.autograd.grad(loss, model.parameters(), retain_graph=True)
  27. return grads

The following works, but uses a for loop:

  1. # inefficient but works
  2. def compute_for():
  3. grads = [loss_grad(x[i], y[i]) for i in range(x.shape[0])]
  4. print(grads)
  5. compute_for()

For efficiency, I tried using torch.vmap instead:

  1. # potentially more efficient but doesn&#39;t work
  2. def compute_vmap():
  3. grads = torch.vmap(loss_grad)(x, y)
  4. print(grads)
  5. compute_vmap()

I was expecting it to compute the gradients of the losses w.r.t. the parameters for each element in x, y. Instead, I get an error:

  1. RuntimeError: element 0 of tensors does not require grad

As I understand, this means that elements from the tensor x will be computed and they don't individually require grad.

How can I modify this code so that it computes all gradients? Or is there another method to do that?

答案1

得分: 0

per-sample gradients 可以使用 vmap 计算,如相关的 教程 中所示:

  1. from torch.func import functional_call, vmap, grad
  2. def compute_loss(params, buffers, sample, target):
  3. batch = sample.unsqueeze(0)
  4. targets = target.unsqueeze(0)
  5. predictions = functional_call(model, (params, buffers), (batch,))
  6. loss = loss_fn(predictions, targets)
  7. return loss
  8. params = {k: v.detach() for k, v in model.named_parameters()}
  9. buffers = {k: v.detach() for k, v in model.named_buffers()}
  10. ft_compute_grad = grad(compute_loss)
  11. ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))
  12. ft_per_sample_grads = ft_compute_sample_grad(params, buffers, x, y)
  13. print(ft_per_sample_grads)

这些与为每个对 (x[i], y[i]) 分别计算的梯度相匹配。

英文:

The per-sample gradients may be computed using vmap as shown in the relevant tutorial:

  1. from torch.func import functional_call, vmap, grad
  2. def compute_loss(params, buffers, sample, target):
  3. batch = sample.unsqueeze(0)
  4. targets = target.unsqueeze(0)
  5. predictions = functional_call(model, (params, buffers), (batch,))
  6. loss = loss_fn(predictions, targets)
  7. return loss
  8. params = {k: v.detach() for k, v in model.named_parameters()}
  9. buffers = {k: v.detach() for k, v in model.named_buffers()}
  10. ft_compute_grad = grad(compute_loss)
  11. ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))
  12. ft_per_sample_grads = ft_compute_sample_grad(params, buffers, x, y)
  13. print(ft_per_sample_grads)

These match the gradients computed individually for each pair (x[i], y[i]).

huangapple
  • 本文由 发表于 2023年5月29日 04:49:11
  • 转载请务必保留本文链接:https://go.coder-hub.com/76353546.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定