torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
生活随笔
收集整理的這篇文章主要介紹了
torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
沿dim指定的軸聚集值。
對于三維張量,輸出由以下公式指定:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2如果input是大小為(x0, x1…, xi?1, xi, xi+1, …, xn?1) 的n維張量并且dim = i,那么index必須是大小為(x0, x1…, xi?1, y, xi+1, …, xn?1) 的n維張量,并且 y >= 1,out和index具有相同的大小。
Parameters
- input (Tensor) – 輸入張量
- dim (int) – 要索引的軸
- index (LongTensor) – 要收集的元素的索引
- sparse_grad (bool,optional) – 如果為True,梯度w.r.t。input將是一個稀疏張量。
- out (Tensor, optional) – 目標張量
Example:
>>> t = torch.tensor([[1, 2], [3, 4]]) >>> torch.gather(dim=1, index=torch.tensor([[0, 0], [1, 0]])) tensor([[ 1, 1],[ 4, 3]])dim=1 時,就是按列進行索引,dim=0 時,就是按行進行索引。
然后按照index去交換元素的位置。
總結
以上是生活随笔為你收集整理的torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: torch.nn.functional.
- 下一篇: 2014/School_C_C++_A/