当前位置:网站首页>Pytorch---使用Pytorch的预训练模型实现四种天气分类问题
Pytorch---使用Pytorch的预训练模型实现四种天气分类问题
2022-06-23 03:45:00 【水哥很水】
一、代码中的数据集可以通过以下链接获取
二、代码运行环境
Pytorch-gpu==1.7.1
Python==3.7
三、数据集处理代码如下所示
import torchvision
from torchvision import transforms
import os
from torch.utils.data import DataLoader
def loader_data():
BATCH_SIZE = 64
train_transform = transforms.Compose([
transforms.Resize(224),
transforms.RandomCrop(192),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(0.2),
transforms.ColorJitter(brightness=0.5),
transforms.ColorJitter(contrast=0.5),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])
])
test_transform = transforms.Compose([
transforms.Resize((192, 192)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])
])
train_ds = torchvision.datasets.ImageFolder(root=os.path.join('dataset', 'train_weather'),
transform=train_transform)
test_ds = torchvision.datasets.ImageFolder(root=os.path.join('dataset', 'test_weather'), transform=test_transform)
train_dl = DataLoader(dataset=train_ds, batch_size=BATCH_SIZE, shuffle=True)
test_dl = DataLoader(dataset=test_ds, batch_size=BATCH_SIZE)
return train_dl, test_dl, test_ds.class_to_idx
四、模型的构建代码如下所示
import torch
import torchvision
def load_model():
model = torchvision.models.vgg16(pretrained=True)
for p in model.features.parameters():
p.requires_grad = False
model.classifier[-1].out_features = 4
return model
def load_resnet18():
model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
param.requires_grad = False
in_f = model.fc.in_features
model.fc = torch.nn.Linear(in_features=in_f, out_features=4)
return model
五、模型的训练代码如下所示
import torch
from data_loader import loader_data
from model_loader import load_model, load_resnet18
import numpy as np
import tqdm
import os
from sklearn.metrics import accuracy_score
from torch.optim import lr_scheduler
# 数据的加载
train_dl, test_dl, class_to_idx = loader_data()
# 模型的加载
model = load_resnet18()
# 训练的相关配置
optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.0001)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer=optimizer, step_size=5, gamma=0.9)
loss_fn = torch.nn.CrossEntropyLoss()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
EPOCHS = 100
# 进行训练
model = model.to(device)
for epoch in range(EPOCHS):
# 训练部分
model.train()
train_tqdm = tqdm.tqdm(iterable=train_dl, total=len(train_dl))
train_tqdm.set_description_str('Train epoch {:2d}'.format(epoch))
train_accuracy_sum = []
train_loss_sum = []
for images, labels in train_tqdm:
images, labels = images.to(device), labels.to(device)
pred = model(images)
loss = loss_fn(pred, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 进行训练部分的展示
train_loss_sum.append(loss.item())
pred = torch.argmax(input=pred, dim=-1)
train_accuracy_sum.append(accuracy_score(y_true=labels.cpu().numpy(), y_pred=pred.cpu().numpy()))
train_tqdm.set_postfix_str(
'loss is {:14f}, accuracy is {:14f}'.format(np.mean(train_loss_sum), np.mean(train_accuracy_sum)))
train_tqdm.close()
# 学习速率
exp_lr_scheduler.step()
# 验证部分
with torch.no_grad():
model.eval()
test_tqdm = tqdm.tqdm(iterable=test_dl, total=len(test_dl))
test_tqdm.set_description_str('Val epoch {:2d}'.format(epoch))
test_accuracy_sum = []
test_loss_sum = []
for images, labels in test_tqdm:
images, labels = images.to(device), labels.to(device)
pred = model(images)
loss = loss_fn(pred, labels)
# 进行验证结果的展示
test_loss_sum.append(loss.item())
pred = torch.argmax(input=pred, dim=-1)
test_accuracy_sum.append(accuracy_score(y_true=labels.cpu().numpy(), y_pred=pred.cpu().numpy()))
test_tqdm.set_postfix_str(
'loss is {:14f}, accuracy is {:14f}'.format(np.mean(test_loss_sum), np.mean(test_accuracy_sum)))
test_tqdm.close()
# 模型的保存
if not os.path.exists(os.path.join('model_data')):
os.mkdir(os.path.join('model_data'))
torch.save(model.state_dict(), os.path.join('model_data', 'model.pth'))
六、模型的预测代码如下所示
import os
import torch
from data_loader import loader_data
from model_loader import load_model, load_resnet18
import matplotlib.pyplot as plt
import matplotlib
# 数据的加载
train_dl, test_dl, class_index = loader_data()
image, label = next(iter(test_dl))
new_class = dict((v, k) for k, v in class_index.items())
# 模型的加载
model = load_resnet18()
model_state_dict = torch.load(os.path.join('model_data', 'model.pth'))
model.load_state_dict(model_state_dict)
model.eval()
# 进行模型的预测
index = 23
with torch.no_grad():
pred = model(image)
pred = torch.argmax(input=pred, dim=-1)
# matplotlib.rc("font", family='Microsoft YaHei')
plt.axis('off')
plt.title('predict result: ' + new_class.get(pred[index].item()) + ', label result: ' + new_class.get(
label[index].item()),
)
plt.imshow(image[index].permute(1, 2, 0))
plt.savefig('result.png')
plt.show()
七、代码的运行结果如下所示

边栏推荐
- What is the difference between redistemplate and CacheManager operation redis
- 移动端城市列表排序js插件vercitylist.js
- How e-commerce makes use of small programs
- 1-1VMware介绍
- 软件项目管理 8.4.软件项目质量计划
- Adobe international certification 𞓜 how IIT Madras brings efficiency and accessibility to scholars through Adobe e Acrobat
- MySQL optimization, the SQL execution is very stuck, and the SQL structure will not be changed until it ends in 10 seconds
- Google Earth engine (GEE) - long time series monthly VCI data extraction, analysis and area calculation (Mexico as an example)
- centos7 安装 MySQL 及配置 innodb_ruby
- 嵌入式软件测试工具TPT18更新全解析
猜你喜欢

【曾书格激光SLAM笔记】Gmapping基于滤波器的SLAM

Svn local computer storage configuration

【机器学习】 吴恩达机器学习作业 ex2逻辑回归 Matlab实现

两招提升硬盘存储数据的写入效率

软件项目管理 8.4.软件项目质量计划

Flutter怎么实现不同缩放动画效果

Software project management 8.4 Software project quality plan

8 key indicators to measure technology debt in 2022

mysql常用指令

Web page dynamic and static separation based on haproxy
随机推荐
Code refactoring Guide
页面导出excel的三种方式
[tcapulusdb knowledge base] [list table] delete all data sample codes in the list
Source code encryption of data encryption technology
【LeetCode】两数之和II
How can I realize video call and interactive live broadcast in a small program?
华为联机对战服务玩家快速匹配后,不同玩家收到的同一房间内玩家列表不同
MCU model selection for charging point system design
How to save the model obtained from sklearn training? Just read this one
[two points] leetcode1011 Capacity To Ship Packages Within D Days
Preliminary sequencing problem
Twitter cooperates with Shopify to introduce merchant products into twitter shopping
mysql,字段问题
Compilation, installation and global configuration section description of haproxy
元素的常用事件
How to implement collection sorting?
Adobe international certification 𞓜 how IIT Madras brings efficiency and accessibility to scholars through Adobe e Acrobat
虫子 日期类 下 太子语言
innodb_ruby 视角下 MySQL 记录增删改
【LeetCode】179. Maximum number