莫凡Pytorch学习笔记(五)

阅读: 评论:0

莫凡Pytorch学习笔记(五)

莫凡Pytorch学习笔记(五)

Pytorch模型保存与提取

本篇笔记主要对应于莫凡Pytorch中的3.4节。主要讲了如何使用Pytorch保存和提取我们的神经网络。

我们将通过两种方式展示模型的保存和提取。
第一种保存方式是保存整个模型,在重新提取时直接加载整个模型。第二种保存方法是只保存模型的参数,这种方式只保存了参数,而不会保存模型的结构等信息。

两种方式各有优缺点。保存完整模型不需要知道网络的结构,一次性保存一次性读入。缺点是模型比较大时耗时较长,保存的文件也大。而只保存参数的方式存储快捷,保存的文件也小一些,但缺点是丢失了网络的结构信息,恢复模型时需要提前建立一个特定结构的网络再读入参数。

以下使用代码展示。

数据生成与展示

import torch
functional as F
import matplotlib.pyplot as plt

这里还是生成一组带有噪声的 y = x 2 y=x^{2} y=x2数据进行回归拟合。

# torch.manual_seed(1)    # reproducible# fake data
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size())  # noisy y data (tensor), shape=(100, 1)

基本网络搭建与保存

我们使用nn.Sequential模块来快速搭建一个网络完成回归操作。这里使用两种方式进行保存。

def save():# save net1net1 = Linear(1, 10),ReLU(),Linear(10, 1))optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)loss_func = MSELoss()for step in range(100):prediction = net1(x)loss = loss_func(prediction, _grad()loss.backward()optimizer.step()# plot resultplt.figure(1, figsize=(10, 3))plt.subplot(131)plt.title('Net1')plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)plt.savefig("./img/05_save.png")torch.save(net1, 'net.pkl')                        # entire networktorch.save(net1.state_dict(), 'net_params.pkl')    # parameters

在这个save函数中,我们首先使用nn.Sequential模块构建了一个基础的二层神经网络。然后对其进行训练。展示训练结果。之后使用两种方式进行保存。

第一种方式直接保存整个网络,代码为

torch.save(net1, 'net.pkl')                        # entire network

第二种方式只保存网络参数,代码为

torch.save(net1.state_dict(), 'net_params.pkl')    # parameters

对保存的模型进行提取恢复

这里我们为两种不同存储方式保存的模型分别定义恢复提取的函数
首先是对整个网络的提取。直接使用torch.load就可以。

def restore_net():# 提取神经网络net2 = torch.load('net.pkl')prediction = net2(x)# plot resultplt.subplot(132)plt.title('Net2')plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)plt.savefig("./img/05_res_net.png")

而对于参数的读取,我们首先需要先搭建好一个与之前保存的模型相同架构的网络。然后使用这个网络的load_state_dict方法进行参数读取和恢复。

def restore_params():# 提取神经网络net3 = Linear(1, 10),ReLU(),Linear(10, 1))net3.load_state_dict(torch.load('net_params.pkl'))prediction = net3(x)# plot resultplt.subplot(133)plt.title('Net3')plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)plt.savefig("./img/05_res_para.png")plt.show()

对比不同提取方法的效果

接下来我们对比一下这两种方法的提取效果

# save net1
save()# restore entire net (may slow)
restore_net()# restore only the net parameters
restore_params()

最后,得到的展示输出如下:

这里Net1即我们训练好的网络,我们使用两种方式保存了Net1。使用第一种方式存储和提取的结果为Net2,使用第二种方式存储和提取的结果为Net3。通过对比可以看出,这三个网络一模一样,证明不同的存储提取方式的效果是相同的,不会有差异。

参考

  1. 莫凡Python:Pytorch动态神经网络,/

本文发布于:2024-02-03 03:52:27,感谢您对本站的认可!

本文链接:https://www.4u4v.net/it/170690354748468.html

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。

标签:学习笔记   Pytorch
留言与评论(共有 0 条评论)
   
验证码:

Copyright ©2019-2022 Comsenz Inc.Powered by ©

网站地图1 网站地图2 网站地图3 网站地图4 网站地图5 网站地图6 网站地图7 网站地图8 网站地图9 网站地图10 网站地图11 网站地图12 网站地图13 网站地图14 网站地图15 网站地图16 网站地图17 网站地图18 网站地图19 网站地图20 网站地图21 网站地图22/a> 网站地图23