当前位置:网站首页>pytorch(网络模型训练)
pytorch(网络模型训练)
2022-06-26 05:30:00 【月屯】
网络模型训练
小插曲
区别
import torch
a=torch.tensor(5)
print(a)
print(a.item())

import torch
output=torch.tensor([[0.1,0.2],[0.05,0.4]])
print(output.argmax(1))# 为1选取每一行最大值的索引,为0选取每一列最大值的索引
preds=output.argmax(1)
target=torch.tensor([0,1])
print(preds==target)
print((preds==target).sum())

训练模型
import torch
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Linear, Flatten
# 搭建神经网络
class Dun(nn.Module):
def __init__(self):
super().__init__()
# 2.
self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10))
def forward(self,x):
x=self.model1(x)
return x
if __name__=='__main__':
dun=Dun()
input=torch.ones((64,3,32,32))
print(dun(input).shape)
数据训练
import torchvision
# 准备数据集
from torch.utils.tensorboard import SummaryWriter
from model import *
from torch.utils.data import DataLoader
train_data=torchvision.datasets.CIFAR10(root="./data_set_train",train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_data=torchvision.datasets.CIFAR10(root="./data_set_test",train=False,transform=torchvision.transforms.ToTensor(),download=True)
#长度
train_data_size=len(train_data)
test_data_size=len(test_data)
print("train_data_size:{}",format(train_data_size))
print("test_data_size:{}",format(test_data_size))
# 加载数据集
train_dataloader=DataLoader(train_data,batch_size=64)
test_dataloder=DataLoader(test_data,batch_size=64)
#创建网络模型
dun=Dun()
#损失函数
loss_fn=nn.CrossEntropyLoss()
# 优化器
learning_rate=1e-2
optimizerr=torch.optim.SGD(dun.parameters(),lr=learning_rate)
#设置训练网络参数
# 记录训练次数
total_train_step=0
# 记录测试次数
total_test_step=0
#训练次数
epoch=10
# 追加tensorboard
writer=SummaryWriter("./logs")
for i in range(epoch):
print("----------第{}轮训练------".format(i+1))
# 训练开始
dun.train()# 网络模型中,对dropout、BatchNorm层等起作用,进入训练状态
for data in train_dataloader:
img,target=data
output=dun(img)
loss=loss_fn(output,target)
#优化器优化
optimizerr.zero_grad()
loss.backward()
optimizerr.step()
total_train_step+=1
print("训练次数:{},loss:{}".format(total_train_step,loss.item()))
writer.add_scalar("train_loss",loss.item(),total_train_step)
# 测试步骤
total_test_loss=0
# 使用正确率判断模型的好坏
total_accuracy=0
dun.eval()# 网络模型中,对dropout、BatchNorm层等起作用,进入验证状态
with torch.no_grad():
for data in test_dataloder:
img,target=data
output=dun(img)
total_test_loss+=loss_fn(output,target).item()
accuracy=(output.argmax(1)==target).sum()
total_accuracy+=accuracy
print("整体测试集上的Loss:{}".format(total_test_loss))
writer.add_scalar("test_loss",total_test_loss,total_test_step)
print("整体测试集上的正确率:{}".format(total_accuracy/test_data_size))
writer.add_scalar("test_accuracy",total_accuracy/test_data_size,total_test_step)
total_test_step+=1
#保存模型
torch.save(dun,"dun{}.pth".format(i))
print("保存模型")
writer.close()
GPU 训练
第一种方式

将以上的三部分调用cuda方法,以上面训练数据的代码为例
# 模型
if torch.cuda.is_available():# 判断是否可以使用gpu
dun=dun.cuda()
#损失函数
if torch.cuda.is_available():
loss_fn=loss_fn.cuda()
# 数据(包含训练和测试的)
if torch.cuda.is_available():
img = img.cuda()
target = target.cuda()
方式二:
# 定义训练的设备
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")# 参数分为cpu和cuda,当显卡多个时cuda:0
将方式一的代码换成
dun=dun.to(device)
# 其他数据、loss类似
查看GPU信息

完整模型验证
查看数据集CIFAR10的类别
import torchvision
from PIL import Image
import torch
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Linear, Flatten
# 搭建神经网络
class Dun(nn.Module):
def __init__(self):
super().__init__()
# 2.
self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10))
def forward(self,x):
x=self.model1(x)
return x
image_path="./img/1.png"
image=Image.open(image_path)
print(image)
# 类型转换
transform=torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),torchvision.transforms.ToTensor()])
image=transform(image)
print(image.shape)
# 加载网络模型注意加载的模型和现在验证的要么使用cpu要么gpu一致,否则需要map——location映射本地的cpu
model=torch.load("dun0.pth",map_location=torch.device("cpu"))
print(model)
# 类型转换
image=torch.reshape(image,(1,3,32,32))
model.eval()# 模型转换测试类型
# 执行模型
with torch.no_grad():
output=model(image)
print(output)
print(output.argmax(1))
边栏推荐
- Henkel database custom operator '~~‘
- 生命原来如此脆弱
- 自定义WebSerivce作为代理解决SilverLight跨域调用WebService问题
- Ribbon负载均衡服务调用
- 1212312321
- cartographer_backend_constraint
- 程序人生
- PHP 2D / multidimensional arrays are sorted in ascending and descending order according to the specified key values
- Serious hazard warning! Log4j execution vulnerability is exposed!
- cartographer_pose_graph_2d
猜你喜欢

Ribbon负载均衡服务调用

As promised: Mars, the mobile terminal IM network layer cross platform component library used by wechat, has been officially open source

Uni app ceiling fixed style

Install the tp6.0 framework under windows, picture and text. Thinkphp6.0 installation tutorial

Tp5.0 framework PDO connection MySQL error: too many connections solution

Implementation of IM message delivery guarantee mechanism (II): ensure reliable delivery of offline messages

uni-app吸顶固定样式

PHP 2D / multidimensional arrays are sorted in ascending and descending order according to the specified key values

Introduction to GUI programming to game practice (I)

12 multithreading
随机推荐
ZigBee explain in simple terms lesson 2 hardware related and IO operation
Installation and deployment of alluxio
递归遍历目录结构和树状展现
LeetCode_二叉搜索树_简单_108.将有序数组转换为二叉搜索树
Introduction to GUI programming to game practice (I)
[leetcode] 713: subarray with product less than k
基于SDN的DDoS攻击缓解
一段不离不弃的爱情
CMakeLists.txt Template
Setting pseudo static under fastadmin Apache
【上采样方式-OpenCV插值】
[arm] add desktop application for buildreoot of rk3568 development board
ZigBee learning in simple terms Lecture 1
Two step processing of string regular matching to get JSON list
How does P2P technology reduce the bandwidth of live video by 75%?
Fedora alicloud source
Chapter 9 setting up structured logging (I)
[red team] what preparations should be made to join the red team?
Implementation of IM message delivery guarantee mechanism (II): ensure reliable delivery of offline messages
Use jedis to monitor redis stream to realize message queue function