3D姿态估计(GAST)
1.下載數據集,只用到human3.6m數據集中的2D數據集和3D數據集。將數據集劃分為訓練集和測試集。(直接將2D數據集輸入進行訓練,得出的prediction和3D的groundtruth做比較)
2.在數據集中提取數據,得相機參數,2D關鍵點和3D關鍵點。
cameras_valid, poses_valid, poses_valid_2d = fetch(subjects_test, action_filter, dataset, keypoints, args.downsample)3.定義模型,UnchunkedGenerator等,將模型數據放入GPU。
主要看一下相對于videopose本程序做出了哪些改變。
首先就是引入了圖卷積,所以對于輸入的數據要增設他們的圖矩陣,在網絡設置的時候加上圖卷積這一模塊,與時間卷積共同構成,這就使得輸入的參數多設置了adj。圖卷積一共有兩個模塊,一個是local-attention,一個是global-attention。
local-attention中先定義SemCHGraphConv網絡框架,具體網上可以查到公式。以這個框架為基礎輸入不同的矩陣adj_sym和adj_con得到local_attention的架構。adj_sym是對“左右對稱的關節點”進行關聯的矩陣,adj_con是鄰接矩陣,一階和二階等體現在矩陣里。(論文中說對于手腕腳腕頭等關節點只有一個連接點所以*****沒仔細看,不太懂
相對比于SemGCN其實GAST并沒有改變多少,將原本的locallayer和globallayer在網絡中的順序變換了,并且加上了注意力,將關聯的節點以矩陣表達。
4.進行訓練,while epoch < args.epochs,訓練后求loss進行模型優化
5.最后進行測試,用測試集調用evaluate(),測試范化能力。
總結
以上是生活随笔為你收集整理的3D姿态估计(GAST)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: [环境搭建八] 深度学习环境搭建--常见
- 下一篇: Python之汉诺塔