当前位置:网站首页>Pytorch---使用Pytorch的预训练模型实现四种天气分类问题
Pytorch---使用Pytorch的预训练模型实现四种天气分类问题
2022-06-23 03:45:00 【水哥很水】
一、代码中的数据集可以通过以下链接获取
二、代码运行环境
Pytorch-gpu==1.7.1
Python==3.7
三、数据集处理代码如下所示
import torchvision
from torchvision import transforms
import os
from torch.utils.data import DataLoader
def loader_data():
BATCH_SIZE = 64
train_transform = transforms.Compose([
transforms.Resize(224),
transforms.RandomCrop(192),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(0.2),
transforms.ColorJitter(brightness=0.5),
transforms.ColorJitter(contrast=0.5),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])
])
test_transform = transforms.Compose([
transforms.Resize((192, 192)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])
])
train_ds = torchvision.datasets.ImageFolder(root=os.path.join('dataset', 'train_weather'),
transform=train_transform)
test_ds = torchvision.datasets.ImageFolder(root=os.path.join('dataset', 'test_weather'), transform=test_transform)
train_dl = DataLoader(dataset=train_ds, batch_size=BATCH_SIZE, shuffle=True)
test_dl = DataLoader(dataset=test_ds, batch_size=BATCH_SIZE)
return train_dl, test_dl, test_ds.class_to_idx
四、模型的构建代码如下所示
import torch
import torchvision
def load_model():
model = torchvision.models.vgg16(pretrained=True)
for p in model.features.parameters():
p.requires_grad = False
model.classifier[-1].out_features = 4
return model
def load_resnet18():
model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
param.requires_grad = False
in_f = model.fc.in_features
model.fc = torch.nn.Linear(in_features=in_f, out_features=4)
return model
五、模型的训练代码如下所示
import torch
from data_loader import loader_data
from model_loader import load_model, load_resnet18
import numpy as np
import tqdm
import os
from sklearn.metrics import accuracy_score
from torch.optim import lr_scheduler
# 数据的加载
train_dl, test_dl, class_to_idx = loader_data()
# 模型的加载
model = load_resnet18()
# 训练的相关配置
optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.0001)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer=optimizer, step_size=5, gamma=0.9)
loss_fn = torch.nn.CrossEntropyLoss()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
EPOCHS = 100
# 进行训练
model = model.to(device)
for epoch in range(EPOCHS):
# 训练部分
model.train()
train_tqdm = tqdm.tqdm(iterable=train_dl, total=len(train_dl))
train_tqdm.set_description_str('Train epoch {:2d}'.format(epoch))
train_accuracy_sum = []
train_loss_sum = []
for images, labels in train_tqdm:
images, labels = images.to(device), labels.to(device)
pred = model(images)
loss = loss_fn(pred, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 进行训练部分的展示
train_loss_sum.append(loss.item())
pred = torch.argmax(input=pred, dim=-1)
train_accuracy_sum.append(accuracy_score(y_true=labels.cpu().numpy(), y_pred=pred.cpu().numpy()))
train_tqdm.set_postfix_str(
'loss is {:14f}, accuracy is {:14f}'.format(np.mean(train_loss_sum), np.mean(train_accuracy_sum)))
train_tqdm.close()
# 学习速率
exp_lr_scheduler.step()
# 验证部分
with torch.no_grad():
model.eval()
test_tqdm = tqdm.tqdm(iterable=test_dl, total=len(test_dl))
test_tqdm.set_description_str('Val epoch {:2d}'.format(epoch))
test_accuracy_sum = []
test_loss_sum = []
for images, labels in test_tqdm:
images, labels = images.to(device), labels.to(device)
pred = model(images)
loss = loss_fn(pred, labels)
# 进行验证结果的展示
test_loss_sum.append(loss.item())
pred = torch.argmax(input=pred, dim=-1)
test_accuracy_sum.append(accuracy_score(y_true=labels.cpu().numpy(), y_pred=pred.cpu().numpy()))
test_tqdm.set_postfix_str(
'loss is {:14f}, accuracy is {:14f}'.format(np.mean(test_loss_sum), np.mean(test_accuracy_sum)))
test_tqdm.close()
# 模型的保存
if not os.path.exists(os.path.join('model_data')):
os.mkdir(os.path.join('model_data'))
torch.save(model.state_dict(), os.path.join('model_data', 'model.pth'))
六、模型的预测代码如下所示
import os
import torch
from data_loader import loader_data
from model_loader import load_model, load_resnet18
import matplotlib.pyplot as plt
import matplotlib
# 数据的加载
train_dl, test_dl, class_index = loader_data()
image, label = next(iter(test_dl))
new_class = dict((v, k) for k, v in class_index.items())
# 模型的加载
model = load_resnet18()
model_state_dict = torch.load(os.path.join('model_data', 'model.pth'))
model.load_state_dict(model_state_dict)
model.eval()
# 进行模型的预测
index = 23
with torch.no_grad():
pred = model(image)
pred = torch.argmax(input=pred, dim=-1)
# matplotlib.rc("font", family='Microsoft YaHei')
plt.axis('off')
plt.title('predict result: ' + new_class.get(pred[index].item()) + ', label result: ' + new_class.get(
label[index].item()),
)
plt.imshow(image[index].permute(1, 2, 0))
plt.savefig('result.png')
plt.show()
七、代码的运行结果如下所示

边栏推荐
- [tcapulusdb knowledge base] [list table] example code of batch deleting data at specified location in the list
- 聊聊内存模型和内存序
- How to save the model obtained from sklearn training? Just read this one
- Adobe international certification 𞓜 how IIT Madras brings efficiency and accessibility to scholars through Adobe e Acrobat
- pyspark,有偿询问数据清洗和上传到数据库的问题
- Tcapulusdb Jun · industry news collection (V)
- Banknext microservice: a case study
- 基于HAProxy实现网页动静分离
- Swiftui component encyclopedia creating animated 3D card scrolling effects using Scrollview and geometryreader
- 元素的常用事件
猜你喜欢

What if the self incrementing IDs of online MySQL are exhausted?

两招提升硬盘存储数据的写入效率

【owt】owt-client-native-p2p-e2e-test vs2017构建 4 : 第三方库的构建及链接p2pmfc.exe

Software project management 8.4 Software project quality plan

1058 multiple choice questions (20 points)

Using jhipster to build microservice architecture

innodb_ruby 视角下 MySQL 记录增删改

The first batch of job hunting after 00: don't misread their "different"
![[two points] leetcode1011 Capacity To Ship Packages Within D Days](/img/fd/c6f31a44ebaf41bd5ab2a342f10d06.png)
[two points] leetcode1011 Capacity To Ship Packages Within D Days

Full analysis of embedded software testing tool tpt18 update
随机推荐
Common events for elements
mysql,字段问题
The new version of Kali switches the highest account
mysql常用指令
Which insurance company is the most cost-effective for purchasing serious illness insurance?
嵌入式软件测试工具TPT18更新全解析
两招提升硬盘存储数据的写入效率
【owt】owt-client-native-p2p-e2e-test vs2017构建2 :测试单元构建及运行
【LeetCode】23. 合并K个升序链表
怎样能在小程序中实现视频通话及互动直播功能?
centos7 安装 MySQL 及配置 innodb_ruby
What if the self incrementing IDs of online MySQL are exhausted?
[tcapulusdb knowledge base] [list table] sample code for inserting data into the specified position in the list
软件项目管理 8.4.软件项目质量计划
怎么使用Shell脚本实现监测文件变化
AI video cloud vs narrowband HD, who is the favorite in the video Era
[leetcode] sum of two numbers II
直接插入排序
虫子 日期类 下 太子语言
关于sql语句的问题