当前位置:网站首页>Pytorch project practice - fashionmnist fashion classification
Pytorch project practice - fashionmnist fashion classification
2022-07-25 12:21:00 【Alexa2077】
One , be based on pytorch Of Fashion-MNIST Fashion classification process
The main code and text of this article come from DataWhale The team , Explain profound theories in simple language PyTorch Course .
Reference link :https://datawhalechina.github.io/thorough-pytorch
1, The task is introduced
The task is introduced : Yes 10 Classify fashion images of categories , Use FashionMNIST Data sets . As shown in the following figure, there are several sample figures , Each graph corresponds to a sample .
Sample introduction :FashionMNIST The data set contains pre divided training sets and test sets , The training set consists of 60,000 Zhang image , The test set consists of 10,000 Zhang image . Each image is a single channel black-and-white image , The size is 28*28pixel, Belong to 10 Categories .
2, Classification process
1- Guide package and super parameter configuration : The basic process is the same as that in the previous section :Pytorch The main modules are similar . Note that for windows user , You can put num_workers Set to 0.
# Guide pack
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
# To configure GPU, There are two ways
## Scheme 1 : Use os.environ
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# Option two : Use “device”, For subsequent use GPU For variables of .to(device) that will do
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
## Configure other super parameters , Such as batch_size, num_workers, learning rate, And in general epochs
batch_size = 256
num_workers = 4 # about Windows user , This should be set to 0, Otherwise, multithreading errors will occur
lr = 1e-4
epochs = 20
2- Data read in : Two ways ;
- Download and use PyTorch The built-in data set provided .
- Download from the website to csv Data stored in different formats , Read in and convert to the expected format
## Read mode 1 : Use torchvision Bring your own dataset , Downloading may take a while
from torchvision import datasets
train_data = datasets.FashionMNIST(root='./', train=True, download=True, transform=data_transform)
test_data = datasets.FashionMNIST(root='./', train=False, download=True, transform=data_transform)
## Read mode 2 : Read in csv Formatted data , Building on its own Dataset class
# csv Data download link :https://www.kaggle.com/zalando-research/fashionmnist
class FMDataset(Dataset):
def __init__(self, df, transform=None):
self.df = df
self.transform = transform
self.images = df.iloc[:,1:].values.astype(np.uint8)
self.labels = df.iloc[:, 0].values
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image = self.images[idx].reshape(28,28,1)
label = int(self.labels[idx])
if self.transform is not None:
image = self.transform(image)
else:
image = torch.tensor(image/255., dtype=torch.float)
label = torch.tensor(label, dtype=torch.long)
return image, label
train_df = pd.read_csv("./FashionMNIST/fashion-mnist_train.csv")
test_df = pd.read_csv("./FashionMNIST/fashion-mnist_test.csv")
train_data = FMDataset(train_df, data_transform)
test_data = FMDataset(test_df, data_transform)
The first one is The data reading method is only applicable to common data sets , Such as MNIST,CIFAR10 etc. ,PyTorch The official provides data download . This method is often suitable for rapid test methods ( For example, test a idea stay MNIST Whether the data set is valid )
The second kind The data reading method needs to be built by yourself Dataset, This is for PyTorch It is very important to apply it to your work
3- Data preprocessing : After the data is read in , It needs to be processed into a data format that meets the requirements of model input .
For example, you need to unify the pictures to a consistent size , So that you can input network training later ; You need to convert the data format to Tensor class , wait . These transformations can be conveniently made with the help of torchvision Bag to finish , This is a PyTorch Official tool library for image processing , The method of using built-in data set mentioned above also needs to be used .
# First set the data transformation
from torchvision import transforms
image_size = 28
data_transform = transforms.Compose([
transforms.ToPILImage(),
# This step depends on how the subsequent data is read , If the built-in data set reading method is used, it is not required
transforms.Resize(image_size),
transforms.ToTensor()
])
4- After building the training and test data set , Need to define DataLoader class , In order to load data during training and testing
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
After reading , We can do some data visualization , Mainly to verify whether the data we read is correct
import matplotlib.pyplot as plt
image, label = next(iter(train_loader))
print(image.shape, label.shape)
plt.imshow(image[0][0], cmap="gray")
Here you can print out , Check whether there is correct input .
The output is as follows :
torch.Size([256, 1, 28, 28])
torch.Size([256])
5- Model design : build CNN, Put it in GPU Training
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 32, 5),
nn.ReLU(),
nn.MaxPool2d(2, stride=2),
nn.Dropout(0.3),
nn.Conv2d(32, 64, 5),
nn.ReLU(),
nn.MaxPool2d(2, stride=2),
nn.Dropout(0.3)
)
self.fc = nn.Sequential(
nn.Linear(64*4*4, 512),
nn.ReLU(),
nn.Linear(512, 10)
)
def forward(self, x):
x = self.conv(x)
x = x.view(-1, 64*4*4)
x = self.fc(x)
# x = nn.functional.normalize(x)
return x
model = Net()
model = model.cuda()
# model = nn.DataParallel(model).cuda() # How Doka writes during training , Later in the course, we will further explain
6- Set the loss function and optimizer : Use torch.nn The module comes with CrossEntropy Loss PyTorch Will automatically put the integer type label To one-hot type , Used to calculate CE loss Here we need to make sure that label It's from 0 At the beginning , At the same time, the model does not add softmax layer ( Use logits Calculation ), It also shows that PyTorch Each part of the training is not independent , We need to think about it all . Use Adam Optimizer .
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
7- Training and validation : Each encapsulated into a function , Convenient for subsequent calls .
about Training :
def train(epoch):
model.train()
train_loss = 0
for data, label in train_loader:
data, label = data.cuda(), label.cuda()
optimizer.zero_grad() # Gradient change 0, Don't let the gradient accumulate
output = model(data)
loss = criterion(output, label)
loss.backward()
optimizer.step()
train_loss += loss.item()*data.size(0)
train_loss = train_loss/len(train_loader.dataset)
print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch, train_loss))
about verification :
def val(epoch):
model.eval() # verification
val_loss = 0
gt_labels = []
pred_labels = []
with torch.no_grad(): # No gradient calculation
for data, label in test_loader:
data, label = data.cuda(), label.cuda()
output = model(data)
preds = torch.argmax(output, 1)
gt_labels.append(label.cpu().data.numpy())
pred_labels.append(preds.cpu().data.numpy())
loss = criterion(output, label) # Losses are not returned
val_loss += loss.item()*data.size(0)
val_loss = val_loss/len(test_loader.dataset)
gt_labels, pred_labels = np.concatenate(gt_labels), np.concatenate(pred_labels)
acc = np.sum(gt_labels==pred_labels)/len(pred_labels)
print('Epoch: {} \tValidation Loss: {:.6f}, Accuracy: {:6f}'.format(epoch, val_loss, acc))
for epoch in range(1, epochs+1):
train(epoch)
val(epoch)
Results output : The accuracy is 92%
8- Model preservation : After training , have access to torch.save Save model parameters or the whole model , You can also save the model during training
save_path = "./FahionModel.pkl"
torch.save(model, save_path)
Two , be based on PyTorch The actual project of 2
1, Project practice 2
Will imitate the classification of fashion , Find a project to practice on your own .
The article links : Pit to be filled !
边栏推荐
- 微软Azure和易观分析联合发布《企业级云原生平台驱动数字化转型》报告
- 基于Caffe ResNet-50网络实现图片分类(仅推理)的实验复现
- Zuul网关使用
- NLP knowledge - pytorch, back propagation, some small pieces of notes for predictive tasks
- 2.1.2 application of machine learning
- 【十一】矢量、栅格数据图例制作以及调整
- Build a series of vision transformer practices, and finally meet, Timm library!
- R语言可视化散点图、使用ggrepel包的geom_text_repel函数避免数据点之间的标签互相重叠(设置min.segment.length参数为Inf不添加标签线段)
- R语言ggplot2可视化:使用ggpubr包的ggviolin函数可视化小提琴图、设置add参数在小提琴内部添加抖动数据点以及均值标准差竖线(jitter and mean_sd)
- 面试官:“同学,你做过真实落地项目吗?”
猜你喜欢

