【深度学习】实验16 使用CNN完成MNIST手写体识别(PyTorch)

阅读: 评论:0

【深度学习】实验16 使用CNN完成MNIST手写体识别(PyTorch)

【深度学习】实验16 使用CNN完成MNIST手写体识别(PyTorch)

文章目录

  • 使用CNN完成MNIST手写体识别(PyTorch)
    • 1. 导入PyTorch库
    • 2. 定义CNN类
    • 3. 下载数据集
    • 4. 训练模型
  • 附:系列文章

使用CNN完成MNIST手写体识别(PyTorch)

卷积神经网络(Convolutional Neural Network,简称CNN)是一种专门用于处理图像、语音、自然语言等数据的深度学习模型。CNN的特点是可以通过卷积运算提取出图像、语音等数据中的特征,从而实现对这些数据进行分类、识别等任务。

CNN的基本结构包括卷积层、池化层和全连接层。其中卷积层是CNN的核心部分,它可以通过卷积核(或滤波器)对输入数据进行卷积运算,从而提取出数据中的空间特征,如边缘、角等。卷积层的输出经过池化层的降采样处理,可以减少参数数量,提高模型的泛化能力。全连接层则将池化层输出的特征向量连接起来,通过权重矩阵进行分类、识别等任务。

CNN的训练过程通常采用反向传播算法来更新网络中的权重参数。反向传播算法可以根据损失函数的导数来逐层计算各层的误差,从而调整各层的权重参数,使得模型对训练数据的拟合效果更好。

CNN在图像识别、目标检测、人脸识别等领域都有广泛应用。其中经典的卷积神经网络模型包括LeNet、AlexNet、VGG、GoogLeNet和ResNet等。这些模型在不同的任务中都取得了很好的效果,为深度学习领域的发展做出了重要贡献。

总的来说,卷积神经网络是一种能够有效处理图像、语音等数据的深度学习模型,在计算机视觉、语音识别等领域具有广泛的应用前景。

1. 导入PyTorch库

import torch
import numpy as np
from torch import nn
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torchvision.datasets import mnist
from torchvision import transforms
from torch import optim

2. 定义CNN类

# 定义CNN
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.layer1 = nn.Sequential(nn.Conv2d(1, 16, kernel_size=3),nn.BatchNorm2d(16),nn.ReLU(inplace=True))self.layer2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size=3),nn.BatchNorm2d(32),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2))self.layer3 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=3),nn.BatchNorm2d(64),nn.ReLU(inplace=True))self.layer4 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3),nn.BatchNorm2d(128),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2,stride=2))self.fc = nn.Sequential(nn.Linear(128 * 4 * 4, 1024),nn.ReLU(inplace=True),nn.Linear(1024, 128),nn.ReLU(inplace=True),nn.Linear(128, 10))def forward(self, x):x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = x.view(x.size(0), -1)x = self.fc(x)return x
# 数据集转换
data_tf = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5], [0.5])]
)

3. 下载数据集

# 使用内置函数下载mnist数据集
train_set = mnist.MNIST('./data', train=True, transform=data_tf, download=True)
test_set = mnist.MNIST('./data', train=False, transform=data_tf, download=True)
train_set, test_set
   (Dataset MNISTNumber of datapoints: 60000Root location: ./dataSplit: TrainStandardTransformTransform: Compose(ToTensor()Normalize(mean=[0.5], std=[0.5])), Dataset MNISTNumber of datapoints: 10000Root location: ./dataSplit: TestStandardTransformTransform: Compose(ToTensor()Normalize(mean=[0.5], std=[0.5])))
# 划分训练集与测试集
train_data = DataLoader(train_set, batch_size=100, shuffle=True)
test_data = DataLoader(test_set, batch_size=100, shuffle=False)
train_data, test_data
   (<torch.utils.data.dataloader.DataLoader at 0x7f43c81d6eb8>,<torch.utils.data.dataloader.DataLoader at 0x7f43c81d6e10>)

4. 训练模型

