pytorch-LSTM的输入和输出尺寸
LSTM的輸入和輸出尺寸
CLASS torch.nn.LSTM(*args, **kwargs)Applies a multi-layer long short-term memory (LSTM) RNN to an input sequence.
For each element in the input sequence, each layer computes the following function:
對于一個輸入序列實現多層長短期記憶的RNN網絡,對于輸入序列中的每一個元素,LSTM的每一層進行如下計算:
it=σ(Wiixt+bii+Whiht?1+bhi)ft=σ(Wifxt+bif+Whfht?1+bhf)gt=tanh?(Wigxt+big+Whght?1+bhg)ot=σ(Wioxt+bio+Whoht?1+bho)ct=ft⊙ct?1+it⊙gtht=ot⊙tanh?(ct)i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\ f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\ g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\ o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\ c_t = f_t \odot c_{t-1} + i_t \odot g_t \\ h_t = o_t \odot \tanh(c_t) \\ it?=σ(Wii?xt?+bii?+Whi?ht?1?+bhi?)ft?=σ(Wif?xt?+bif?+Whf?ht?1?+bhf?)gt?=tanh(Wig?xt?+big?+Whg?ht?1?+bhg?)ot?=σ(Wio?xt?+bio?+Who?ht?1?+bho?)ct?=ft?⊙ct?1?+it?⊙gt?ht?=ot?⊙tanh(ct?)
其中:
- ht:h_t:ht?:時間步t的隱藏狀態
- ct:c_t:ct?:時間步t的細胞狀態
- xt:x_t:xt?:時間步t的輸入
- ht?1:h_{t-1}:ht?1?:時間步t-1的隱藏狀態或者初始化的隱藏狀態(時間步0)
- it、ft、gt:i_t、f_t、g_t:it?、ft?、gt?:分別是輸入門,遺忘門,單元門和輸出門
- σ:\sigma:σ:sigmoid函數
- ⊙:\odot:⊙:Hadamard積
其中的參數:
input_size :輸入的維度hidden_size:h的維度num_layers:堆疊LSTM的層數,默認值為1bias:偏置 ,默認值:Truebatch_first: 如果是True,則input為(batch, seq, input_size)。默認值為:False(seq_len, batch, input_size)bidirectional :是否雙向傳播,默認值為False輸入
Inputs: input, (h_0, c_0)-
Input輸入維度是(seq_len, batch, input_size),即(句子中字的數量,批量大小,每個字向量的長度)
-
h_0 的維度(num_layers * num_directions, batch, hidden_size),即(層數?*?LSTM方向數量(單向或者雙向),批量大小,隱藏向量維度)
-
c_0 的維度 (num_layers * num_directions, batch, hidden_size),即(層數?*?LSTM方向數量,隱藏向量維度)
-
If (h_0, c_0) is not provided, both h_0 and c_0 default to zero,h_0和c_0的默認參數都是全0.
輸出
Outputs: output, (h_n, c_n)- output 輸出維度 (seq_len, batch, num_directions * hidden_size),即(句子中字的數量,批量大小,LSTM方向數量?*?隱藏向量維度)
- h_n 維度 (num_layers * num_directions, batch, hidden_size)
- c_n 維度 (num_layers * num_directions, batch, hidden_size)
舉個例子
- num_layers = 1
輸出:
output size:torch.Size([5, 50, 20]) hidden size:torch.Size([2, 50, 20]) cell size:torch.Size([2, 50, 20])- bidirecrtional = True
輸出:
output size:torch.Size([5, 50, 40]) hidden size:torch.Size([2, 50, 20]) cell size:torch.Size([2, 50, 20])總結
以上是生活随笔為你收集整理的pytorch-LSTM的输入和输出尺寸的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 再看那个用代码把人类送上月球的女人——M
- 下一篇: 10大反直觉的数学结论