咳咳,先讲点废话,释放一下喜忧参半的心情,快速读取核心内容,请看目录!
如果你去查列表的相关信息,相信大量文字会显示——“列表是序列的一种,属最常用的Python数据类型,它支持字符,数字,字符串甚至可以包含列表(嵌套),他的数据项不需要具有相同的类型······”
所以,当你在for循环中得到一个一个的tensor,而且这些tensor对你来说都有用,该怎么把它存起来?
这个问题困扰了我很久,列表虽然好用,但是根据网上对列表的介绍。我真的不知道列表里面能不能存储tensor,直到修改不同batchsize的时候······各种报错,我要吐了😢
神经网络中,大多输入都是input = torch.randn([16, 3, 224, 224])
,对batchsize有一定的要求。然而总有一些时候你想要的是torch.randn([3, 224, 224])
,而且,16个
batch的内容,你都想要,那就试试用列表把它们 “装起来” 吧。
各位看官请瞅瞅代码!
input = torch.randn([16, 3, 224, 224])bs, nc, h, w = input.shape
a = [] # 用了存储每个batch里面的张量
for i in range(0, bs):img = input[i] # 得到torch.Size([3, 224, 224])# 此处是你的各种操作# result = result.unsqueeze(0).unsqueeze(0),按需升维# 例如result => torch.Size([1, 1, 224, 224])a.append(result)# 此时,想要验证列表存储张量的可行性,就可以如下
print(a)
print(type(a)) # <class 'list'>
print(type(a[0])) # <class 'torch.Tensor'>
注意:使用unsqueeze()、squeeze()方法的时候,一定要注意,你要升或者降的那个维度数值为1。例如:input = torch.randn([16, 3, 224, 224]),想要得到torch.Size([3, 224, 224]),使用input.unsqueeze(0)是行不通的!
如果成功的取出16个torch.Size([3, 224, 224])
,并且对其一一操作后,如何再恢复成torch.Size([16, 3, 224, 224])
?这里我用的是at()方法,参考1+2+3+···+100的for循环算法。
不懂就查系列 ☞ torch.stack, torch.cat, torch.stack.max/mean/sum维度变换详解
# 接上面那一段代码,a还是存储张量的列表在这里插入代码片out = at((a[0], a[1]), dim=0)
for j in range(2, len(a)):out = at((out,a[j]),dim=0)
print(out.shape)
return out
顺便,给大家安利一下截图软件!!真非常非常非常好用,可以把截屏的图片贴在界面上。
本文发布于:2024-01-29 04:49:07,感谢您对本站的认可!
本文链接:https://www.4u4v.net/it/170647495012806.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
留言与评论(共有 0 条评论) |