【深度学习】在PyTorch中使用 LSTM 自动编码器进行时间序列异常检测
寫在前面
環(huán)境準(zhǔn)備
本次數(shù)據(jù)集的格式.arff,需要用到arff2pandas模塊讀取。
#?!nvidia-smi #?!pip?install?-qq?arff2pandas #?!pip?install?-q?-U?watermark另外本次運(yùn)行環(huán)境可通過如下方法查看。
%reload_ext?watermark %watermark?-v?-p?numpy,pandas,torch,arff2pandasPython implementation: CPython Python version : 3.8.8 IPython version : 7.22.0numpy : 1.19.5 pandas : 1.2.4 torch : 1.9.1 arff2pandas: 1.0.1導(dǎo)入相關(guān)模塊
import?torchimport?copy import?numpy?as?np import?pandas?as?pd import?seaborn?as?sns from?pylab?import?rcParams import?matplotlib.pyplot?as?plt from?matplotlib?import?rc from?sklearn.model_selection?import?train_test_split from?torch?import?nn,?optim import?torch.nn.functional?as?F from?arff2pandas?import?a2p%matplotlib?inline %config?InlineBackend.figure_format='retina' sns.set(style='whitegrid',?palette='muted',?font_scale=1.2) HAPPY_COLORS_PALETTE?=?["#01BEFE",?"#FFDD00",?"#FF7D00",?"#FF006D",?"#ADFF02",?"#8F00FF"] sns.set_palette(sns.color_palette(HAPPY_COLORS_PALETTE)) rcParams['figure.figsize']?=?15,?8 RANDOM_SEED?=?42 np.random.seed(RANDOM_SEED) torch.manual_seed(RANDOM_SEED)<torch._C.Generator at 0x7fa5fcf70c50>本文核心內(nèi)容
本案例使用真實(shí)的心電圖 (ECG) 數(shù)據(jù)來檢測患者心跳的異常情況。我們將一起構(gòu)建一個(gè) LSTM 自動(dòng)編碼器,使用來自單個(gè)心臟病患者的真實(shí)心電圖數(shù)據(jù)對(duì)其進(jìn)行訓(xùn)練,并將在新的樣本中,使用訓(xùn)練好的模型對(duì)其進(jìn)行預(yù)測分類為正常或異常來來檢測異常心跳。
本案例主要圍繞以下幾大核心展開。
從時(shí)間序列數(shù)據(jù)中準(zhǔn)備用于異常檢測的數(shù)據(jù)集
使用 PyTorch 構(gòu)建 LSTM 自動(dòng)編碼器
訓(xùn)練和評(píng)估模型
設(shè)定異常檢測的閾值
將新的樣本分類為正常或異常
數(shù)據(jù)集
該數(shù)據(jù)集包含 5,000 個(gè)通過 ECG 獲得的時(shí)間序列樣本,樣本一共具有 140 個(gè)時(shí)間步長。每個(gè)序列對(duì)應(yīng)于一個(gè)患有充血性心力衰竭的患者的一次心跳。
心電圖(ECG 或 EKG)是一種通過測量心臟的電活動(dòng)來檢查心臟功能的測試。每一次心跳,都會(huì)有一個(gè)電脈沖(或電波)穿過您的心臟。這種波會(huì)導(dǎo)致肌肉擠壓并從心臟泵出血液。?來源[1]
我們有 5 種類型的心跳類別,他們分別是:
正常 (N)
室性早搏 (R-on-T PVC)
室性早搏 (PVC)
室上性早搏或異位搏動(dòng)(SP 或 EB)
未分類的搏動(dòng) (UB)。
假設(shè)心臟健康,典型心率為每分鐘 70 到 75 次,每個(gè)心動(dòng)周期或心跳大約需要 0.8 秒才能完成該周期。頻率:每分鐘 60–100 次(人類)持續(xù)時(shí)間:0.6–1 秒(人類)?來源[2]
如果你的設(shè)備安裝有 GPU,這將是非常好的,因?yàn)樗倪\(yùn)行速度更快,可以節(jié)約你寶貴的時(shí)間。
device?=?torch.device("cuda"?if?torch.cuda.is_available()?else?"cpu")如下圖所示,數(shù)據(jù)有多種格式,我們將加載.arff格式的文件到pandas數(shù)據(jù)幀中。
數(shù)據(jù)獲取方法:?在公眾號(hào)『機(jī)器學(xué)習(xí)研習(xí)院』中后臺(tái)消息框回復(fù)【heart】免費(fèi)獲取!
with?open('./data/ECG5000/ECG5000_TRAIN.arff')?as?f:train?=?a2p.load(f) with?open('./data/ECG5000/ECG5000_TEST.arff')?as?f:test?=?a2p.load(f)把訓(xùn)練和測試數(shù)據(jù)組合成一個(gè)單一的數(shù)據(jù)框。兩者的加成,將為我們提供更多數(shù)據(jù)來訓(xùn)練我們的自動(dòng)編碼器。
df?=?train.append(test) df?=?df.sample(frac=1.0) df.shape(5000, 141)看下數(shù)據(jù)集樣貌。
df.head()我們有5000個(gè)例子。每一行代表一個(gè)心跳記錄。我們重新命名所有的類。并將最后一列重命名為target,這樣在后面引用它將更為方便。
CLASS_NORMAL?=?1 class_names?=?['Normal','R?on?T','PVC','SP','UB']new_columns?=?list(df.columns) new_columns[-1]?=?'target' df.columns?=?new_columns探索性數(shù)據(jù)分析
通過函數(shù)value_counts()可以看看每個(gè)不同的心跳類分別有多少個(gè)樣本。
df.target.value_counts()1 2919 2 1767 4 194 3 96 5 24 Name: target, dtype: int64當(dāng)然,為了更加直觀,我們通過可視化方法將心跳類別通過sns.countplot()清晰展示出。
ax?=?sns.countplot(x="target",?data=df,order?=?df['target'].value_counts().index) ax.set_xticklabels(class_names);通過統(tǒng)計(jì)分析,我們發(fā)現(xiàn)普通類的樣本最多。這個(gè)結(jié)果是非常理想的,也是意料之中的(異常檢測中的異常往往是最少的),又因?yàn)槲覀冃枰褂眠@些正常類的數(shù)據(jù)來訓(xùn)練模型。
接下來,我們看一下每個(gè)類的平均時(shí)間序列(前面和后面做一個(gè)標(biāo)準(zhǔn)差平滑)。
首先定義一個(gè)輔助繪圖函數(shù)。
def?plot_time_series_class(data,?class_name,?ax,?n_steps=10):"""param?data:數(shù)據(jù)param?class_name:?不同心跳類名param?ax:畫布"""time_series_df?=?pd.DataFrame(data)#?平滑時(shí)間窗口smooth_path?=?time_series_df.rolling(n_steps).mean()#?路徑偏差path_deviation?=?2?*?time_series_df.rolling(n_steps).std()#?以正負(fù)偏差上下定義界限under_line?=?(smooth_path?-?path_deviation)[0]over_line?=?(smooth_path?+?path_deviation)[0]#?繪制平滑曲線ax.plot(smooth_path,?linewidth=2)ax.fill_between(path_deviation.index,under_line,over_line,alpha=.125)ax.set_title(class_name)根據(jù)上面的定義的輔助函數(shù),循環(huán)繪制每個(gè)心跳類的平滑曲線。
#?獲取所有不同心跳類別 classes?=?df.target.unique() #?定義畫布 fig,?axs?=?plt.subplots(nrows=len(classes)?//?3?+?1,ncols=3,sharey=True,figsize=(14,?8)) #?循環(huán)繪制曲線 for?i,?cls?in?enumerate(classes):ax?=?axs.flat[i]data?=?df[df.target?==?cls]?\.drop(labels='target',?axis=1)?\.mean(axis=0)?\.to_numpy()plot_time_series_class(data,?class_names[i],?ax)fig.delaxes(axs.flat[-1]) fig.tight_layout();根據(jù)上面五種心跳類的可視化結(jié)果看出,正常類具有與所有其他類明顯不同的特征,這也許就是我們構(gòu)建的模型能夠檢測出異常的關(guān)鍵所在。
LSTM 自動(dòng)編碼器
自動(dòng)編碼器是個(gè)啥
自編碼器模型架構(gòu)圖解自動(dòng)編碼器模型是一種神經(jīng)網(wǎng)絡(luò),旨在以無監(jiān)督的方式學(xué)習(xí)恒等函數(shù)以重建原始輸入,同時(shí)在此過程中壓縮數(shù)據(jù),從而發(fā)現(xiàn)更有效和壓縮的表示。
該網(wǎng)絡(luò)可以看作由兩部分組成:一個(gè)編碼器函數(shù)??和一個(gè)生成重構(gòu)的解碼器?
編碼器網(wǎng)絡(luò):將原始的高維輸入轉(zhuǎn)換為潛在的低維代碼。輸入尺寸大于輸出尺寸。
解碼器網(wǎng)絡(luò):解碼器網(wǎng)絡(luò)從代碼中恢復(fù)數(shù)據(jù),輸出層可能越來越大。
編碼器網(wǎng)絡(luò)本質(zhì)上完成了降維,就像我們?nèi)绾问褂弥鞒煞址治?#xff08;PCA)或矩陣分解(MF)一樣。此外,自動(dòng)編碼器針對(duì)代碼中的數(shù)據(jù)重構(gòu)進(jìn)行了顯式優(yōu)化。一個(gè)好的中間表示不僅可以捕獲潛在變量,而且有利于完整的解壓過程。
該模型包含由???參數(shù)化的編碼器函數(shù)??和由?θ?參數(shù)化的解碼器函數(shù)?。在瓶頸層為輸入x學(xué)習(xí)的低維代碼為??,重構(gòu)輸入為?θ?。
參數(shù) (θ,?) 一起學(xué)習(xí)以輸出與原始輸入相同的重構(gòu)數(shù)據(jù)樣本,θ?,或者換句話說,學(xué)習(xí)恒等函數(shù)。有多種指標(biāo)可以量化兩個(gè)向量之間的差異,例如激活函數(shù)為 sigmoid 時(shí)的交叉熵,或者像 MSE 損失一樣簡單:
θ?θ?
心電數(shù)據(jù)異常檢測
我們將使用正常的心跳作為模型的訓(xùn)練數(shù)據(jù),并記錄重構(gòu)損失。但首先需要準(zhǔn)備數(shù)據(jù)。
數(shù)據(jù)預(yù)處理
獲取所有正常的心跳并刪除目標(biāo)類的列。
normal_df?=?df[df.target?==?str(CLASS_NORMAL)].drop(labels='target',?axis=1) normal_df.shape(2919, 140)合并所有其他類并將它們標(biāo)記為異常。
anomaly_df?=?df[df.target?!=?str(CLASS_NORMAL)].drop(labels='target',?axis=1) anomaly_df.shape(2081, 140)將正常類樣本分為訓(xùn)練集、驗(yàn)證集和測試集。
train_df,?val_df?=?train_test_split(normal_df,test_size=0.15,random_state=RANDOM_SEED)val_df,?test_df?=?train_test_split(val_df,test_size=0.33,?random_state=RANDOM_SEED)需要將樣本轉(zhuǎn)換為張量,使用它們來訓(xùn)練自動(dòng)編碼器。為此編寫一個(gè)輔助函數(shù)來實(shí)現(xiàn)樣本數(shù)據(jù)類型的轉(zhuǎn)換,以便后續(xù)復(fù)用。
def?create_dataset(df):sequences?=?df.astype(np.float32).to_numpy().tolist()dataset?=?[torch.tensor(s).unsqueeze(1).float()?for?s?in?sequences]n_seq,?seq_len,?n_features?=?torch.stack(dataset).shapereturn?dataset,?seq_len,?n_features關(guān)于torch.unsqueeze()?和?torch.stack()?詳解可參見文末。
轉(zhuǎn)換示例:
每個(gè)時(shí)間序列將被轉(zhuǎn)換為形狀?序列長度?x *特征數(shù)量 *的二維張量 。在我們的例子中為140x1的二維張量。
接下來將所有需要用到的數(shù)據(jù)集進(jìn)行如上轉(zhuǎn)換。
#?_?表示不需要該項(xiàng) train_dataset,?seq_len,?n_features?=?create_dataset(train_df) val_dataset,?_,?_?=?create_dataset(val_df) test_normal_dataset,?_,?_?=?create_dataset(test_df) test_anomaly_dataset,?_,?_?=?create_dataset(anomaly_df)構(gòu)建 LSTM 自動(dòng)編碼器
自動(dòng)編碼器的工作是獲取一些輸入數(shù)據(jù),將其通過模型傳遞,并獲得輸入的重構(gòu),重構(gòu)應(yīng)該盡可能匹配輸入。
從某種意義上說,自動(dòng)編碼器試圖只學(xué)習(xí)數(shù)據(jù)中最重要的特征,這里使用幾個(gè) LSTM 層(即LSTM Autoencoder)來捕獲數(shù)據(jù)的時(shí)間依賴性。接下來我們一起看看如何將時(shí)間序列數(shù)據(jù)提供給自動(dòng)編碼器。
為了將序列分類為正常或異常,需要設(shè)定一個(gè)閾值,并規(guī)定高于該閾值時(shí),心跳是異常的。
重構(gòu)損失
當(dāng)訓(xùn)練一個(gè)自動(dòng)編碼器時(shí),模型目標(biāo)是盡可能地重構(gòu)輸入。這里的目標(biāo)是通過最小化損失函數(shù)來實(shí)現(xiàn)的(就像在監(jiān)督學(xué)習(xí)中一樣)。這里所使用的損失函數(shù)被稱為重構(gòu)損失。常用的重構(gòu)損失是交叉熵?fù)p失和均方誤差。
接下來將以GitHub[3]中的 LSTM Autoencoder為基礎(chǔ),并進(jìn)行一些小調(diào)整。因?yàn)槟P偷墓ぷ魇侵亟〞r(shí)間序列數(shù)據(jù),因此該模型需要從編碼器開始定義。
class?Encoder(nn.Module):"""定義一個(gè)編碼器的子類,繼承父類?nn.Modul"""def?__init__(self,?seq_len,?n_features,?embedding_dim=64):super(Encoder,?self).__init__()self.seq_len,?self.n_features?=?seq_len,?n_featuresself.embedding_dim,?self.hidden_dim?=?embedding_dim,?2?*?embedding_dim#?使用雙層LSTMself.rnn1?=?nn.LSTM(input_size=n_features,hidden_size=self.hidden_dim,num_layers=1,batch_first=True)self.rnn2?=?nn.LSTM(input_size=self.hidden_dim,hidden_size=embedding_dim,num_layers=1,batch_first=True)def?forward(self,?x):x?=?x.reshape((1,?self.seq_len,?self.n_features))x,?(_,?_)?=?self.rnn1(x)x,?(hidden_n,?_)?=?self.rnn2(x)return?hidden_n.reshape((self.n_features,?self.embedding_dim))編碼器使用兩個(gè)LSTM層壓縮時(shí)間序列數(shù)據(jù)輸入。
接下來,我們將使用Decoder對(duì)壓縮表示進(jìn)行解碼。
class?Decoder(nn.Module):"""定義一個(gè)解碼器的子類,繼承父類?nn.Modul"""def?__init__(self,?seq_len,?input_dim=64,?n_features=1):super(Decoder,?self).__init__()self.seq_len,?self.input_dim?=?seq_len,?input_dimself.hidden_dim,?self.n_features?=?2?*?input_dim,?n_featuresself.rnn1?=?nn.LSTM(input_size=input_dim,hidden_size=input_dim,num_layers=1,batch_first=True)self.rnn2?=?nn.LSTM(input_size=input_dim,hidden_size=self.hidden_dim,num_layers=1,batch_first=True)self.output_layer?=?nn.Linear(self.hidden_dim,?n_features)def?forward(self,?x):x?=?x.repeat(self.seq_len,?self.n_features)x?=?x.reshape((self.n_features,?self.seq_len,?self.input_dim))x,?(hidden_n,?cell_n)?=?self.rnn1(x)x,?(hidden_n,?cell_n)?=?self.rnn2(x)x?=?x.reshape((self.seq_len,?self.hidden_dim))return?self.output_layer(x)編碼器和解碼器均包含兩個(gè) LSTM 層和一個(gè)提供最終重建的輸出層。
這里將所有內(nèi)容包裝成一個(gè)易于使用的模塊了。
class?RecurrentAutoencoder(nn.Module):"""定義一個(gè)自動(dòng)編碼器的子類,繼承父類?nn.Module并且自動(dòng)編碼器通過編碼器和解碼器傳遞輸入"""def?__init__(self,?seq_len,?n_features,?embedding_dim=64):super(RecurrentAutoencoder,?self).__init__()self.encoder?=?Encoder(seq_len,?n_features,?embedding_dim).to(device)self.decoder?=?Decoder(seq_len,?embedding_dim,?n_features).to(device)def?forward(self,?x):x?=?self.encoder(x)x?=?self.decoder(x)return?x自動(dòng)編碼器類已經(jīng)定義好,接下來創(chuàng)建一個(gè)它的實(shí)例。
model?=?RecurrentAutoencoder(seq_len,?n_features,?128) model?=?model.to(device)訓(xùn)練模型
自動(dòng)編碼器模型已經(jīng)定義好。接下來需要訓(xùn)練模型。下面為訓(xùn)練過程編寫一個(gè)輔助函數(shù)train_model。
def?train_model(model,?train_dataset,?val_dataset,?n_epochs):optimizer?=?torch.optim.Adam(model.parameters(),?lr=1e-3)criterion?=?nn.L1Loss(reduction='sum').to(device)history?=?dict(train=[],?val=[])best_model_wts?=?copy.deepcopy(model.state_dict())best_loss?=?10000.0for?epoch?in?range(1,?n_epochs?+?1):model?=?model.train()train_losses?=?[]for?seq_true?in?train_dataset:optimizer.zero_grad()seq_true?=?seq_true.to(device)seq_pred?=?model(seq_true)loss?=?criterion(seq_pred,?seq_true)loss.backward()optimizer.step()train_losses.append(loss.item())val_losses?=?[]model?=?model.eval()with?torch.no_grad():for?seq_true?in?val_dataset:seq_true?=?seq_true.to(device)seq_pred?=?model(seq_true)loss?=?criterion(seq_pred,?seq_true)val_losses.append(loss.item())train_loss?=?np.mean(train_losses)val_loss?=?np.mean(val_losses)history['train'].append(train_loss)history['val'].append(val_loss)if?val_loss?<?best_loss:best_loss?=?val_lossbest_model_wts?=?copy.deepcopy(model.state_dict())print(f'Epoch?{epoch}:?train?loss?{train_loss}?val?loss?{val_loss}')model.load_state_dict(best_model_wts)return?model.eval(),?history在每個(gè)epoch中,訓(xùn)練過程為模型提供所有訓(xùn)練樣本,并評(píng)估驗(yàn)證集上的模型效果。注意,這里使用的批處理大小為1 ,即模型一次只能得到一個(gè)序列。另外還記錄了過程中的訓(xùn)練和驗(yàn)證集損失。
值得注意的是,重構(gòu)時(shí)做的是最小化L1損失,它測量的是 MAE(平均絕對(duì)誤差),似乎比 MSE(均方誤差)更好。
最后,我們將獲得具有最小驗(yàn)證誤差的模型,并使用該模型進(jìn)行接下來的異常檢測預(yù)。現(xiàn)在開始做一些訓(xùn)練。
#?這一步耗時(shí)較長 model,?history?=?train_model(model,?train_dataset,?val_dataset,?n_epochs=150 )繪制模型損失
繪制模型在訓(xùn)練和測試數(shù)據(jù)集上面的損失曲線。
ax?=?plt.figure().gca() ax.plot(history['train']) ax.plot(history['val']) plt.ylabel('Loss') plt.xlabel('Epoch') plt.legend(['train',?'test']) plt.title('Loss?over?training?epochs') plt.show();從可視化結(jié)果看出,我們所訓(xùn)練的模型收斂得很好。看起來我們可能需要一個(gè)更大的驗(yàn)證集來優(yōu)化模型,但本文就不做展開了,現(xiàn)在就這樣了。
保存模型
存儲(chǔ)模型以備后用。模型保存是必須要做的,他是保存和避免我們寶貴工作不被浪費(fèi)的重要步驟。
MODEL_PATH?=?'model.pth' torch.save(model,?MODEL_PATH)如果要下載和加載預(yù)訓(xùn)練模型,請(qǐng)取消注釋下一行。
#?model?=?torch.load('model.pth') #?model?=?model.to(device)設(shè)定閾值
有了訓(xùn)練好了的模型,可以看看訓(xùn)練集上的重構(gòu)誤差。同樣編寫一個(gè)輔助函數(shù)來使用模型預(yù)測結(jié)果。
def?predict(model,?dataset):predictions,?losses?=?[],?[]criterion?=?nn.L1Loss(reduction='sum').to(device)with?torch.no_grad():model?=?model.eval()for?seq_true?in?dataset:seq_true?=?seq_true.to(device)seq_pred?=?model(seq_true)loss?=?criterion(seq_pred,?seq_true)predictions.append(seq_pred.cpu().numpy().flatten())losses.append(loss.item())return?predictions,?losses該預(yù)測函數(shù)遍歷數(shù)據(jù)集中的每個(gè)樣本并記錄預(yù)測結(jié)果和損失。
_,?losses?=?predict(model,?train_dataset) sns.distplot(losses,?bins=50,?kde=True);從圖結(jié)果看,該閾值設(shè)定為26較為合適。
THRESHOLD?=?26評(píng)估
利用上面設(shè)定的閾值,我們可以將問題轉(zhuǎn)化為一個(gè)簡單的二分類任務(wù):
如果一個(gè)例子的重構(gòu)損失低于閾值,我們將其歸類為"正常"心跳
或者,如果損失高于閾值,我們會(huì)將其歸類為**"異常"**
正常心跳
我們檢查一下模型在正常心跳上的表現(xiàn)如何。這里使用新的測試集中的正常心跳。
predictions,?pred_losses?=?predict(model,?test_normal_dataset) sns.distplot(pred_losses,?bins=50,?kde=True);計(jì)算下模型預(yù)測正確的樣本有多少。
correct?=?sum(l?<=?THRESHOLD?for?l?in?pred_losses) print(f'Correct?normal?predictions:?{correct}/{len(test_normal_dataset)}')Correct normal predictions: 142/145異常心跳
我們對(duì)異常樣本執(zhí)行相同的操作,由于異常心跳和正常心跳的樣本數(shù)量不一致,因此需要獲得一個(gè)與正常心跳大小相同的子集,并對(duì)異常子集進(jìn)行模型的預(yù)測。
anomaly_dataset?=?test_anomaly_dataset[:len(test_normal_dataset)] predictions,?pred_losses?=?predict(model,?anomaly_dataset) sns.distplot(pred_losses,?bins=50,?kde=True);最后計(jì)算高于閾值的樣本數(shù)量,而這些樣本將被視為異常心跳數(shù)據(jù)。
correct?=?sum(l?>?THRESHOLD?for?l?in?pred_losses) print(f'Correct?anomaly?predictions:?{correct}/{len(anomaly_dataset)}')Correct anomaly predictions: 142/145由此可見,我們得到了很好的結(jié)果。在現(xiàn)實(shí)項(xiàng)目中,可以根據(jù)要容忍的錯(cuò)誤類型來調(diào)整閾值。在這種情況下,可能希望誤報(bào)(正常心跳被視為異常)多于漏報(bào)(異常被視為正常)。
樣本對(duì)比觀察
可以疊加真實(shí)的和重構(gòu)的時(shí)間序列值,看看它們有多接近。得到相比的結(jié)果,可以針對(duì)一些正常和異常情況進(jìn)行處理。
#?定義輔助函數(shù) def?plot_prediction(data,?model,?title,?ax):predictions,?pred_losses?=?predict(model,?[data])ax.plot(data,?label='true')ax.plot(predictions[0],?label='reconstructed')ax.set_title(f'{title}?(loss:?{np.around(pred_losses[0],?2)})')ax.legend() #?繪圖 fig,?axs?=?plt.subplots(nrows=2,ncols=6,sharey=True,sharex=True,figsize=(22,?8))for?i,?data?in?enumerate(test_normal_dataset[:6]):plot_prediction(data,?model,?title='Normal',?ax=axs[0,?i])for?i,?data?in?enumerate(test_anomaly_dataset[:6]):plot_prediction(data,?model,?title='Anomaly',?ax=axs[1,?i])fig.tight_layout();到目前為止,該實(shí)戰(zhàn)案例已經(jīng)告一段落了。在本案例中,我們一起學(xué)習(xí)了如何使用 PyTorch 創(chuàng)建 LSTM 自動(dòng)編碼器并使用它來檢測 ECG 數(shù)據(jù)中的心跳異常。
附
torch.unsqueeze 詳解
torch.unsqueeze(input,?dim,?out=None)返回一個(gè)新的張量,對(duì)輸入的既定位置插入維度 1
作用:擴(kuò)展維度
注意:?返回張量與輸入張量共享內(nèi)存,所以改變其中一個(gè)的內(nèi)容會(huì)改變另一個(gè)。
參數(shù):
tensor?(Tensor) – 輸入張量
dim?(int) – 插入維度的索引,如果dim為負(fù),則將會(huì)被轉(zhuǎn)化dim+input.dim()+1
out?(Tensor, optional) – 結(jié)果張量
例子:
x?=?torch.Tensor([1,?2,?3,?4]) torch.unsqueeze(x,?0)?? >>>?tensor([[1.,?2.,?3.,?4.]]) torch.unsqueeze(x,?1) >>>?tensor([[1.],[2.],[3.],?[4.]])torch.stack() 詳解
沿著一個(gè)新維度對(duì)輸入張量序列進(jìn)行連接。序列中所有的張量都應(yīng)該為相同形狀。
簡而言之:把多個(gè)二維的張量湊成一個(gè)三維的張量;多個(gè)三維的湊成一個(gè)四維的張量…以此類推,也就是在增加新的維度進(jìn)行堆疊。
outputs?=?torch.stack(inputs,?dim=0)?→?Tensor參數(shù):
inputs?(sequence of Tensors) - 待連接的張量序列。
注:python的序列數(shù)據(jù)只有l(wèi)ist和tuple。函數(shù)中的輸入inputs只允許是序列;且序列內(nèi)部的張量元素,必須shape相等。dim?(int) 新的維度, 必須在0到len(outputs)之間。注:len(outputs)是生成數(shù)據(jù)的維度大小,也就是outputs的維度值。dim是選擇生成的維度,必須滿足0<=dim<len(outputs);len(outputs)是輸出后的tensor的維度大小。
例子:
#?假設(shè)是時(shí)間步T1的輸出 T1?=?torch.tensor([[1,?2,?3],[4,?5,?6],[7,?8,?9]]) #?假設(shè)是時(shí)間步T2的輸出 T2?=?torch.tensor([[10,?20,?30],[40,?50,?60],[70,?80,?90]]) torch.stack((T1,T2),dim=1) >>>?tensor([[[?1,??2,??3], ...????????[10,?20,?30]], ... ...???????[[?4,??5,??6], ...????????[40,?50,?60]], ... ...???????[[?7,??8,??9], ...????????[70,?80,?90]]]) torch.stack((T1,T2),dim=0) >>>?tensor([[[?1,??2,??3], ...?????????[?4,??5,??6], ...?????????[?7,??8,??9]], ... ...????????[[10,?20,?30], ...?????????[40,?50,?60], ...?????????[70,?80,?90]]])參考資料
[1]?
來源:?https://www.heartandstroke.ca/heart/tests/electrocardiogram
[2]?來源:?https://en.wikipedia.org/wiki/Cardiac_cycle
[3]?GitHub:?https://github.com/shobrook/sequitur
[4]?參考原文:?https://curiousily.com/posts/time-series-anomaly-detection-using-lstm-autoencoder-with-pytorch-in-python/
[5]?Sequitur - Recurrent Autoencoder (RAE):?https://github.com/shobrook/sequitur
[6]?Towards Never-Ending Learning from Time Series Streams:?https://www.cs.ucr.edu/~eamonn/neverending.pdf
[7]?LSTM Autoencoder for Anomaly Detection:?https://towardsdatascience.com/lstm-autoencoder-for-anomaly-detection-e1f4f2ee7ccf
往期精彩回顧適合初學(xué)者入門人工智能的路線及資料下載中國大學(xué)慕課《機(jī)器學(xué)習(xí)》(黃海廣主講)機(jī)器學(xué)習(xí)及深度學(xué)習(xí)筆記等資料打印機(jī)器學(xué)習(xí)在線手冊深度學(xué)習(xí)筆記專輯《統(tǒng)計(jì)學(xué)習(xí)方法》的代碼復(fù)現(xiàn)專輯 AI基礎(chǔ)下載本站qq群955171419,加入微信群請(qǐng)掃碼:總結(jié)
以上是生活随笔為你收集整理的【深度学习】在PyTorch中使用 LSTM 自动编码器进行时间序列异常检测的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Android如何回调编码后的音视频数据
- 下一篇: 别人家的孩子!高校博士实现Nature、