当前位置:网站首页>Pytorch学习笔记-Advanced_CNN(Using Inception_Module)实现Mnist数据集分类-(注释及结果)
Pytorch学习笔记-Advanced_CNN(Using Inception_Module)实现Mnist数据集分类-(注释及结果)
2022-07-25 15:28:00 【whut_L】
目录
程序代码
import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
batch_size = 64
transform = transforms.Compose([
transforms.ToTensor(), #将shape为(H, W, C)的img转为shape为(C, H, W)的tensor,将每一个数值归一化到[0,1]
transforms.Normalize((0.1307, ), (0.3081, )) #按通道进行数据标准化
])
train_dataset = datasets.MNIST(root = '../Pycharm/dataset/mnist/', train = True, download = True, transform = transform)
train_loader = DataLoader(train_dataset, shuffle = True, batch_size = batch_size)
test_dataset = datasets.MNIST(root = '../Pycharm/dataset/mnist/', train = False, download = True, transform = transform)
test_loader = DataLoader(test_dataset, shuffle = False, batch_size = batch_size)
class InceptionA(torch.nn.Module):
def __init__(self, in_channels):
super(InceptionA, self).__init__()
self.branch1x1 = torch.nn.Conv2d(in_channels, 16, kernel_size = 1) # 1x1卷积
self.branch5x5_1 = torch.nn.Conv2d(in_channels, 16, kernel_size = 1) # 先1x1卷积
self.branch5x5_2 = torch.nn.Conv2d(16, 24, kernel_size = 5, padding = 2) # 再5x5卷积 padding = 2是为了保证图像尺寸不变
self.branch3x3_1 = torch.nn.Conv2d(in_channels, 16, kernel_size = 1) # 先1x1卷积
self.branch3x3_2 = torch.nn.Conv2d(16, 24, kernel_size = 3, padding = 1) # 再3x3卷积
self.branch3x3_3 = torch.nn.Conv2d(24, 24, kernel_size = 3, padding = 1) # # 再3x3卷积 注意输入输出维度大小
self.branch_pool = torch.nn.Conv2d(in_channels, 24, kernel_size = 1) # 池化后再1x1卷积
def forward(self, x):
branch1x1 = self.branch1x1(x) # Module1
branch5x5 = self.branch5x5_1(x) # Module2
branch5x5 = self.branch5x5_2(branch5x5)
branch3x3 = self.branch3x3_1(x) # Module3
branch3x3 = self.branch3x3_2(branch3x3)
branch3x3 = self.branch3x3_3(branch3x3)
branch_pool = F.avg_pool2d(x, kernel_size = 3, stride = 1, padding = 1) # 平均池化
branch_pool = self.branch_pool(branch_pool) # Module4
outputs = [branch1x1, branch5x5, branch3x3, branch_pool]
return torch.cat(outputs, dim = 1)
class Net(torch.nn.Module):
def __init__(self): # 构造函数
super(Net, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 10, kernel_size = 5) # 卷积层1
self.conv2 = torch.nn.Conv2d(88, 20, kernel_size = 5) # 卷积层2
self.incep1 = InceptionA(in_channels = 10)
self.incep2 = InceptionA(in_channels = 20)
self.mp = torch.nn.MaxPool2d(2)
self.fc = torch.nn.Linear(1408, 10)
def forward(self, x):
in_size = x.size(0)
x = F.relu(self.mp(self.conv1(x))) # 卷积、池化、激活函数
x = self.incep1(x)
x = F.relu(self.mp(self.conv2(x))) # 卷积、池化、激活函数
x = self.incep2(x)
x = x.view(in_size, -1) # reshape
x = self.fc(x) # 全连接层
return x
model = Net()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # 调用GPU或CPU
model.to(device)
criterion = torch.nn.CrossEntropyLoss() # 计算交叉熵损失
optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum = 0.5) #构建优化器,lr为学习率,momentum为冲量因子
def train(epoch):
running_loss = 0.0
for batch_idx, data in enumerate(train_loader, 0): # 遍历函数,0表示从第0个元素开始,返回数据下标和数值
inputs, target = data #特征,标签
inputs, target = inputs.to(device), target.to(device)
optimizer.zero_grad() #梯度归零
# forward + backward + updata
outputs = model(inputs)
loss = criterion(outputs, target) #计算损失
loss.backward() #反向传播梯度值
optimizer.step() #更新参数
running_loss += loss.item() #得到元素张量的一个元素值,将张量转换成浮点数
if batch_idx % 300 == 299:
print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 300))
running_loss = 0.0
def test():
correct = 0
total = 0
with torch.no_grad(): #数据不计算梯度
for data in test_loader:
images, labels = data
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, dim = 1) #predicted为tensor每行最大值的索引
total += labels.size(0) # 总样本
correct += (predicted == labels).sum().item() #预测准确的样本数
print('Accuracy on test set: %d %%' % (100 * correct / total)) #准确率
def main():
for epoch in range(10):
train(epoch)
test()
main()
Inception框图

执行结果

边栏推荐
- How to finally generate a file from saveastextfile in spark
- ML - 自然语言处理 - 关键技术
- Take you to create your first C program (recommended Collection)
- Tasks, micro tasks, queues and scheduling (animation shows each step of the call)
- 盒子躲避鼠标
- 本地缓存--Ehcache
- Week303 of leetcode
- 带你创建你的第一个C#程序(建议收藏)
- Image cropper example
- The number of query results of maxcompute SQL is limited to 1W
猜你喜欢

获取键盘按下的键位对应ask码

ML - 自然语言处理 - 自然语言处理简介

Box avoiding mouse

window系统黑窗口redis报错20Creating Server TCP listening socket *:6379: listen: Unknown error19-07-28

Solve the timeout of dbeaver SQL client connection Phoenix query

matlab 如何保存所有运行后的数据

ML - natural language processing - Key Technologies

Delayed loading source code analysis:

How to solve the problem of scanf compilation error in Visual Studio

ML - 自然语言处理 - 基础知识
随机推荐
CF888G-巧妙字典树+暴力分治(异或最小生成树)
CGO is realy Cool!
ML - natural language processing - Introduction to natural language processing
ML - natural language processing - Key Technologies
C language function review (pass value and address [binary search], recursion [factorial, Hanoi Tower, etc.))
ML - 自然语言处理 - 基础知识
Week303 of leetcode
ML - natural language processing - Basics
Simulate setinterval timer with setTimeout
图论及概念
《图书馆管理系统——“借书还书”模块》项目研发阶段性总结
Implementation of asynchronous FIFO
JVM parameter configuration details
记一次Yarn Required executor memeory is above the max threshold(8192MB) of this cluster!
Remember that spark foreachpartition once led to oom
Get the ask code corresponding to the key pressed by the keyboard
Spark提交参数--files的使用
Singleton mode 3-- singleton mode
Spark AQE
ML - 语音 - 高级语音模型