当前位置:网站首页>使用pytorch实现基于VGG 19预训练模型的鲜花识别分类器,准确度达到97%
使用pytorch实现基于VGG 19预训练模型的鲜花识别分类器,准确度达到97%
2022-07-23 05:37:00 【mandala -chen】
项目说明
本文使用的数据集是网络开源的鲜花数据集,并且基于VGG19的预训练模型通过迁移学习重新训练鲜花数据由此构建一个鲜花识别分类器
数据集
可以在此处找到有关花朵数据集的信息。数据集为102个花类的每一个都包含一个单独的文件夹。每朵花都标记为一个数字,每个编号的目录都包含许多.jpg文件。
实验环境
prtorch库
PIL库
如果想使用GPU训练的话请使用英伟达的显卡并安装好CUDA
如果用GPU的话我在自己电脑上使用GPU只使用了91分钟(我的GPU是1050)
##倒入库并检测是否有可用GPU
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
import time
import json
import copy
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import PIL
from PIL import Image
from collections import OrderedDict
import torch
from torch import nn, optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data.sampler import SubsetRandomSampler
import torch.nn as nn
import torch.nn.functional as F
import os
# check if GPU is available
train_on_gpu = torch.cuda.is_available()
if not train_on_gpu:
print('Bummer! Training on CPU ...')
else:
print('You are good to go! Training on GPU ...')
##有GPU就启用
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
定义数据集位置
data_dir = 'F:\资料\项目\image_classifier_pytorch-master\\flower_data'
train_dir = 'flower_data/train'
valid_dir = 'flower_data/valid'
导入数据集并对数据进行处理
# Define your transforms for the training and testing sets
data_transforms = {
'train': transforms.Compose([
transforms.RandomRotation(30),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
]),
'valid': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
}
# Load the datasets with ImageFolder
image_datasets = {
x: datasets.ImageFolder(os.path.join(data_dir, x),
data_transforms[x])
for x in ['train', 'valid']}
# Using the image datasets and the trainforms, define the dataloaders
batch_size = 64
dataloaders = {
x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size,
shuffle=True, num_workers=4)
for x in ['train', 'valid']}
class_names = image_datasets['train'].classes
dataset_sizes = {
x: len(image_datasets[x]) for x in ['train', 'valid']}
class_names = image_datasets['train'].classes
# Label mapping
with open('F:\资料\项目\image_classifier_pytorch-master\cat_to_name.json', 'r') as f:
cat_to_name = json.load(f)
查看数据情况
# Run this to test the data loader
images, labels = next(iter(dataloaders['train']))
images.size()
# # Run this to test your data loader
images, labels = next(iter(dataloaders['train']))
rand_idx = np.random.randint(len(images))
# print(rand_idx)
print("label: {}, class: {}, name: {}".format(labels[rand_idx].item(),
class_names[labels[rand_idx].item()],
cat_to_name[class_names[labels[rand_idx].item()]]))
定义模型
model_name = 'densenet' #vgg
if model_name == 'densenet':
model = models.densenet161(pretrained=True)
num_in_features = 2208
print(model)
elif model_name == 'vgg':
model = models.vgg19(pretrained=True)
num_in_features = 25088
print(model.classifier)
else:
print("Unknown model, please choose 'densenet' or 'vgg'")
# Create classifier
for param in model.parameters():
param.requires_grad = False
def build_classifier(num_in_features, hidden_layers, num_out_features):
classifier = nn.Sequential()
if hidden_layers == None:
classifier.add_module('fc0', nn.Linear(num_in_features, 102))
else:
layer_sizes = zip(hidden_layers[:-1], hidden_layers[1:])
classifier.add_module('fc0', nn.Linear(num_in_features, hidden_layers[0]))
classifier.add_module('relu0', nn.ReLU())
classifier.add_module('drop0', nn.Dropout(.6))
classifier.add_module('relu1', nn.ReLU())
classifier.add_module('drop1', nn.Dropout(.5))
for i, (h1, h2) in enumerate(layer_sizes):
classifier.add_module('fc'+str(i+1), nn.Linear(h1, h2))
classifier.add_module('relu'+str(i+1), nn.ReLU())
classifier.add_module('drop'+str(i+1), nn.Dropout(.5))
classifier.add_module('output', nn.Linear(hidden_layers[-1], num_out_features))
return classifier
hidden_layers = None#[4096, 1024, 256][512, 256, 128]
classifier = build_classifier(num_in_features, hidden_layers, 102)
print(classifier)
# Only train the classifier parameters, feature parameters are frozen
if model_name == 'densenet':
model.classifier = classifier
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adadelta(model.parameters()) # Adadelta #weight optim.Adam(model.parameters(), lr=0.001, momentum=0.9)
#optimizer_conv = optim.SGD(model.parameters(), lr=0.0001, weight_decay=0.001, momentum=0.9)
sched = optim.lr_scheduler.StepLR(optimizer, step_size=4)
elif model_name == 'vgg':
model.classifier = classifier
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.classifier.parameters(), lr=0.0001)
sched = lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.1)
else:
pass
def train_model(model, criterion, optimizer, sched, num_epochs=5):
since = time.time()
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch+1, num_epochs))
print('-' * 10)
# Each epoch has a training and validation phase
for phase in ['train', 'valid']:
if phase == 'train':
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode
running_loss = 0.0
running_corrects = 0
# Iterate over data.
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward
# track history if only in train
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
# backward + optimize only if in training phase
if phase == 'train':
#sched.step()
loss.backward()
optimizer.step()
# statistics
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects.double() / dataset_sizes[phase]
print('{} Loss: {:.4f} Acc: {:.4f}'.format(
phase, epoch_loss, epoch_acc))
# deep copy the model
if phase == 'valid' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())
print()
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
print('Best val Acc: {:4f}'.format(best_acc))
#load best model weights
model.load_state_dict(best_model_wts)
return model
开始训练
epochs = 30
model.to(device)
model = train_model(model, criterion, optimizer, sched, epochs)
边栏推荐
- 【视觉SLAM】ORB-SLAM: Tracking and Mapping Recognizable Features
- [untitled]
- 《天幕红尘》笔记与思考(四)相互价值无效
- 【Anaconda 环境管理与包管理】
- Redis source code and design analysis -- 9. String object
- Redis源码与设计剖析 -- 9.字符串对象
- 【达人专栏】还不会用Apache Dolphinscheduler吗,大佬用时一个月写出的最全入门教学【二】
- Pytorch (V) -- pytorch advanced training skills
- [swift bug] Xcode prompt error running playground: failed to prepare for communication with playground
- 赫克Hurco工控机维修WinMax数控机床控制器维修
猜你喜欢

