当前位置:网站首页>Pytorch读入据集(典型数据集及自定义数据集两种模式)
Pytorch读入据集(典型数据集及自定义数据集两种模式)
2022-06-24 07:24:00 【Hydrion-Qlz】
数据读入
Pytorch的数据读入是通过DataSet+DataLoader的方式完成的,DataSet定义好数据的格式和数据变换形式,DataLoader通过iterative的方式不断读入批次数据
读入已有的数据集
Pytorch自身支持很多的数据集,可以直接通过对应的函数得到对应的DataSet,然后传入DataLoader中等待处理:
例如读入MNIST数据集
from torchvision import datasets
import torchvision.transforms as transforms
transform = transforms.Compose([transforms.RandomHorizontalFlip,
transforms.RandomCrop,
transforms.ToTensor])
train = datasets.MNIST(root="./datasets",
train=True,
transform=transform,
download=True)
val = datasets.MNIST(root="./datasets",
train=False,
transform=transform,
download=True)
读入自己的数据集
另外也可以通过实现DataSet类来读入自己的数据集,一般来说需要实现三个函数:
__init__:用于向类中传入外部参数,同时定义样本集__getitem__:用于逐个读取样本集合中的元素,可以进行一定的变换,并将返回训练/验证所需的数据__len__:用户返回数据集的样本数
下面的例子是所有的图片存储在一个文件夹下面,同时在一个csv文件中保存有图片名称及其对应的标签
from PIL import Image
class CustomDataSet(Dataset):
def __init__(self, image_path, image_class, transform=None, device="cpu"):
self.image_path = image_path
self.image_class = image_class
self.transform = transform
self.device = device
def show_img(self, index):
plt.subplots(1, 1)
img = Image.open(self.image_path[index])
plt.imshow(img[2])
plt.show()
def __getitem__(self, index):
img = Image.open(self.image_path[index])
if img.mode != 'RGB':
raise ValueError("image:{} isn't RGB mode.".format(self.image_path[index]))
label = np.argmax(self.image_class[index])
label = torch.tensor(label).to(self.device)
if self.transform is not None:
img = self.transform(img)
return img.to(self.device), label
def __len__(self):
return len(self.image_path)
构建好DataSet之后就可以通过DataLoader读取自己的数据了
train_loader = DataLoader(train, batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(val, batch_size, shuffle=True, drop_last=False)
- shuffle:表示在加载的时候打乱顺序
- drop_last:丢弃掉最后不够一个batch的数据
全部设置完成之后就可以通过下面的函数不断的读取数据集了
for X, y in train_loader:
pass
边栏推荐
- 【PyTorch基础教程30】DSSM双塔模型代码解析
- Prompt code when MySQL inserts Chinese data due to character set problems: 1366
- Telnet port login method with user name for liunx server
- [team management] 25 tips for testing team performance management
- Qingcloud based R & D cloud solution for geographic information enterprises
- leetcode 1642. Furthest building you can reach
- K8s deployment of highly available PostgreSQL Cluster -- the road to building a dream
- 2022-06-23:给定一个非负数组,任意选择数字,使累加和最大且为7的倍数,返回最大累加和。 n比较大,10的5次方。 来自美团。3.26笔试。
- 4274. 后缀表达式
- [pytorch basic tutorial 30] code analysis of DSSM twin tower model
猜你喜欢

什么是图神经网络?图神经网络有什么用?

WebRTC系列-网络传输之5选择最优connection切换

K8s deployment of highly available PostgreSQL Cluster -- the road to building a dream

À propos de ETL il suffit de lire cet article, trois minutes pour vous faire comprendre ce qu'est ETL
![[noi Simulation Competition] geiguo and time chicken (structure)](/img/4c/ed1b5bc2bed653c49b8b7922ce1674.png)
[noi Simulation Competition] geiguo and time chicken (structure)

数云发布2022美妆行业全域消费者数字化经营白皮书:全域增长破解营销难题

A tip to read on Medium for free

【牛客】把字符串转换成整数

Centos7 installation of jdk8, mysql5.7 and Navicat connection to virtual machine MySQL and solutions (solutions to MySQL download errors are attached)

關於ETL看這篇文章就够了,三分鐘讓你明白什麼是ETL
随机推荐
What is the future development trend of Business Intelligence BI
MySQL | view notes on Master Kong MySQL from introduction to advanced
MySQL | 存储《康师傅MySQL从入门到高级》笔记
Solution: Nan occurs in loss during model training
【使用 PicGo+腾讯云对象存储COS 作为图床】
Liunx change the port number of vsftpd
解决:模型训练时loss出现nan
基于单片机开发的酒精浓度测试仪方案
[10 day SQL introduction] Day2
陆奇:我现在最看好这四大技术趋势
Double pointer analog
4274. 后缀表达式
关于 GIN 的路由树
Data middle office: middle office practice and summary
【MySQL从入门到精通】【高级篇】(一)字符集的修改与底层原理
【PyTorch基础教程30】DSSM双塔模型代码解析
I heard that you are still spending money to buy ppt templates from the Internet?
基于QingCloud的 “房地一体” 云解决方案
leetcode——错误的集合
华为路由器:GRE技术