pytorch中的conv2d函数

数据挖掘 Python 卷积 火炬
2022-02-14 22:02:53

我正在尝试使用 Pytorch 中的函数torch.conv2d但无法得到我理解的结果......

这是一个简单的示例,其中内核 ( filt) 与输入 ( ) 的大小相同,im以解释我在寻找什么。

import pytorch

filt = torch.rand(3, 3)
im = torch.rand(3, 3)

我想计算一个没有填充的简单卷积,所以结果应该是一个标量(即 1x1 张量)。

我试过这个conv2d

# I have to convert image and kernel to 4 dimensions tensors to use conv2d
im_torch = im.reshape((im_height, filt_height, 1, 1))
filt_torch = filt.reshape((filt_height, im_height, 1, 1))
out = torch.nn.functional.conv2d(im_torch, filt_torch, stride=1, padding=0)
print(out)

但结果不是我所期望的:

tensor([[[[0.6067]], [[0.3564]], [[0.5397]]],
    [[[0.2557]], [[0.0493]], [[0.2562]]],
    [[[0.6067]], [[0.3564]], [[0.5397]]]])

为了了解我想要什么,我想重现 scipyconvolve2d行为:

import scipy.signal
out_scipy = scipy.signal.convolve2d(im.detach().numpy(), filt.detach().numpy(), 'valid')
print(out_scipy)

打印:

array([[1.195723]], dtype=float32)
2个回答

要获得等效结果:

im_torch = im.reshape((1, 1, 3, 3))
filt_torch = filt.reshape((1, 1, 3, 3))

out = torch.nn.functional.conv2d(filt_torch, im_torch, stride=1, padding=0)
print(out)


import scipy.signal


out_scipy = scipy.signal.convolve2d(im.detach().numpy(), filt.detach().numpy()[::-1, ::-1], 'valid')
print(out_scipy)

有一个整形问题(https://pytorch.org/docs/stable/nn.functional.html),并且在torch中过滤器的应用与在scipy中的处理方式有所不同:https://discuss.pytorch .org/t/functional-conv2d-produces-different-results-vs-scipy-convolved2d/17762

好的,我没有找到我的问题的确切答案(即如何使用 conv2d),但我找到了另一种方法。

首先,我了解到我正在寻找的称为有效互相关,它实际上是[Conv2d][1]该类实现的操作。

因此我的解决方案使用Conv2d类而不是conv2d函数。

import pytorch

img = torch.rand(3, 3)

model = torch.nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(3, 3), stride=1, padding=0, bias=False)

res = conv_mdl(img)
print(res.shape)

打印我想要的标量:

torch.Size([1, 1, 1, 1])

PS:我还检查了结果是正确的,而不仅仅是维度。