英文:
How do I compute multiple per-sample gradients efficiently?
问题
我正在尝试在PyTorch中高效地计算多个损失梯度(不使用for循环)。给定:
```python
import torch
from torch import nn
class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Sequential(
nn.Linear(input_size, 16, bias=False),
nn.Linear(16, output_size, bias=False),
)
def forward(self, x):
return self.linear(x)
device = "cpu"
input_size = 2
output_size = 2
x = torch.randn(10, 1, input_size).to(device)
y = torch.randn(10, 1, output_size).to(device)
model = NeuralNetwork().to(device)
loss_fn = nn.MSELoss()
def loss_grad(x, label):
y = model(x)
loss = loss_fn(y, label)
grads = torch.autograd.grad(loss, model.parameters(), retain_graph=True)
return grads
以下部分有效,但使用了for循环:
# 有效但效率低下
def compute_for():
grads = [loss_grad(x[i], y[i]) for i in range(x.shape[0])]
print(grads)
compute_for()
为了提高效率,我尝试使用torch.vmap
:
# 可能更高效但不起作用
def compute_vmap():
grads = torch.vmap(loss_grad)(x, y)
print(grads)
compute_vmap()
我原本希望它计算每个元素在x, y
中关于参数的损失梯度。但是,我收到了错误消息:
RuntimeError: element 0 of tensors does not require grad
据我了解,这意味着来自张量x
的元素将被计算,它们不需要独立地梯度。
我该如何修改这段代码以计算所有梯度?或者是否有其他方法可以做到这一点?
<details>
<summary>英文:</summary>
I am trying to compute multiple loss gradients efficiently (without a for loop) in PyTorch. Given:
```python
import torch
from torch import nn
class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Sequential(
nn.Linear(input_size, 16, bias=False),
nn.Linear(16, output_size, bias=False),
)
def forward(self, x):
return self.linear(x)
device = "cpu"
input_size = 2
output_size = 2
x = torch.randn(10, 1, input_size).to(device)
y = torch.randn(10, 1, output_size).to(device)
model = NeuralNetwork().to(device)
loss_fn = nn.MSELoss()
def loss_grad(x, label):
y = model(x)
loss = loss_fn(y, label)
grads = torch.autograd.grad(loss, model.parameters(), retain_graph=True)
return grads
The following works, but uses a for loop:
# inefficient but works
def compute_for():
grads = [loss_grad(x[i], y[i]) for i in range(x.shape[0])]
print(grads)
compute_for()
For efficiency, I tried using torch.vmap
instead:
# potentially more efficient but doesn't work
def compute_vmap():
grads = torch.vmap(loss_grad)(x, y)
print(grads)
compute_vmap()
I was expecting it to compute the gradients of the losses w.r.t. the parameters for each element in x, y
. Instead, I get an error:
RuntimeError: element 0 of tensors does not require grad
As I understand, this means that elements from the tensor x
will be computed and they don't individually require grad.
How can I modify this code so that it computes all gradients? Or is there another method to do that?
答案1
得分: 0
per-sample gradients 可以使用 vmap
计算,如相关的 教程 中所示:
from torch.func import functional_call, vmap, grad
def compute_loss(params, buffers, sample, target):
batch = sample.unsqueeze(0)
targets = target.unsqueeze(0)
predictions = functional_call(model, (params, buffers), (batch,))
loss = loss_fn(predictions, targets)
return loss
params = {k: v.detach() for k, v in model.named_parameters()}
buffers = {k: v.detach() for k, v in model.named_buffers()}
ft_compute_grad = grad(compute_loss)
ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))
ft_per_sample_grads = ft_compute_sample_grad(params, buffers, x, y)
print(ft_per_sample_grads)
这些与为每个对 (x[i], y[i])
分别计算的梯度相匹配。
英文:
The per-sample gradients may be computed using vmap
as shown in the relevant tutorial:
from torch.func import functional_call, vmap, grad
def compute_loss(params, buffers, sample, target):
batch = sample.unsqueeze(0)
targets = target.unsqueeze(0)
predictions = functional_call(model, (params, buffers), (batch,))
loss = loss_fn(predictions, targets)
return loss
params = {k: v.detach() for k, v in model.named_parameters()}
buffers = {k: v.detach() for k, v in model.named_buffers()}
ft_compute_grad = grad(compute_loss)
ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))
ft_per_sample_grads = ft_compute_sample_grad(params, buffers, x, y)
print(ft_per_sample_grads)
These match the gradients computed individually for each pair (x[i], y[i])
.
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论