原文鏈接:
- https://blog.csdn.net/xinjieyuan/article/details/105205326
- https://blog.csdn.net/xinjieyuan/article/details/105208352
在pytorch中,常見的拼接函數主要是兩個,分別是:
stack()cat()
torch.stack()
函數的意義:使用stack可以保留兩個信息:[1. 序列] 和 [2. 張量矩陣] 信息,屬于【擴張再拼接】的函數。
形象的理解:假如數據都是二維矩陣(平面),它可以把這些一個個平面按第三維(例如:時間序列)壓成一個三維的立方體,而立方體的長度就是時間序列長度。該函數常出現在自然語言處理(NLP)和圖像卷積神經網絡(CV)中。
1 stack()
官方解釋:沿著一個新維度對輸入張量序列進行連接。 序列中所有的張量都應該為相同形狀。
淺顯說法:把多個2維的張量湊成一個3維的張量;多個3維的湊成一個4維的張量…以此類推,也就是在增加新的維度進行堆疊。
outputs
= torch
.stack
(inputs
, dim
=?
) → Tensor
參數
inputs : 待連接的張量序列。
注:python的序列數據只有list和tuple。
dim : 新的維度, 必須在0到len(outputs)之間。
注:len(outputs)是生成數據的維度大小,也就是outputs的維度值。
2 重點
函數中的輸入inputs只允許是序列;且序列內部的張量元素,必須shape相等
----舉例:[tensor_1, tensor_2,…]或者(tensor_1, tensor_2,…),且必須tensor_1.shape == tensor_2.shape
dim是選擇生成的維度,必須滿足0<=dim<len(outputs);len(outputs)是輸出后的tensor的維度大小
不懂的看例子,再回過頭看就懂了。
3 例子
準備2個tensor數據,每個的shape都是[3,3]
T1
= torch
.tensor
([[1, 2, 3],[4, 5, 6],[7, 8, 9]])
T2
= torch
.tensor
([[10, 20, 30],[40, 50, 60],[70, 80, 90]])
T1
:tensor
([[1, 2, 3],[4, 5, 6],[7, 8, 9]])
T2
:tensor
([[10, 20, 30],[40, 50, 60],[70, 80, 90]])
測試stack函數
R0
= torch
.stack
((T1
, T2
), dim
=0)
print("R0:\n", R0
)
print("R0.shape:\n", R0
.shape
)
"""
R0:tensor([[[ 1, 2, 3],[ 4, 5, 6],[ 7, 8, 9]],[[10, 20, 30],[40, 50, 60],[70, 80, 90]]])
R0.shape:torch.Size([2, 3, 3])
"""R1
= torch
.stack
((T1
, T2
), dim
=1)
print("R1:\n", R1
)
print("R1.shape:\n", R1
.shape
)
"""
R1:tensor([[[ 1, 2, 3],[10, 20, 30]],[[ 4, 5, 6],[40, 50, 60]],[[ 7, 8, 9],[70, 80, 90]]])
R1.shape:torch.Size([3, 2, 3])"""R2
= torch
.stack
((T1
, T2
), dim
=2)
print("R2:\n", R2
)
print("R2.shape:\n", R2
.shape
)
"""
R2:tensor([[[ 1, 10],[ 2, 20],[ 3, 30]],[[ 4, 40],[ 5, 50],[ 6, 60]],[[ 7, 70],[ 8, 80],[ 9, 90]]])
R2.shape:torch.Size([3, 3, 2])"""R3
= torch
.stack
((T1
, T2
), dim
=3)
print("R3:\n", R3
)
print("R3.shape:\n", R3
.shape
)
"""
IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)
"""
可以復制代碼運行試試:拼接后的tensor形狀,會根據不同的dim發生變化。
4 重點
函數作用:
函數stack()對序列數據內部的張量進行擴維拼接,指定維度由程序員選擇、大小是生成后數據的維度區間。
存在意義:
在自然語言處理和卷及神經網絡中, 通常為了保留–[序列(先后)信息] 和 [張量的矩陣信息] 才會使用stack。
函數存在意義?》》》
手寫過RNN的同學,知道在循環神經網絡中輸出數據是:一個list,該列表插入了seq_len個形狀是[batch_size, output_size]的tensor,不利于計算,需要使用stack進行拼接,保留–[1.seq_len這個時間步]和–[2.張量屬性[batch_size, output_size]]。
torch.cat()
一般torch.cat()是為了把函數torch.stack()得到tensor進行拼接而存在的。torch.cat() 和python中的內置函數cat(), 在使用和目的上,是沒有區別的,區別在于前者操作對象是tensor。
1 cat()
函數目的: 在給定維度上對輸入的張量序列seq 進行連接操作。
outputs
= torch
.cat
(inputs
, dim
=0) → Tensor
參數
- inputs : 待連接的張量序列,可以是任意相同Tensor類型的python 序列。
- dim : 選擇的擴維, 必須在0到len(inputs[0])之間,沿著此維連接張量序列。
2 重點
- 輸入數據必須是序列,序列中數據是任意相同的shape的同類型tensor
- 維度不可以超過輸入數據的任一個張量的維度
3 例子
準備數據,每個的shape都是[2,3]
x1
= torch
.tensor
([[11, 21, 31], [21, 31, 41]], dtype
=torch
.int)
print("x1:\n", x1
)
print("x1.shape:\n", x1
.shape
)
'''
x1:tensor([[11, 21, 31],[21, 31, 41]], dtype=torch.int32)
x1.shape:torch.Size([2, 3])
'''
x2
= torch
.tensor
([[12, 22, 32], [22, 32, 42]])
print("x2:\n", x2
)
print("x2.shape:\n", x2
.shape
)
'''
x2:tensor([[12, 22, 32],[22, 32, 42]])
x2.shape:torch.Size([2, 3])
'''
合成inputs
inputs
= [x1
, x2
]
print("inputs:\n", inputs
)
'''
inputs:[tensor([[11, 21, 31],[21, 31, 41]], dtype=torch.int32), tensor([[12, 22, 32],[22, 32, 42]])]
'''
查看結果, 測試不同的dim拼接結果
R0
= torch
.cat
(inputs
, dim
=0)
print("R0:\n", R0
)
print("R0.shape:\n", R0
.shape
)
'''
R0:tensor([[11, 21, 31],[21, 31, 41],[12, 22, 32],[22, 32, 42]])
R0.shape:torch.Size([4, 3])
'''R1
= torch
.cat
(inputs
, dim
=1)
print("R1:\n", R1
)
print("R1.shape:\n", R1
.shape
)
'''
R1:tensor([[11, 21, 31, 12, 22, 32],[21, 31, 41, 22, 32, 42]])
R1.shape:torch.Size([2, 6])
'''R2
= torch
.cat
(inputs
, dim
=2)
print("R2:\n", R2
)
print("R2.shape:\n", R2
.shape
)
'''
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)
'''
總結
以上是生活随笔為你收集整理的pytorch拼接函数:torch.stack()和torch.cat()--详解及例子的全部內容,希望文章能夠幫你解決所遇到的問題。
如果覺得生活随笔網站內容還不錯,歡迎將生活随笔推薦給好友。