collate

阅读: 评论:0

collate

collate

collate_fn 参数

当继承Dataset类自定义类时,__getitem__方法一般返回一组类似于(image,label)的一个样本,在创建DataLoader类的对象时,collate_fn函数会将batch_size个样本整理成一个batch样本,便于批量训练。

default_collate(batch)中的参数就是这里的 [self.dataset[i] for i in indices],indices是从所有样本的索引中选取的batch_size个索引,表示本次批量获取这些样本进行训练。self.dataset[i]就是自定义Dataset子类中__getitem__返回的结果。默认的函数default_collate(batch) 只能对大小相同image的batch_size个image整理,如[(img0, label0), (img1, label1),(img2, label2), ] 整理成([img0,img1,img2,], [label0,label1,label2,]), 这里要求多个img的size相同。所以在我们的图像大小不同时,需要自定义函数callate_fn来将batch个图像整理成统一大小的,若读取的数据有(img, box, label)这种你也需要自定义,因为默认只能处理(img,label)。当然你可以提前将数据集全部整理成统一大小的。

以下是文字识别时,文本行图像长度不一,需要自定义整理。

class AlignCollate(object):"""将数据整理成batch"""def __init__(self, imgH=32, imgW=100, keep_ratio_with_pad=False):self.imgH = imgHself.imgW = imgWself.keep_ratio_with_pad = keep_ratio_with_paddef __call__(self, batch):# 有可能__getitem__返回的图像是None, 所以需要过滤掉batch = filter(lambda x: x is not None, batch)images, labels = zip(*batch)if self.keep_ratio_with_pad:  # same concept with 'Rosetta' paperresized_max_w = self.imgWinput_channel = 3 if images[0].mode == 'RGB' else 1transform = NormalizePAD((input_channel, self.imgH, resized_max_w))resized_images = []for image in images:w, h = image.sizeratio = w / float(h)# 图片的宽度大于设定的输入il(self.imgH * ratio) > self.imgW:resized_w = self.imgWelse:resized_w = il(self.imgH * ratio)resized_image = size((resized_w, self.imgH), Image.BICUBIC)resized_images.append(transform(resized_image))# resized_image.save('./image_test/%d_test.jpg' % w)image_tensors = torch.cat([t.unsqueeze(0) for t in resized_images], 0)else:transform = ResizeNormalize((self.imgW, self.imgH))image_tensors = [transform(image) for image in images]image_tensors = torch.cat([t.unsqueeze(0) for t in image_tensors], 0)return image_tensors, labels

再看个在做目标检测时自定义collate_fn函数,给每个图像添加索引

def collate_fn(self, batch):paths, imgs, targets = list(zip(*batch))# Remove empty placeholder targets  # 有可能__getitem__返回的图像是None, 所以需要过滤掉targets = [boxes for boxes in targets if boxes is not None]# Add sample index to targets# boxes是每张图像上的目标框,但是每个图片上目标框数量不一样呢,所以需要给这些框添加上索引,对应到是哪个图像上的框。for i, boxes in enumerate(targets):boxes[:, 0] = itargets = torch.cat(targets, 0)# Selects new image size every tenth batchif self.multiscale and self.batch_count % 10 == 0:self.img_size = random.choice(range(self.min_size, self.max_size + 1, 32))# Resize images to input shape# 每个图像大小不同呢,所以resize到统一大小imgs = torch.stack([resize(img, self.img_size) for img in imgs])self.batch_count += 1return paths, imgs, targets

其实也可以自定义collate_fn同时,结合使用默认的default_collate

from torch.utils.data.dataloader import default_collate  # 导入这个函数def collate_fn(batch):"""params:batch :是一个列表,列表的长度是 batch_size列表的每一个元素是 (x,y) 这样的元组tuple,元祖的两个元素分别是x,y大致的格式如下 [(x1,y1),(x2,y2),(x3,y3)...(xn,yn)]returns:整理之后的新的batch"""# 这一部分是对 batch 进行重新 “校对、整理”的代码return default_collate(batch) #返回校对之后的batch,一般就直接推荐使用default_collate进行包装,因为它里面有很多功能,比如将numpy转化成tensor等操作,这是必须的。

Reference:

一文读懂Dataset, DataLoader及collate_fn, Sampler等参数

本文发布于:2024-01-27 12:06:30,感谢您对本站的认可!

本文链接:https://www.4u4v.net/it/17063283971309.html

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。

标签:collate
留言与评论(共有 0 条评论)
   
验证码:

Copyright ©2019-2022 Comsenz Inc.Powered by ©

网站地图1 网站地图2 网站地图3 网站地图4 网站地图5 网站地图6 网站地图7 网站地图8 网站地图9 网站地图10 网站地图11 网站地图12 网站地图13 网站地图14 网站地图15 网站地图16 网站地图17 网站地图18 网站地图19 网站地图20 网站地图21 网站地图22/a> 网站地图23