lstm需要优化的参数_使用PyTorch手写代码从头构建LSTM,更深入的理解其工作原理...
這是一個造輪子的過程,但是從頭構(gòu)建LSTM能夠使我們對體系結(jié)構(gòu)進行更加了解,并將我們的研究帶入下一個層次。
LSTM單元是遞歸神經(jīng)網(wǎng)絡深度學習研究領域中最有趣的結(jié)構(gòu)之一:它不僅使模型能夠從長序列中學習,而且還為長、短期記憶創(chuàng)建了一個數(shù)值抽象,可以在需要時相互替換。
在這篇文章中,我們不僅將介紹LSTM單元的體系結(jié)構(gòu),還將通過PyTorch手工實現(xiàn)它。
最后但最不重要的是,我們將展示如何對我們的實現(xiàn)做一些小的調(diào)整,以實現(xiàn)一些新的想法,這些想法確實出現(xiàn)在LSTM研究領域,如peephole。
LSTM體系結(jié)構(gòu)
LSTM被稱為門結(jié)構(gòu):一些數(shù)學運算的組合,這些運算使信息流動或從計算圖的那里保留下來。因此,它能夠“決定”其長期和短期記憶,并輸出對序列數(shù)據(jù)的可靠預測:
LSTM單元中的預測序列。注意,它不僅會傳遞預測值,而且還會傳遞一個c,c是長期記憶的代表
遺忘門
遺忘門(forget gate)是輸入信息與候選者一起操作的門,作為長期記憶。請注意,在輸入、隱藏狀態(tài)和偏差的第一個線性組合上,應用一個sigmoid函數(shù):
sigmoid將遺忘門的輸出“縮放”到0-1之間,然后,通過將其與候選者相乘,我們可以將其設置為0,表示長期記憶中的“遺忘”,或者將其設置為更大的數(shù)字,表示我們從長期記憶中記住的“多少”。
新型長時記憶的輸入門及其解決方案
輸入門是將包含在輸入和隱藏狀態(tài)中的信息組合起來,然后與候選和部分候選c''u t一起操作的地方:
在這些操作中,決定了多少新信息將被引入到內(nèi)存中,如何改變——這就是為什么我們使用tanh函數(shù)(從-1到1)。我們將短期記憶和長期記憶中的部分候選組合起來,并將其設置為候選。
單元的輸出門和隱藏狀態(tài)(輸出)
之后,我們可以收集ot作為LSTM單元的輸出門,然后將其乘以候選單元(長期存儲器)的tanh,后者已經(jīng)用正確的操作進行了更新。網(wǎng)絡輸出為ht。
LSTM單元方程
在PyTorch上實現(xiàn)
import math
import torch
import torch.nn as nn
我們現(xiàn)在將通過繼承nn.Module,然后還將引用其參數(shù)和權(quán)重初始化,如下所示(請注意,其形狀由網(wǎng)絡的輸入大小和輸出大小決定):
class NaiveCustomLSTM(nn.Module):
def __init__(self, input_sz: int, hidden_sz: int):
super().__init__()
self.input_size = input_sz
self.hidden_size = hidden_sz
#i_t
self.U_i = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
self.V_i = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
self.b_i = nn.Parameter(torch.Tensor(hidden_sz))
#f_t
self.U_f = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
self.V_f = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
self.b_f = nn.Parameter(torch.Tensor(hidden_sz))
#c_t
self.U_c = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
self.V_c = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
self.b_c = nn.Parameter(torch.Tensor(hidden_sz))
#o_t
self.U_o = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
self.V_o = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
self.b_o = nn.Parameter(torch.Tensor(hidden_sz))
self.init_weights()
要了解每個操作的形狀,請看:
矩陣的輸入形狀是(批量大小、序列長度、特征長度),因此將序列的每個元素相乘的權(quán)重矩陣必須具有該形狀(特征長度、輸出長度)。
序列上每個元素的隱藏狀態(tài)(也稱為輸出)都具有形狀(批大小、輸出大小),這將在序列處理結(jié)束時產(chǎn)生輸出形狀(批大小、序列長度、輸出大小)。-因此,將其相乘的權(quán)重矩陣必須具有與單元格的參數(shù)hiddensz相對應的形狀(outputsize,output_size)。
這里是權(quán)重初始化,我們將其用作PyTorch默認值中的權(quán)重初始化nn.Module:
def init_weights(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
前饋操作
前饋操作接收initstates參數(shù),該參數(shù)是上面方程的(ht,ct)參數(shù)的元組,如果不引入,則設置為零。然后,我們對每個保留(ht,c_t)的序列元素執(zhí)行LSTM方程的前饋,并將其作為序列下一個元素的狀態(tài)引入。
最后,我們返回預測和最后一個狀態(tài)元組。讓我們看看它是如何發(fā)生的:
def forward(self,x,init_states=None):
"""
assumes x.shape represents (batch_size, sequence_size, input_size)
"""
bs, seq_sz, _ = x.size()
hidden_seq = []
if init_states is None:
h_t, c_t = (
torch.zeros(bs, self.hidden_size).to(x.device),
torch.zeros(bs, self.hidden_size).to(x.device),
)
else:
h_t, c_t = init_states
for t in range(seq_sz):
x_t = x[:, t, :]
i_t = torch.sigmoid(x_t @ self.U_i + h_t @ self.V_i + self.b_i)
f_t = torch.sigmoid(x_t @ self.U_f + h_t @ self.V_f + self.b_f)
g_t = torch.tanh(x_t @ self.U_c + h_t @ self.V_c + self.b_c)
o_t = torch.sigmoid(x_t @ self.U_o + h_t @ self.V_o + self.b_o)
c_t = f_t * c_t + i_t * g_t
h_t = o_t * torch.tanh(c_t)
hidden_seq.append(h_t.unsqueeze(0))
#reshape hidden_seq p/ retornar
hidden_seq = torch.cat(hidden_seq, dim=0)
hidden_seq = hidden_seq.transpose(0, 1).contiguous()
return hidden_seq, (h_t, c_t)
優(yōu)化版本
這個LSTM在運算上是正確的,但在計算時間上沒有進行優(yōu)化:我們分別執(zhí)行8個矩陣乘法,這比矢量化的方式慢得多。我們現(xiàn)在將演示如何通過將其減少到2個矩陣乘法來完成,這將使它更快。
為此,我們設置了兩個矩陣U和V,它們的權(quán)重包含在4個矩陣乘法上。然后,我們對已經(jīng)通過線性組合+偏置操作的矩陣執(zhí)行選通操作。
通過矢量化操作,LSTM單元的方程式為:
class CustomLSTM(nn.Module):
def __init__(self, input_sz, hidden_sz):
super().__init__()
self.input_sz = input_sz
self.hidden_size = hidden_sz
self.W = nn.Parameter(torch.Tensor(input_sz, hidden_sz * 4))
self.U = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz * 4))
self.bias = nn.Parameter(torch.Tensor(hidden_sz * 4))
self.init_weights()
def init_weights(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
def forward(self, x,
init_states=None):
"""Assumes x is of shape (batch, sequence, feature)"""
bs, seq_sz, _ = x.size()
hidden_seq = []
if init_states is None:
h_t, c_t = (torch.zeros(bs, self.hidden_size).to(x.device),
torch.zeros(bs, self.hidden_size).to(x.device))
else:
h_t, c_t = init_states
HS = self.hidden_size
for t in range(seq_sz):
x_t = x[:, t, :]
# batch the computations into a single matrix multiplication
gates = x_t @ self.W + h_t @ self.U + self.bias
i_t, f_t, g_t, o_t = (
torch.sigmoid(gates[:, :HS]), # input
torch.sigmoid(gates[:, HS:HS*2]), # forget
torch.tanh(gates[:, HS*2:HS*3]),
torch.sigmoid(gates[:, HS*3:]), # output
)
c_t = f_t * c_t + i_t * g_t
h_t = o_t * torch.tanh(c_t)
hidden_seq.append(h_t.unsqueeze(0))
hidden_seq = torch.cat(hidden_seq, dim=0)
# reshape from shape (sequence, batch, feature) to (batch, sequence, feature)
hidden_seq = hidden_seq.transpose(0, 1).contiguous()
return hidden_seq, (h_t, c_t)
最后但并非最不重要的是,我們可以展示如何優(yōu)化,以使用LSTM peephole connections。
LSTM peephole
LSTM peephole對其前饋操作進行了細微調(diào)整,從而將其更改為優(yōu)化的情況:
如果LSTM實現(xiàn)得很好并經(jīng)過優(yōu)化,我們可以添加peephole選項,并對其進行一些小的調(diào)整:
class CustomLSTM(nn.Module):
def __init__(self, input_sz, hidden_sz, peephole=False):
super().__init__()
self.input_sz = input_sz
self.hidden_size = hidden_sz
self.peephole = peephole
self.W = nn.Parameter(torch.Tensor(input_sz, hidden_sz * 4))
self.U = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz * 4))
self.bias = nn.Parameter(torch.Tensor(hidden_sz * 4))
self.init_weights()
def init_weights(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
def forward(self, x,
init_states=None):
"""Assumes x is of shape (batch, sequence, feature)"""
bs, seq_sz, _ = x.size()
hidden_seq = []
if init_states is None:
h_t, c_t = (torch.zeros(bs, self.hidden_size).to(x.device),
torch.zeros(bs, self.hidden_size).to(x.device))
else:
h_t, c_t = init_states
HS = self.hidden_size
for t in range(seq_sz):
x_t = x[:, t, :]
# batch the computations into a single matrix multiplication
if self.peephole:
gates = x_t @ U + c_t @ V + bias
else:
gates = x_t @ U + h_t @ V + bias
g_t = torch.tanh(gates[:, HS*2:HS*3])
i_t, f_t, o_t = (
torch.sigmoid(gates[:, :HS]), # input
torch.sigmoid(gates[:, HS:HS*2]), # forget
torch.sigmoid(gates[:, HS*3:]), # output
)
if self.peephole:
c_t = f_t * c_t + i_t * torch.sigmoid(x_t @ U + bias)[:, HS*2:HS*3]
h_t = torch.tanh(o_t * c_t)
else:
c_t = f_t * c_t + i_t * g_t
h_t = o_t * torch.tanh(c_t)
hidden_seq.append(h_t.unsqueeze(0))
hidden_seq = torch.cat(hidden_seq, dim=0)
# reshape from shape (sequence, batch, feature) to (batch, sequence, feature)
hidden_seq = hidden_seq.transpose(0, 1).contiguous()
return hidden_seq, (h_t, c_t)
我們的LSTM就這樣結(jié)束了。如果有興趣大家可以將他與torch LSTM內(nèi)置層進行比較。
代碼:https://github.com/piEsposito/pytorch-lstm-by-hand
作者:Piero Esposito
總結(jié)
以上是生活随笔為你收集整理的lstm需要优化的参数_使用PyTorch手写代码从头构建LSTM,更深入的理解其工作原理...的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 地址栏 输入 参数 刷新参数丢失_小米1
- 下一篇: 从Gaussian 09的Hartree