After having a meal with trump, I wrote this article

使用TensorBoard可视化训练过程

那些离开网易的年轻人

记录一次线上死锁的定位分析
![[micro service ~sentinel] sentinel degradation, current limiting, fusing](/img/60/448c5f40af4c0937814c243bd7cb04.png)
[micro service ~sentinel] sentinel degradation, current limiting, fusing

【AI4Code】《CoSQA: 20,000+ Web Queries for Code Search and Question Answering》 ACL 2021

【GCN-RS】Are Graph Augmentations Necessary? Simple Graph Contrastive Learning for RS (SIGIR‘22)

3.2.1 什么是机器学习?

Location analysis of recording an online deadlock

Eureka usage record
随机推荐
Median (two point answer + two point search)
【六】地图框设置
Plus版SBOM:流水线物料清单PBOM
Eureka使用记录
R语言使用ggpubr包的ggarrange函数将多幅图像组合起来、使用ggexport函数将可视化图像保存为jpeg格式(width参数指定宽度、height参数指定高度、res参数指定分辨率)
【十一】矢量、栅格数据图例制作以及调整
selenium使用———xpath和模拟输入和模拟点击协作
[GCN multimodal RS] pre training representations of multi modal multi query e-commerce search KDD 2022
如何从远程访问 DMS数据库?IP地址是啥?用户名是啥?
NLP的基本概念1
Hystrix使用
Meta learning (meta learning and small sample learning)
技术管理杂谈
Multi label image classification
Fiddler抓包APP
R language Visual scatter diagram, geom using ggrep package_ text_ The rep function avoids overlapping labels between data points (set the min.segment.length parameter to inf and do not add label segm
【AI4Code】《Contrastive Code Representation Learning》 (EMNLP 2021)
【AI4Code】《InferCode: Self-Supervised Learning of Code Representations by Predicting Subtrees》ICSE‘21
R language ggplot2 visualization: use the ggviolin function of ggpubr package to visualize the violin graph, set the add parameter to add jitter data points and mean standard deviation vertical bars (
aaaaaaaaaaA heH heH nuN