当前位置:网站首页>Pytorch读入据集(典型数据集及自定义数据集两种模式)
Pytorch读入据集(典型数据集及自定义数据集两种模式)
2022-06-24 07:24:00 【Hydrion-Qlz】
数据读入
Pytorch的数据读入是通过DataSet+DataLoader的方式完成的,DataSet定义好数据的格式和数据变换形式,DataLoader通过iterative的方式不断读入批次数据
读入已有的数据集
Pytorch自身支持很多的数据集,可以直接通过对应的函数得到对应的DataSet,然后传入DataLoader中等待处理:
例如读入MNIST数据集
from torchvision import datasets
import torchvision.transforms as transforms
transform = transforms.Compose([transforms.RandomHorizontalFlip,
transforms.RandomCrop,
transforms.ToTensor])
train = datasets.MNIST(root="./datasets",
train=True,
transform=transform,
download=True)
val = datasets.MNIST(root="./datasets",
train=False,
transform=transform,
download=True)
读入自己的数据集
另外也可以通过实现DataSet类来读入自己的数据集,一般来说需要实现三个函数:
__init__
:用于向类中传入外部参数,同时定义样本集__getitem__
:用于逐个读取样本集合中的元素,可以进行一定的变换,并将返回训练/验证所需的数据__len__
:用户返回数据集的样本数
下面的例子是所有的图片存储在一个文件夹下面,同时在一个csv文件中保存有图片名称及其对应的标签
from PIL import Image
class CustomDataSet(Dataset):
def __init__(self, image_path, image_class, transform=None, device="cpu"):
self.image_path = image_path
self.image_class = image_class
self.transform = transform
self.device = device
def show_img(self, index):
plt.subplots(1, 1)
img = Image.open(self.image_path[index])
plt.imshow(img[2])
plt.show()
def __getitem__(self, index):
img = Image.open(self.image_path[index])
if img.mode != 'RGB':
raise ValueError("image:{} isn't RGB mode.".format(self.image_path[index]))
label = np.argmax(self.image_class[index])
label = torch.tensor(label).to(self.device)
if self.transform is not None:
img = self.transform(img)
return img.to(self.device), label
def __len__(self):
return len(self.image_path)
构建好DataSet之后就可以通过DataLoader读取自己的数据了
train_loader = DataLoader(train, batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(val, batch_size, shuffle=True, drop_last=False)
- shuffle:表示在加载的时候打乱顺序
- drop_last:丢弃掉最后不够一个batch的数据
全部设置完成之后就可以通过下面的函数不断的读取数据集了
for X, y in train_loader:
pass
边栏推荐
- 【LeetCode】541. 反转字符串 II
- [force deduction 10 days SQL introduction] Day3
- “不平凡的代理初始值设定不受支持”,出现的原因及解决方法
- [team management] 25 tips for testing team performance management
- Database to query the quantity of books lent in this month. If it is higher than 10, it will display "more than 10 books lent in this month". Otherwise, it will display "less than 10 books lent in thi
- 【牛客】HJ1 字符串最后一个单词的长度
- 110. 平衡二叉树-递归法
- 520. detect capital letters
- 从华为WeAutomate数字机器人论坛,看政企领域的“政务新智理”
- 工具类
猜你喜欢
Prompt code when MySQL inserts Chinese data due to character set problems: 1366
【E325: ATTENTION】vim编辑时报错
Qingcloud based "real estate integration" cloud solution
MySQL | view notes on Master Kong MySQL from introduction to advanced
基于QingCloud的 “房地一体” 云解决方案
Background management of uniapp hot update
数据中台:数据采集和抽取的技术栈详解
The form image uploaded in chorme cannot view the binary image information of the request body
【PyTorch基础教程30】DSSM双塔模型代码解析
玄铁E906移植----番外0:玄铁C906仿真环境搭建
随机推荐
Qingcloud based R & D cloud solution for geographic information enterprises
华为路由器:ipsec技术
“论解不了数独所以选择做个数独游戏这件事”
How does the tunnel mobile inspection track robot monitor 24 hours to ensure the safety of tunnel construction?
One article explains in detail | those things about growth
数据中台:数据中台技术架构详解
Opencv maximum filtering (not limited to images)
MySQL | 存储《康师傅MySQL从入门到高级》笔记
110. balanced binary tree recursive method
数据中台:民生银行的数据中台实践方案
every()、map()、forEarch()方法。数组里面有对象的情况
Data midrange: detailed explanation of the technical stack of data acquisition and extraction
MyCAT读写分离与MySQL主从同步
Qingcloud based "real estate integration" cloud solution
[force deduction 10 days SQL introduction] Day3
2138. 将字符串拆分为若干长度为 k 的组
第七章 操作位和位串(三)
“不平凡的代理初始值设定不受支持”,出现的原因及解决方法
Idea another line shortcut
1844. 将所有数字用字符替换