深度学习项目实战:垃圾分类系统
簡介:
今天開啟深度學習另一板塊。就是計算機視覺方向,這里主要討論圖像分類任務--垃圾分類系統。其實這個項目早在19年的時候,我就寫好了一個版本了。之前使用的是python搭建深度學習網絡,然后前后端交互的采用的是java spring MVC來寫的。之前感覺還挺好的,但是使用起來還比較困難的。不光光需要有python的基礎,同時還需要有一定的java的基礎。尤其是搭建java的環境,還是很煩的。最近剛好有空,就給這個項目拿了過來優化了一下,本次優化主要涉及前后端界面交互的優化,另外一條就是在模型的識別性能上的優化,提高模型的識別速度。
展示:
下面是項目的初始化界面:
使用本系統的話也是比較簡單的,點擊選擇文件按鈕選擇需要識別的圖片數據。然后再點擊開始識別就可以識別了
識別結果如下:
實際的使用請看下面的視頻:
B站--深度學習項目實戰:垃圾分類系統
項目實現思路:
項目主要分為兩塊,第一塊是深度學習模塊,另一塊呢就是系統的使用界面了。
1、深度學習模塊
先說第一個模塊,也就是深度學習模塊,這塊的主體呢其實就是深度學習的網絡的搭建以及模型的訓練,還有就是模型的使用了。
深度學習網絡的我主要使用的是ResNet的網絡結構,使用這個網絡結構來實現四分類的垃圾分類的任務肯定是可以的。同時呢在訓練模型的時候,我這里又使用了一些調參的手法--遷移學習。為什么要使用遷移學習呢?由于ResNet在圖像任務上表現的是比較出色的,同時我們的任務也是圖像分類,所以呢是可以使用ResNet來進行遷移學習的。
下面是相關代碼:
`import torch
from torch import nn
from torch.nn import functional as F
class ResBlk(nn.Module):
def init(self, ch_in, ch_out, stride=1):
? ? ? ? super(ResBlk, self).init()
self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
? ? ? ? self.bn1 = nn.BatchNorm2d(ch_out)
? ? ? ? self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
? ? ? ? self.bn2 = nn.BatchNorm2d(ch_out)
self.extra = nn.Sequential()
? ? ? ? if ch_out != ch_in:
? ? ? ? ? ? self.extra = nn.Sequential(
? ? ? ? ? ? ? ? nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),
? ? ? ? ? ? ? ? nn.BatchNorm2d(ch_out)
? ? ? ? ? ? )
def forward(self, x):
? ? ? ? out = F.relu(self.bn1(self.conv1(x)))
? ? ? ? out = self.bn2(self.conv2(out))
? ? ?
? ? ? ? out = self.extra(x) + out
? ? ? ? out = F.relu(out)
return out
class ResNet18(nn.Module):
def init(self, num_class):
? ? ? ? super(ResNet18, self).init()
self.conv1 = nn.Sequential(
? ? ? ? ? ? nn.Conv2d(3, 16, kernel_size=3, stride=3, padding=0),
? ? ? ? ? ? nn.BatchNorm2d(16)
? ? ? ? )
? ? ?
? ? ? ? self.blk1 = ResBlk(16, 32, stride=3)
? ? ? ? self.blk2 = ResBlk(32, 64, stride=3)
? ? ? ? self.blk3 = ResBlk(64, 128, stride=2)
? ? ? ? self.blk4 = ResBlk(128, 256, stride=2)
self.outlayer = nn.Linear(25633, num_class)
def forward(self, x):
? ? ? ? x = F.relu(self.conv1(x))
? ? ? ? x = self.blk1(x)
? ? ? ? x = self.blk2(x)
? ? ? ? x = self.blk3(x)
? ? ? ? x = self.blk4(x)
# print(x.shape)
? ? ? ? x = x.view(x.size(0), -1)
? ? ? ? x = self.outlayer(x)
return x
def main():
? ? blk = ResBlk(64, 128)
? ? tmp = torch.randn(2, 64, 224, 224)
? ? out = blk(tmp)
? ? print('block:', out.shape)
model = ResNet18(5)
? ? tmp = torch.randn(2, 3, 224, 224)
? ? out = model(tmp)
? ? print('resnet:', out.shape)
p = sum(map(lambda p:p.numel(), model.parameters()))
? ? print('parameters size:', p)
if name == 'main':
? ? main()`
下面是遷移學習的主要代碼:trained_model=resnet18(pretrained=True) ? ? model = nn.Sequential(*list(trained_model.children())[:-1], ? ? ? ? ? ? ? ? ? ? ? ? ? Flatten(), ? ? ? ? ? ? ? ? ? ? ? ? ? nn.Linear(512,4) ? ? ? ? ? ? ? ? ? ? ? ? ? ).to(device)
這部分代碼將預訓練模型的所有層(除了最后一層)復制到新模型中。Flatten()是將最后一層的輸出展平,以便可以輸入到全連接層(nn.Linear(512,4))。nn.Linear(512,4)是一個全連接層,有512個輸入節點和4個輸出節點,對應于任務中的類別數。
最后,.to(device)將模型移動到指定的設備上(例如GPU或CPU)。如果你沒有指定設備,那么默認會使用CPU。
之后呢設置batchsize、learning rate、優化器就可以進行模型的訓練了
參數設置如下:batchsz = 64 lr = 1e-4 epochs = 5
2、使用界面
接下來呢,就是關于使用界面的實現思路介紹了。使用界面就是為了方便對模型使用不是很了解的小伙伴使用的。如下所示,可以看到我們只需要點擊兩個按鈕就可以使用了。
這里的實現呢,主要采用的是Flask進行開發的,以前的版本是采用java的方式開的,使用起來不但笨重,同時模型識別的速度還比較的慢。最要命的是,搭建環境也是讓人頭疼的一件事。所以這次我給整個項目做了優化。主要就是提高模型的識別速度,同時讓使用者擁有良好的使用體驗。
系統主要架構如下圖所示:
其實比較簡單,其實也就4步:
第一步:就是給通過使用端選擇需要識別的圖片數據
第二步:給數據傳到指定目錄下,然后給模型識別使用
第三步:模型進行識別
第四步:給識別結果以網頁的方式進行展示,這里做的是四分類的任務,所以主要設計了四個網頁。還有一個就是出現意外狀況的test.html
我舉一個例子:比如我們輸入的圖片是廚房的垃圾圖片,那么模型識別以后給識別結果交給Flask代碼,Flask代碼會根據對應的識別結果給跳轉到kitch.html界面中,最后的結果如下所示,可以看到的有識別結果還有識別的圖片,以及對于相應的垃圾的分類的定義還有一些小貼士。
Flask的主要代碼如下:
`uploaded_file = request.files['file']
? ? file_name = uploaded_file.filename
? ? if not os.path.exists(UPLOAD_FOLDER):
? ? ? ? os.makedirs(UPLOAD_FOLDER)
# get file path
? ? file_path = os.path.join(UPLOAD_FOLDER, file_name)
# write image to UPLOAD_FOLDER
? ? with open(file_path, 'wb') as f:
? ? ? ? f.write(uploaded_file.read())`
下面的代碼主要就是獲取到form傳遞過來的圖片數據,然后整個代碼就會給數據上傳到指定的文件夾下面。
最后說明:
由于筆者能力有限,所以在描述的過程中難免會有不準確的地方,還請多多包含!
更多NLP和CV文章以及完整代碼請到"陶陶name"獲取。
項目實戰持續更新,大家加油!!!!
總結
以上是生活随笔為你收集整理的深度学习项目实战:垃圾分类系统的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: iTunes升级iOS出现未知错误300
- 下一篇: 橙光游戏《师姐该吃药了》师父攻略