当前位置:网站首页>Pytoch --- use pytoch's pre training model to realize four weather classification problems
Pytoch --- use pytoch's pre training model to realize four weather classification problems
2022-06-23 04:20:00 【Brother Shui is very water】
One 、 The datasets in the code can be obtained through the following link
Baidu online disk extraction code :lala
Two 、 Code running environment
Pytorch-gpu==1.7.1
Python==3.7
3、 ... and 、 Data set processing codes are as follows
import torchvision
from torchvision import transforms
import os
from torch.utils.data import DataLoader
def loader_data():
BATCH_SIZE = 64
train_transform = transforms.Compose([
transforms.Resize(224),
transforms.RandomCrop(192),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(0.2),
transforms.ColorJitter(brightness=0.5),
transforms.ColorJitter(contrast=0.5),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])
])
test_transform = transforms.Compose([
transforms.Resize((192, 192)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])
])
train_ds = torchvision.datasets.ImageFolder(root=os.path.join('dataset', 'train_weather'),
transform=train_transform)
test_ds = torchvision.datasets.ImageFolder(root=os.path.join('dataset', 'test_weather'), transform=test_transform)
train_dl = DataLoader(dataset=train_ds, batch_size=BATCH_SIZE, shuffle=True)
test_dl = DataLoader(dataset=test_ds, batch_size=BATCH_SIZE)
return train_dl, test_dl, test_ds.class_to_idx
Four 、 The construction code of the model is as follows
import torch
import torchvision
def load_model():
model = torchvision.models.vgg16(pretrained=True)
for p in model.features.parameters():
p.requires_grad = False
model.classifier[-1].out_features = 4
return model
def load_resnet18():
model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
param.requires_grad = False
in_f = model.fc.in_features
model.fc = torch.nn.Linear(in_features=in_f, out_features=4)
return model
5、 ... and 、 The training code of the model is as follows
import torch
from data_loader import loader_data
from model_loader import load_model, load_resnet18
import numpy as np
import tqdm
import os
from sklearn.metrics import accuracy_score
from torch.optim import lr_scheduler
# Data loading
train_dl, test_dl, class_to_idx = loader_data()
# Model loading
model = load_resnet18()
# Training related configurations
optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.0001)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer=optimizer, step_size=5, gamma=0.9)
loss_fn = torch.nn.CrossEntropyLoss()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
EPOCHS = 100
# Training
model = model.to(device)
for epoch in range(EPOCHS):
# Training part
model.train()
train_tqdm = tqdm.tqdm(iterable=train_dl, total=len(train_dl))
train_tqdm.set_description_str('Train epoch {:2d}'.format(epoch))
train_accuracy_sum = []
train_loss_sum = []
for images, labels in train_tqdm:
images, labels = images.to(device), labels.to(device)
pred = model(images)
loss = loss_fn(pred, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Show the training part
train_loss_sum.append(loss.item())
pred = torch.argmax(input=pred, dim=-1)
train_accuracy_sum.append(accuracy_score(y_true=labels.cpu().numpy(), y_pred=pred.cpu().numpy()))
train_tqdm.set_postfix_str(
'loss is {:14f}, accuracy is {:14f}'.format(np.mean(train_loss_sum), np.mean(train_accuracy_sum)))
train_tqdm.close()
# Learning rate
exp_lr_scheduler.step()
# Verification part
with torch.no_grad():
model.eval()
test_tqdm = tqdm.tqdm(iterable=test_dl, total=len(test_dl))
test_tqdm.set_description_str('Val epoch {:2d}'.format(epoch))
test_accuracy_sum = []
test_loss_sum = []
for images, labels in test_tqdm:
images, labels = images.to(device), labels.to(device)
pred = model(images)
loss = loss_fn(pred, labels)
# Display the verification results
test_loss_sum.append(loss.item())
pred = torch.argmax(input=pred, dim=-1)
test_accuracy_sum.append(accuracy_score(y_true=labels.cpu().numpy(), y_pred=pred.cpu().numpy()))
test_tqdm.set_postfix_str(
'loss is {:14f}, accuracy is {:14f}'.format(np.mean(test_loss_sum), np.mean(test_accuracy_sum)))
test_tqdm.close()
# Save model
if not os.path.exists(os.path.join('model_data')):
os.mkdir(os.path.join('model_data'))
torch.save(model.state_dict(), os.path.join('model_data', 'model.pth'))
6、 ... and 、 The prediction code of the model is as follows
import os
import torch
from data_loader import loader_data
from model_loader import load_model, load_resnet18
import matplotlib.pyplot as plt
import matplotlib
# Data loading
train_dl, test_dl, class_index = loader_data()
image, label = next(iter(test_dl))
new_class = dict((v, k) for k, v in class_index.items())
# Model loading
model = load_resnet18()
model_state_dict = torch.load(os.path.join('model_data', 'model.pth'))
model.load_state_dict(model_state_dict)
model.eval()
# Predict the model
index = 23
with torch.no_grad():
pred = model(image)
pred = torch.argmax(input=pred, dim=-1)
# matplotlib.rc("font", family='Microsoft YaHei')
plt.axis('off')
plt.title('predict result: ' + new_class.get(pred[index].item()) + ', label result: ' + new_class.get(
label[index].item()),
)
plt.imshow(image[index].permute(1, 2, 0))
plt.savefig('result.png')
plt.show()
7、 ... and 、 The running result of the code is as follows

边栏推荐
- Using jhipster to build microservice architecture
- pyspark,有偿询问数据清洗和上传到数据库的问题
- Three ways to export excel from pages
- 最新编程语言排行榜
- 怎样能在小程序中实现视频通话及互动直播功能?
- 京东云分布式数据库StarDB荣获中国信通院 “稳定性实践先锋”
- 深度学习 TensorFlow入门
- 华为联机对战服务玩家快速匹配后,不同玩家收到的同一房间内玩家列表不同
- 【LeetCode】23. Merge K ascending linked lists
- [advanced binary tree] AVLTree - balanced binary search tree
猜你喜欢

在word里,如何让页码从指定页开始编号

Talk about memory model and memory order

Review the SQL row column conversion, and the performance has been improved

给你的AppImage创建桌面快捷方式

MySQL data recovery (.Ibdata1, bin log)

Twitter与Shopify合作 将商家产品引入Twitter购物当中

摆烂LuoGu刷题记

IDEA-导入模块
![[OWT] OWT client native P2P E2E test vs2017 construction 4: Construction and link of third-party databases p2pmfc exe](/img/cd/7f896a0f05523a07b5dd04a8737879.png)
[OWT] OWT client native P2P E2E test vs2017 construction 4: Construction and link of third-party databases p2pmfc exe

What if the self incrementing IDs of online MySQL are exhausted?
随机推荐
mysql如何删除表的一行数据
【owt】owt-client-native-p2p-e2e-test vs2017构建 3 : 无 测试单元对比, 手动生成vs项目
【owt】owt-client-native-p2p-e2e-test vs2017构建 4 : 第三方库的构建及链接p2pmfc.exe
在 KubeSphere 上部署 Apache Pulsar
P1347 排序(topo)
Pytorch---Pytorch进行自定义Dataset
炫酷鼠标跟随动画js插件5种
Create a desktop shortcut to your appimage
深度学习 TensorFlow入门
pyspark,有偿询问数据清洗和上传到数据库的问题
photoshop PS 查看像素坐标、像素颜色、像素HSB颜色
众昂矿业:新能源新材料产业链对萤石需求大增
[leetcode] sum of two numbers II
Using jhipster to build microservice architecture
折半查找法
深度学习 简介
【LeetCode】23. Merge K ascending linked lists
For patch rollback, please check the cbpersistent log
两招提升硬盘存储数据的写入效率
【二叉樹進階】AVLTree - 平衡二叉搜索樹