当前位置:网站首页>PyG教程(4):自定义数据集
PyG教程(4):自定义数据集
2022-06-21 06:24:00 【斯曦巍峨】
一.前言
在PyG中,除了直接使用它自带的benchmark数据集外,用户还可以自定义数据集,其方式与Pytorch类似,需要继承数据集类。PyG中提供了两个数据集抽象类:
torch_geometric.data.Dataset:用于构建大型数据集(非内存数据集);torch_geometric.data.InMemoryDataset:用于构建内存数据集(小数据集),继承自Dataset。
下面是对其的详细介绍。
二.内存数据集
2.1 创建说明
在PyG中要构建自己的内存数据集需要先继承InMemoryDataset类,并实现如下方法:
raw_file_names():返回原始数据集的文件名列表,若self.raw_dir中没有该列表中的文件,则会通过download()进行下载;processed_file_names():返回process()方法处理后的文件名列表,若self.processed_dir中没有确实该列表中的文件,则需要通过process()方法进行处理;download():下载原始数据集到self.raw_dir中;process():处理原始数据集,并保存到processed_dir中。
在前两个方法中,若只有单个文件,则直接返回文件字符串即可,不一定要返回list对象。
另外,上面的self.raw_dir和self.processed_dir其实是两个方法,其源码为:
# 加上@property,可以使得方法像属性一样被调用
@property
def raw_dir(self) -> str:
return osp.join(self.root, 'raw')
@property
def processed_dir(self) -> str:
return osp.join(self.root, 'processed')
从源码可以看出,self.raw_dir和self.processed_dir是给定保存路径root下的原始数据文件夹和处理后的数据文件夹的路径。
2.2 创建演示
本文以SNAP数据集中的一个社交网络Facebook为例,来演示如何创建一个InMemoryDataset数据集FaceBook,该数据集包含4039个节点、88234条边。利用Gephi对该网络进行可视化如下:

根据3.1节中的说明,下面是自定义FaceBook类的源码:
import os
import pandas as pd
import torch
from torch_geometric.data import Data
from torch_geometric.data import InMemoryDataset, download_url, extract_gz
class FaceBook(InMemoryDataset):
url = "https://snap.stanford.edu/data/facebook_combined.txt.gz"
def __init__(self,
root,
transform=None,
pre_transform=None,
pre_filter=None):
super().__init__(root, transform, pre_transform, pre_filter)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
return ["facebook_combined.txt"]
@property
def processed_file_names(self):
return "data.pt"
def download(self):
path = download_url(self.url, self.raw_dir)
extract_gz(path, self.raw_dir)
def process(self):
# 加载原始数据文件
path = os.path.join(self.raw_dir, "facebook_combined.txt")
edges = pd.read_csv(path, header=None,
delimiter=" ").values.reshape(2, -1)
# 构建Data对象
edge_index = torch.from_numpy(edges)
g = Data(edge_index=edge_index, num_nodes=4039)
data, slices = self.collate([g])
torch.save((data, slices), self.processed_paths[0])
if __name__ == "__main__":
dataset = FaceBook(root="tmp")
data = dataset[0]
print(data.num_edges, data.num_nodes)
# 88234 4039
需要注意的是
download和process只在第一次调用时会调用,之后会直接加载处理好的数据集。- 以上4个方法并不都是需要的,例如如果你本地已经有了数据集,就不需要重写
download()函数来下载原始数据集。
三.大型数据集
对于大型图数据集,需要继承Dataset类,除了InMemoryDataset中需要重写的4个方法外,还需重写如下方法:
len(): 返回数据集中实例的数量;get():加载单个图的逻辑。
由于自定义大型数据集与InMemoryDataset类似,具体演示略。
四.结语
参考资料:
自定义数据集是一项重要的事情,尤其是当你本地有些数据需要转换为PyG中标准的图数据集的时候。
边栏推荐
- 第8期:云原生—— 大学生职场小白该如何学
- 827. 最大人工岛 并查集
- Contos7 installing SVN server
- Which is better for children's consumption type serious diseases at present? Are there any recommended children's products
- Pycharm的快捷键Button 4 Click是什么?
- FPGA - 7 Series FPGA selectio -02- introduction to source language
- Direct attack on the Internet layoffs in 2022: flowers on the ground, chicken feathers on the ground
- scikit-learn中的Scaler
- Unity hidden directories and hidden files
- 超参数和模型参数
猜你喜欢

TypeError: iter() returned non-iterator of type ‘xxx‘

小程序【第一期】

Aurora 8b10b IP use - 02 - IP function design skills

5254. 卖木头块 动态规划

超参数和模型参数

Aurora8b10b IP use-04-ip routine application example

FPGA - 7 Series FPGA selectio -03- ILOGIC of logic resources

第8期:云原生—— 大学生职场小白该如何学

That's great. MySQL's summary is too comprehensive

FPGA - 7 Series FPGA selectio -02- introduction to source language
随机推荐
Which of the children's critical illness insurance companies has the highest cost performance in 2022?
FPGA - 7系列 FPGA SelectIO -04- 逻辑资源之IDELAY和IDELAYCTRL
Cache cache (notes on principles of computer composition)
[JDBC from starting to Real combat] JDBC Basic clearance tutoriel (Summary of the first part)
Chapter 1: overview of database system (final review of database)
我的高考经历与总结
Judge whether a tree is a complete binary tree
That's great. MySQL's summary is too comprehensive
[is the network you are familiar with really safe?] Wanziwen
如何通过JDBC访问MySQL数据库?手把手实现登录界面(图解+完整代码)
Latest analysis on operation of refrigeration and air conditioning equipment in 2022 and examination question bank for operation of refrigeration and air conditioning equipment
tf. compat. v1.pad
【利用MSF工具内网复现MS08-067】
如何限制内网网速
NOP法破解简易登录系统
端口占用解决
leetcode 410. Maximum value of split array - (Day30)
Dual tone search: array is incremented first and then decremented
IP - 射频数据转换器 -04- API使用指南 - ADC状态指示函数
My college entrance examination experience and summary