当前位置:网站首页>Build the first neural network with pytoch and optimize it
Build the first neural network with pytoch and optimize it
2022-06-28 08:36:00 【Sol-itude】
I've been learning pytorch, This time I built a neural network following the tutorial , The most classic CIFAR10, Let's look at the principle first 
Input 3 passageway 32*32, Last pass 3 A convolution ,3 Maximum pooling , also 1 individual flatten, And two linearizations , Get ten outputs
The procedure is as follows :
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear
class NetWork(nn.Module):
def __init__(self):
super(NetWork, self).__init__()
self.conv1=Conv2d(3,32,5,padding=2)
self.maxpool1=MaxPool2d(2)
self.conv2=Conv2d(32,32,5,padding=2)
self.maxpool2=MaxPool2d(2)
self.conv3=Conv2d(32,64,5,padding=2)
self.maxpool3=MaxPool2d(2)
self.flatten=Flatten()
self.linear1=Linear(1024,64)#1024=64*4*4
self.linear2=Linear(64,10)
def forward(self,x):
x=self.conv1(x)
x=self.maxpool1(x)
x=self.conv2(x)
x=self.maxpool2(x)
x=self.conv3(x)
x=self.maxpool3(x)
x=self.flatten(x)
x=self.linear1(x)
x=self.linear2(x)
return x
network=NetWork()
print(network)
Here we can also use tensorboard Have a look , Remember import
input=torch.ones((64,3,32,32))
output=network(input)
writer=SummaryWriter("logs_seq")
writer.add_graph(network,input)
writer.close()
stay tensorboard It's like this in English 
open NetWork
You can zoom in to see 
Neural networks have errors , So we use gradient descent to reduce the error
The code is as follows
import torchvision.datasets
from torch import nn
from torch.nn import Sequential,Conv2d,MaxPool2d,Flatten,Linear
from torch.utils.data import DataLoader
import torch
dataset=torchvision.datasets.CIFAR10("./dataset2",train=False,transform=torchvision.transforms.ToTensor(),
download=True)
dataloader=DataLoader(dataset,batch_size=1)
class NetWork(nn.Module):
def __init__(self):
super(NetWork, self).__init__()
self.conv1=Conv2d(3,32,5,padding=2)
self.maxpool1=MaxPool2d(2)
self.conv2=Conv2d(32,32,5,padding=2)
self.maxpool2=MaxPool2d(2)
self.conv3=Conv2d(32,64,5,padding=2)
self.maxpool3=MaxPool2d(2)
self.flatten=Flatten()
self.linear1=Linear(1024,64)#1024=64*4*4
self.linear2=Linear(64,10)
self.model1=Sequential(
Conv2d(3,32,5,padding=2),
MaxPool2d(2),
Conv2d(32,32,5,padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10)
)
def forward(self,x):
# x=self.conv1(x)
# x=self.maxpool1(x)
# x=self.conv2(x)
# x=self.maxpool2(x)
# x=self.conv3(x)
# x=self.maxpool3(x)
# x=self.flatten(x)
# x=self.linear1(x)
# x=self.linear2(x)
x=self.model1(x)
return x
loss=nn.CrossEntropyLoss()
network=NetWork()
optim=torch.optim.SGD(network.parameters(),lr=0.01)## Using gradient descent as the optimizer
for epoch in range(20):## loop 20 Time
running_loss=0.0
for data in dataloader:
imgs, targets=data
outputs=network(imgs)
result_loss=loss(outputs, targets)
optim.zero_grad()## Set the value of each drop to zero
result_loss.backward()
optim.step()
running_loss=running_loss+result_loss
print(running_loss)
My computer's GPU yes RTX2060 It belongs to the older one , It took about three times 1 minute , It was so slow that I finished running
Output results :
tensor(18733.7539, grad_fn=<AddBackward0>)
tensor(16142.7451, grad_fn=<AddBackward0>)
tensor(15420.9199, grad_fn=<AddBackward0>)
It can be seen that the error is getting smaller and smaller , But in the application 20 There are too few layers , When my new computer arrived, I ran 100 layer
边栏推荐
- Comment supprimer le crosstalk SiC MOSFET?
- What is the bandwidth of the Tiktok server that can be used by hundreds of millions of people at the same time?
- [introduction to SQL for 10 days] day4 Combined Query & specified selection
- [learning notes] differential constraint
- Oracle view all tablespaces in the current library
- 【Go ~ 0到1 】 第二天 6月25 Switch语句,数组的声明与遍历
- Why are function templates not partial specialization?
- Trailing Zeroes (II)
- Understanding of CUDA, cudnn and tensorrt
- AI chief architect 8-aica-gao Xiang, in-depth understanding and practice of propeller 2.0
猜你喜欢

Privacy computing fat----- offline prediction

AI chief architect 8-aica-gao Xiang, in-depth understanding and practice of propeller 2.0

The 6th smart home Asia 2022 will be held in Shanghai in October

Wasmedge 0.10.0 release! New plug-in extension mechanism, socket API enhancement, llvm 14 support

WasmEdge 0.10.0 发布!全新的插件扩展机制、Socket API 增强、LLVM 14 支持

Chenglian premium products donated love materials for flood fighting and disaster relief to Yingde

Unity gets the coordinate point in front of the current object at a certain angle and distance

About using font icons in placeholder
![[untitled]](/img/bb/213f213c695795daecb81a4cf2adcd.jpg)
[untitled]

AWS builds a virtual infrastructure including servers and networks (2)
随机推荐
Anniversary party
Kali installation configuration
神殿
PLSQL installation under Windows
Leetcode swing series
Force buckle 1884 Egg drop - two eggs
Webrtc advantages and module splitting
A - deep sea exploration
How to suppress SiC MOSFET crosstalk?
AWS builds a virtual infrastructure including servers and networks (2)
Priority of JS operator
[learning notes] matroid
centos mysql5.5配置文件在哪
webrtc优势与模块拆分
罗氏线圈工作原理
[go ~ 0 to 1] the next day, June 25, switch statement, array declaration and traversal
Super Jumping! Jumping! Jumping!
In flood fighting and disaster relief, the city donated 100000 yuan of love materials to help Yingde
About using font icons in placeholder
TCP那点事