Social GAN: Socially Acceptable Trajectories with Generative Adversarial Networks 中文翻译
Social GAN: Socially Acceptable Trajectories with Generative Adversarial Networks 中文翻譯
如有異議,請多指教,非專業(yè)人員,僅供參考
摘要
理解人類的運動行為對于自主移動平臺(如自動駕駛汽車和社交機器人)在以人類為中心(human-centric)的環(huán)境中導(dǎo)航至關(guān)重要。這是一項具有挑戰(zhàn)性的任務(wù),因為人類的運動本質(zhì)上是多模態(tài)的:根據(jù)人的歷史行動軌跡,在接下來的步驟中人類可以有很多條路作為選擇。我們通過結(jié)合序列預(yù)測和生成對抗網(wǎng)絡(luò)來解決這個問題:使用一個周期性的序列到序列(sequence-to-sequence)模型觀察運動歷史并預(yù)測未來的行為,使用一個新穎的池化機制來聚集人們之間的信息。我們通過對抗訓(xùn)練來預(yù)測可信的未來行為,并使用新型的多樣性損失函數(shù)來鼓勵多樣化預(yù)測。通過對幾個數(shù)據(jù)集的實驗,我們證明了我們的方法在準(zhǔn)確性(accuracy)、多樣性(variety)、避免碰撞(collision avoidance)和計算復(fù)雜度(computational complexity)方面優(yōu)于先前的工作。
1.介紹
預(yù)測行人的運動行為對于自動駕駛汽車或社交機器人等與人類共享同一生態(tài)系統(tǒng)的自主移動平臺來說至關(guān)重要。人類能夠有效地處理復(fù)雜的社交(social interaction),而這些機器也應(yīng)該能夠做到這一點。為此,一個具體而重要的任務(wù)是:給定行人的觀察運動軌跡( 舉例:過去3.2秒的坐標(biāo)),預(yù)測所有可能的未來軌跡,見 <圖1>
由于擁擠場景中人類運動的固有特性,因此預(yù)測人類行為是一項挑戰(zhàn):
圖1 說明兩個行人想避開對方的場景。有許多可能的方法可以避免潛在的沖突。我們提出了一個方法,給定相同的可觀察過去路徑(observed past),在擁擠的場景中預(yù)測出多個接近人類真實行為(socially acceptable)的輸出。
1.人際互動(Interpersonal)
每個人的行動都取決于周圍的人。人類有一種天生的能力,能夠在人群中解讀他人的行為。聯(lián)合建模這些相關(guān)性(dependencies)是一個挑戰(zhàn)。
2.社會可以接受性(socially acceptable)
有些軌跡在物理上是可能的,但在社會中這種行為是不可能發(fā)生的。行人的行為是受社會規(guī)則約束的,比如:讓路(yielding right-of-way)或尊重個人空間。而將這些行為形式化并非易事。
3.多模式(multimodal)
考慮到歷史行為,沒有單一正確的未來預(yù)測。多種可能的發(fā)展軌跡是合理的,也是符合社會規(guī)律的。
軌跡預(yù)測的探索工作已經(jīng)解決了上述一些挑戰(zhàn)。interpersonal 的相關(guān)問題已經(jīng)被基于手工特征 (hand-crafted feature) [2,7,41,46]的傳統(tǒng)方法完全解決。最近,基于遞歸神經(jīng)網(wǎng)絡(luò)(RNNs)的數(shù)據(jù)驅(qū)動技術(shù)(data-driven techniques)[1,28,12,4]重新探討了社會可接受性。最后,在給定靜態(tài)場景[28, 24](例如,十字路口應(yīng)該走哪條街)的情況下,研究了有關(guān)該問題多模態(tài)(multimodal)的方面。Robicquet等人[38]證實了在行人在面對不同的場景(溫和的或者激進的場景之下)會采取多重的導(dǎo)航(navigation styles)方式。因此,預(yù)測任務(wù)需要輸出不同的可能結(jié)果。
雖然現(xiàn)有的方法在應(yīng)對這些具體挑戰(zhàn)方面取得了很大的進展,但它們存在兩個局限性:
i) 首先,在做出預(yù)測時,他們模擬每個人周邊的一個區(qū)域。因此,它們在對場景中所有人之間的交互進行建模的時候不具備很高的計算效率(computationally efficient fashion)。
ii) 第二,他們傾向于學(xué)習(xí)“平均行為”,因為通常使用的損失函數(shù),可以最小化預(yù)測輸出和ground truth之間的歐氏距離。相反,我們的目標(biāo)是學(xué)習(xí)多種“表現(xiàn)優(yōu)異的行為”,即:,多重的可以被社會所接受的行為軌跡。
為了解決以往工作的局限性,我們建議利用生成模型(recent progress in generative models)的最新進展。生成式對抗網(wǎng)絡(luò)(GANs)是近年來發(fā)展起來的一種用于解決:難以處理的概率計算和行為推斷的逼近困難等問題[14]的網(wǎng)絡(luò)。雖然他們已經(jīng)被用來產(chǎn)生逼真的信號,如圖像[34],我們提出:給定一個可觀察的過去狀態(tài),使用他們(GAN)來產(chǎn)生多個“社會可接受(socially acceptable)”的軌跡。一個網(wǎng)絡(luò)(生成器)生成候選對象,另一個(鑒別器)對它們進行評估。對抗性損失(adversarial loss) 使我們的預(yù)測模型能夠超越 L2損失 的限制,并有可能了解那些能夠欺騙鑒別者的“良好行為”的分布。在我們的工作中,這些行為被稱為“在擁擠的場景中的 ‘ 社會可接受 ’ 運動軌跡”。
我們提出的GAN是一個RNN編碼器-解碼器生成器和一個基于RNN的編碼器鑒別器(encoder discriminator),具有以下兩個新穎之處:
i) 我們引入了多樣性損失(variety loss),這鼓勵了GAN的生成網(wǎng)絡(luò)擴展其分布并覆蓋可能的路徑空間,同時與觀察到的輸入保持一致(being consistent with the observed input)。
ii) 我們提出了一個新的池化機制(pooling mechanism),它學(xué)習(xí)一個“全局”池化向量,為所有參與場景的人編碼細(xì)致的線索。
通過對幾個公開的真實世界人群數(shù)據(jù)集的實驗,我們展示了最先進的準(zhǔn)確性、速度,并證明我們的模型有能力產(chǎn)生各種“社會可接受”的軌跡
2.相關(guān)工作
預(yù)測人類行為的研究可以分(can be grouped as)為學(xué)習(xí)預(yù)測人與空間的相互作用或人與人的相互作用。前者學(xué)習(xí)場景特定的動作模式(scene-specific motion patterns) [3, 9, 18, 21, 24, 33, 49],后者模擬場景的動態(tài)內(nèi)容,即行人之間如何相互影響(dynamic content of scenes)。
我們工作的重點是后者: 學(xué)習(xí)預(yù)測人與人之間的互動。我們討論了這方面的現(xiàn)有工作,以及RNN在序列預(yù)測和生成模型方面的相關(guān)工作。
- 人與人的交互(human-human interaction)
從宏觀模型的人群視角(macroscopic models)和微觀模型(microscopic models)的個體視角(我們工作的重點)對人類行為進行了研究。微觀模型的一個例子是 Helbing和Molnar [17] 對行人行為進行了建模,引力(attractive force)引導(dǎo)他們朝著目標(biāo)前進,排斥力(repulsive force)鼓勵他們避免碰撞。在過去的幾十年里,這種方法經(jīng)常被重新使用 [5, 6, 25, 26, 30, 31, 36, 46]。經(jīng)濟學(xué)中流行的工具也被投入使用,如Antonini等人的離散選擇框架(discrete choice framework) [2]。Treuille等人 [42]采用連續(xù)體動力學(xué),Wang等人 [44],Tay等人 [41]使用高斯過程。這些函數(shù)也被用于研究固定群組(stationary groups) [35, 47]。然而,所有這些方法都使用基于相對距離和特定規(guī)則的手工制作的能量勢(hand crafted energy potentials based on relative distances and specific rules)。相比之下,在過去的兩年中,基于RNNs的數(shù)據(jù)驅(qū)動方法(data-driven methods)已經(jīng)超越上述傳統(tǒng)方法。
- 用于序列預(yù)測的RNNs(RNNs for sequence prediction)
遞歸神經(jīng)網(wǎng)絡(luò)是一類豐富的動態(tài)模型,它將前饋網(wǎng)絡(luò)擴展到多個領(lǐng)域進行序列生成如語音識別 [7, 8, 15],機器翻譯 [8],為圖像添加字幕 [20, 43, 45, 39]但這些應(yīng)用缺乏高層次的時空結(jié)構(gòu) [29],人們多次嘗試使用多個網(wǎng)絡(luò)來捕獲復(fù)雜的交互 [1, 10, 40],Alahi等人的 [1]使用一個social pooling 層,模擬附近的行人。在本文的其余部分中,我們證明了使用多層感知器(multi-layer perceptron)(MLP)再進行最大池化(max pooling)在計算上更有效,并且與來自 [1] 的social pooling 表現(xiàn)相近或更好。Lee等人 [28]介紹了一個采用變分自編碼器(variational autoencoder)(VAE)的RNN編解碼框架從而進行軌跡預(yù)測的。然而,他們并沒有在擁擠的場景中模擬人與人之間的互動。
- 生成模型(Generative Modeling)
生成模型如變分自編碼器 [23]是通過 最大化訓(xùn)練數(shù)據(jù)似然下界(maximize the lower bound of training data likelihood) 來訓(xùn)練的。Goodfellow等人提出了另一種方法,生成對抗網(wǎng)絡(luò)(GANs) [14],其中訓(xùn)練過程是**生成模型(generative model)和判別模型(discriminative model)**之間的極小極大博弈(minimax game)。這就克服了逼近難以計算的概率的運算困難。生成模型在超分辨率 [27]、圖像到圖像轉(zhuǎn)換 [19]和圖像合成 [16, 34, 48]等任務(wù)中顯示出良好的結(jié)果,這些任務(wù)對于給定的輸入具有多個可能的輸出。然而,它們在序列生成問題中的應(yīng)用,如自然語言處理,已經(jīng)滯后了,因為從這些生成的輸出中進行采樣并將其提供給鑒別器是一個不可微(non-differentiable)的操作。
3.理論
當(dāng)人類再人群中進行路徑規(guī)劃的時候會本能地考慮到自己身邊的人的狀態(tài)。我們計劃我們的路徑,牢記我們的目標(biāo),同時也考慮周圍人的運動,如他們的運動方向,速度等。然而,在這種情況下,通常存在多個可能的選項。我們需要的模型不僅能夠理解這些復(fù)雜的人類交互行為(human interactions),而且還能夠捕捉各種選項(capture the variety of options)。 目前的方法側(cè)重于預(yù)測未來的平均軌跡,所以他們通過最小化到ground truth的 L2 距離,而我們希望預(yù)測多個“良好”軌跡。在這個部分,我們首先提出了我們的基于編解碼結(jié)構(gòu)的GAN 來解決這些問題,接下來我們展示了我們最新的池化層結(jié)構(gòu),這是一個模擬人與人之間交互(human-human interaction)的池化層,最終我們引入了我們的多類損失(variety loss)來促使網(wǎng)絡(luò)基于給定的觀測序列可以產(chǎn)生多種不同的未來軌跡。
3.1定義問題(problem definition)
我們的目標(biāo)是聯(lián)合推理(reason)和預(yù)測(predict)一個場景中所涉及的所有對象(agent)的未來軌跡。我們假設(shè)我們接收到的輸入是場景中所有人的軌跡 X = X 1 , X 2 , … , X n , X = X_1, X_2,…,X_n, X=X1?,X2?,…,Xn?, 同時 預(yù)測未來軌跡 Y ^ = Y ^ 1 , Y ^ 2 , . . . . Y ^ n \hat {Y} =\hat Y_1,\hat Y_2,....\hat Y_n Y^=Y^1?,Y^2?,....Y^n? ,行人的輸入軌跡 i i i 定義為: X i = ( x i t , y i t ) X_i=(x_i^t,y_i^t) Xi?=(xit?,yit?),在時間步長 t = t o b s + 1 , . . . . , t p r e d t=t_{obs}+1,....,t_{pred} t=tobs?+1,....,tpred?,我們用 Y i ^ \hat {Y_i} Yi?^? 表示預(yù)測。
3.2 生成對抗網(wǎng)絡(luò)(generative adversial networks)
生成式對抗網(wǎng)絡(luò)(GAN)由兩個相互對立訓(xùn)練的神經(jīng)網(wǎng)絡(luò)組成 [14]。進行對抗訓(xùn)練的兩個模型是:一個是捕獲數(shù)據(jù)分布的生成模型 G G G,一個是估計樣本來自訓(xùn)練數(shù)據(jù)而不是 G G G 的概率的判別模型 D D D。生成器 G G G 以潛在變量 z z z 作為輸入,輸出樣本 G ( z ) G(z) G(z) 。鑒別器 D D D 把 x x x 作為輸入,輸出的 D ( x ) D(x) D(x) 代表著 x x x 真實的概率。訓(xùn)練過程類似于一個兩個人的最小最大的博弈,目標(biāo)函數(shù)如下:
m i n G m a x D V ( G , D ) = E x ? p d a t a ( x ) [ log ? D ( x ) ] + E z ? p ( z ) [ log ? ( 1 ? D ( G ( z ) ) ) ] min_{G}\ max_{D}V(G,D)=\mathbb E_{x-p_{data(x)}}[\log D(x)]+\mathbb E_{z-p_{(z)}}[\log(1-D(G(z)))] minG??maxD?V(G,D)=Ex?pdata(x)??[logD(x)]+Ez?p(z)??[log(1?D(G(z)))] (公式1)
GANs可以通過向生成器和鑒別器提供額外的輸入 c c c 來用于條件模型,生成 G ( z , c ) G(z, c) G(z,c) 和 D ( x , c ) D(x, c) D(x,c) [13, 32]
3.3 社交意識 GAN(social-aware GAN)
正如第1節(jié)所討論的,軌跡預(yù)測是一個多模態(tài)問題,生成模型可以與時間序列(timeseries)數(shù)據(jù)一起使用,以模擬可能的未來。我們在設(shè)計SGAN時利用了這一觀點,它使用GANs解決了問題的多種模態(tài)(參見圖2)
圖2 系統(tǒng)總覽 我們的模型由三個關(guān)鍵部分組成: 生成器 ( G ) (G) (G) , 池化模塊,鑒別器 ( D ) (D) (D) 。 G G G 作為過去的軌跡 X i X_i Xi? 的輸入,并且對人 i i i 的過去行為編碼為 H i t H_i^t Hit? ,池化模塊作為所有 H i t o b s H_i^{t_{obs}} Hitobs?? 的輸入并且最終對于每一個場景中的人 i i i 輸出一個池化向量 P i P_i Pi?,解碼器根據(jù) H i t o b s H_i^{t_{obs}} Hitobs?? 和 P i P_i Pi? 生成未來的軌跡。 D D D 將 T r e a l T_{real} Treal? 或 T f a k e T_{fake} Tfake? 作為輸入,并將它們歸類為社會可接受的或不可接受的(PM參見圖3)。
我們的模型由三個關(guān)鍵部分組成:生成器 ( G ) (G) (G) , 池化模塊 P M PM PM 和鑒別器 ( D ) (D) (D), G G G是G是基于編碼器-解碼器框架,我們通過 P M PM PM 來鏈接編碼器和解碼器的隱藏狀態(tài)。對 G G G輸入 X i X_i Xi?可以輸出軌跡預(yù)測 Y i ^ \hat {Y_i} Yi?^?。 D D D將包含輸入 X i X_i Xi?和預(yù)測結(jié)果輸出 Y i ^ \hat {Y_i} Yi?^?的整個序列輸入到鑒別器中,然后將他們分類為真或者是假。
【生成器】
我們首先將每個人的位置嵌入到一個單層的MLP中來獲取一個固定長度的向量 e i t e_i^t eit?。這些嵌入部分在 t 時刻作為編碼器的LSTM單元的輸入,引入如下遞歸式:
e i t = ? ( x i t , y i t , W e e ) e_i^t=\phi(x_i^t,y_i^t,W_{ee}) eit?=?(xit?,yit?,Wee?) (公式2)
h e i t = L S T M ( h e i t ? 1 , e i t ; W e n c o d e r ) h_{ei}^t=LSTM(h_{ei}^{t-1},e_i^t;W_{encoder}) heit?=LSTM(heit?1?,eit?;Wencoder?)
其中 ? \phi ? 是一個由ReLU非線性單元的嵌入函數(shù), W e e W_{ee} Wee?是一個嵌入的權(quán)重。LSTM權(quán)重 W e n c o d e r W_{encoder} Wencoder? 在一個場景中所有人共享。
簡單的對于每一個人使用一個LSTM不能獲取人與人之間的交互行為,編碼器學(xué)習(xí)每個單元的狀態(tài)并存儲它們的運動歷史。然而,正如Alahi等人所示 [1]。我們需要一個緊湊的表示,它結(jié)合了來自不同編碼器的信息,以有效地推理有關(guān)社會互動。在我們的方法中,我們通過一個池模塊(PM)對人與人的交互進行建模。 t o b s t_{obs} tobs?之后,我們把場景中所有人的隱藏狀態(tài)集合起來,得到每個人的池化張量(pooled tensor) p i p_i pi?。通常情況下,GANs用輸入的噪聲來產(chǎn)生樣例。我們的目標(biāo)是創(chuàng)造出與過去一致的未來情景。為了實現(xiàn)這一點,我們通過初始化的隱藏狀態(tài)來設(shè)定生成輸出軌跡的條件,如下:
c i t = γ ( P i , h e i t ; W c ) c_i^t=\gamma(P_i,h_{ei}^t;W_c) cit?=γ(Pi?,heit?;Wc?) (公式3)
h d i t = [ c i t , z ] h_{di}^t=[c_i^t,z] hdit?=[cit?,z]
γ ( ? ) \gamma(·) γ(?) 是一個擁有非線性單元ReLU和嵌入權(quán)重 W c W_c Wc? 的多層感知機(multi-layer perceptron)(MLP)。在軌道預(yù)測方面,我們在兩個重要方面偏離(deviate)了之前的工作:
-
之前的工作 [1] 使用隱藏狀態(tài)來預(yù)測二元高斯分布的參數(shù)。然而,這如何在不可微的情況下,通過采樣的反向傳播,在訓(xùn)練過程中引入了困難,我們通過直接預(yù)測坐標(biāo) ( x ^ i t , y ^ i t ) (\hat x_i^t,\hat y_i^t) (x^it?,y^?it?)來避免這種情況。
-
“社會”語境一般是作為輸入提供的LSTM細(xì)胞 [1, 28] 相反,我們只提供一次池化上下文(pooled context)作為解碼器的輸入。這也為我們提供了在特定時間步長的情況下選擇池的能力,并且與S-LSTM [1] 相比,速度提高了16倍(見表2)。
初始化上述解碼器狀態(tài)后,我們可以得到如下預(yù)測:
e i t = ? ( x i t ? 1 , y i t ? 1 , W e d ) e_i^t=\phi(x_i^{t-1},y_i^{t-1},W_{ed}) eit?=?(xit?1?,yit?1?,Wed?)
P i = P M ( h d 1 t ? 1 , . . . , h d n t ) P_i=PM(h_{d_1}^{t-1},...,h_{d_n}^{t}) Pi?=PM(hd1?t?1?,...,hdn?t?) (公式4)
h d i t = L S T M ( γ ( P i , h d i t ? 1 ) , e i t ; W d e c o d e r ) h_{di}^t=LSTM(\gamma (P_i,h_{d_i}^{t-1}),e_i^t;W_{decoder}) hdit?=LSTM(γ(Pi?,hdi?t?1?),eit?;Wdecoder?)
( x ^ i t , y ^ i t ) = γ ( h d i t ) (\hat x_i^t,\hat y_i^t)=\gamma(h_{d_i}^t) (x^it?,y^?it?)=γ(hdi?t?)
其中 ? ( . ) \phi(.) ?(.)是擁有非線性單元ReLU和嵌入權(quán)重 W e d W_{ed} Wed?的嵌入函數(shù)。 W d e c o d e r W_{decoder} Wdecoder?表示的是LSTM的權(quán)重, γ \gamma γ表示的是多層向量感知機(MLP)
【鑒別器】
鑒別器由一個單獨的編碼器組成。具體地說,它取輸入 T r e a l = [ X i , Y i ] 或 T f a k e = [ X i , Y ^ i ] T_{real} = [X_i,Y_i]或T_{fake}= [X_i, \hat Y_i] Treal?=[Xi?,Yi?]或Tfake?=[Xi?,Y^i?]并且分類真/假的。我們在編碼器的最后隱藏狀態(tài)上應(yīng)用一個多層向量感知機(MLP)來獲得一個分類的分?jǐn)?shù)。理想情況下,“鑒別者”將學(xué)習(xí)微妙的社會互動規(guī)則,并將社會不能接受的軌跡歸類為“假軌跡”。
【損失】
除了對抗性損失(adversarial loss)外,我們還將L2損失應(yīng)用于預(yù)測軌跡,該軌跡測量生成的樣本與實際groundtruth真實值之間的差距。
3.4. 池化模塊
為了在多人之間共同推理,我們需要一種機制來在LSTMs之間共享信息,然而,有幾個挑戰(zhàn)的方法應(yīng)該解決:
-
一個場景中可能有很多人。我們需要一個緊湊表示,從所有人那里收集信息。
-
分散的人-人互動。本地信息并不總是足夠的,遠(yuǎn)處的行人可能會互相影響。因此,網(wǎng)絡(luò)需要對全局配置建模。
圖3 我們通過圖中紅色的人來比較我們的池化機制(紅色虛線箭頭)和社交池化(social-pooling)[1](紅色虛線格)之間的差距。 我們的方法計算紅色的人和所有其他人之間的相對位置; 這些位置與每個人的隱藏狀態(tài)連接,由MLP(多層感知機)獨立處理,然后匯集元素以計算紅人的池化向量 P 1 P_1 P1?。 社交池只考慮網(wǎng)格內(nèi)的人,并且不能模擬所有人之間的交互。
社交池化 [1]通過提出一個基于網(wǎng)格的池化方案來解決第一個問題。然而,這個人工制作的解決方案速度很慢,并且不能捕獲全局上下文。Qi等 [37] 表明,在輸入點集的變換元素上應(yīng)用一個學(xué)習(xí)的對稱函數(shù)可以實現(xiàn)上述性質(zhì)。如圖2所示,這可以通過通過一個多層向量感知機(MLP)和一個對稱函數(shù)(我們使用Max-Pooling)傳遞輸入坐標(biāo)來實現(xiàn)。 池化向量 P i P_i Pi? 需要總結(jié)一個人的所有做決定需要的信息。由于我們使用相對坐標(biāo)來表示平移不變性,所以我們用每個人相對于person i i i 的相對位置來增加池模塊的輸入。
3.5. 鼓勵產(chǎn)生多樣性樣本
軌跡預(yù)測是一個具有挑戰(zhàn)性的問題,因為考慮到過去有限的歷史,一個模型必須對多個可能的結(jié)果進行推理。到目前為止所描述的方法產(chǎn)生了良好的預(yù)測,但是這些預(yù)測試圖在可能有多個輸出的情況下產(chǎn)生“平均”預(yù)測。此外,我們發(fā)現(xiàn)輸出對噪音的變化不是很敏感,有無噪聲產(chǎn)生的預(yù)測非常相近。
我們提出了一個多樣性損失函數(shù)來鼓勵網(wǎng)絡(luò)產(chǎn)生不同的樣本。對于每個場景,我們通過從 N ( 0 , 1 ) N(0,1) N(0,1) 中隨機采樣 z z z 并根據(jù) L 2 L2 L2 意義上的“最佳”預(yù)測,生成 k k k 個可能的輸出預(yù)測。
L v a r i e t y = m i n k ∣ ∣ Y i ? Y i ^ ( k ) ∣ ∣ 2 \frak L_{variety}=min_{k}||Y_i- {\hat{Y_i}}^{(k)}||_2 Lvariety?=mink?∣∣Yi??Yi?^?(k)∣∣2? (公式5)
其中 k k k 是超參數(shù)
通過僅考慮最佳軌跡,這種損失促使網(wǎng)絡(luò)進行“減小錯誤的兩方面預(yù)測(hedge the bet)”并覆蓋符合過去軌跡的輸出空間。 這個損失在結(jié)構(gòu)上類似于“最小化N(MoN)損失[11]”,但據(jù)我們所知,這并未在GAN的背景下用于鼓勵生成樣本的多樣性。
表1 跨數(shù)據(jù)集的所有方法的定量結(jié)果。我們報告了兩個誤差指標(biāo)平均位移誤差(ADE)和最終位移誤差(FDE), t p r e d = 8 t_{pred}= 8 tpred?=8 和 t p r e d = 12 t_{pred}= 12 tpred?=12 (8 / 12) 單位是米。我們的方法始終優(yōu)于最先進的 S-LSTM方法,尤其適用于長期預(yù)測(圖中的值越低越好)。
3.6 實驗細(xì)節(jié)
我們在解碼器和編碼器模型中使用 L S T M LSTM LSTM。 編碼器隱藏狀態(tài)的大小為 16 16 16,解碼器為 32 32 32 。我們將輸入坐標(biāo)嵌入為 16 16 16 維向量。 我們使用 A d a m [ 22 ] Adam [22] Adam[22]優(yōu)化器訓(xùn)練每批次數(shù)量為 64 64 64的發(fā)生器和鑒別器,迭代 200 200 200 次,初始學(xué)習(xí)率為 0.001 0.001 0.001。
4.實驗
在本節(jié)中,我們在兩個公開可用的數(shù)據(jù)集上評估我們的方法:ETH [36]和UCY [25]。 這些數(shù)據(jù)集由具有豐富的人類交互場景的真實世界人類軌跡組成。 我們將所有數(shù)據(jù)轉(zhuǎn)換為真實世界坐標(biāo)并進行插值以達(dá)到每 0.4 0.4 0.4 秒獲取一個值。 總共有 5 5 5 組數(shù)據(jù)(ETH-2, UCY-3),有 4 4 4 個不同的場景,由擁擠的環(huán)境中的 1536 1536 1536 名行人組成具有挑戰(zhàn)性的場景,如:群體行為,人們相互交叉,避免碰撞以及群體聚集和散開。
【評估指標(biāo)】
類似于先前的工作[1,28]我們使用兩個誤差指標(biāo):
① 平均位移誤差(ADE):在所有預(yù)測時間步長上, g r o u n d ? t r u t h ground-truth ground?truth 標(biāo)簽與我們預(yù)測之間的平均 L 2 L2 L2 距離。
② 最終位移誤差(FDE):在預(yù)測周期 T p r e d T_{pred} Tpred? 結(jié)束時“預(yù)測的最終目的地”與“真實最終目的地”之間的距離。
【Baseline(基線)】
我們與以下基線進行比較:
① 線性: 線性回歸量,通過最小化最小平方誤差來估計線性參數(shù)。
② LSTM: 沒有池化機制的簡單LSTM。
③ S-LSTM: Alahi等人提出的方法[1]。 每個人都通過LSTM建模,隱藏狀態(tài)在每個時間步使用“社交池(social-pooling)”層進行合并。
我們也在不同的控制設(shè)置下對我們的模型進行“切除研究”(ablation research)。我們在章節(jié)中稱我們的完整方法為 SGAN-kVP-N,其中 kV 表示模型是否使用多樣化損失進行了訓(xùn)練(k = 1基本上表示沒有使用多樣化損失),P 表示使用我們提出的池化模塊。在測試時,我們從模型中多次采樣,選擇 L2 意義下的最佳預(yù)測進行定量評估。N 是我們在測試期間從模型中采樣的時間。
【評估方法】
我們遵循與[1]類似的評估方法。 我們使用“留一法”(leave-one-out),使用4組訓(xùn)練并測試剩下的一組。 我們觀察8個步驟(3.2秒)的軌跡并顯示8個(3.2秒)和12個(4.8秒)時間步長的預(yù)測結(jié)果。
4.1定量評估
我們將兩個指標(biāo)ADE和FDE的方法與表1中的不同基線進行比較。正如預(yù)期的那樣,線性模型只能對直線路徑進行建模,并且在預(yù)測時間較長時( t p r e d = 12 t_{pred} = 12 tpred?=12)尤其糟糕。 LSTM和S-LSTM都比線性基線表現(xiàn)更好,因為它們可以模擬更復(fù)雜的軌跡。 然而,在我們的實驗中,S-LSTM并不優(yōu)于LSTM。 我們盡力重現(xiàn)論文的結(jié)果。 [1]在合成數(shù)據(jù)集上訓(xùn)練模型,然后在真實數(shù)據(jù)集上進行微調(diào)。 我們不使用合成數(shù)據(jù)來訓(xùn)練任何可能導(dǎo)致性能下降的模型。
圖4 品種損失的影響。 對于SGAN-1V-N,我們訓(xùn)練單個模型,在訓(xùn)練期間為每個序列繪制一個樣本,在測試期間繪制 N個樣本。 對于SGAN-NV-N,我們在訓(xùn)練和測試過程中使用 N個樣本訓(xùn)練多個模型以減少變種。 多樣性損失的訓(xùn)練顯性提高了準(zhǔn)確性。
總結(jié)
以上是生活随笔為你收集整理的Social GAN: Socially Acceptable Trajectories with Generative Adversarial Networks 中文翻译的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 数据分析系列之挖掘建模
- 下一篇: 帮帮忙谢谢