当前位置:网站首页>《PyTorch深度学习实践》第十课(卷积神经网络CNN)
《PyTorch深度学习实践》第十课(卷积神经网络CNN)
2022-08-05 05:40:00 【falldeep】
b站刘二视频,地址:
《PyTorch深度学习实践》完结合集_哔哩哔哩_bilibili
CNN模型
卷积运算

卷积核运算
黄色方块为fliter(卷积核n * 3* 3),要想输出通道数为m,需要m个卷积核


import torch
in_channel, out_channel = 5, 10 #输入通道数,输出通道数(图层数)
width, height = 100, 100 #输入一张图层的大小
kernel_size = 3 #卷积核的大小(3 * 3)
batch_size = 1
input = torch.randn(batch_size, in_channel, width, height)
# B N W H
conv_layer = torch.nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size)
# N M (3 * 3)
output = conv_layer(input)
print(input.shape)
print(output.shape)
print(conv_layer.weight.shape)
#输出结果
# torch.Size([1, 5, 100, 100])
# batch大小 通道数 一个图层的大小
# torch.Size([1, 10, 98, 98])
# torch.Size([10, 5, 3, 3])
#10个卷积核 每个卷积核有5个通道 卷积核大小为3 * 3
padding
保持输出图像大小不变,进行零填充
stride
跳一格扫描

maxpooling最大池化层

网络整体


作业,手写MNIST识别
import torch
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt
BATCH_SIZE = 64
transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_set = datasets.MNIST(download=False, root='mnist', train=True, transform=transforms)
train_loader = DataLoader(dataset=train_set, batch_size=BATCH_SIZE, shuffle=True)
test_set = datasets.MNIST(download=False, root='mnist', train=False, transform=transforms)
test_loader = DataLoader(dataset=test_set, batch_size=BATCH_SIZE, shuffle=False)
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5)
self.pooling = torch.nn.MaxPool2d(2)
self.fc = torch.nn.Linear(320, 10)
def forward(self, x):
batch_size = x.size(0)
x = F.relu(self.pooling(self.conv1(x)))
x = F.relu(self.pooling(self.conv2(x)))
x = x.view(batch_size, -1)
x = self.fc(x)
return x
modle = Net()
criteration = torch.nn.CrossEntropyLoss()
optimizor = torch.optim.SGD(modle.parameters(), lr=0.01, momentum=0.5)
def train():
sum = 0
for i, data in enumerate(train_loader, 0):
inputs, lables = data
y_pred = modle(inputs)
loss = criteration(y_pred, lables)
sum += loss
optimizor.zero_grad()
loss.backward()
optimizor.step()
if(i % 300 == 299):
sum /= 300
loss_lst.append(sum)
sum = 0
def test():
correct = 0
totall = 0
with torch.no_grad():
for i, data in enumerate(test_loader, 0):
inputs, lables = data
y_pred = modle(inputs)
_, predicted = torch.max(y_pred, dim=-1)
correct += (lables == predicted).sum().item()
totall += lables.size(0)
acc_lst.append(correct / totall * 100)
if __name__ == '__main__':
loss_lst = []
acc_lst = []
for epoch in range(10):
train()
test()
num_lst = [i for i in range(len(loss_lst))]
plt.plot(num_lst, loss_lst)
plt.xlabel("i")
plt.ylabel("loss")
plt.show()
num_lst = [i for i in range(len(acc_lst))]
plt.plot(num_lst, acc_lst)
plt.xlabel("epoch")
plt.ylabel("acc")
plt.show()
边栏推荐
猜你喜欢

摆脱极域软件的限制

系统基础-学习笔记(一些命令记录)

Jenkins详细配置

DevOps-了解学习

Collision, character controller, Cloth components (cloth), joints in the Unity physics engine

NAT experiment

el-progress implements different colors of the progress bar
![In-depth analysis if according to data authority @datascope (annotation + AOP + dynamic sql splicing) [step by step, with analysis process]](/img/b5/03f55bb9058c08a48eae368233376c.png)
In-depth analysis if according to data authority @datascope (annotation + AOP + dynamic sql splicing) [step by step, with analysis process]

八大排序之堆排序

Some basic method records of commonly used languages in LeetCode
随机推荐
js判断文字是否超过区域
Error correction notes for the book Image Processing, Analysis and Machine Vision
In-depth analysis if according to data authority @datascope (annotation + AOP + dynamic sql splicing) [step by step, with analysis process]
LaTeX笔记
scikit-image图像处理笔记
Network Troubleshooting Basics - Study Notes
Media query, rem mobile terminal adaptation
The cocos interview answers you are looking for are all here!
config.js related configuration summary
设置文本向两边居中展示
花花省V5淘宝客APP源码无加密社交电商自营商城系统带抖音接口
What is Alibaba Cloud Express Beauty Station?
VLAN is introduced with the experiment
微信小程序仿input组件、虚拟键盘
D41_buffer pool
LaTeX使用frame制作PPT图片没有标号
字体样式及其分类
Nacos配置服务的源码解析(全)
网络协议基础-学习笔记
DevOps流程demo(实操记录)