Alibaba cloud object storage service OSS front and rear joint debugging

项目部署(简版)
![[Social Media Marketing] new idea of going to sea: WhatsApp business replaces Facebook](/img/9a/39e5dde85ba005f1cfe37826f70bef.png)
[Social Media Marketing] new idea of going to sea: WhatsApp business replaces Facebook

讲师征集令 | Apache DolphinScheduler Meetup分享嘉宾,期待你的议题和声音!

Gerrit 使用操作手册

动态内存管理

H1--HDMI接口测试应用2022-07-15
![[ROS advanced chapter] Lesson 8 syntax explanation of URDF file](/img/ad/038d088d5cd17784d3e2d7291bb750.jpg)
[ROS advanced chapter] Lesson 8 syntax explanation of URDF file

一次 MySQL 误操作导致的事故,「高可用」都不好使了

Redis source code and design analysis -- 13. Ordered collection objects
随机推荐
Notes and Thoughts on the red dust of the sky (IV) invalid mutual value
C语言n番战--链表(九)
adb常用命令
Selenium JD crawler
Mysql database foundation
Deploy metersphere
Chapter 1 Overview - Section 1 - 1.2 overview of the Internet
Xssgame games (XSS learning) level1-15
Pyqt5 use qpainter to draw the coordinate axis and display the scatter diagram
Redis source code and design analysis -- 6. Compressed list
Redis源码与设计剖析 -- 9.字符串对象
《天幕红尘》笔记与思考(三)只要条件具足了,结果自然来
Database process stuck solution
【信息系统项目管理师】第六章 复盘进度管理知识架构
联合主键和索引
Redis源码与设计剖析 -- 6.压缩列表
mysql语法(纯语法)
Heidelberg CP2000 circuit board maintenance printer host controller operation and maintenance precautions
Concepts and differences of bit, bit, byte and word
部署metersphere