在 PyTorch 中微调 RetinaNet 模型

数据挖掘 微调 火炬视觉
2022-03-03 02:39:38

我想改进 torchvision 中可用的预训练 RetinaNet 模型,以便创建我自己的对象检测。

我正在尝试在此链接上复制 FastRCNN 所做的工作: https ://pytorch.org/tutorials/intermediate/torchvision_tutorial.html#finetuning-from-a-pretrained-model

我所做的如下:

model = model = torchvision.models.detection.retinanet_resnet50_fpn(pretrained=True)
num_classes = 2 

# get number of input features and anchor boxed for the classifier
in_features = model.head.classification_head.conv[0].in_channels
num_anchors = model.head.classification_head.num_anchors

# replace the pre-trained head with a new one
model.head = RetinaNetHead(in_features, num_anchors, num_classes)

模型已声明,训练不会中断。然而,性能是如此糟糕,以至于一个非常愚蠢的检测都不起作用。

我的问题是,我写的代码可以重新训练 RetinaNet 模型吗?

1个回答

我也在尝试做类似的事情。下面的代码应该可以工作。在 COCO 数据集上加载预训练的权重后,我们需要用我们自己的替换分类器层。

num_classes = # num of objects to identify + background class

model = torchvision.models.detection.retinanet_resnet50_fpn(pretrained=True)
# replace classification layer 
in_features = model.head.classification_head.conv[0].in_channels
num_anchors = model.head.classification_head.num_anchors
model.head.classification_head.num_classes = num_classes

cls_logits = torch.nn.Conv2d(out_channels, num_anchors * num_classes, kernel_size = 3, stride=1, padding=1)
torch.nn.init.normal_(cls_logits.weight, std=0.01)  # as per pytorch code
torch.nn.init.constant_(cls_logits.bias, -math.log((1 - 0.01) / 0.01))  # as per pytorcch code 
# assign cls head to model
model.head.classification_head.cls_logits = cls_logits

回归框网络不需要更改,因为每个空间位置的锚框数量不会像在 COCO 数据集上预训练模型时那样发生变化。与 Faster R-CNN 相比,该代码有点麻烦。希望看到一个更优雅的解决方案。

希望这可以帮助。