当前位置:网站首页>Pytorch学习笔记-Advanced_CNN(Using Inception_Module)实现Mnist数据集分类-(注释及结果)
Pytorch学习笔记-Advanced_CNN(Using Inception_Module)实现Mnist数据集分类-(注释及结果)
2022-07-25 15:28:00 【whut_L】
目录
程序代码
import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
batch_size = 64
transform = transforms.Compose([
transforms.ToTensor(), #将shape为(H, W, C)的img转为shape为(C, H, W)的tensor,将每一个数值归一化到[0,1]
transforms.Normalize((0.1307, ), (0.3081, )) #按通道进行数据标准化
])
train_dataset = datasets.MNIST(root = '../Pycharm/dataset/mnist/', train = True, download = True, transform = transform)
train_loader = DataLoader(train_dataset, shuffle = True, batch_size = batch_size)
test_dataset = datasets.MNIST(root = '../Pycharm/dataset/mnist/', train = False, download = True, transform = transform)
test_loader = DataLoader(test_dataset, shuffle = False, batch_size = batch_size)
class InceptionA(torch.nn.Module):
def __init__(self, in_channels):
super(InceptionA, self).__init__()
self.branch1x1 = torch.nn.Conv2d(in_channels, 16, kernel_size = 1) # 1x1卷积
self.branch5x5_1 = torch.nn.Conv2d(in_channels, 16, kernel_size = 1) # 先1x1卷积
self.branch5x5_2 = torch.nn.Conv2d(16, 24, kernel_size = 5, padding = 2) # 再5x5卷积 padding = 2是为了保证图像尺寸不变
self.branch3x3_1 = torch.nn.Conv2d(in_channels, 16, kernel_size = 1) # 先1x1卷积
self.branch3x3_2 = torch.nn.Conv2d(16, 24, kernel_size = 3, padding = 1) # 再3x3卷积
self.branch3x3_3 = torch.nn.Conv2d(24, 24, kernel_size = 3, padding = 1) # # 再3x3卷积 注意输入输出维度大小
self.branch_pool = torch.nn.Conv2d(in_channels, 24, kernel_size = 1) # 池化后再1x1卷积
def forward(self, x):
branch1x1 = self.branch1x1(x) # Module1
branch5x5 = self.branch5x5_1(x) # Module2
branch5x5 = self.branch5x5_2(branch5x5)
branch3x3 = self.branch3x3_1(x) # Module3
branch3x3 = self.branch3x3_2(branch3x3)
branch3x3 = self.branch3x3_3(branch3x3)
branch_pool = F.avg_pool2d(x, kernel_size = 3, stride = 1, padding = 1) # 平均池化
branch_pool = self.branch_pool(branch_pool) # Module4
outputs = [branch1x1, branch5x5, branch3x3, branch_pool]
return torch.cat(outputs, dim = 1)
class Net(torch.nn.Module):
def __init__(self): # 构造函数
super(Net, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 10, kernel_size = 5) # 卷积层1
self.conv2 = torch.nn.Conv2d(88, 20, kernel_size = 5) # 卷积层2
self.incep1 = InceptionA(in_channels = 10)
self.incep2 = InceptionA(in_channels = 20)
self.mp = torch.nn.MaxPool2d(2)
self.fc = torch.nn.Linear(1408, 10)
def forward(self, x):
in_size = x.size(0)
x = F.relu(self.mp(self.conv1(x))) # 卷积、池化、激活函数
x = self.incep1(x)
x = F.relu(self.mp(self.conv2(x))) # 卷积、池化、激活函数
x = self.incep2(x)
x = x.view(in_size, -1) # reshape
x = self.fc(x) # 全连接层
return x
model = Net()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # 调用GPU或CPU
model.to(device)
criterion = torch.nn.CrossEntropyLoss() # 计算交叉熵损失
optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum = 0.5) #构建优化器,lr为学习率,momentum为冲量因子
def train(epoch):
running_loss = 0.0
for batch_idx, data in enumerate(train_loader, 0): # 遍历函数,0表示从第0个元素开始,返回数据下标和数值
inputs, target = data #特征,标签
inputs, target = inputs.to(device), target.to(device)
optimizer.zero_grad() #梯度归零
# forward + backward + updata
outputs = model(inputs)
loss = criterion(outputs, target) #计算损失
loss.backward() #反向传播梯度值
optimizer.step() #更新参数
running_loss += loss.item() #得到元素张量的一个元素值,将张量转换成浮点数
if batch_idx % 300 == 299:
print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 300))
running_loss = 0.0
def test():
correct = 0
total = 0
with torch.no_grad(): #数据不计算梯度
for data in test_loader:
images, labels = data
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, dim = 1) #predicted为tensor每行最大值的索引
total += labels.size(0) # 总样本
correct += (predicted == labels).sum().item() #预测准确的样本数
print('Accuracy on test set: %d %%' % (100 * correct / total)) #准确率
def main():
for epoch in range(10):
train(epoch)
test()
main()
Inception框图

执行结果

边栏推荐
猜你喜欢

Yan required executor memory is above the max threshold (8192mb) of this cluster!

ML - Speech - advanced speech model

ML - natural language processing - Key Technologies

Image cropper example

MATLAB 如何生产随机复序列

MySQL transactions and mvcc

CF888G-巧妙字典树+暴力分治(异或最小生成树)

解决DBeaver SQL Client 连接phoenix查询超时

解决vender-base.66c6fc1c0b393478adf7.js:6 TypeError: Cannot read property ‘validate‘ of undefined问题

伤透脑筋的CPU 上下文切换
随机推荐
matlab randint,Matlab的randint函数用法「建议收藏」
PageHelper does not take effect, and SQL does not automatically add limit
解决DBeaver SQL Client 连接phoenix查询超时
ML - 自然语言处理 - 自然语言处理简介
Find out what happened in the process of new
Remember that spark foreachpartition once led to oom
带你详细认识JS基础语法(建议收藏)
How to solve the login problem after the 30 day experience period of visual stuido2019
ZOJ - 4114 Flipping Game-dp,合理状态表示
C # carefully sorting out key points of knowledge 11 entrustment and events (recommended Collection)
Object.prototype. Hasownproperty() and in
C#精挑整理知识要点11 委托和事件(建议收藏)
2021上海市赛-D-卡特兰数变种,dp
MySQL installation and configuration super detailed tutorial and simple database and table building method
How to understand the maximum allowable number of errors per client connection of MySQL parameters in Seata?
Spark DF adds a column
谷歌云盘如何关联Google Colab
分布式原理 - 什么是分布式系统
JVM garbage collector details
Record a redis timeout