pytorch dataloader collate

阅读: 评论:0

pytorch dataloader collate

pytorch dataloader collate

pytorch dataset collect_fn 在复杂label情况下的使用技巧

  • 在如目标检测等任务的训练过程中,label可能出现shape不一致的情况,以目标检测为例,一个batch的不同图片,可能有不同数量的bounding box,如果将bbox以(n,5)的shape的张量形式返回,n的数量就不统一,在使用默认的collect_fn时,pytorch的dataloader就会报错,这时就需要专门写一个collect_fn,如下:
import torch.utils.data as data
from PIL import Image
import osclass my_dataloader_with_det_label(data.Dataset):def __init__(self, images_path):ain_list = [os.path.join(images_path, i) for i in os.listdir(images_path)]self.size = 256self.data_list = ain_listprint("Total training examples:", ain_list))def __getitem__(self, index):# get imagedata_path = self.data_list[index]data = Image.open(data_lowlight_path).convert('RGB')data = size((self.size,self.size), Image.ANTIALIAS)data = (np.asarray(data_lowlight)/255.0) data = torch.from_numpy(data_lowlight).float()# get labellabel_path = place('images', 'labels')label_path = label_path[:-3] + 'txt'with open(label_path) as f:l = [x.split() for x ad().strip().splitlines() if len(x)]nl = len(l)label = np.zeros((nl, 6))label[:, 1:] = np.array(l, dtype=np.float32)return data.permute(2,0,1), torch.from_numpy(label).float()def __len__(self):return len(self.data_list)@staticmethoddef collate_fn(batch):img, label = zip(*batch)  # transposedfor i, l in enumerate(label):l[:, 0] = i  # add target image index for build_targets()return {'img': torch.stack(img, 0), 'label': torch.cat(label, 0)}train_dataset = my_dataloader_with_det_label(images_path)	
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size&#ain_batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True, collate_fn=llate_fn)
  • 解释一下,collate_fn之所以这么写,是因为dataloader在返回一个batch的数据时,首先会调用 batch_size 次 dataset的__getitem__()方法,然后将 batch_size 个返回值作为一个列表送入collate_fn 函数,以打包成一个batch的数据返回出来。如果没有重写collate_fn函数,会用默认的collate_fn方法打包,这时label自动concatenate时会出现维度不匹配的错误。
  • 因此这里就用到了label 6列中的第一列。可能你会好奇为什么bbox明明只需要5个数据,label却是一个 n × 6 ntimes6 n×6的矩阵。第2列到第6列装的就是label本身的5个数据,而第一列装的是一个batch中各个图片的索引。这里可以注意到img用的是stack方法,而label用的是cat方法,这样一来img就变成了(bs,c,h,w)而label却变成了(k,6),少了batch size的一维。这里k其实是一个batch的所有图片的所有bbox的总数,把不同图片的所有bbox都按顺序混在了一起,这样一来就不会出现维度不匹配的问题,但就会需要第一列来区分出哪些行是对batch中第一张图片的标注,哪些行是对第二张的标注,等等。然后在算loss的时候再根据label的第一维去对应出来就可以了。
  • 这个collate_fn其实是pytorch版的yolov3的库中借鉴来的,源码的地址在:

本文发布于:2024-01-30 02:55:02,感谢您对本站的认可!

本文链接:https://www.4u4v.net/it/170655450418733.html

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。

标签:pytorch   dataloader   collate
留言与评论(共有 0 条评论)
   
验证码:

Copyright ©2019-2022 Comsenz Inc.Powered by ©

网站地图1 网站地图2 网站地图3 网站地图4 网站地图5 网站地图6 网站地图7 网站地图8 网站地图9 网站地图10 网站地图11 网站地图12 网站地图13 网站地图14 网站地图15 网站地图16 网站地图17 网站地图18 网站地图19 网站地图20 网站地图21 网站地图22/a> 网站地图23