# 调用卷积神经网络
net = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), 1e-1)
# 开始训练
losses = []
acces = []
eval_losses = []
eval_acces = []nums_epoch = 1print("开始训练......")
for epoch in range(nums_epoch):print("Test:" + str(epoch))train_loss = 0train_acc = 0net = ain()i = 0for img, label in train_data:i +=1print("第" + str(i) + "批训练")img = Variable(img)label =Variable(label)# 前向传播out = net(img)loss = criterion(out, label)# 反向传播_grad()loss.backward()optimizer.step()# 记录误差train_loss += loss.item()# 计算分类的准确率_, pred = out.max(1)num_correct = (pred ==label).sum().item()acc = num_correct / img.shape[0]# 记录准确率train_acc += acclosses.append(train_loss / len(train_data))acces.append(train_acc / len(train_data))eval_loss = 0eval_acc = 0# 测试集不训练for img, label in test_data:img = Variable(img)label = Variable(label)# 前向传播out = net(img)loss = criterion(out, label)# 记录误差eval_loss += loss.item()# 计算分类的准确率_, pred = out.max(1)num_correct = (pred == label).sum().item()acc = num_correct / img.shape[0]# 记录准确率eval_acc += acceval_losses.append(eval_loss / len(test_data))eval_acces.append(eval_acc / len(test_data))print('Epoch {}: nTrain Loss: {} Train Accuracy: {} nTest Loss: {} Test Accuarcy: {}'.format(epoch + 1, train_loss / len(train_data), train_acc / len(train_data),eval_loss / len(test_data), eval_acc / len(test_data)))
   开始训练......Test:0第1批训练第2批训练第3批训练第4批训练第5批训练第6批训练第7批训练第8批训练第9批训练第10批训练第11批训练第12批训练第13批训练第14批训练第15批训练第16批训练第17批训练第18批训练第19批训练第20批训练第21批训练第22批训练第23批训练第24批训练第25批训练第26批训练第27批训练第28批训练第29批训练第30批训练第31批训练第32批训练第33批训练第34批训练第35批训练第36批训练第37批训练第38批训练第39批训练第40批训练第41批训练第42批训练第43批训练第44批训练第45批训练第46批训练第47批训练第48批训练第49批训练第50批训练第51批训练第52批训练第53批训练第54批训练第55批训练第56批训练第57批训练第58批训练第59批训练第60批训练第61批训练第62批训练第63批训练第64批训练第65批训练第66批训练第67批训练第68批训练第69批训练第70批训练第71批训练第72批训练第73批训练第74批训练第75批训练第76批训练第77批训练第78批训练第79批训练第80批训练第81批训练第82批训练第83批训练第84批训练第85批训练第86批训练第87批训练第88批训练第89批训练第90批训练第91批训练第92批训练第93批训练第94批训练第95批训练第96批训练第97批训练第98批训练第99批训练第100批训练第101批训练第102批训练第103批训练第104批训练第105批训练第106批训练第107批训练第108批训练第109批训练第110批训练第111批训练第112批训练第113批训练第114批训练第115批训练第116批训练第117批训练第118批训练第119批训练第120批训练第121批训练第122批训练第123批训练第124批训练第125批训练第126批训练第127批训练第128批训练第129批训练第130批训练第131批训练第132批训练第133批训练第134批训练第135批训练第136批训练第137批训练第138批训练第139批训练第140批训练第141批训练第142批训练第143批训练第144批训练第145批训练第146批训练第147批训练第148批训练第149批训练第150批训练第151批训练第152批训练第153批训练第154批训练第155批训练第156批训练第157批训练第158批训练第159批训练第160批训练第161批训练第162批训练第163批训练第164批训练第165批训练第166批训练第167批训练第168批训练第169批训练第170批训练第171批训练第172批训练第173批训练第174批训练第175批训练第176批训练第177批训练第178批训练第179批训练第180批训练第181批训练第182批训练第183批训练第184批训练第185批训练第186批训练第187批训练第188批训练第189批训练第190批训练第191批训练第192批训练第193批训练第194批训练第195批训练第196批训练第197批训练第198批训练第199批训练第200批训练第201批训练第202批训练第203批训练第204批训练第205批训练第206批训练第207批训练第208批训练第209批训练第210批训练第211批训练第212批训练第213批训练第214批训练第215批训练第216批训练第217批训练第218批训练第219批训练第220批训练第221批训练第222批训练第223批训练第224批训练第225批训练第226批训练第227批训练第228批训练第229批训练第230批训练第231批训练第232批训练第233批训练第234批训练第235批训练第236批训练第237批训练第238批训练第239批训练第240批训练第241批训练第242批训练第243批训练第244批训练第245批训练第246批训练第247批训练第248批训练第249批训练第250批训练第251批训练第252批训练第253批训练第254批训练第255批训练第256批训练第257批训练第258批训练第259批训练第260批训练第261批训练第262批训练第263批训练第264批训练第265批训练第266批训练第267批训练第268批训练第269批训练第270批训练第271批训练第272批训练第273批训练第274批训练第275批训练第276批训练第277批训练第278批训练第279批训练第280批训练第281批训练第282批训练第283批训练第284批训练第285批训练第286批训练第287批训练第288批训练第289批训练第290批训练第291批训练第292批训练第293批训练第294批训练第295批训练第296批训练第297批训练第298批训练第299批训练第300批训练第301批训练第302批训练第303批训练第304批训练第305批训练第306批训练第307批训练第308批训练第309批训练第310批训练第311批训练第312批训练第313批训练第314批训练第315批训练第316批训练第317批训练第318批训练第319批训练第320批训练第321批训练第322批训练第323批训练第324批训练第325批训练第326批训练第327批训练第328批训练第329批训练第330批训练第331批训练第332批训练第333批训练第334批训练第335批训练第336批训练第337批训练第338批训练第339批训练第340批训练第341批训练第342批训练第343批训练第344批训练第345批训练第346批训练第347批训练第348批训练第349批训练第350批训练第351批训练第352批训练第353批训练第354批训练第355批训练第356批训练第357批训练第358批训练第359批训练第360批训练第361批训练第362批训练第363批训练第364批训练第365批训练第366批训练第367批训练第368批训练第369批训练第370批训练第371批训练第372批训练第373批训练第374批训练第375批训练第376批训练第377批训练第378批训练第379批训练第380批训练第381批训练第382批训练第383批训练第384批训练第385批训练第386批训练第387批训练第388批训练第389批训练第390批训练第391批训练第392批训练第393批训练第394批训练第395批训练第396批训练第397批训练第398批训练第399批训练第400批训练第401批训练第402批训练第403批训练第404批训练第405批训练第406批训练第407批训练第408批训练第409批训练第410批训练第411批训练第412批训练第413批训练第414批训练第415批训练第416批训练第417批训练第418批训练第419批训练第420批训练第421批训练第422批训练第423批训练第424批训练第425批训练第426批训练第427批训练第428批训练第429批训练第430批训练第431批训练第432批训练第433批训练第434批训练第435批训练第436批训练第437批训练第438批训练第439批训练第440批训练第441批训练第442批训练第443批训练第444批训练第445批训练第446批训练第447批训练第448批训练第449批训练第450批训练第451批训练第452批训练第453批训练第454批训练第455批训练第456批训练第457批训练第458批训练第459批训练第460批训练第461批训练第462批训练第463批训练第464批训练第465批训练第466批训练第467批训练第468批训练第469批训练第470批训练第471批训练第472批训练第473批训练第474批训练第475批训练第476批训练第477批训练第478批训练第479批训练第480批训练第481批训练第482批训练第483批训练第484批训练第485批训练第486批训练第487批训练第488批训练第489批训练第490批训练第491批训练第492批训练第493批训练第494批训练第495批训练第496批训练第497批训练第498批训练第499批训练第500批训练第501批训练第502批训练第503批训练第504批训练第505批训练第506批训练第507批训练第508批训练第509批训练第510批训练第511批训练第512批训练第513批训练第514批训练第515批训练第516批训练第517批训练第518批训练第519批训练第520批训练第521批训练第522批训练第523批训练第524批训练第525批训练第526批训练第527批训练第528批训练第529批训练第530批训练第531批训练第532批训练第533批训练第534批训练第535批训练第536批训练第537批训练第538批训练第539批训练第540批训练第541批训练第542批训练第543批训练第544批训练第545批训练第546批训练第547批训练第548批训练第549批训练第550批训练第551批训练第552批训练第553批训练第554批训练第555批训练第556批训练第557批训练第558批训练第559批训练第560批训练第561批训练第562批训练第563批训练第564批训练第565批训练第566批训练第567批训练第568批训练第569批训练第570批训练第571批训练第572批训练第573批训练第574批训练第575批训练第576批训练第577批训练第578批训练第579批训练第580批训练第581批训练第582批训练第583批训练第584批训练第585批训练第586批训练第587批训练第588批训练第589批训练第590批训练第591批训练第592批训练第593批训练第594批训练第595批训练第596批训练第597批训练第598批训练第599批训练第600批训练Epoch 1: Train Loss: 0.14750646080588922 Train Accuracy: 0.9542000000000053 Test Loss: 0.04495963536784984 Test Accuarcy: 0.9845999999999998

