当前位置:网站首页>Dataset 和 Dataloader数据加载
Dataset 和 Dataloader数据加载
2022-07-25 09:26:00 【zzh1370894823】
初学pytorch, 一直分不清数据是如何加载的,分不清Dataset 和 Dataloader的联系。
utils包含Dataset和Dataloader两个类。自定义数据集需要继承这个类,并实现两个函数,一个是__len__,另一个是__getitem__,前者提供数据的大小,后者通过索引获取数据和标签。
__getitem__一次只能获取一个数据,所以需要通过Dataloader来定义一个新的迭代器,实现batch读取。
下面举一个直观的小例子来搞明白是怎么回事!
import torch
from torch.utils import data
import numpy as np
''' 数据集: label:data 0:[1, 2], 1:[3, 4], 0:[2, 1], 1:[3, 4], 2:[4, 5] '''
class TextDataset(data.Dataset): # 继承Dataset
def __init__(self):
self.Data = np.asarray([[1, 2], [3, 4], [2, 1], [3, 4], [4, 5]]) # 一些由2维向量表示的数据集
self.Label = np.asarray([0, 1, 0, 1, 2]) # 数据集对应的标签
def __getitem__(self, item):
text = torch.from_numpy(self.Data[item]) # 把numpy转化为Tensor
label = torch.tensor(self.Label[item])
return text, label
def __len__(self):
return len(self.Data)
# 获取数据集中数据
Test = TextDataset()
print(Test[3]) # 相当于调用getitem(3)
# 输出:
# (tensor([3, 4], dtype=torch.int32), tensor(1, dtype=torch.int32))
以上数据以tuple 返回,每次只返回一个样本,如果希望批量处理batch,需要用到DataLoader
test_loader = data.DataLoader(Test,batch_size=2,shuffle=False)
for i, traindata in enumerate(test_loader):
print("i:",i)
Data, Label = traindata
print("data:", Data) # 其中一个data包含2组数据,一个batch大小
print("label:", Label)
# 输出:
# i: 0
# data: tensor([[1, 2],
# [3, 4]], dtype=torch.int32)
# label: tensor([0, 1], dtype=torch.int32)
# i: 1
# data: tensor([[2, 1],
# [3, 4]], dtype=torch.int32)
# label: tensor([0, 1], dtype=torch.int32)
# i: 2
# data: tensor([[4, 5]], dtype=torch.int32)
# label: tensor([2], dtype=torch.int32)
其中一个data变成原来两组data的组成
相应的label也变成了原来对应的两个label的组成
.
参考于吴茂贵的python深度学习
边栏推荐
- VScode配置ROS开发环境:修改代码不生效问题原因及解决方法
- Introduction to testbench
- JS uses requestanimationframe to detect the FPS frame rate of the current animation in real time
- 入住阿里云MQTT物联网平台
- 手持振弦采集仪对振弦传感器激励方法和激励电压
- 字符串切片的用法
- mysql历史数据补充新数据
- Temperature, humidity and light intensity acquisition based on smart cloud platform
- Arm preliminaries
- Vant problem record
猜你喜欢

Fundamentals of C language

NLM5系列无线振弦传感采集仪的工作模式及休眠模式下状态

¥ 1-2 example 2.2 put the union of two sets into the linear table

JS uses requestanimationframe to detect the FPS frame rate of the current animation in real time

Swift simple implementation of to-do list

vscode插件开发

文件的上传功能

拷贝过来老的项目变成web项目

线程池的设计和原理
![[recommended collection] with these learning methods, I joined the world's top 500 - the](/img/95/e34473a1628521d4b07e56877fcff1.png)
[recommended collection] with these learning methods, I joined the world's top 500 - the "fantastic skills and extravagance" in the Internet age
随机推荐
ISP image signal processing
C函数不加括号的教训
FLASH read / write operation and flash upload file of esp8266
I2C也可总线取电!
Mlx90640 infrared thermal imager temperature measurement module development notes (I)
emmet语法速查 syntax基本语法部分
微信小程序跳转其他小程序
ADC introduction
See how a junior student of double non-2 (0 Internship) can get an offer from Alibaba and Tencent
framework打包合并脚本
MVC three-tier architecture understanding
How Android uses ADB command to view application local database
T5 paper summary
GCD详解
【成长必备】我为什么推荐你写博客?愿你多年以后成为你想成为的样子。
无线振弦采集仪应用工程安全监测
CCF 201512-4 delivery
手持振弦VH501TC采集仪传感器的连接与数据读取
Mlx90640 infrared thermal imaging sensor temperature measurement module development notes (II)
An ASP code that can return to the previous page and automatically refresh the page