当前位置:网站首页>Pytoch --- use pytoch's pre training model to realize four weather classification problems
Pytoch --- use pytoch's pre training model to realize four weather classification problems
2022-06-23 04:20:00 【Brother Shui is very water】
One 、 The datasets in the code can be obtained through the following link
Baidu online disk extraction code :lala
Two 、 Code running environment
Pytorch-gpu==1.7.1
Python==3.7
3、 ... and 、 Data set processing codes are as follows
import torchvision
from torchvision import transforms
import os
from torch.utils.data import DataLoader
def loader_data():
BATCH_SIZE = 64
train_transform = transforms.Compose([
transforms.Resize(224),
transforms.RandomCrop(192),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(0.2),
transforms.ColorJitter(brightness=0.5),
transforms.ColorJitter(contrast=0.5),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])
])
test_transform = transforms.Compose([
transforms.Resize((192, 192)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])
])
train_ds = torchvision.datasets.ImageFolder(root=os.path.join('dataset', 'train_weather'),
transform=train_transform)
test_ds = torchvision.datasets.ImageFolder(root=os.path.join('dataset', 'test_weather'), transform=test_transform)
train_dl = DataLoader(dataset=train_ds, batch_size=BATCH_SIZE, shuffle=True)
test_dl = DataLoader(dataset=test_ds, batch_size=BATCH_SIZE)
return train_dl, test_dl, test_ds.class_to_idx
Four 、 The construction code of the model is as follows
import torch
import torchvision
def load_model():
model = torchvision.models.vgg16(pretrained=True)
for p in model.features.parameters():
p.requires_grad = False
model.classifier[-1].out_features = 4
return model
def load_resnet18():
model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
param.requires_grad = False
in_f = model.fc.in_features
model.fc = torch.nn.Linear(in_features=in_f, out_features=4)
return model
5、 ... and 、 The training code of the model is as follows
import torch
from data_loader import loader_data
from model_loader import load_model, load_resnet18
import numpy as np
import tqdm
import os
from sklearn.metrics import accuracy_score
from torch.optim import lr_scheduler
# Data loading
train_dl, test_dl, class_to_idx = loader_data()
# Model loading
model = load_resnet18()
# Training related configurations
optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.0001)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer=optimizer, step_size=5, gamma=0.9)
loss_fn = torch.nn.CrossEntropyLoss()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
EPOCHS = 100
# Training
model = model.to(device)
for epoch in range(EPOCHS):
# Training part
model.train()
train_tqdm = tqdm.tqdm(iterable=train_dl, total=len(train_dl))
train_tqdm.set_description_str('Train epoch {:2d}'.format(epoch))
train_accuracy_sum = []
train_loss_sum = []
for images, labels in train_tqdm:
images, labels = images.to(device), labels.to(device)
pred = model(images)
loss = loss_fn(pred, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Show the training part
train_loss_sum.append(loss.item())
pred = torch.argmax(input=pred, dim=-1)
train_accuracy_sum.append(accuracy_score(y_true=labels.cpu().numpy(), y_pred=pred.cpu().numpy()))
train_tqdm.set_postfix_str(
'loss is {:14f}, accuracy is {:14f}'.format(np.mean(train_loss_sum), np.mean(train_accuracy_sum)))
train_tqdm.close()
# Learning rate
exp_lr_scheduler.step()
# Verification part
with torch.no_grad():
model.eval()
test_tqdm = tqdm.tqdm(iterable=test_dl, total=len(test_dl))
test_tqdm.set_description_str('Val epoch {:2d}'.format(epoch))
test_accuracy_sum = []
test_loss_sum = []
for images, labels in test_tqdm:
images, labels = images.to(device), labels.to(device)
pred = model(images)
loss = loss_fn(pred, labels)
# Display the verification results
test_loss_sum.append(loss.item())
pred = torch.argmax(input=pred, dim=-1)
test_accuracy_sum.append(accuracy_score(y_true=labels.cpu().numpy(), y_pred=pred.cpu().numpy()))
test_tqdm.set_postfix_str(
'loss is {:14f}, accuracy is {:14f}'.format(np.mean(test_loss_sum), np.mean(test_accuracy_sum)))
test_tqdm.close()
# Save model
if not os.path.exists(os.path.join('model_data')):
os.mkdir(os.path.join('model_data'))
torch.save(model.state_dict(), os.path.join('model_data', 'model.pth'))
6、 ... and 、 The prediction code of the model is as follows
import os
import torch
from data_loader import loader_data
from model_loader import load_model, load_resnet18
import matplotlib.pyplot as plt
import matplotlib
# Data loading
train_dl, test_dl, class_index = loader_data()
image, label = next(iter(test_dl))
new_class = dict((v, k) for k, v in class_index.items())
# Model loading
model = load_resnet18()
model_state_dict = torch.load(os.path.join('model_data', 'model.pth'))
model.load_state_dict(model_state_dict)
model.eval()
# Predict the model
index = 23
with torch.no_grad():
pred = model(image)
pred = torch.argmax(input=pred, dim=-1)
# matplotlib.rc("font", family='Microsoft YaHei')
plt.axis('off')
plt.title('predict result: ' + new_class.get(pred[index].item()) + ', label result: ' + new_class.get(
label[index].item()),
)
plt.imshow(image[index].permute(1, 2, 0))
plt.savefig('result.png')
plt.show()
7、 ... and 、 The running result of the code is as follows

边栏推荐
猜你喜欢

Full analysis of embedded software testing tool tpt18 update

Pytorch---使用Pytorch的预训练模型实现四种天气分类问题

冒泡排序法

软件项目管理 8.4.软件项目质量计划

Redis启动有问题

8 key indicators to measure technology debt in 2022

【owt】owt-client-native-p2p-e2e-test vs2017构建2 :测试单元构建及运行

SVG+JS智能家居监控网格布局

Insert sort directly

支持在 Kubernetes 运行,添加多种连接器,SeaTunnel 2.1.2 版本正式发布!
随机推荐
Ideal car × Oceanbase: when new forces of car building meet new forces of database
Common events for elements
基于HAProxy实现网页动静分离
理想汽车×OceanBase:当造车新势力遇上数据库新势力
[OWT] OWT client native P2P E2E test vs2017 build 3: no test unit comparison, manually generate vs projects
嵌入式软件测试工具TPT18更新全解析
电商如何借助小程序发力
如何处理大体积 XLSX/CSV/TXT 文件?
How to realize data transaction
Two ways to improve the writing efficiency of hard disk storage data
[leetcode] flip linked list II
Goframe framework: quick creation of static file download web service
虫子 STM32 高级定时器 (哈哈我说实话硬件定时器不能体现实力,实际上想把内核定时器发上来的,一想算了,慢慢来吧)
pyspark,有偿询问数据清洗和上传到数据库的问题
What is the difference between redistemplate and CacheManager operation redis
【owt】owt-client-native-p2p-e2e-test vs2017构建 4 : 第三方库的构建及链接p2pmfc.exe
[pycharm] ide Eval resetter
虫子 STM32 中断 (懂的都懂)
Halcon胶线检测—模板匹配、位姿变换、胶宽,胶连续性检测
Talk about memory model and memory order