使用pytorch动手实现LSTM模块
生活随笔
收集整理的這篇文章主要介紹了
使用pytorch动手实现LSTM模块
小編覺得挺不錯的,現(xiàn)在分享給大家,幫大家做個參考.
原文
import torch import torch.nn as nn from torch.nn import Parameter from torch.nn import init from torch import Tensor import mathclass NaiveLSTM(nn.Module):"""Naive LSTM like nn.LSTM"""def __init__(self, input_size: int, hidden_size: int):super(NaiveLSTM, self).__init__()self.input_size = input_sizeself.hidden_size = hidden_size# input gateself.w_ii = Parameter(Tensor(hidden_size, input_size))self.w_hi = Parameter(Tensor(hidden_size, hidden_size))self.b_ii = Parameter(Tensor(hidden_size, 1))self.b_hi = Parameter(Tensor(hidden_size, 1))# forget gateself.w_if = Parameter(Tensor(hidden_size, input_size))self.w_hf = Parameter(Tensor(hidden_size, hidden_size))self.b_if = Parameter(Tensor(hidden_size, 1))self.b_hf = Parameter(Tensor(hidden_size, 1))# output gateself.w_io = Parameter(Tensor(hidden_size, input_size))self.w_ho = Parameter(Tensor(hidden_size, hidden_size))self.b_io = Parameter(Tensor(hidden_size, 1))self.b_ho = Parameter(Tensor(hidden_size, 1))# cellself.w_ig = Parameter(Tensor(hidden_size, input_size))self.w_hg = Parameter(Tensor(hidden_size, hidden_size))self.b_ig = Parameter(Tensor(hidden_size, 1))self.b_hg = Parameter(Tensor(hidden_size, 1))self.reset_weigths()def reset_weigths(self):"""reset weights"""stdv = 1.0 / math.sqrt(self.hidden_size)for weight in self.parameters():init.uniform_(weight, -stdv, stdv)def forward(self, inputs: Tensor, state: Tuple[Tensor]) \-> Tuple[Tensor, Tuple[Tensor, Tensor]]:"""ForwardArgs:inputs: [1, 1, input_size]state: ([1, 1, hidden_size], [1, 1, hidden_size])""" # seq_size, batch_size, _ = inputs.size()if state is None:h_t = torch.zeros(1, self.hidden_size).t()c_t = torch.zeros(1, self.hidden_size).t()else:(h, c) = stateh_t = h.squeeze(0).t()c_t = c.squeeze(0).t()hidden_seq = []seq_size = 1for t in range(seq_size):x = inputs[:, t, :].t()# input gatei = torch.sigmoid(self.w_ii @ x + self.b_ii + self.w_hi @ h_t +self.b_hi)# forget gatef = torch.sigmoid(self.w_if @ x + self.b_if + self.w_hf @ h_t +self.b_hf)# cellg = torch.tanh(self.w_ig @ x + self.b_ig + self.w_hg @ h_t+ self.b_hg)# output gateo = torch.sigmoid(self.w_io @ x + self.b_io + self.w_ho @ h_t +self.b_ho)c_next = f * c_t + i * gh_next = o * torch.tanh(c_next)c_next_t = c_next.t().unsqueeze(0)h_next_t = h_next.t().unsqueeze(0)hidden_seq.append(h_next_t)hidden_seq = torch.cat(hidden_seq, dim=0)return hidden_seq, (h_next_t, c_next_t)def reset_weigths(model):"""reset weights"""for weight in model.parameters():init.constant_(weight, 0.5)### test inputs = torch.ones(1, 1, 10) h0 = torch.ones(1, 1, 20) c0 = torch.ones(1, 1, 20) print(h0.shape, h0) print(c0.shape, c0) print(inputs.shape, inputs)# test naive_lstm with input_size=10, hidden_size=20 naive_lstm = NaiveLSTM(10, 20) reset_weigths(naive_lstm)output1, (hn1, cn1) = naive_lstm(inputs, (h0, c0))print(hn1.shape, cn1.shape, output1.shape) print(hn1) print(cn1) print(output1) # Use official lstm with input_size=10, hidden_size=20 lstm = nn.LSTM(10, 20) reset_weigths(lstm) output2, (hn2, cn2) = lstm(inputs, (h0, c0)) print(hn2.shape, cn2.shape, output2.shape) print(hn2) print(cn2) print(output2)總結(jié)
以上是生活随笔為你收集整理的使用pytorch动手实现LSTM模块的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 用 pytorch 实现 一个rnn
- 下一篇: encoder decoder 模型理解