深度学习模型计算准确率背后的解释

数据挖掘 Python 深度学习 准确性
2022-02-23 02:15:50

我正在尝试使用卷积神经网络对图像分割问题进行建模。我在 Github 中遇到了代码,我无法理解以下代码行用于计算准确性的含义 -

def new_test(loaders,model,criterion,use_cuda):
    for batch_idx, (data,target) in enumerate(loaders):
        output = model(data)
###Accuracy
    _, predicted = torch.max(output.data, 1)
    total_train += target.nelement()
    correct_train += predicted.eq(target.data).sum().item()

model(data)输出形状为B * N * H * W
B = Batch Size
N = 分割类数
H,W = 图像的高度,宽度的张量

1个回答
_, predicted = torch.max(output.data, 1)

在上面的第一行中,它用于查找沿维度 - 1(即沿每一行)的torch.max()所有案例。output.data此外,torch.max()返回(values, indices),因此只有索引存储在一个名为的变量中predicted,而值不存储。

total_train += target.nelement()

上面的第二行只是将训练集的总大小跟踪到一个total_train基于元素数量的变量中target

correct_train += predicted.eq(target.data).sum().item()

上面的第三行采用predicted变量,其中包含输出大于 1 时的所有索引,对照 进行检查,并对target.data所有相等的项求和。换句话说,它正在计算有多少预测值等于目标数据。