RuntimeError: 在PyTorch代码中,预期标量类型为Double,但找到了Float。

huangapple go评论60阅读模式
英文:

RuntimeError: expected scalar type Double but found Float in Pytorch code

问题

def encoder_block(inp, max_pool, in_channels):
    conv = torch.nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=3, padding='same')(inp.double())
    relu = torch.nn.ReLU()(conv)
    conv = torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding='same')(relu)
    relu = torch.nn.ReLU()(conv)
    if max_pool:
        return torch.nn.MaxPool2d(2, 2)(relu)
    return relu

test_load = nib.load(fpath).get_fdata()
test_numpy = test_load[:, :, 0].reshape(1, 1, 256, 256).astype(np.double)
tens = torch.DoubleTensor(test_numpy)
out = encoder_block(tens, True, 1)

这段代码应该从本地存储中获取一个NIfTI文件,将其转换为NumPy数组,然后在2D图像上执行一些卷积操作,仅用于基本测试。错误发生在第一个Conv2d 操作上,错误信息是 RuntimeError: expected scalar type Double but found Float。我不确定如何将数据转换为浮点数。

英文:
def encoder_block(inp, max_pool, in_channels):
    conv = torch.nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=3, padding='same')(inp.double())
    relu = torch.nn.ReLU()(conv)
    conv = torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding='same')(relu)
    relu = torch.nn.ReLU()(conv)
    if max_pool:
        return torch.nn.MaxPool2d(2,2)(relu)
    return relu

test_load = nib.load(fpath).get_fdata()
test_numpy = test_load[:,:,0].reshape(1,1,256,256).astype(np.double)
tens = torch.DoubleTensor(test_numpy)
out = encoder_block(tens, True, 1)

This code should take a nifti file from my local storage, convert it to a numpy array then perform some convolutions on the 2d image as a basic test just for now.

The error happens on the first conv2d and says RuntimeError: expected scalar type Double but found Float. I'm not sure what else I can do to convert my data to float.

答案1

得分: 1

这只是一个转换问题。PyTorch 默认使用 torch.float32,但您明确创建了一个 torch.float64 类型的张量。因此,您还需要使用关键字参数 dtype 来告诉卷积层正确的数据类型。代码的工作示例如下:

def encoder_block(inp, max_pool, in_channels):
    conv = torch.nn.Conv2d(in_channels=in_channels, 
                           out_channels=64, 
                           kernel_size=3, 
                           padding='same', 
                           dtype=torch.float64)(inp)
    relu = torch.nn.ReLU()(conv)
    conv = torch.nn.Conv2d(in_channels=64, 
                           out_channels=64, 
                           kernel_size=3,
                           padding='same', 
                           dtype=torch.float64)(relu)
    relu = torch.nn.ReLU()(conv)
    if max_pool:
        return torch.nn.MaxPool2d(2,2)(relu)
    return relu

test_load = nib.load(fpath).get_fdata()
# tested with the next line
# test_load = np.random.rand(256, 256, 1)
test_numpy = test_load[:,:,0].reshape(1,1,256,256)
tens = torch.DoubleTensor(test_numpy)
out = encoder_block(tens, True, 1)
英文:

This is simply a conversion problem. Pytorch uses torch.float32 by default but you explicitly create a tensor of type torch.float64. So you also have to tell the convolutional layers the correct dtype using the keyword argument dtype. A working example of the code looks like:

def encoder_block(inp, max_pool, in_channels):
    conv = torch.nn.Conv2d(in_channels=in_channels, 
                           out_channels=64, 
                           kernel_size=3, 
                           padding='same', 
                           dtype=torch.float64)(inp)
    relu = torch.nn.ReLU()(conv)
    conv = torch.nn.Conv2d(in_channels=64, 
                           out_channels=64, 
                           kernel_size=3,
                           padding='same', 
                           dtype=torch.float64)(relu)
    relu = torch.nn.ReLU()(conv)
    if max_pool:
        return torch.nn.MaxPool2d(2,2)(relu)
    return relu

test_load = nib.load(fpath).get_fdata()
# tested with the next line
# test_load = np.random.rand(256, 256, 1)
test_numpy = test_load[:,:,0].reshape(1,1,256,256)
tens = torch.DoubleTensor(test_numpy)
out = encoder_block(tens, True, 1)

答案2

得分: 1

get_fdata() 接受一个 dtype 参数。

test_load = nib.load(fpath).get_fdata(dtype=np.float32)
英文:

> I'm not sure what else I can do to convert my data to float.

get_fdata() takes a dtype parameter.

test_load = nib.load(fpath).get_fdata(dtype=np.float32)

huangapple
  • 本文由 发表于 2023年2月27日 14:42:24
  • 转载请务必保留本文链接:https://go.coder-hub.com/75577409.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定