我需要量化输入,但我需要这样做的方法(bucketize)是不可区分的。我当然可以分离张量,但是我失去了梯度流到更早的权重。我想这个问题很简单,你如何在必要时继续渐变的流动。例如,使用以下代码...
x = self.linear1(x)
min, max = int(x.min()), int(x.max())
bins = torch.linspace(min, max+1, 16)
x = torch.bucketize(x.detach(), bins) # forced to detach here
x = self.linear2(x)
我知道这是可能的,因为它是使用 VQVAE 完成的,并且梯度仍然可以很好地用于量化目的。但是我检查了几个版本的 VQVAE 代码,梯度如何通过 argmin 方法似乎不合逻辑。无论如何,它似乎都有效,这让我感到困惑。我对此感到很困惑。对于这方面的任何帮助,我将不胜感激。提前致谢。