点云网络的论文理解(二)- PointNet的pytorch复现
1.了解PointNet
為了更好的復現這個東西我們需要先了解這個東西,先把原文給出的圖片放在這里,之后我們再一點點理解。
1.1點云的特點
1.1.1無序性:也就是說這個點的先后順序和實際上是什么無關
你不管這些點加入集合的順序如何,最后的最后他們組成的圖形還是那么個圖形,也就是說這些東西的順序是完全沒有必要的。
所以我們必須使用對稱的函數:
也就是說,這個函數必須要滿足,你怎么調換函數變量的輸入順序,函數的計算結果也都不發生變化也就是下圖:
所以,我們看一下我們有哪些函數滿足這個特點:顯然max是滿足這個條件的,所以我們可以使用下面的方式來提取點云的特征,但是這樣做的話,損失也太大了,所以我們不能如此使用。
所以我們可以使用一個全連接層來擴大維度,這樣結果的特征損失就不那么大了。
這樣其實我們就完成了一個簡單的PointNet。這個東西我們一般將其叫做PointNet(vanilla)
好了說到這里,我們再來仔細看一下原文的圖片:
這里其實就是多個全連接層并排放在一起,這樣就能達到擴大特征的目的。
1.1.2理論證明
論文中其實有給出理論的證明,大致的意思是:任意一個在Hausdorff空間上連續的函數,都可以被這樣的PointNet(vanilla)無限的逼近。但是目前,還沒看懂,大家可以自己看一下。
1.2 旋轉無關性 也就是說這個一個兔子,你轉來轉去,他也還是一個兔子
點云的旋轉不變性指的是,給予一個點云一個旋轉,所有的x , y , z 坐標都變了,但是代表的還是同一個物體
因此對于普通的PointNet(vanilla),如果先后輸入同一個但是經過不同旋轉角度的物體,它可能不能很好地將其識別出來。在論文中的方法是新引入了一個T-Net網絡去學習點云的旋轉,將物體校準,剩下來的PointNet(vanilla)只需要對校準后的物體進行分類或者分割即可。
我理解這里是作了這里特殊需要的數據增強,這里需要傳入一定量增強之后的數據,因此文章提出來一種新的網絡T-Net,因此對于普通的PointNet(vanilla),如果先后輸入同一個但是經過不同旋轉角度的物體,它可能不能很好地將其識別出來。在論文中的方法是新引入了一個T-Net網絡去學習點云的旋轉,將物體校準,剩下來的PointNet(vanilla)只需要對校準后的物體進行分類或者分割即可。
所謂的T-Net也就是下面這個原來的圖片中所展示的部分。
由圖可以看出,由于點云的旋轉非常的簡單,只需要對一個N × D 的點云矩陣乘以一個D × D的旋轉矩陣即可,因此對輸入點云學習一個3 × 3 的矩陣,即可將其矯正;我們可以看到這樣的話,其實是對原來的物體進行一個仿射變換,也就是原來的情況將被仿射變換一次。
同樣的將點云映射到K維的冗余空間后,再對K維的點云特征做一次校對,只不過這次校對需要引入一個正則化懲罰項,希望其盡可能接近于一個正交矩陣。
因為我們維度變多之后,可能會出現某些權重很大某些權重很小的情況,我們可以將這樣的情況理解成我們把某個因素過度看重了,而沒有足夠重視一些其他因素(其他正常大小的參數在和一個巨大的參數比起來相對就小了),這可能會引起過擬合的問題,所以我們有時候在損失函數中加入一個參數的平方作為一個需要優化的因素來保證參數整體的大小都不太大。
正則化可以參見:
好了,也就是這個小模塊就是我們旋轉需要的模塊。
1.2整個網絡的理解
再看一次整體圖
之后我們開始逐個塊理解:
第一個部分:
下面這個部分是為了旋轉原來的圖片,將這個圖片轉的正過來,其實這里不一定真的可以旋轉過來,這里只是讓其向著正確的方向靠攏。
第二個部分:
這里我們是使用線性全連接層來擴大我們的特征,主要是防止之后maxpooling的時候損失的太多了
第三個部分
這里主要是對新得到的高緯度的信息再進行一次矯正。再嘗試將其轉正。
第四個部分
這個部分,就是對每個點坐標,再進行一次擴展,讓每個點的維度更高。還是使用線性全連接層。
第五個部分
這里就是直接得到了一個全局的特征集合,之后再接上一個網絡,讓輸出順利完成就是了。(這里后面的網絡怎么設計就看你具體是目標識別、目標檢測、語義分割等具體的哪個了。)
但是上面這個并沒有畫明白怎么得到的全局特征。還得看下面那個圖:
這里先得看這個n×1088,其實這個一部分是來自于第三部分的輸出,一部分來自于第四部分的輸出,其實是構成一個小的跳連接,之后再使用線性全連接,逐漸得到你需要的點的特征。
這個東西全部讀完之后其實挺神奇的,這個從始至終都是對單個點進行操作的。
2.實現PointNet
2.0我們先引入需要的包
import torch import torch.nn as nn import numpy as np import torch.nn.functional as F from torch.autograd import Variable2.1實現一個T-Net:
''' 這里是實現一個T-Net。 這個輸入應當是batchsize*3*n_pts(batchsize是點數、n_pts這里的情況是我們xyz這個東西到底需要多少num_feature來進行表示) 輸出是一個batchsize*3*3 ''' class T_Net(nn.Module):def __init__(self):super(T_Net, self).__init__()# 這里需要注意的是上文提到的MLP均由卷積結構完成# 比如說將3維映射到64維,其利用64個1x3的卷積核self.conv1 = torch.nn.Conv1d(3, 64, 1)self.conv2 = torch.nn.Conv1d(64, 128, 1)self.conv3 = torch.nn.Conv1d(128, 1024, 1)self.fc1 = nn.Linear(1024, 512)self.fc2 = nn.Linear(512, 256)self.fc3 = nn.Linear(256, 9)#因為relu沒有參數,所以我們定義一個就行self.relu = nn.ReLU()self.bn1 = nn.BatchNorm1d(64)self.bn2 = nn.BatchNorm1d(128)self.bn3 = nn.BatchNorm1d(1024)self.bn4 = nn.BatchNorm1d(512)self.bn5 = nn.BatchNorm1d(256)def forward(self, x):#老操作了,我們使用size取出來一個batch_sizebatchsize = x.size()[0]#下面的卷積其實是一個仿射變換x = F.relu(self.bn1(self.conv1(x)))x = F.relu(self.bn2(self.conv2(x)))x = F.relu(self.bn3(self.conv3(x)))#這里的2指定的是最后一個維度x = torch.max(x, 2, keepdim=True)[0]x = x.view(-1, 1024)x = F.relu(self.bn4(self.fc1(x)))x = F.relu(self.bn5(self.fc2(x)))x = self.fc3(x)#這里是對對角線的數據進行加強iden = Variable(torch.from_numpy(np.array([1,0,0,0,1,0,0,0,1]).astype(np.float32)))iden = iden.repeat(batchsize).view(-1,9)iden = iden.cuda()x = x + idenx = x.view(-1, 3, 3) # 輸出為Batch*3*3的張量return x實現主體部分
''' 這里是實現一個PointNet的核心部分。 這個輸入應當是batchsize*3*1(batchsize是點數),因為一開始過的是一個T_Net - ''' class PointNetEncoder(nn.Module):def __init__(self, global_feat = True):super(PointNetEncoder, self).__init__()self.tnet = T_Net()#這里其實是用一個卷積實現了一個全連接的情況self.conv1 = torch.nn.Conv1d(3, 64, 1)self.conv2 = torch.nn.Conv1d(64, 128, 1)self.conv3 = torch.nn.Conv1d(128, 1024, 1)self.bn1 = nn.BatchNorm1d(64)self.bn2 = nn.BatchNorm1d(128)self.bn3 = nn.BatchNorm1d(1024)self.global_feat = global_featdef forward(self, x):'''生成全局特征'''n_pts = x.size()[2]#這里是我們旋轉之后得到的結果trans = self.tnet(x)#這個是我們進行矩陣乘法之前常用的操作x = x.transpose(2,1)x = torch.bmm(x, trans) # batch matrix multiply 即乘以T-Net的結果#當然乘過之后還得換回來x = x.transpose(2,1)x = self.conv1(x)x = F.relu(self.bn1(x))#這里的pointfeat主要目的就是給一會的跳連接使用的pointfeat = xx_skip = self.conv2(x)x = F.relu(self.bn2(x_skip))x = self.bn3(self.conv3(x))x = torch.max(x, 2, keepdim=True)[0]x = x.view(-1, 1024)if self.global_feat:return x, transelse:x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)return torch.cat([x, pointfeat], 1), trans如果需求為傳統分類任務
class PointNetCls(nn.Module):def __init__(self, k = 2):super(PointNetCls, self).__init__()self.k = kself.feat = PointNetEncoder(global_feat=False)self.conv1 = torch.nn.Conv1d(1088, 512, 1)self.conv2 = torch.nn.Conv1d(512, 256, 1)self.conv3 = torch.nn.Conv1d(256, 128, 1)self.conv4 = torch.nn.Conv1d(128, self.k, 1)self.bn1 = nn.BatchNorm1d(512)self.bn2 = nn.BatchNorm1d(256)self.bn3 = nn.BatchNorm1d(128)def forward(self, x):'''分類網絡'''batchsize = x.size()[0]n_pts = x.size()[2]x, trans = self.feat(x)x = F.relu(self.bn1(self.conv1(x)))x = F.relu(self.bn2(self.conv2(x)))x = F.relu(self.bn3(self.conv3(x)))x = self.conv4(x)x = x.transpose(2,1).contiguous()x = F.log_softmax(x.view(-1,self.k), dim=-1)x = x.view(batchsize, n_pts, self.k)return x如果需求為語義分割任務
class PointNetPartSeg(nn.Module):def __init__(self,num_class):super(PointNetPartSeg, self).__init__()self.k = num_classself.feat = PointNetEncoder(global_feat=False)self.conv1 = torch.nn.Conv1d(1088, 512, 1)self.conv2 = torch.nn.Conv1d(512, 256, 1)self.conv3 = torch.nn.Conv1d(256, 128, 1)self.conv4 = torch.nn.Conv1d(128, self.k, 1)self.bn1 = nn.BatchNorm1d(512)self.bn1_1 = nn.BatchNorm1d(1024)self.bn2 = nn.BatchNorm1d(256)self.bn3 = nn.BatchNorm1d(128)def forward(self, x):'''分割網絡'''batchsize = x.size()[0]n_pts = x.size()[2]x, trans = self.feat(x)x = F.relu(self.bn1(self.conv1(x)))x = F.relu(self.bn2(self.conv2(x)))x = F.relu(self.bn3(self.conv3(x)))x = self.conv4(x)x = x.transpose(2,1).contiguous()x = F.log_softmax(x.view(-1,self.k), dim=-1)x = x.view(batchsize, n_pts, self.k)return x, trans 《新程序員》:云原生和全面數字化實踐50位技術專家共同創作,文字、視頻、音頻交互閱讀總結
以上是生活随笔為你收集整理的点云网络的论文理解(二)- PointNet的pytorch复现的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 什么是微调?什么是模型迁移?
- 下一篇: 点云网络的论文理解(一)-点云网络的提出