torch_geometric 笔记:nn.ChebNet
生活随笔
收集整理的這篇文章主要介紹了
torch_geometric 笔记:nn.ChebNet
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
1 理論部分
?
交通預測論文翻譯:Deep Learning on Traffic Prediction: Methods,Analysis and Future Directions_UQI-LIUWJ的博客-CSDN博客-4.1.2.1.1 ChebNet
2? 類寫法
CLASSChebConv(in_channels: int, out_channels: int, K: int, normalization: Optional[str] = 'sym', bias: bool = True, **kwargs)3 參數說明
| in_channels?(int)? | 輸入樣本的通道數 | ||||||
| out_channels?(int) | 輸出樣本的通道數 (在Cheb的源碼中,每一階切比雪夫多項式 進行卷積之后,都會再過一個FC,這個就是給每一階的切比雪夫多項式卷積 修改維度、調整權重用的) | ||||||
| K?(int) | 幾階切比雪夫多項式近似 | ||||||
| normalization?(str,?optional) | 圖拉普拉斯矩陣的歸一化方法:默認是sym
?需要將lambda_max參數提供給forward()方法,以防normalization是不對稱的 lambda_max 需要時一個[batch_size]維度的Tensor 可以使用torch_geometric.transforms.LaplacianLambdaMax?方法事先計算lambda_max | ||||||
| bias | 默認是True ,如果是False,那么這個ChebNet就不會有偏移量 |
4 forward 函數
forward(x,edge_index, edge_weight: Optional[torch.Tensor] = None, batch: Optional[torch.Tensor] = None, lambda_max: Optional[torch.Tensor] = None)注:這里的batch是指torch_geometric筆記:數據集 ENZYMES &Minibatches_UQI-LIUWJ的博客-CSDN博客?第2小節中說的batch
5 源碼
這里處理得很高妙,它相當于把正則化拉普拉斯矩陣作為新圖的鄰接矩陣
from typing import Optional from torch_geometric.typing import OptTensorimport torch from torch.nn import Parameterfrom torch_geometric.nn.inits import zeros from torch_geometric.utils import get_laplacian from torch_geometric.nn.dense.linear import Linear from torch_geometric.nn.conv import MessagePassing from torch_geometric.utils import remove_self_loops, add_self_loopsclass ChebConv(MessagePassing):def __init__(self, in_channels: int, out_channels: int, K: int,normalization: Optional[str] = 'sym', bias: bool = True,**kwargs):kwargs.setdefault('aggr', 'add')super(ChebConv, self).__init__(**kwargs)#設置聚合方式(add,也就是將各層切比雪夫多項式近似求和)assert K > 0assert normalization in [None, 'sym', 'rw'], 'Invalid normalization'#兩個斷言,切比雪夫多項式近似的階數大于0;在這三種normalization里面選擇self.in_channels = in_channelsself.out_channels = out_channelsself.normalization = normalizationself.lins = torch.nn.ModuleList([Linear(in_channels, out_channels, bias=False,weight_initializer='glorot') for _ in range(K)])#各層切比雪夫多項式近似之后接的維度轉換全連接層if bias:self.bias = Parameter(torch.Tensor(out_channels))else:self.register_parameter('bias', None)self.reset_parameters()def reset_parameters(self):#初始化參數for lin in self.lins:lin.reset_parameters()zeros(self.bias)def __norm__(self, edge_index, num_nodes: Optional[int],edge_weight: OptTensor, normalization: Optional[str],lambda_max, dtype: Optional[int] = None,batch: OptTensor = None):#這里處理得很高妙,它相當于把正則化拉普拉斯矩陣作為新圖的鄰接矩陣edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)#去掉自環edge_index, edge_weight = get_laplacian(edge_index, edge_weight,normalization, dtype,num_nodes)#計算拉普拉斯矩陣if batch is not None and lambda_max.numel() > 1:lambda_max = lambda_max[batch[edge_index[0]]]edge_weight = (2.0 * edge_weight) / lambda_maxedge_weight.masked_fill_(edge_weight == float('inf'), 0)#圖中所有原來邊權重非零的邊,權重全部乘以2/lambda_maxedge_index, edge_weight = add_self_loops(edge_index, edge_weight,fill_value=-1.,num_nodes=num_nodes)#由于歸一化拉普拉斯矩陣還需要-I,所以所有的自環權重減一assert edge_weight is not Nonereturn edge_index, edge_weight#返回以拉普拉斯矩陣為鄰接矩陣的“新圖”def forward(self, x, edge_index, edge_weight: OptTensor = None,batch: OptTensor = None, lambda_max: OptTensor = None):""""""if self.normalization != 'sym' and lambda_max is None:raise ValueError('You need to pass `lambda_max` to `forward() in`''case the normalization is non-symmetric.')if lambda_max is None:lambda_max = torch.tensor(2.0, dtype=x.dtype, device=x.device)if not isinstance(lambda_max, torch.Tensor):lambda_max = torch.tensor(lambda_max, dtype=x.dtype,device=x.device)assert lambda_max is not Noneedge_index, norm = self.__norm__(edge_index, x.size(self.node_dim),edge_weight, self.normalization,lambda_max, dtype=x.dtype,batch=batch)#得到以拉普拉斯矩陣為鄰接矩陣的“新圖”Tx_0 = x#Z_1=Xout = self.lins[0](Tx_0)# propagate_type: (x: Tensor, norm: Tensor)if len(self.lins) > 1:Tx_1 = self.propagate(edge_index, x=x, norm=norm, size=None)#每一輪的propagate相當于對每個點,計算所有鄰邊的拉普拉斯矩陣權重*臨近點,再求和【aggr=add】out = out + self.lins[1](Tx_1)#Z_2=LXfor lin in self.lins[2:]:Tx_2 = self.propagate(edge_index, x=Tx_1, norm=norm, size=None)#Tx_2=Z_k=L*Z_k-1Tx_2 = 2. * Tx_2 - Tx_0#Z_k=2*L*k-1-Z_k-2out = out + lin.forward(Tx_2)Tx_0, Tx_1 = Tx_1, Tx_2if self.bias is not None:out += self.biasreturn outdef message(self, x_j, norm):return norm.view(-1, 1) * x_j#就是對應的鄰邊權重*鄰接點def __repr__(self):return '{}({}, {}, K={}, normalization={})'.format(self.__class__.__name__, self.in_channels, self.out_channels,len(self.lins), self.normalization)6 舉例
from torch_geometric.nn import ChebConvdata #Batch(x=[9893, 1], edge_index=[2, 34637], y=[9893, 1], batch=[9893], ptr=[2])conv1 = ChebConv(1, 32, 2)x = conv1(data.x, data.edge_index)type(x) #torch.Tensorx.shape #torch.Size([9893, 32]) 每個點的維度是[9893,32]總結
以上是生活随笔為你收集整理的torch_geometric 笔记:nn.ChebNet的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: torch_geometric 笔记:T
- 下一篇: 论文笔记:Weighted Graph