当前位置:网站首页>深度学习 pytorch cifar10数据集训练「建议收藏」
深度学习 pytorch cifar10数据集训练「建议收藏」
2022-06-25 15:35:00 【全栈程序员站长】
大家好,又见面了,我是你们的朋友全栈君。
1.加载数据集,并对数据集进行增强,类型转换 官网cifar10数据集 附链接:https://www.cs.toronto.edu/~kriz/cifar.html
读取数据过程中,可以改变batch_size和num_workers来加快训练速度
transform=transforms.Compose([
#图像增强
transforms.Resize(120),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(96),
transforms.ColorJitter(brightness=0.5,contrast=0.5,hue=0.5),
#转变为tensor 正则化
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)) #正则化
])
trainset=tv.datasets.CIFAR10(
root=r'E:\桌面\资料\cv3\数据集\cifar-10-batches-py',
train=True,
download=True,
transform=transform
)
trainloader=data.DataLoader(
trainset,
batch_size=8,
shuffle=True, #乱序
num_workers=4,
)
testset=tv.datasets.CIFAR10(
root=r'E:\桌面\资料\cv3\数据集\cifar-10-batches-py',
train=False,
download=True,
transform=transform
)
testloader=data.DataLoader(
testset,
batch_size=2,
shuffle=False,
num_workers=2
)
net网络:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1=nn.Conv2d(in_channels=3,out_channels=6,kernel_size=5)
self.conv2=nn.Conv2d(in_channels=6,out_channels=16,kernel_size=5)
self.max=nn.MaxPool2d(2,2)
self.q1=nn.Linear(16*441,120)
self.q2=nn.Linear(120,84)
self.q3=nn.Linear(84,10)
self.relu=nn.ReLU()
def forward(self,x):
x1=self.max(F.relu(self.conv1(x)))
x2=F.max_pool2d(self.relu(self.conv2(x1)),2)
x3=x2.view(x2.size()[0],-1)
x4=F.relu(self.q1(x3))
x5=F.relu(self.q2(x4))
x6=self.q3(x5)
return x6
训练模型
net=Net()
#损失函数
loss=nn.CrossEntropyLoss()
opt=optim.SGD(net.parameters(),lr=0.001)
for epoch in range(5):
running_loss=0.0
for i,data in enumerate(trainloader,0):
inputs,labels=data
inputs=inputs.cuda()
labels=labels.cuda()
inputs,labels=Variable(inputs),Variable(labels)
opt.zero_grad()
net.to(torch.device('cuda:0'))
h=net(inputs)
cost=loss(h,labels)
cost.backward()
opt.step()
running_loss+=cost.item()
if i%2000==1999:
print('[%d,%5d] loss:%.3f' %(epoch+1,i+1,running_loss/2000))
running_loss=0.0
torch.save(net.state_dict(),r'net.pth')
correct=0
total=0
for data in testloader:
images,labels=data
optputs=net(Variable(images.cuda()))
_,predicted=torch.max(optputs.cpu(),1)
total+=labels.size(0)
correct+=(predicted==labels).sum()
print("准确率: %d %%" %(100*correct/total))
接下来可以直接进行训练
在运行过程中会出现虚拟内存不够的情况,可以调整虚拟内存大小,解决这一问题。
发布者:全栈程序员栈长,转载请注明出处:https://javaforall.cn/152102.html原文链接:https://javaforall.cn
边栏推荐
- Joseph Ring - formula method (recursive formula)
- Pytest测试框架笔记
- Error com mysql. cj. jdbc. exceptions. Communicationsexception: solutions to communications link failure
- Is Guoxin golden sun reliable? Is it legal? Is it safe to open a stock account?
- JS中的==和===的区别(详解)
- Sword finger offer 04 Find in 2D array
- Sword finger offer 10- I. Fibonacci sequence
- Golang uses Mongo driver operation - increase (Advanced)
- Asynchronous processing of error prone points
- Report on Hezhou air32f103cbt6 development board
猜你喜欢
Rapport de la main - d'oeuvre du Conseil de développement de l'aecg air32f103cbt6
Postman usage notes, interface framework notes
Sword finger offer 05 Replace spaces
Arthas, a sharp tool for online diagnosis - several important commands
Asynchronous processing of error prone points
TFIDF and BM25
What is session? How is it different from cookies?
JSON module dictionary and string conversion
JMeter reading and writing excel requires jxl jar
异步处理容易出错的点
随机推荐
在打新债开户证券安全吗,需要什么准备
Could not connect to redis at 127.0.0.1:6379 in Windows
If a thread overflows heap memory or stack memory, will other threads continue to work
Pytest测试框架笔记
Sword finger offer 04 Find in 2D array
合宙Air32F103CBT6开发板上手报告
Source code analysis of nine routing strategies for distributed task scheduling platform XXL job
分享自己平时使用的socket多客户端通信的代码技术点和软件使用
不要再「外包」AI 模型了!最新研究发现:有些破坏机器学习模型安全的「后门」无法被检测到
Record the time to read the file (the system cannot find the specified path)
Sword finger offer 06 Print linked list from end to end
Kali modify IP address
Detailed description of crontab command format and summary of common writing methods
解决Visio和office365安装兼容问题
Cloning and importing DOM nodes
镁光256Gb NAND Flash芯片介绍
Download and installation tutorial of consumer
Globally unique key generation strategy - implementation principle of the sender
Go language modifies / removes multiple line breaks in strings
Using reentrantlock and synchronized to implement blocking queue