pytorch TV loss代码分析

阅读: 评论:0

pytorch TV loss代码分析

pytorch TV loss代码分析

代码来自:

版权声明:本文为CSDN博主「Kindle君」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:

本文在代码中加入了部分注释

TV loss在图像上就是求每一个像素和横向下一个像素的差的平方,加上纵向下一个像素的差的平方。然后开β/2次根。

import torch
 as nn
from torch.autograd import Variableclass TVLoss(nn.Module):def __init__(self,TVLoss_weight=1):super(TVLoss,self).__init__()self.TVLoss_weight = TVLoss_weightdef forward(self,x):batch_size = x.size()[0]h_x = x.size()[2]w_x = x.size()[3]count_h = self._tensor_size(x[:,:,1:,:])  #算出总共求了多少次差count_w = self._tensor_size(x[:,:,:,1:])h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()  # x[:,:,1:,:]-x[:,:,:h_x-1,:]就是对原图进行错位,分成两张像素位置差1的图片,第一张图片# 从像素点1开始(原图从0开始),到最后一个像素点,第二张图片从像素点0开始,到倒数第二个            # 像素点,这样就实现了对原图进行错位,分成两张图的操作,做差之后就是原图中每个像素点与相# 邻的下一个像素点的差。w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_sizedef _tensor_size(self,t):return t.size()[1]*t.size()[2]*t.size()[3]def main():# x = Variable(torch.FloatTensor([[[1,2],[2,3]],[[1,2],[2,3]]]).view(1,2,2,2), requires_grad=True)# x = Variable(torch.FloatTensor([[[3,1],[4,3]],[[3,1],[4,3]]]).view(1,2,2,2), requires_grad=True)# x = Variable(torch.FloatTensor([[[1,1,1], [2,2,2],[3,3,3]],[[1,1,1], [2,2,2],[3,3,3]]]).view(1, 2, 3, 3), requires_grad=True)x = Variable(torch.FloatTensor([[[1, 2, 3], [2, 3, 4], [3, 4, 5]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]]]).view(1, 2, 3, 3),requires_grad=True)addition = TVLoss()z = addition(x)print xprint z.dataz.backward()adif __name__ == '__main__':

本文发布于:2024-01-29 13:19:18,感谢您对本站的认可!

本文链接:https://www.4u4v.net/it/170650556115555.html

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。

标签:代码   pytorch   TV   loss
留言与评论(共有 0 条评论)
   
验证码:

Copyright ©2019-2022 Comsenz Inc.Powered by ©

网站地图1 网站地图2 网站地图3 网站地图4 网站地图5 网站地图6 网站地图7 网站地图8 网站地图9 网站地图10 网站地图11 网站地图12 网站地图13 网站地图14 网站地图15 网站地图16 网站地图17 网站地图18 网站地图19 网站地图20 网站地图21 网站地图22/a> 网站地图23