PyTorch 中的渐变通道

数据挖掘 神经网络 深度学习 火炬 反向传播 坡度
2022-02-21 13:34:48

我需要量化输入,但我需要这样做的方法(bucketize)是不可区分的。我当然可以分离张量,但是我失去了梯度流到更早的权重。我想这个问题很简单,你如何在必要时继续渐变的流动。例如,使用以下代码...

x = self.linear1(x)        
min, max = int(x.min()), int(x.max())        
bins = torch.linspace(min, max+1, 16)
x = torch.bucketize(x.detach(), bins) # forced to detach here
x = self.linear2(x)

我知道这是可能的,因为它是使用 VQVAE 完成的,并且梯度仍然可以很好地用于量化目的。但是我检查了几个版本的 VQVAE 代码,梯度如何通过 argmin 方法似乎不合逻辑。无论如何,它似乎都有效,这让我感到困惑。我对此感到很困惑。对于这方面的任何帮助,我将不胜感激。提前致谢。

1个回答

感谢@thanatoz 在评论中回答这个问题。事实上,这确实解决了我的问题。我也从其他人的 VQVAE 代码中意识到,还有另一种可能的解决方案,其中不需要 retrain_graph 参数。我的示例代码可以写成...

x = self.linear1(x)        
min, max = int(x.min()), int(x.max())        
bins = torch.linspace(min, max+1, 16)
x_buckets = torch.bucketize(x.detach(), bins) # forced to detach here
x = x + (x - x_buckets).detach() # Reintroduce gradients here. <------------
x = self.linear2(x)

答案在第五行,减去然后再加回来。然后可以重新引入渐变。