pytorch __init__() 得到了一个意外的关键字参数“train”

数据挖掘 Python 火炬
2022-03-04 06:15:01

我在 colab 中使用此代码。下面的函数在 intanciar 时给了我一个错误。

import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.utils.data as data
import torchvision
from torchvision import transforms
def load_images(image_size=32, batch_size=64, root=""):
        transform = transforms.Compose([
        transforms.Resize(32),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

        train_set = torchvision.datasets.ImageFolder(root=root,  train=True,transform=transform)
        train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)

        return train_loader

data = load_images (28.64, "/ content / data /")

错误

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-85-a809f71788e0> in <module>()
----> 1 datotes=load_images(28,64,"/content/datitos/")

<ipython-input-84-0b4989f589a9> in load_images(image_size, batch_size, root)
     15     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
     16 
---> 17     train_set = torchvision.datasets.ImageFolder(root=root,  train=True,transform=transform)
     18     train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)
     19 

TypeError: __init__() got an unexpected keyword argument 'train'
1个回答

该错误意味着它无法识别ImageFolder 类中的train参数。

ImageFolder() 没有参数train,删除它将修复错误。代替...

train_set = torchvision.datasets.ImageFolder(root=root,  train=True,transform=transform)

...你应该有...

train_set = torchvision.datasets.ImageFolder(root=root, transform=transform).

有关更多信息,请查看文档