当前位置:网站首页>Data sorting and usage before torchvision.datasets.imagefolder
Data sorting and usage before torchvision.datasets.imagefolder
2022-07-23 22:33:00 【Kakananan】
Usually in a complete neural network training process , We often need to build Dataset and Dataloader It is used for data reading during subsequent model training , Generally, we define one by ourselves Dataset class , rewrite
__geiitem__and__len__Function to build Dataset. However, for simple image classification tasks , There is no need to define Dataset class , calltorchvision.datasets.ImageFolderFunction to build your own Dataset, Very convenient .
This code refers to teacher Li Mo's 《 Hands-on deep learning 》
One . Organize the data set into the format specified by the function
Since it's a call API, Then the data set must follow API To organize , torchvision.datasets.ImageFolder Data sets are required to be organized as follows :
A generic data loader where the images are arranged in this way:
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
···
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
among dog and cat Labels representing pictures . In the root directory , We need to create a folder for each category , And save the pictures corresponding to the category in this folder . If our pictures are divided into apple、banana and orange Three types of , Then we need to create three folders , The name corresponds to three labels , The pictures corresponding to the labels are stored under the folder .
However, sometimes the data set we get is not exactly in this format , At this time, we need to sort out the data set , Manual sorting is very time-consuming , Next, I will take a set of data sets as an example , This paper introduces how to use programs to organize data sets into the required format , Data sets in other formats can also be used as a reference .
1.1 Raw data set format

- test: Pictures to be classified , common 5 Zhang
- train: Training pictures , common 1000 Zhang
- trainLabels.csv: The type of each picture in the training set
1.2 Organize data sets
import collections
import math
import os
import shutil
import pandas as pd
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
''' Download datasets '''
#@save
d2l.DATA_HUB['cifar10_tiny'] = (d2l.DATA_URL + 'kaggle_cifar10_tiny.zip','2068874e4b9a9f0fb07ebe0ad2b29754449ccacd')
data_dir = d2l.download_extract('cifar10_tiny') # Return the address of the root directory
First , We use the following function to read CSV Labels in files , It returns a dictionary , The dictionary maps the part of the file name without extension to its label .
#@save
def read_csv_labels(fname):
""" Read fname To return a file name to the label Dictionary """
with open(fname, 'r') as f:
# Skip the header line ( Name )
lines = f.readlines()[1:]
tokens = [l.rstrip().split(',') for l in lines]
return dict(((name, label) for name, label in tokens))
labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv'))
print('# The training sample :', len(labels))
print('# Category :', len(set(labels.values())))
''' Output : # The training sample : 1000 # Category : 10 among label For : {'1': 'frog', '2': 'truck', '3': 'truck', '4': 'deer', '5': 'automobile', '6': 'automobile', ···} '''
Next , We define reorg_train_valid Function to Split the verification set from the original training set . Parameters in this function valid_ratio Is the ratio of the number of samples in the validation set to the number of samples in the original training set . More specifically , Make n Equal to the number of images in the category with the least samples , and r It's the ratio . To ensure the same number of pictures in each category in the validation set , The validation set will be split for each category max(⌊nr⌋,1) Zhang image . Therefore, the total number of pictures in the final verification set is max(⌊nr⌋,1) Multiply by the number of categories
In this case , The minimum number of pictures in the category is 85, The ratio is 0.1, Therefore, the number of final validation sets is ⌊85×0.1⌋×10=80
#@save
def copyfile(filename, target_dir):
""" Copy the file to the destination directory """
os.makedirs(target_dir, exist_ok=True)
shutil.copy(filename, target_dir)
#@save
def reorg_train_valid(data_dir, labels, valid_ratio):
""" Split the verification set from the original training set """
# The number of samples in the category with the least samples in the training data set
n = collections.Counter(labels.values()).most_common()[-1][1]
# The number of samples in each category in the validation set
n_valid_per_label = max(1, math.floor(n * valid_ratio))
label_count = {
}
# Traverse all the pictures in the training set
for train_file in os.listdir(os.path.join(data_dir, 'train')):
# Get the corresponding label
label = labels[train_file.split('.')[0]]
# Get image address
fname = os.path.join(data_dir, 'train', train_file)
# Copy the picture to label Under the corresponding folder
copyfile(fname, os.path.join(data_dir, 'train_valid_test',
'train_valid', label))
# If the validation set is not full , Then save the picture to the corresponding label Under the validation set of
if label not in label_count or label_count[label] < n_valid_per_label:
copyfile(fname, os.path.join(data_dir, 'train_valid_test',
'valid', label))
label_count[label] = label_count.get(label, 0) + 1
# If the validation set is full , Then save the picture to the corresponding label Training set of
else:
copyfile(fname, os.path.join(data_dir, 'train_valid_test',
'train', label))
return n_valid_per_label
function reorg_train_valid A large folder will be generated train_valid_test, There are three folders in it , Namely :
- train_valid: Store all pictures , namely 1000 Zhang
- valid: Store verification set pictures , namely 80 Zhang
- train: Store training pictures , namely 920 Zhang
Under each folder is the folder corresponding to each category .
Below reorg_test Function is used to organize the test set , To facilitate the reading of prediction .
#@save
def reorg_test(data_dir):
""" Organize test sets during prediction , For easy reading """
# Save pictures in unknow In the folder
for test_file in os.listdir(os.path.join(data_dir, 'test')):
copyfile(os.path.join(data_dir, 'test', test_file),
os.path.join(data_dir, 'train_valid_test', 'test',
'unknown'))
The pictures of the test set will be saved in a file named unknown Under folder .
Last , We use a function to call the previously defined function read_csv_labels、reorg_train_valid and reorg_test.
""" Enter the address and ratio of the root directory """
def reorg_cifar10_data(data_dir, valid_ratio):
labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv'))
reorg_train_valid(data_dir, labels, valid_ratio)
reorg_test(data_dir)
valid_ratio = 0.1
reorg_cifar10_data(data_dir, valid_ratio)
After operation , A file named train_valid_test Folder , It contains test、train、train_valid and valid Four folders .

