本篇笔记主要对应于莫凡Pytorch中的3.1节。主要讲了如何使用Pytorch搭建一个回归模型的神经网络。
在Pytorch中自定义一个神经网络时,我们需要继承Module来书写自己的神经网络。在继承该类时,必须重新实现__init__构造函数和forward这两个方法。这里有一些注意点:
一般把网络中具有可学习参数的层(如全连接层、卷积层等)放在构造函数__init__()中,当然我也可以吧不具有参数的层也放在里面;
一般把不具有可学习参数的层(如ReLU、dropout、BatchNormanation层)可放在构造函数中,也可不放在构造函数中,如果不放在构造函数__init__里面,则在forward方法里面可以使用nn.functional来代替;
forward方法是必须要重写的,它是实现模型的功能,实现各个层之间的连接关系的核心。
接下来我们来自己搭建一个回归模型的神经网络。
这里生成一组 y = x 2 y=x^2 y=x2的数据,并加入一些随机噪声。
import torch
functional as F
import matplotlib.pyplot as plt
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1)
plt.scatter(x.data.numpy(), y.data.numpy())
plt.show()
我们自定义一个类来完成回归操作
class Module):def __init__(self, n_feature, n_hidden, n_output):# 分别表示feature个数、隐藏层神经元数个数、输出值数目super(Net, self).__init__()self.hidden = Linear(n_feature, n_hidden)self.predict = Linear(n_hidden, n_output)def forward(self, x):# x 是输入数据x = F.relu(self.hidden(x))y = self.predict(x)return y
这是一个两层的神经网络,其包含一个隐藏层即self.hidden,之后便连接一个输出层self.predict。在前向传播时,网络对隐层的输出进行了Relu操作。
网络搭建完成后,我们可以打印输出一下这个网络的基本结构
net = Net(1, 10, 1)
print(net)
得到输出如下
Net((hidden): Linear(in_features=1, out_features=10, bias=True)(predict): Linear(in_features=10, out_features=1, bias=True)
)
接下来我们设置网络的优化器和损失函数。
优化方法设置为随机梯度下降法,学习率设置为0.5。
一般回归问题使用最小均方误差作为损失函数。
optimizer = torch.optim.SGD(net.parameters(), lr=0.5)
loss_func = MSELoss() # 回归问题采用MSE
最后,我们展示输出并可视化中间过程。
plt.ion()
for step in range(100):prediction = net(x)loss = loss_func(prediction, _grad() # 首先将所有参数的梯度降为0(因为每次计算梯度后这个值都会保留,不清零就会导致不正确)loss.backward() # 进行反向传递,计算出计算图中所有节点的梯度optimizer.step() # 计算完成后,使用optimizer优化这些梯度if step % 20 == 0:# plot and show learning processplt.cla()plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=(0.5, 0, 'Loss=%.4f' % loss.data)plt.savefig("./img/02_"+str(step)+".png")plt.pause(0.1)plt.ioff()
plt.show()
可以看到随着训练的进行,loss逐渐降低,模型拟合的效果越来越好。
本文发布于:2024-02-03 03:51:24,感谢您对本站的认可!
本文链接:https://www.4u4v.net/it/170690348348463.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
留言与评论(共有 0 条评论) |