DGL教程【五】使用自己的数据集
生活随笔
收集整理的這篇文章主要介紹了
DGL教程【五】使用自己的数据集
小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,幫大家做個(gè)參考.
如果想構(gòu)建自己的數(shù)據(jù)集,應(yīng)該繼承dgl.data.DGLDataset類,并且實(shí)現(xiàn)下面的方法:
- __getitem__(self,i):得到數(shù)據(jù)集的第i個(gè)數(shù)據(jù),
- __len__(self):數(shù)據(jù)集的大小
- process(self):從硬盤加載和處理原始數(shù)據(jù)
這里使用一個(gè)小數(shù)據(jù)集Zachary’s Karate Club network,包含:
- menbers.csv文件包含每個(gè)成員的屬性
- interactions.csv文件包含兩個(gè)成員的關(guān)系
我們將成員視作節(jié)點(diǎn),關(guān)系視作邊,年齡視作節(jié)點(diǎn)的屬性,加入的club作為節(jié)點(diǎn)的標(biāo)簽,邊的權(quán)重作為變的屬性:
import pandas as pd import dgl from dgl.data import DGLDataset import torch import osclass KarateClubDataset(DGLDataset):def __init__(self):super().__init__(name='karate_club')def process(self):nodes_data = pd.read_csv('./karate/members.csv')edges_data = pd.read_csv('./karate/interactions.csv')node_features = torch.from_numpy(nodes_data['Age'].to_numpy())node_labels = torch.from_numpy(nodes_data['Club'].astype('category').cat.codes.to_numpy()) # 將Club屬性變?yōu)閏ategory類型,往往作為label 并且轉(zhuǎn)為0,1edge_features = torch.from_numpy(edges_data['Weight'].to_numpy())edges_src = torch.from_numpy(edges_data['Src'].to_numpy())edges_dst = torch.from_numpy(edges_data['Dst'].to_numpy())self.graph = dgl.graph((edges_src, edges_dst), num_nodes=nodes_data.shape[0])self.graph.ndata['feat'] = node_featuresself.graph.ndata['label'] = node_labelsself.graph.edata['weight'] = edge_features# If your dataset is a node classification dataset, you will need to assign# masks indicating whether a node belongs to training, validation, and test set.n_nodes = nodes_data.shape[0]n_train = int(n_nodes * 0.6)n_val = int(n_nodes * 0.2)train_mask = torch.zeros(n_nodes, dtype=torch.bool)val_mask = torch.zeros(n_nodes, dtype=torch.bool)test_mask = torch.zeros(n_nodes, dtype=torch.bool)train_mask[:n_train] = Trueval_mask[n_train:n_train + n_val] = Truetest_mask[n_train + n_val:] = Trueself.graph.ndata['train_mask'] = train_maskself.graph.ndata['val_mask'] = val_maskself.graph.ndata['test_mask'] = test_maskdef __getitem__(self, i):return self.graphdef __len__(self):return 1dataset = KarateClubDataset() graph = dataset[0]print(graph)總結(jié)
以上是生活随笔為你收集整理的DGL教程【五】使用自己的数据集的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 输卵管导丝介入费用
- 下一篇: creating a tensor fr