当前位置:网站首页>Data sorting and usage before torchvision.datasets.imagefolder
Data sorting and usage before torchvision.datasets.imagefolder
2022-07-23 22:33:00 【Kakananan】
Usually in a complete neural network training process , We often need to build Dataset and Dataloader It is used for data reading during subsequent model training , Generally, we define one by ourselves Dataset class , rewrite
__geiitem__and__len__Function to build Dataset. However, for simple image classification tasks , There is no need to define Dataset class , calltorchvision.datasets.ImageFolderFunction to build your own Dataset, Very convenient .
This code refers to teacher Li Mo's 《 Hands-on deep learning 》
One . Organize the data set into the format specified by the function
Since it's a call API, Then the data set must follow API To organize , torchvision.datasets.ImageFolder Data sets are required to be organized as follows :
A generic data loader where the images are arranged in this way:
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
···
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
among dog and cat Labels representing pictures . In the root directory , We need to create a folder for each category , And save the pictures corresponding to the category in this folder . If our pictures are divided into apple、banana and orange Three types of , Then we need to create three folders , The name corresponds to three labels , The pictures corresponding to the labels are stored under the folder .
However, sometimes the data set we get is not exactly in this format , At this time, we need to sort out the data set , Manual sorting is very time-consuming , Next, I will take a set of data sets as an example , This paper introduces how to use programs to organize data sets into the required format , Data sets in other formats can also be used as a reference .
1.1 Raw data set format

- test: Pictures to be classified , common 5 Zhang
- train: Training pictures , common 1000 Zhang
- trainLabels.csv: The type of each picture in the training set
1.2 Organize data sets
import collections
import math
import os
import shutil
import pandas as pd
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
''' Download datasets '''
#@save
d2l.DATA_HUB['cifar10_tiny'] = (d2l.DATA_URL + 'kaggle_cifar10_tiny.zip','2068874e4b9a9f0fb07ebe0ad2b29754449ccacd')
data_dir = d2l.download_extract('cifar10_tiny') # Return the address of the root directory
First , We use the following function to read CSV Labels in files , It returns a dictionary , The dictionary maps the part of the file name without extension to its label .
#@save
def read_csv_labels(fname):
""" Read fname To return a file name to the label Dictionary """
with open(fname, 'r') as f:
# Skip the header line ( Name )
lines = f.readlines()[1:]
tokens = [l.rstrip().split(',') for l in lines]
return dict(((name, label) for name, label in tokens))
labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv'))
print('# The training sample :', len(labels))
print('# Category :', len(set(labels.values())))
''' Output : # The training sample : 1000 # Category : 10 among label For : {'1': 'frog', '2': 'truck', '3': 'truck', '4': 'deer', '5': 'automobile', '6': 'automobile', ···} '''
Next , We define reorg_train_valid Function to Split the verification set from the original training set . Parameters in this function valid_ratio Is the ratio of the number of samples in the validation set to the number of samples in the original training set . More specifically , Make n Equal to the number of images in the category with the least samples , and r It's the ratio . To ensure the same number of pictures in each category in the validation set , The validation set will be split for each category max(⌊nr⌋,1) Zhang image . Therefore, the total number of pictures in the final verification set is max(⌊nr⌋,1) Multiply by the number of categories
In this case , The minimum number of pictures in the category is 85, The ratio is 0.1, Therefore, the number of final validation sets is ⌊85×0.1⌋×10=80
#@save
def copyfile(filename, target_dir):
""" Copy the file to the destination directory """
os.makedirs(target_dir, exist_ok=True)
shutil.copy(filename, target_dir)
#@save
def reorg_train_valid(data_dir, labels, valid_ratio):
""" Split the verification set from the original training set """
# The number of samples in the category with the least samples in the training data set
n = collections.Counter(labels.values()).most_common()[-1][1]
# The number of samples in each category in the validation set
n_valid_per_label = max(1, math.floor(n * valid_ratio))
label_count = {
}
# Traverse all the pictures in the training set
for train_file in os.listdir(os.path.join(data_dir, 'train')):
# Get the corresponding label
label = labels[train_file.split('.')[0]]
# Get image address
fname = os.path.join(data_dir, 'train', train_file)
# Copy the picture to label Under the corresponding folder
copyfile(fname, os.path.join(data_dir, 'train_valid_test',
'train_valid', label))
# If the validation set is not full , Then save the picture to the corresponding label Under the validation set of
if label not in label_count or label_count[label] < n_valid_per_label:
copyfile(fname, os.path.join(data_dir, 'train_valid_test',
'valid', label))
label_count[label] = label_count.get(label, 0) + 1
# If the validation set is full , Then save the picture to the corresponding label Training set of
else:
copyfile(fname, os.path.join(data_dir, 'train_valid_test',
'train', label))
return n_valid_per_label
function reorg_train_valid A large folder will be generated train_valid_test, There are three folders in it , Namely :
- train_valid: Store all pictures , namely 1000 Zhang
- valid: Store verification set pictures , namely 80 Zhang
- train: Store training pictures , namely 920 Zhang
Under each folder is the folder corresponding to each category .
Below reorg_test Function is used to organize the test set , To facilitate the reading of prediction .
#@save
def reorg_test(data_dir):
""" Organize test sets during prediction , For easy reading """
# Save pictures in unknow In the folder
for test_file in os.listdir(os.path.join(data_dir, 'test')):
copyfile(os.path.join(data_dir, 'test', test_file),
os.path.join(data_dir, 'train_valid_test', 'test',
'unknown'))
The pictures of the test set will be saved in a file named unknown Under folder .
Last , We use a function to call the previously defined function read_csv_labels、reorg_train_valid and reorg_test.
""" Enter the address and ratio of the root directory """
def reorg_cifar10_data(data_dir, valid_ratio):
labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv'))
reorg_train_valid(data_dir, labels, valid_ratio)
reorg_test(data_dir)
valid_ratio = 0.1
reorg_cifar10_data(data_dir, valid_ratio)
After operation , A file named train_valid_test Folder , It contains test、train、train_valid and valid Four folders .

