目录
数据集下载(pytorch预设数据集)
3个超参数,批数量,训练世代,学习率
根据超参数的多寡,可以使用一个config类来管理
定义模型,损失函数,优化器
训练前的准备,模型.train(),初始化正确率参数,初始化数据
前向传播,误差后向传导,优化器迭代,Loss显示
计算准确率
缩小学习率
# 图像预处理,因为VGG是使用224 * 224大小的图片,但是 CIFAR10 只有32 * 32, 为了能快点跑出结果,
# 我们将它们放大到96*96。
transform = transforms.Compose([transforms.Resize(96), # 缩放到 96 * 96 大小transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 归一化
])# 是否下载数据集
DOWNLOAD = True# 下载 CIFAR10 dataset
train_dataset = torchvision.datasets.CIFAR10(root='./data/', train=True, transform=transform, download=DOWNLOAD)
test_dataset = torchvision.datasets.CIFAR10(root='./data/', train=False, transform=transform)# dataloadertrain_loader = DataLoader(dataset=train_dataset,batch_size=BATCH_SIZE,shuffle=True)
test_loader = DataLoader(dataset=test_dataset,batch_size=BATCH_SIZE,shuffle=False)
# 超参数
DOWNLOAD = True #数据集下载,这个根据自己的需求
BATCH_SIZE = 256
EPOCH = 5
learning_rate = 0.001# 是否使用GPU
use_gpu = True
# 超参数类,用于控制各种超参数
class Config(object):def __init__(self):self.lr = 0.005self.batch_size = 256self.use_gpu = torch.cuda.is_available()self.DOWNLOAD = Trueself.epoch_num = 5 # 因为只是demo,就跑了2个epoch,可以自己多加几次试试结果self.class_num = 10 # CIFAR10 共有10类
config = Config()
# 定义模型
alex = AlextNet(3, 10)
if use_gpu:alex = alex.cuda()# loss and optimizerloss_fn = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(alex.parameters(), lr=learning_rate)
# Training
ain()for epoch in range(EPOCH):total = 0correct = 0for i, (images, labels) in enumerate(train_loader):images = Variable(images)labels = Variable(labels)if use_gpu:images = images.cuda()labels = labels.cuda()
# forward + backward + _grad()y_pred = alex(images)loss = loss_fn(y_pred, labels)loss.backward()optimizer.step()if (i + 1) % 100 == 0:print("Epoch [%d/%d], Iter [%d/%d] Loss: %.4f" % (epoch + 1, EPOCH, i + 1, 100, loss.data.iterm))
# 计算训练精确度_, predicted = torch.max(y_pred.data, 1)total += labels.size(0)correct += (predicted == labels.data).sum()print('Accuracy of the model on the train images: %d %%' % (100 * correct / total))
# Decaying Learning Rateif (epoch+1) % 2 == 0:learning_rate /= 3optimizer = torch.optim.Adam(alex.parameters(), lr=learning_rate)
本文发布于:2024-01-28 08:00:39,感谢您对本站的认可!
本文链接:https://www.4u4v.net/it/17064000425962.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
留言与评论(共有 0 条评论) |