Two . call torchvision.datasets.ImageFolder structure Dataset
torchvision.datasets.ImageFolder(root,transform,target_transform,loader)
- root: The root directory of image storage , That is, the upper level directory of the directory where each category folder is located , In the example above , The root directory of the training set is
./train_valid_test/train - transform: The operation of preprocessing pictures ( function ), The original image as input , Return a converted image
- target_transform: The operation of preprocessing picture categories , Input is target, Output to its conversion . If you don't pass this parameter , to target No conversion , The order index returned 0,1, 2…
- loader: Indicates how the dataset is loaded , Usually, the default loading method is OK
in addition , The API There are the following member variables :
- self.classes: Use one list Save category name
- self.class_to_idx: The index corresponding to the category , And return without any conversion target Corresponding
- self.imgs: preservation (img-path, class) tuple Of list, Customize with us Dataset Class __getitem__ The return value is similar to
Build a training set 、 Verification set 、 Test set Dataset As shown below :
""" Image enlargement , You can modify it according to your own needs """
transform_train = torchvision.transforms.Compose([
# Enlarge the image to 40 Pixel square
torchvision.transforms.Resize(40),
# Randomly cut a height and width of 40 A square image of pixels ,
# Generate an area that is the area of the original image 0.64 To 1 Times the small square ,
# Then scale it to a height and width of 32 Pixel square
torchvision.transforms.RandomResizedCrop(32, scale=(0.64, 1.0),
ratio=(1.0, 1.0)),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor(),
# Standardize each channel of the image
torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465],
[0.2023, 0.1994, 0.2010])])
transform_test = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465],
[0.2023, 0.1994, 0.2010])])
""" structure Dataset"""
train_ds, train_valid_ds = [torchvision.datasets.ImageFolder(
os.path.join(data_dir, 'train_valid_test', folder),
transform=transform_train) for folder in ['train', 'train_valid']]
valid_ds, test_ds = [torchvision.datasets.ImageFolder(
os.path.join(data_dir, 'train_valid_test', folder),
transform=transform_test) for folder in ['valid', 'test']]
During training , We need to specify all image augmentation operations defined above . When the validation set is used for model evaluation in the process of superparametric adjustment , The randomness of image enlargement should not be introduced . Before the final forecast , We train according to the training model composed of training set and verification set , To make full use of all marked data .
Well structured Dataset Then we can continue to construct Dataloader:
train_iter, train_valid_iter = [torch.utils.data.DataLoader(
dataset, batch_size, shuffle=True, drop_last=True)
for dataset in (train_ds, train_valid_ds)]
valid_iter = torch.utils.data.DataLoader(valid_ds, batch_size, shuffle=False,
drop_last=True)
test_iter = torch.utils.data.DataLoader(test_ds, batch_size, shuffle=False,
drop_last=False)
Finally, take it Dataloader Go and refine pills happily ~
边栏推荐
- Life always needs a little passion
- $, $*, [email protected], $$ Understand the meaning of 0
- Investment suggestions for overseas senior players (2) 2021-05-03
- el-select下拉框多选远程搜索反显
- 海外资深玩家的投资建议(2) 2021-05-03
- [learning notes] diameter and center of gravity of the tree
- How about opening an account for Haitong Securities? Is it safe
- LeetCode高频题53. 最大子数组和,具有最大和的连续子数组,返回其最大和
- Can Verilog of synthetizable be integrated
- 【golang学习笔记】Go语言中参数的传递是值传递还是引用传递
猜你喜欢
About synchronizing data from computer to mobile

What are the product life cycle, common project functions, and information flow

Introduction to I2C Principle & Application of esp32

达梦数据库tools包中的工具(操作达梦数据库)

Investment suggestions for overseas senior players (3) 2021-05-04

LeetCode高频题62. 不同路径:机器人从左上角到右下角的路径有多少条?纯概率排列组合问题,而不是动态规划题

多线程问题:为什么不应该使用多线程读写同一个socket连接?

Rails搭配OSS最佳实践

Life always needs a little passion

为了一劳永逸而写的数独
随机推荐
DeFi項目的盈利邏輯 2021-04-26
MySQL的JDBC编程
达梦数据库tools包中的工具(操作达梦数据库)
[jailhouse article] a novel software architecture for mixed criticality systems (2020)
10道面试基础笔试题,你能对几题?
Zhongbang technology devotes itself to another work -- gyro maker OA system
众邦科技又一潜心力作 —— 陀螺匠 OA 系统
Explain NAT technology in detail
10 basic written interview questions, how many questions can you answer?
Application of performance test knowledge to actual combat
Leetcode high frequency question 53. maximum subarray sum, continuous subarray with maximum sum, return its maximum sum
Matlab wavelet toolbox import signal error (doesn't contain one dimensional single)
Array -- 209. Subarray with the smallest length
"Morning reading" if you were in my position, what would you do? How do we do it,
Notes on network segment CIDR
Investment suggestions for overseas senior players (3) 2021-05-04
小说里的编程 【连载之十七】元宇宙里月亮弯弯
Programming in the novel [serial 17] the moon bends in the yuan universe
Basic syntax of MySQL DDL and DML and DQL
ES6箭头函数的使用