当前位置:网站首页>Pytorch (network model training)
Pytorch (network model training)
2022-06-26 05:40:00 【Yuetun】
Table of contents title
Network model training
episode
difference
import torch
a=torch.tensor(5)
print(a)
print(a.item())
import torch
output=torch.tensor([[0.1,0.2],[0.05,0.4]])
print(output.argmax(1))# by 1 Select the index of the maximum value of each row , by 0 Select the index of the maximum value of each column
preds=output.argmax(1)
target=torch.tensor([0,1])
print(preds==target)
print((preds==target).sum())
Training models
import torch
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Linear, Flatten
# Building neural networks
class Dun(nn.Module):
def __init__(self):
super().__init__()
# 2.
self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10))
def forward(self,x):
x=self.model1(x)
return x
if __name__=='__main__':
dun=Dun()
input=torch.ones((64,3,32,32))
print(dun(input).shape)
Data training
import torchvision
# Prepare the dataset
from torch.utils.tensorboard import SummaryWriter
from model import *
from torch.utils.data import DataLoader
train_data=torchvision.datasets.CIFAR10(root="./data_set_train",train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_data=torchvision.datasets.CIFAR10(root="./data_set_test",train=False,transform=torchvision.transforms.ToTensor(),download=True)
# length
train_data_size=len(train_data)
test_data_size=len(test_data)
print("train_data_size:{}",format(train_data_size))
print("test_data_size:{}",format(test_data_size))
# Load data set
train_dataloader=DataLoader(train_data,batch_size=64)
test_dataloder=DataLoader(test_data,batch_size=64)
# Create a network model
dun=Dun()
# Loss function
loss_fn=nn.CrossEntropyLoss()
# Optimizer
learning_rate=1e-2
optimizerr=torch.optim.SGD(dun.parameters(),lr=learning_rate)
# Set training network parameters
# Record the number of workouts
total_train_step=0
# Record the number of tests
total_test_step=0
# Training times
epoch=10
# Additional tensorboard
writer=SummaryWriter("./logs")
for i in range(epoch):
print("---------- The first {} Round training ------".format(i+1))
# Training begins
dun.train()# In the network model , Yes dropout、BatchNorm Layer, etc , Get into training
for data in train_dataloader:
img,target=data
output=dun(img)
loss=loss_fn(output,target)
# Optimizer optimization
optimizerr.zero_grad()
loss.backward()
optimizerr.step()
total_train_step+=1
print(" Training times :{},loss:{}".format(total_train_step,loss.item()))
writer.add_scalar("train_loss",loss.item(),total_train_step)
# testing procedure
total_test_loss=0
# Use the accuracy rate to judge whether the model is good or bad
total_accuracy=0
dun.eval()# In the network model , Yes dropout、BatchNorm Layer, etc , Enter the verification state
with torch.no_grad():
for data in test_dataloder:
img,target=data
output=dun(img)
total_test_loss+=loss_fn(output,target).item()
accuracy=(output.argmax(1)==target).sum()
total_accuracy+=accuracy
print(" On the overall test set Loss:{}".format(total_test_loss))
writer.add_scalar("test_loss",total_test_loss,total_test_step)
print(" Accuracy on the overall test set :{}".format(total_accuracy/test_data_size))
writer.add_scalar("test_accuracy",total_accuracy/test_data_size,total_test_step)
total_test_step+=1
# Save the model
torch.save(dun,"dun{}.pth".format(i))
print(" Save the model ")
writer.close()
GPU Training
The first way
Call the above three parts cuda Method , Take the code of training data above as an example
# Model
if torch.cuda.is_available():# Decide if you can use gpu
dun=dun.cuda()
# Loss function
if torch.cuda.is_available():
loss_fn=loss_fn.cuda()
# data ( Including training and testing )
if torch.cuda.is_available():
img = img.cuda()
target = target.cuda()
Mode two :
# Define the training equipment
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")# The parameters are divided into cpu and cuda, When there are multiple graphics cards cuda:0
Replace the code of mode 1 with
dun=dun.to(device)
# Other data 、loss similar
see GPU Information
Complete model validation
Look at the dataset CIFAR10 Categories
import torchvision
from PIL import Image
import torch
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Linear, Flatten
# Building neural networks
class Dun(nn.Module):
def __init__(self):
super().__init__()
# 2.
self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10))
def forward(self,x):
x=self.model1(x)
return x
image_path="./img/1.png"
image=Image.open(image_path)
print(image)
# Type conversion
transform=torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),torchvision.transforms.ToTensor()])
image=transform(image)
print(image.shape)
# Load the network model. Note that the loaded model and the verified model can be either used cpu or gpu Agreement , Otherwise, we need map——location Map local cpu
model=torch.load("dun0.pth",map_location=torch.device("cpu"))
print(model)
# Type conversion
image=torch.reshape(image,(1,3,32,32))
model.eval()# Model transformation test type
# Execution model
with torch.no_grad():
output=model(image)
print(output)
print(output.argmax(1))
边栏推荐
- 操作符的优先级、结合性、是否控制求值顺序【详解】
- The use of loops in SQL syntax
- Sql语法中循环的使用
- Navicat如何将当前连接信息复用另一台电脑
- Life is so fragile
- 无线网络存在的安全问题及现代化解决方案
- Uni app ceiling fixed style
- Introduction to lcm32f037 series of MCU chip for motor
- The model defined (modified) in pytoch loads some required pre training model parameters and freezes them
- data = self._ data_ queue. get(timeout=timeout)
猜你喜欢
SDN based DDoS attack mitigation
When was the autowiredannotationbeanpostprocessor instantiated?
Installation and deployment of alluxio
[arm] add desktop application for buildreoot of rk3568 development board
[C language] deep analysis of data storage in memory
【ARM】讯为rk3568开发板buildroot添加桌面应用
使用Jenkins执行TestNg+Selenium+Jsoup自动化测试和生成ExtentReport测试报告
小小面试题之GET和POST的区别
12 multithreading
How Navicat reuses the current connection information to another computer
随机推荐
Source code of findcontrol
Serious hazard warning! Log4j execution vulnerability is exposed!
uni-app吸顶固定样式
劣币驱逐良币的思考
C XX management system
Use jedis to monitor redis stream to realize message queue function
RIA ideas
1212312321
Wechat team sharing: technical decryption behind wechat's 100 million daily real-time audio and video chats
机器学习 05:非线性支持向量机
Internship May 29, 2019
Talk 5 wireless communication
组合模式、透明方式和安全方式
Yunqi lab recommends experience scenarios this week, free cloud learning
skimage. morphology. medial_ axis
redis探索之布隆过滤器
What management systems (Updates) for things like this
旧情书
Could not get unknown property ‘*‘ for SigningConfig container of type org.gradle.api.internal
Leetcode114. 二叉树展开为链表