附:系列文章

序号文章目录直达链接
1波士顿房价预测
2鸢尾花数据集分析
3特征处理
4交叉验证
5构造神经网络示例
6使用TensorFlow完成线性回归
7使用TensorFlow完成逻辑回归
8TensorBoard案例
9使用Keras完成线性回归
10使用Keras完成逻辑回归
11使用Keras预训练模型完成猫狗识别
12使用PyTorch训练模型
13使用Dropout抑制过拟合
14使用CNN完成MNIST手写体识别(TensorFlow)
15使用CNN完成MNIST手写体识别(Keras)
16使用CNN完成MNIST手写体识别(PyTorch)
17使用GAN生成手写数字样本
18自然语言处理

本文发布于:2024-01-28 17:58:50,感谢您对本站的认可!

本文链接:https://www.4u4v.net/it/17064359359227.html

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。

标签:手写体   深度   CNN   MNIST   PyTorch
留言与评论(共有 0 条评论)
   
验证码:

Copyright ©2019-2022 Comsenz Inc.Powered by ©

网站地图1 网站地图2 网站地图3 网站地图4 网站地图5 网站地图6 网站地图7 网站地图8 网站地图9 网站地图10 网站地图11 网站地图12 网站地图13 网站地图14 网站地图15 网站地图16 网站地图17 网站地图18 网站地图19 网站地图20 网站地图21 网站地图22/a> 网站地图23