PyTorch Geometric库中虽然已经包含自带的数据集如 Cora 等,但有时我们也需要用户个人数据创建自己的数据集进行一些数据研究。当然博主也建议大家若是第一次使用PyTorch Geometric库可以先使用其自带的数据集进行理解,再创建自己的数据集做到灵活运用。
PyG为数据集提供了三个抽象类:Data、InMemoryDataset 和 Dataset。
为创建一个torch_geometric.data.InMemoryDataset,需要实现以下四个基本方法:
InMemoryDataset实现流程图如下:
完整代码如下:
import torch
from torch_geometric.data import InMemoryDataset, download_url
from torch_geometric.data import Data
from torch_geometric.data import DataLoader
import warnings
warnings.filterwarnings("ignore", category=Warning)
from torch_geometric.datasets import TUDataset
# 这里给出大家注释方便理解
# 程序只要第一次运行后,processed文件生成后就不会执行proces函数,而且只要不重写download()和process()方法,也会直接跳过下载和处理。
class MyOwnDataset(InMemoryDataset):def __init__(self, root, transform=None, pre_transform=None):super().__init__(root, transform, pre_transform)self.data, self.slices = torch.load(self.processed_paths[0])print(self.data) # 输出torch.load加载的数据集data# print(root) # MYdata# print(self.data) # Data(x=[3, 1], edge_index=[2, 4], y=[3])# print(self.slices) # defaultdict(<class 'dict'>, {'x': tensor([0, 3, 6]), 'edge_index': tensor([ 0, 4, 10]), 'y': tensor([0, 3, 6])})# print(self.processed_paths[0]) # MYdataprocesseddatas.pt# 返回数据集源文件名,告诉原始的数据集存放在哪个文件夹下面,如果数据集已经存放进去了,那么就会直接从raw文件夹中读取。@propertydef raw_file_names(self):# pass # 不能使用pass,会报join() argument must be str or bytes, not 'NoneType'错误return []# 首先寻找processed_paths[0]路径下的文件名也就是之前process方法保存的文件名@propertydef processed_file_names(self):return ['datas.pt']# 用于从网上下载数据集,下载原始数据到指定的文件夹下,自己的数据集可以跳过def download(self):pass# 生成数据集所用的方法,程序第一次运行才执行并生成processed文件夹的处理过后数据的文件,否则必须删除已经生成的processed文件夹中的所有文件才会重新执行此函数def process(self):# Read data into huge `Data` list.# Read data into huge `Data` list.# 这里用于构建dataedge_index1 = sor([[0, 1, 1, 2],[1, 0, 2, 1]], dtype=torch.long)edge_index2 = sor([[0, 1, 1, 2 ,0 ,1],[1, 0, 2, 1 ,0 ,1]], dtype=torch.long)# 节点及每个节点的特征:从0号节点开始X = sor([[-1], [0], [1]], dtype=torch.float)# 每个节点的标签:从0号节点开始-两类0,1Y = sor([0, 1, 0], dtype=torch.float)# 创建data数据data1 = Data(x=X, edge_index=edge_index1, y=Y)data2 = Data(x=X, edge_index=edge_index2, y=Y)# 将data放入datalistdata_list = [data1,data2]# data_list = data_list.append(data)if self.pre_filter is not None: # pre_filter函数可以在保存之前手动过滤掉数据对象。用例可能涉及数据对象属于特定类的限制。默认Nonedata_list = [data for data in data_list if self.pre_filter(data)]if self.pre_transform is not None: # pre_transform函数在将数据对象保存到磁盘之前应用转换(因此它最好用于只需执行一次的大量预计算),默认Nonedata_list = [self.pre_transform(data) for data in data_list]data, slices = llate(data_list) # 直接保存list可能很慢,所以使用collate函数转换成大的torch_geometric.data.Data对象# print(data)torch.save((data, slices), self.processed_paths[0])# 数据集对象操作
b = MyOwnDataset("E:pycharmprojectMYdata") # 创建数据集对象
data_loader = DataLoader(b, batch_size=1, shuffle=False) # 加载数据进行处理,每批次数据的数量为1
for data in data_loader:print(data) # 按批次输出数据
程序运行结果如下:
程序代码注意三事项:
相较于torch_geometric.data.InMemoryDataset基础上额外增加两个方法:
# 返回数据集中示例的数目
def len(self):return len(self.processed_file_names)
# 实现加载单个图的逻辑
def get(self, idx):data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(idx)))return data
本文发布于:2024-02-05 06:05:07,感谢您对本站的认可!
本文链接:https://www.4u4v.net/it/170725839163678.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
留言与评论(共有 0 条评论) |