我用 PyTorch 复现了 LeNet-5 神经网络(自定义数据集篇)!
大家好,我是紅色石頭!
在上三篇文章:
這可能是神經(jīng)網(wǎng)絡 LeNet-5 最詳細的解釋了!
我用 PyTorch 復現(xiàn)了 LeNet-5 神經(jīng)網(wǎng)絡(MNIST 手寫數(shù)據(jù)集篇)!
我用 PyTorch 復現(xiàn)了 LeNet-5 神經(jīng)網(wǎng)絡(CIFAR10 數(shù)據(jù)集篇)!
詳細介紹了卷積神經(jīng)網(wǎng)絡 LeNet-5 的理論部分和使用 PyTorch 復現(xiàn) LeNet-5 網(wǎng)絡來解決 MNIST 數(shù)據(jù)集和 CIFAR10 數(shù)據(jù)集。然而大多數(shù)實際應用中,我們需要自己構(gòu)建數(shù)據(jù)集,進行識別。因此,本文將講解一下如何使用 LeNet-5 訓練自己的數(shù)據(jù)。
正文開始!
三、用 LeNet-5 訓練自己的數(shù)據(jù)
下面使用 LeNet-5 網(wǎng)絡來訓練本地的數(shù)據(jù)并進行測試。數(shù)據(jù)集是本地的 LED 數(shù)字 0-9,尺寸為 28x28 單通道,跟 MNIST 數(shù)據(jù)集類似。訓練集 0-9 各 95 張,測試集 0~9 各 40 張。圖片樣例如圖所示:
3.1 數(shù)據(jù)預處理
制作圖片數(shù)據(jù)的索引
對于訓練集和測試集,要分別制作對應的圖片數(shù)據(jù)索引,即 train.txt 和 test.txt兩個文件,每個 txt 中包含每個圖片的目錄和對應類別 class。示意圖如下:
制作圖片數(shù)據(jù)索引的 python 腳本程序如下:
import ostrain_txt_path = os.path.join("data", "LEDNUM", "train.txt") train_dir = os.path.join("data", "LEDNUM", "train_data") valid_txt_path = os.path.join("data", "LEDNUM", "test.txt") valid_dir = os.path.join("data", "LEDNUM", "test_data")def gen_txt(txt_path, img_dir):f = open(txt_path, 'w')for root, s_dirs, _ in os.walk(img_dir, topdown=True): # 獲取 train文件下各文件夾名稱for sub_dir in s_dirs:i_dir = os.path.join(root, sub_dir) # 獲取各類的文件夾 絕對路徑img_list = os.listdir(i_dir) # 獲取類別文件夾下所有png圖片的路徑for i in range(len(img_list)):if not img_list[i].endswith('jpg'): # 若不是png文件,跳過continuelabel = img_list[i].split('_')[0]img_path = os.path.join(i_dir, img_list[i])line = img_path + ' ' + label + '\n'f.write(line)f.close()if __name__ == '__main__':gen_txt(train_txt_path, train_dir)gen_txt(valid_txt_path, valid_dir)運行腳本之后就在 ./data/LEDNUM/ 目錄下生成 train.txt 和 test.txt 兩個索引文件。
構(gòu)建Dataset子類
pytorch 加載自己的數(shù)據(jù)集,需要寫一個繼承自 torch.utils.data 中 Dataset 類,并修改其中的 __init__ 方法、__getitem__ 方法、__len__ 方法。默認加載的都是圖片,__init__ 的目的是得到一個包含數(shù)據(jù)和標簽的 list,每個元素能找到圖片位置和其對應標簽。然后用 __getitem__ 方法得到每個元素的圖像像素矩陣和標簽,返回 img 和 label。
from PIL import Image from torch.utils.data import Datasetclass MyDataset(Dataset):def __init__(self, txt_path, transform = None, target_transform = None):fh = open(txt_path, 'r')imgs = []for line in fh:line = line.rstrip()words = line.split()imgs.append((words[0], int(words[1])))self.imgs = imgs self.transform = transformself.target_transform = target_transformdef __getitem__(self, index):fn, label = self.imgs[index]#img = Image.open(fn).convert('RGB') img = Image.open(fn)if self.transform is not None:img = self.transform(img) return img, labeldef __len__(self):return len(self.imgs)getitem 是核心函數(shù)。self.imgs 是一個 list,self.imgs[index] 是一個 str,包含圖片路徑,圖片標簽,這些信息是從上面生成的txt文件中讀取;利用 Image.open 對圖片進行讀取,注意這里的 img 是單通道還是三通道的;self.transform(img) 對圖片進行處理,這個 transform 里邊可以實現(xiàn)減均值、除標準差、隨機裁剪、旋轉(zhuǎn)、翻轉(zhuǎn)、放射變換等操作。
當 Mydataset構(gòu) 建好,剩下的操作就交給 DataLoder,在 DataLoder 中,會觸發(fā) Mydataset 中的 getiterm 函數(shù)讀取一張圖片的數(shù)據(jù)和標簽,并拼接成一個 batch 返回,作為模型真正的輸入。
pipline_train = transforms.Compose([#隨機旋轉(zhuǎn)圖片transforms.RandomHorizontalFlip(),#將圖片尺寸resize到32x32transforms.Resize((32,32)),#將圖片轉(zhuǎn)化為Tensor格式transforms.ToTensor(),#正則化(當模型出現(xiàn)過擬合的情況時,用來降低模型的復雜度)transforms.Normalize((0.1307,),(0.3081,)) ]) pipline_test = transforms.Compose([#將圖片尺寸resize到32x32transforms.Resize((32,32)),transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,)) ]) train_data = MyDataset('./data/LEDNUM/train.txt', transform=pipline_train) test_data = MyDataset('./data/LEDNUM/test.txt', transform=pipline_test)#train_data 和test_data包含多有的訓練與測試數(shù)據(jù),調(diào)用DataLoader批量加載 trainloader = torch.utils.data.DataLoader(dataset=train_data, batch_size=8, shuffle=True) testloader = torch.utils.data.DataLoader(dataset=test_data, batch_size=4, shuffle=False)3.2 搭建 LeNet-5 神經(jīng)網(wǎng)絡結(jié)構(gòu)
class LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(1, 6, 5) self.relu = nn.ReLU()self.maxpool1 = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.maxpool2 = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(16*5*5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.conv1(x)x = self.relu(x)x = self.maxpool1(x)x = self.conv2(x)x = self.maxpool2(x)x = x.view(-1, 16*5*5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)output = F.log_softmax(x, dim=1)return output3.3 將定義好的網(wǎng)絡結(jié)構(gòu)搭載到 GPU/CPU,并定義優(yōu)化器
#創(chuàng)建模型,部署gpu device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = LeNet().to(device) #定義優(yōu)化器 optimizer = optim.Adam(model.parameters(), lr=0.001)3.4 定義訓練函數(shù)
def train_runner(model, device, trainloader, optimizer, epoch):#訓練模型, 啟用 BatchNormalization 和 Dropout, 將BatchNormalization和Dropout置為Truemodel.train()total = 0correct =0.0#enumerate迭代已加載的數(shù)據(jù)集,同時獲取數(shù)據(jù)和數(shù)據(jù)下標for i, data in enumerate(trainloader, 0):inputs, labels = data#把模型部署到device上inputs, labels = inputs.to(device), labels.to(device)#初始化梯度optimizer.zero_grad()#保存訓練結(jié)果outputs = model(inputs)#計算損失和#多分類情況通常使用cross_entropy(交叉熵損失函數(shù)), 而對于二分類問題, 通常使用sigmodloss = F.cross_entropy(outputs, labels)#獲取最大概率的預測結(jié)果#dim=1表示返回每一行的最大值對應的列下標predict = outputs.argmax(dim=1)total += labels.size(0)correct += (predict == labels).sum().item()#反向傳播loss.backward()#更新參數(shù)optimizer.step()if i % 100 == 0:#loss.item()表示當前l(fā)oss的數(shù)值print("Train Epoch{} \t Loss: {:.6f}, accuracy: {:.6f}%".format(epoch, loss.item(), 100*(correct/total)))Loss.append(loss.item())Accuracy.append(correct/total)return loss.item(), correct/total3.5 定義測試函數(shù)
def test_runner(model, device, testloader):#模型驗證, 必須要寫, 否則只要有輸入數(shù)據(jù), 即使不訓練, 它也會改變權(quán)值#因為調(diào)用eval()將不啟用 BatchNormalization 和 Dropout, BatchNormalization和Dropout置為Falsemodel.eval()#統(tǒng)計模型正確率, 設置初始值correct = 0.0test_loss = 0.0total = 0#torch.no_grad將不會計算梯度, 也不會進行反向傳播with torch.no_grad():for data, label in testloader:data, label = data.to(device), label.to(device)output = model(data)test_loss += F.cross_entropy(output, label).item()predict = output.argmax(dim=1)#計算正確數(shù)量total += label.size(0)correct += (predict == label).sum().item()#計算損失值print("test_avarage_loss: {:.6f}, accuracy: {:.6f}%".format(test_loss/total, 100*(correct/total)))3.6 運行
#調(diào)用 epoch = 5 Loss = [] Accuracy = [] for epoch in range(1, epoch+1):print("start_time",time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))loss, acc = train_runner(model, device, trainloader, optimizer, epoch)Loss.append(loss)Accuracy.append(acc)test_runner(model, device, testloader)print("end_time: ",time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())),'\n')print('Finished Training') plt.subplot(2,1,1) plt.plot(Loss) plt.title('Loss') plt.show() plt.subplot(2,1,2) plt.plot(Accuracy) plt.title('Accuracy') plt.show()經(jīng)歷 5 次 epoch 的 loss 和 accuracy 曲線如下:
3.7 模型保存
torch.save(model, './models/model-mine.pth') #保存模型3.8 模型測試
下面使用上面訓練的模型對一張 LED 圖片進行測試。
from PIL import Image import numpy as npif __name__ == '__main__':device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')model = torch.load('./models/model-mine.pth') #加載模型model = model.to(device)model.eval() #把模型轉(zhuǎn)為test模式#讀取要預測的圖片# 讀取要預測的圖片img = Image.open("./images/test_led.jpg") # 讀取圖像#img.show()plt.imshow(img,cmap="gray") # 顯示圖片plt.axis('off') # 不顯示坐標軸plt.show()# 導入圖片,圖片擴展后為[1,1,32,32]trans = transforms.Compose([#將圖片尺寸resize到32x32transforms.Resize((32,32)),transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])img = trans(img)img = img.to(device)img = img.unsqueeze(0) #圖片擴展多一維,因為輸入到保存的模型中是4維的[batch_size,通道,長,寬],而普通圖片只有三維,[通道,長,寬]# 預測 output = model(img)prob = F.softmax(output,dim=1) #prob是10個分類的概率print("概率:",prob)value, predicted = torch.max(output.data, 1)predict = output.argmax(dim=1)print("預測類別:",predict.item())概率:tensor([[7.2506e-11, 7.0065e-18, 7.1749e-06, 7.4855e-13, 7.3532e-08, 8.5405e-17,2.5753e-15, 9.7887e-10, 2.7855e-05, 9.9996e-01]],grad_fn=<SoftmaxBackward>) 預測類別:9模型預測結(jié)果正確!
以上就是 PyTorch 構(gòu)建 LeNet-5 卷積神經(jīng)網(wǎng)絡并用它來識別自定義數(shù)據(jù)集的例子。全文的代碼都是可以順利運行的,建議大家自己跑一邊。
總結(jié):
是我們目前分別復現(xiàn)了 LeNet-5 來識別 MNIST、CIFAR10?和自定義數(shù)據(jù)集,基本上涵蓋了基于 PyToch 的 LeNet-5 實戰(zhàn)的所有內(nèi)容。希望對大家有所幫助!
所有完整的代碼我都放在 GitHub 上,GitHub地址為:
https://github.com/RedstoneWill/ObjectDetectionLearner/tree/main/LeNet-5
也可以點擊閱讀原文進入~
推薦閱讀
(點擊標題可跳轉(zhuǎn)閱讀)
干貨 | 公眾號歷史文章精選
我的深度學習入門路線
我的機器學習入門路線圖
重磅!
AI有道年度技術(shù)文章電子版PDF來啦!
掃描下方二維碼,添加?AI有道小助手微信,可申請入群,并獲得2020完整技術(shù)文章合集PDF(一定要備注:入群?+ 地點 + 學校/公司。例如:入群+上海+復旦。?
長按掃碼,申請入群
(添加人數(shù)較多,請耐心等待)
感謝你的分享,點贊,在看三連??
總結(jié)
以上是生活随笔為你收集整理的我用 PyTorch 复现了 LeNet-5 神经网络(自定义数据集篇)!的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 中国程序员的前景并非一片黑暗,教你如何拥
- 下一篇: 如果你还在徘徊在程序员的门口,那就赶紧来