替换 pytorch 张量中的值

数据挖掘 火炬
2022-03-04 02:26:30
t=tensor([0.1 0.2 0.3 0.4 0.5 0.6])

现在我需要修改这个现有的张量 t 如下:

t=tensor([0.1 0.2 0.7 0.4 0.8 0.6])

我尝试如下:

t=tensor([0.1 0.2 0.3 0.4 0.5 0.6])
a=tensor([0.1 0.2 0.7 0.4 0.8 0.6])
index=range(len(a))
t.index_copy_(0,index,a)

但它仍然没有更新我如何修改pytorch中的张量?

1个回答

试试这个:

t = tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6])
t.index_copy_(0, tensor([2, 4]), tensor([0.7, 0.8]))

参考:

torch.Tensor.index_copy_