行人重识别的代码复现
參考:https://github.com/layumi/Person_reID_baseline_pytorch/tree/master/tutorial
1、環境的安裝
系統的基礎環境:
- ubantu16.04
- CUDA9.0+Cudnn7.4.2
- Python3.7.4
- Anaconda 3
創建虛擬環境
conda create -n re_id python=3.7.4 source activate re_id安裝Pytorch
根據CUDA的版本來安裝:
 https://pytorch.org/get-started/previous-versions/
安裝 yacs
git clone https://github.com/rbgirshick/yacs cd yacs python setup.py install安裝其他依賴庫
pip install pretrainedmodels conda install matplotlib conda install future pip install torchvision pip install tensorboardX pip install tensorflow -i https://pypi.mirrors.ustc.edu.cn/simple conda install scipy conda install Cython2、開始
數據集和代碼的準備
數據集:Market-1501
 代碼:Practical-Baseline
2.1訓練
2.1.1:數據的準備(python prepare.py)
數據集分布如下:
├── Market/ │ ├── bounding_box_test/ /* Files for testing (candidate images pool) │ ├── bounding_box_train/ /* Files for training │ ├── gt_bbox/ /* Files for multiple query testing │ ├── gt_query/ /* We do not use it │ ├── query/ /* Files for testing (query images) │ ├── readme.txt 1. "bounding_box_test" – 19732張圖片,測試集,也是所謂的gallery參考圖像集;2. "bounding_box_train" – 12936張圖片,訓練集;3. "query" – 3368張query圖片,即要查詢的圖片,在 "bounding_box_test"中執行搜索;4. "gt_bbox" – 25259張圖片(人工標注),對應test和train數據集中1501個個體,用于區分"good"、“junk"和"distractors”;5. "gt_query" – 對于3368張query圖片的每一個,都有"good"和"junk"相關的圖像(包含相同個體),這個文件夾包含了"good"和"junk"圖像的索引,用在性能評估中。打開代碼prepare.py。 將第五行的地址改為你本地的地址,比如 \home\zzd\Download\Market,然后在終端中運行代碼。
 記得所有操作都在剛剛創建的虛擬環境下
運行后文件的改變如下:
├── Market/ │ ├── bounding_box_test/ /* Files for testing (candidate images pool) │ ├── bounding_box_train/ /* Files for training │ ├── gt_bbox/ /* Files for multiple query testing │ ├── gt_query/ /* We do not use it │ ├── query/ /* Files for testing (query images) │ ├── readme.txt │ ├── pytorch/ │ ├── train/ /* train │ ├── 0002 | ├── 0007 | ... │ ├── val/ /* val │ ├── train_all/ /* train+val │ ├── query/ /* query files │ ├── gallery/ /* gallery files2.1.2:搭建神經網絡模型(model.py)
我們可以使用預先訓練好的網絡結構,例如“ AlexNet”,“ VGG16”,“ ResNet”和“ DenseNet”。 通常,經過預訓練的網絡結構有助于保留更好的性能,因為它保留了ImageNet的優點[1].
在 pytorch中, 兩行代碼就可以導入模型:
from torchvision import models model = models.resnet50(pretrained=True)但是我們需要稍微調整一下網絡結構。 Market-1501中有751個類別(不同的人),與ImageNet中的1,000個類別所不同。 因此,在這里我們修正模型以使用分類器。
import torch import torch.nn as nn from torchvision import models# Define the ResNet50-based Model class ft_net(nn.Module):def __init__(self, class_num = 751):super(ft_net, self).__init__()#load the modelmodel_ft = models.resnet50(pretrained=True) # change avg pooling to global poolingmodel_ft.avgpool = nn.AdaptiveAvgPool2d((1,1))self.model = model_ftself.classifier = ClassBlock(2048, class_num) #define our classifier.def forward(self, x):x = self.model.conv1(x)x = self.model.bn1(x)x = self.model.relu(x)x = self.model.maxpool(x)x = self.model.layer1(x)x = self.model.layer2(x)x = self.model.layer3(x)x = self.model.layer4(x)x = self.model.avgpool(x)x = torch.squeeze(x)x = self.classifier(x) #use our classifier.return x為什么我們使用AdaptiveAvgPool2d? AvgPool2d和AdaptiveAvgPool2d有什么區別? 該模型現在有參數嗎? 如何在新的網絡層中初始化參數?
仔細看看model.py吧。
 這里我們不需要修改model.py,已經修改好了
2.1.3:開始訓練(train.py)
- 訓練方法 【ResNet-50】
- 訓練方法 【ResNet-50(alltricks)
2.1.4:開始測試(test.py)
- 測試Market-1501數據集
- 測試自己的數據集,并生成json格式的文件。
總結
以上是生活随笔為你收集整理的行人重识别的代码复现的全部內容,希望文章能夠幫你解決所遇到的問題。
 
                            
                        - 上一篇: vue3的语法
- 下一篇: beats耳机红白交替闪烁三次_beat
