本博客记录一下在Caltech256数据集中遇到的一些坑。
Caltech256主页,以及博主提供的免费数据集下载链接。
踩坑1:
waring:torchvision提供的标签信息与实际下载的数据集不匹配,数据集加载出错
标签文件与数据集不匹配,标签文件长达30609,但实际图片为30607张(官方数据)。判断了下所有文件是否存在,出现两个奇葩index找不到文件。
train_set = datasets.Caltech256(root=data_folder,download=True,# train=True,transform=train_transform)
for i in range(len(train_set.index)):if not ists(os.path.join(root,"256_ObjectCategories",train_set.categories[train_set.y[i]],"{:03d}_{:04d}.jpg".format(train_set.y[i] + 1, train_set.index[i]))):print(i)
解决方案:
train_set.index = train_set.index[:6307] + train_set.index[6308:22619] + train_set.index[22620:]train_set.y = train_set.y[:6307] + train_set.y[6308:22619] + train_set.y[22620:]
踩坑2:
waring:总类别为257,使用错误的类别将会导致预测概率为负,损失计算为NaN等无厘头错误。
n_cls = 256 #错误类别
n_cls = 257 #正确类别
踩坑3:
waring:部分图片为单通道,直接使用transforms进行标准化将会报错。
顺便把标准化参数也在这公布一下吧,同学们以后就不用自己再算一遍了。
# 三通道均值和标准差分别为:
0.5520 0.5336 0.5050
0.2353 0.2345 0.2372
解决方案:
通过路径读取图片后先将img转换为tensor,并判断tensor是不是单通道,如果单通道则直接复制为3通道。确保所有img为3通道之后再进行标准化和resize等操作。
img = Image.open(img_path) # 读取该图片img = transforms.Compose([transforms.ToTensor()])(img)ansform is not None:if img.shape[0] != 3:img = peat(3, 1, 1)img = ansform(img)
暂时记录到这啦
本文发布于:2024-02-05 01:49:18,感谢您对本站的认可!
本文链接:https://www.4u4v.net/it/170721289161917.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
留言与评论(共有 0 条评论) |