Two . call torchvision.datasets.ImageFolder structure Dataset
torchvision.datasets.ImageFolder(root,transform,target_transform,loader)
- root: The root directory of image storage , That is, the upper level directory of the directory where each category folder is located , In the example above , The root directory of the training set is
./train_valid_test/train - transform: The operation of preprocessing pictures ( function ), The original image as input , Return a converted image
- target_transform: The operation of preprocessing picture categories , Input is target, Output to its conversion . If you don't pass this parameter , to target No conversion , The order index returned 0,1, 2…
- loader: Indicates how the dataset is loaded , Usually, the default loading method is OK
in addition , The API There are the following member variables :
- self.classes: Use one list Save category name
- self.class_to_idx: The index corresponding to the category , And return without any conversion target Corresponding
- self.imgs: preservation (img-path, class) tuple Of list, Customize with us Dataset Class __getitem__ The return value is similar to
Build a training set 、 Verification set 、 Test set Dataset As shown below :
""" Image enlargement , You can modify it according to your own needs """
transform_train = torchvision.transforms.Compose([
# Enlarge the image to 40 Pixel square
torchvision.transforms.Resize(40),
# Randomly cut a height and width of 40 A square image of pixels ,
# Generate an area that is the area of the original image 0.64 To 1 Times the small square ,
# Then scale it to a height and width of 32 Pixel square
torchvision.transforms.RandomResizedCrop(32, scale=(0.64, 1.0),
ratio=(1.0, 1.0)),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor(),
# Standardize each channel of the image
torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465],
[0.2023, 0.1994, 0.2010])])
transform_test = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465],
[0.2023, 0.1994, 0.2010])])
""" structure Dataset"""
train_ds, train_valid_ds = [torchvision.datasets.ImageFolder(
os.path.join(data_dir, 'train_valid_test', folder),
transform=transform_train) for folder in ['train', 'train_valid']]
valid_ds, test_ds = [torchvision.datasets.ImageFolder(
os.path.join(data_dir, 'train_valid_test', folder),
transform=transform_test) for folder in ['valid', 'test']]
During training , We need to specify all image augmentation operations defined above . When the validation set is used for model evaluation in the process of superparametric adjustment , The randomness of image enlargement should not be introduced . Before the final forecast , We train according to the training model composed of training set and verification set , To make full use of all marked data .
Well structured Dataset Then we can continue to construct Dataloader:
train_iter, train_valid_iter = [torch.utils.data.DataLoader(
dataset, batch_size, shuffle=True, drop_last=True)
for dataset in (train_ds, train_valid_ds)]
valid_iter = torch.utils.data.DataLoader(valid_ds, batch_size, shuffle=False,
drop_last=True)
test_iter = torch.utils.data.DataLoader(test_ds, batch_size, shuffle=False,
drop_last=False)
Finally, take it Dataloader Go and refine pills happily ~
边栏推荐
猜你喜欢

The ultimate experiment of OSPF -- learn the example of OSPF century template

FL Studio 20.9 update Chinese version host Daw digital audio workstation

JDBC programming of MySQL

Taoying collects goods in batches. How to save the babies that have not been uploaded and then import them later

The simple use of ADB command combined with monkey is super detailed

Introduction and project development of MVVM and mvvmlight (I)

Tools in the tools package of Damon database (operate Damon database)

疯狂的牛市,下半场何去何从?2021-04-30

D2admin framework is basically used

海外资深玩家的投资建议(3) 2021-05-04
随机推荐
02. Supplement of knowledge related to web page structure
TreeMap
Tools in the tools package of Damon database (operate Damon database)
YOLO7 口罩识别实战
Taoying collects goods in batches. How to save the babies that have not been uploaded and then import them later
openEuler 资源利用率提升之道 01:概论
$, $*, [email protected], $$ Understand the meaning of 0
Comment forcer complètement le meurtre de processus indépendants de l'arrière - plan?
synthesizable之Verilog可不可综合
MySQL的 DDL和DML和DQL的基本语法
Altium designer—Arduino UNO原理图&PCB图(自制Arduino板)
did you register the component correctly
How can I open an account to buy financial products with a 6% income?
Rails搭配OSS最佳实践
10道面试基础笔试题,你能对几题?
STM32 MCU uses ADC function to drive finger heartbeat detection module
Utilisation des fonctions fléchées es6
Introduction to I2C Principle & Application of esp32
Matlab小波工具箱导入信号出错(doesn‘t contain one dimensional Singal)
小说里的编程 【连载之十八】元宇宙里月亮弯弯