【Pytorch神经网络实战案例】18 最大化深度互信信息模型DIM实现搜索最相关与最不相关的图片
圖片搜索器分為圖片的特征提取和匹配兩部分,其中圖片的特征提取是關(guān)鍵。將使用一種基于無(wú)監(jiān)督模型的提取特征的方法實(shí)現(xiàn)特征提取,即最大化深度互信息(DeepInfoMax,DIM)方法。
1 最大深度互信信息模型DIM簡(jiǎn)介
在DIM模型中,結(jié)合了自編碼和對(duì)抗神經(jīng)網(wǎng)絡(luò),損失函數(shù)使用了MINE與f-GAN方法的結(jié)合。在此之上,DM模型又從全局損失、局部損失和先驗(yàn)損失3個(gè)損失出發(fā)進(jìn)行訓(xùn)練。
1.1 DIM模型原理
性能好的編碼器應(yīng)該能夠提取出樣本中最獨(dú)特、具體的信息,而不是單純地追求過(guò)小的重構(gòu)誤差。而樣本的獨(dú)特信息可以使用互信息(MutualInformation,MI)來(lái)衡量。
因此,在DIM模型中,編碼器的目標(biāo)函數(shù)不是最小化輸入與輸出的MSE,而是最大化輸入與輸出的互信息。
1.2 DIM模型的主要思想
DIM模型中的互信息解決方案主要來(lái)自MINE方法,即計(jì)算輸入樣本與編碼器輸出的轉(zhuǎn)征向量之間的互信息,通過(guò)最大化互信息來(lái)實(shí)現(xiàn)模型的訓(xùn)練。
1.2.1 DIM模型在無(wú)監(jiān)督訓(xùn)練中的兩種約束。
在實(shí)現(xiàn)時(shí),DlM模型使用了3個(gè)判別器,分別從局部互信息最大化、全局互信息最大化和先驗(yàn)分布匹配最小化的3個(gè)角度對(duì)編碼器的輸出結(jié)果進(jìn)行約束。(論文arXv:1808.06670,2018)
1.3 局部與全局互信息最大化約束的原理
許多表示學(xué)習(xí)只使用已探索過(guò)的數(shù)據(jù)空間(稱為像素級(jí)別),當(dāng)一小部分?jǐn)?shù)據(jù)十分關(guān)心語(yǔ)義級(jí)別時(shí),表明該表示學(xué)習(xí)將不利于訓(xùn)練。
? ? 對(duì)于圖片,它的相關(guān)性更多體現(xiàn)在局部。圖片的識(shí)別、分類等應(yīng)該是一個(gè)從局部到整體的過(guò)程,即全局特征更適合用于重構(gòu),局部特征更適合用于下游的分類任務(wù)。
局部特征可以理解為卷積后得到的特征圖,全局特征可以理解為對(duì)特征圖進(jìn)行編碼得到的特征向量。
DIM模型從局部和全局兩個(gè)角度出發(fā)對(duì)輸入和輸出執(zhí)行互信息計(jì)算。
1.4 先驗(yàn)分布匹配最小化約束原理
先驗(yàn)匹配的目的是對(duì)編碼器生成向量形式進(jìn)行約束,使其更接近高斯分布。
DIM模型的編碼器的主要思想是:在對(duì)輸入數(shù)據(jù)編碼成特征向量的同時(shí),還希望這個(gè)特征向量服從于標(biāo)準(zhǔn)的高斯分布。這種做法使編碼空間更加規(guī)整,甚至有利于解耦特征,便于后續(xù)學(xué)習(xí),與變分自編碼中編碼器的使命是一樣的。
因此,在DIM模型中引入變分自編碼神經(jīng)網(wǎng)絡(luò)的原理,將高斯分布當(dāng)作先驗(yàn)分布,對(duì)編碼器輸出的向量進(jìn)行約束。
2 DIM模型的結(jié)構(gòu)
2.1 DIM模型結(jié)構(gòu)圖
DIM模型的結(jié)構(gòu)DIM模型由4個(gè)子模型構(gòu)成:1個(gè)編碼器、3個(gè)判別器。其中解碼器的作用主要是對(duì)圖進(jìn)行特征提取,3個(gè)判器需分別從局部、全局、先驗(yàn)匹配3個(gè)角度對(duì)編碼器的輸出結(jié)果進(jìn)行約束。
2.2 DlM模型的特殊之處
? ? 在DlM模型的實(shí)際實(shí)現(xiàn)過(guò)程中,沒有直接對(duì)原始的輸入數(shù)據(jù)與編碼器輸出的特征數(shù)據(jù)執(zhí)行最大化互信息計(jì)算,而使用了編碼器中間過(guò)程中的特征圖與最終的特征數(shù)據(jù)執(zhí)行互信息計(jì)算。
? ? 根據(jù)MINE方法,利用神經(jīng)網(wǎng)絡(luò)計(jì)算互信息的方法可以換算成計(jì)算兩個(gè)數(shù)據(jù)集的聯(lián)合分布和邊緣分布間的散度,即將判別器處理特征圖和特征數(shù)據(jù)的結(jié)果當(dāng)作聯(lián)合分布,將亂序后的特征圖和特征數(shù)據(jù)輸入判別器得到邊緣分布。
DIM模型打亂特征圖的批次順序后與編碼器輸出的提示特征向量一起作為判別器的輸入,即令輸入判別器的特征圖與特征向量各自獨(dú)立(破壞特征圖與特征向量間的對(duì)應(yīng)關(guān)系),詳見互信息神經(jīng)估計(jì)的原理介紹。
2.3 全局判別器模型
如圖8-29,全局判別器的輸入值有兩個(gè):特征圖和特征數(shù)據(jù)y。在計(jì)算互信息的過(guò)程中,聯(lián)合分布的特征圖和特征數(shù)據(jù)y都來(lái)自編碼神經(jīng)網(wǎng)絡(luò)的輸出。計(jì)算邊緣分布的特征圖是由改變特征圖的批次順序得來(lái)的,特征數(shù)據(jù)y來(lái)自編碼神經(jīng)網(wǎng)絡(luò)的輸出,如圖8-30所示。
在全局判別器中,具體的處理步驟如下。
(1)使用卷積層對(duì)特征圖進(jìn)行處理,得到全局特征。
(2)將該全局特征與特征數(shù)據(jù)y用torch.cat()函數(shù)連接起來(lái)。
(3)將連接后的結(jié)果輸入全連接網(wǎng)絡(luò)(對(duì)兩個(gè)全局特征進(jìn)行判定),最終輸出判別結(jié)果(一維向量)。
2.4 局部判別器模型
如圖8-29所示,局部判別器的輸入值是一個(gè)特殊的合成向量:將編碼器輸出的特征數(shù)據(jù)y按照特征圖的尺寸復(fù)制成m×m份。令特征圖中的每個(gè)像素都與編碼器輸出的全局特征數(shù)據(jù)ν相連。這樣,判別器所做的事情就變成對(duì)每個(gè)像素與全局特征向量之間的互信息進(jìn)行計(jì)算。因此,該判別器稱為局部判別器。
在局部判別器中,計(jì)算互信息的聯(lián)合分布和邊緣分布方式與全局判別器一致,如圖8-31所示,在局部判別器中主要使用了1×1的卷積操作(步長(zhǎng)也為1)。因?yàn)檫@種卷積操作不會(huì)改變特征圖的尺寸(只是通道數(shù)的變換),所以判別器的最終輸出也是大小為m×m的值。
局部判別器通過(guò)執(zhí)行多層的1×1卷積操作,將通道數(shù)最終變成1,并作為最終的判別結(jié)果。該過(guò)程可以理解為,同時(shí)對(duì)每個(gè)像素與全局特征計(jì)算互信息。
2.5 先驗(yàn)判別器模型
先驗(yàn)判別器模型主要是輔助編碼器生成的向量趨近于高斯分布,其做法與普通的對(duì)抗神經(jīng)網(wǎng)絡(luò)一致。先驗(yàn)判別器模型輸出的結(jié)果只有0或1:令判別器對(duì)高斯分布采樣的數(shù)據(jù)判定為真(1),對(duì)編碼器輸出的特征向量判定為假(0),如圖8-32所示。
先驗(yàn)判別器模型如圖8-32所示,先驗(yàn)判別器模型的輸入只有一個(gè)特征向量。其結(jié)構(gòu)主要使用了全連接神經(jīng)網(wǎng)絡(luò),最終會(huì)輸出“真”或“假”的判定結(jié)果。
2.6 損失函數(shù)
? ? 在DIM模型中,將MINE方法中的KL散度換成JS散度來(lái)作為互信息的度量。這樣做的原因是:JS散度是有上界的,而KL散度是沒有上界的。相比之下,JS散度更適合在最大化任務(wù)中使用,因?yàn)樗谟?jì)算時(shí)不會(huì)產(chǎn)生特別大的數(shù),并且JS散度的梯度又是無(wú)偏的。
在f-GAN中可以找到JS散度的計(jì)算公式,見式(8-46)(其原理在式(8-46)下面的提示部分進(jìn)行了闡述)。
?先驗(yàn)判別器的損失函數(shù)非常簡(jiǎn)單、與原始的GAN模型(參見的論文編號(hào)為anXiv:1406.2661,2014)中的損失函數(shù)一致,對(duì)這3個(gè)判別器各自損失函數(shù)的計(jì)算結(jié)果加權(quán)求和,便得到整個(gè)DM模型的損失函數(shù)。
3 實(shí)戰(zhàn)案例簡(jiǎn)介與代碼實(shí)現(xiàn)(訓(xùn)練模型代碼實(shí)現(xiàn))
使用最大化深度互信息模型提取圖片信息,并用提取出來(lái)的低維特征制作圖片搜索器。
3.1 CIFAR數(shù)據(jù)集
? ? 本例使用的數(shù)據(jù)集是ClFAR,它與Fashion-MNIST數(shù)據(jù)集類似,也是一些圖片。ClFAR比Fashion-MNIST更為復(fù)雜,而且由彩色圖像組成,相比之下,與實(shí)際場(chǎng)景中接觸的樣本更為接近。
3.1.1?CIFAR數(shù)據(jù)集的組成
CIFAR數(shù)據(jù)集的版本因?yàn)槠鸪醯臄?shù)據(jù)集共將數(shù)據(jù)分為10類,分別為飛機(jī)、汽車、鳥、貓、鹿、狗、青蛙、馬、船、卡車,所以ClFAR的數(shù)據(jù)集常以CIFAR-10命名,其中包含60000張32像素×32像素的彩色圖像(包含50000張訓(xùn)練圖片、10000張測(cè)試圖片),沒有任何類型重疊的情況。因?yàn)槭遣噬珗D像,所以這個(gè)數(shù)據(jù)集是三通道的,具有R、G、B這3個(gè)通道。
CIFAR又推出了一個(gè)分類更多的版本:ClFAR-100,從名字也可以看出,其將數(shù)據(jù)分為100類。它將圖片分得更細(xì),當(dāng)然,這對(duì)神經(jīng)網(wǎng)絡(luò)圖像識(shí)別是更大的挑戰(zhàn),有了這數(shù)據(jù),我們可以把精力全部投入在網(wǎng)絡(luò)優(yōu)化上。
?3.2 獲取數(shù)據(jù)集
ClFAR數(shù)據(jù)集是已經(jīng)打包好的文件,分為Python、二進(jìn)制bin文件包,方便不同的程序讀取,本次使用的數(shù)據(jù)集是ClFAR-10版本中的Python文件包,對(duì)應(yīng)的文件名稱為“cifar-10-pyhon.tar.gz”。該文件可以在官網(wǎng)上手動(dòng)下載,也可以使用與獲取Fashion-MNIST類似的方法,通過(guò)PyTorch的內(nèi)嵌代碼進(jìn)行下載。
3.3 加載并顯示CIFAR數(shù)據(jù)集------DIM_CIRFAR_train.py(第1部分)
import torch from torch import nn import torch.nn.functional as F import torchvision from torch.optim import Adam from torchvision.transforms import ToTensor from torch.utils.data import DataLoader from torchvision.datasets.cifar import CIFAR10 from matplotlib import pyplot as plt import numpy as np from tqdm import tqdm from pathlib import Path from torchvision.transforms import ToPILImage import os os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"# 1.1 獲取數(shù)據(jù)集并顯示數(shù)據(jù)集 # 指定運(yùn)算設(shè)備 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(device) # 加載數(shù)據(jù)集 batch_size = 512 data_dir = r'./cifar10/' # 將CIFAR10數(shù)據(jù)集下載到本地:共有三份文件,標(biāo)簽說(shuō)明文件batches.meta,訓(xùn)練樣本集data_batch_x(一共五個(gè),包含10000條訓(xùn)練樣本),測(cè)試樣本test.batch train_dataset = CIFAR10(data_dir,download=True,transform=ToTensor()) train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True,drop_last=True,pin_memory=torch.cuda.is_available()) print("訓(xùn)練樣本個(gè)數(shù):",len(train_dataset)) # 定義函數(shù)用于顯示圖片 def imshowrow(imgs,nrow):plt.figure(dpi=200) # figsize=(9,4)# ToPILImage()調(diào)用PyTorch的內(nèi)部轉(zhuǎn)換接口,實(shí)現(xiàn)張量===>PLImage類型圖片的轉(zhuǎn)換。# 該接口主要實(shí)現(xiàn)。(1)將張量的每個(gè)元素乘以255。(2)將張量的數(shù)據(jù)類型由FloatTensor轉(zhuǎn)化成uint8。(3)將張量轉(zhuǎn)化成NumPy的ndarray類型。(4)對(duì)ndarray對(duì)象執(zhí)行transpose(1,2,0)的操作。(5)利用Image下的fromarray()函數(shù),將ndarray對(duì)象轉(zhuǎn)化成PILImage形式。(6)輸出PILImage。_img = ToPILImage()(torchvision.utils.make_grid(imgs,nrow=nrow)) # 傳入PLlmage()接口的是由torchvision.utis.make_grid接口返回的張量對(duì)象plt.axis('off')plt.imshow(_img)plt.show()# 定義標(biāo)簽與對(duì)應(yīng)的字符 classes = ('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') # 獲取一部分樣本用于顯示 sample = iter(train_loader) images,labels = sample.next() print("樣本形狀:",np.shape(images)) print('樣本標(biāo)簽:',','.join('%2d:%-5s' % (labels[j],classes[labels[j]]) for j in range(len(images[:10])))) imshowrow(images[:10],nrow=10)輸出:
3.5?定義DIM模型------DIM_CIRFAR_train.py(第2部分)
# 1.2 定義DIM模型 class Encoder(nn.Module): # 通過(guò)多個(gè)卷積層對(duì)輸入數(shù)據(jù)進(jìn)行編碼,生成64維特征向量def __init__(self):super().__init__()self.c0 = nn.Conv2d(3, 64, kernel_size=4, stride=1) # 輸出尺寸29self.c1 = nn.Conv2d(64, 128, kernel_size=4, stride=1) # 輸出尺寸26self.c2 = nn.Conv2d(128, 256, kernel_size=4, stride=1) # 輸出尺寸23self.c3 = nn.Conv2d(256, 512, kernel_size=4, stride=1) # 輸出尺寸20self.l1 = nn.Linear(512*20*20, 64)# 定義BN層self.b1 = nn.BatchNorm2d(128)self.b2 = nn.BatchNorm2d(256)self.b3 = nn.BatchNorm2d(512)def forward(self, x):h = F.relu(self.c0(x))features = F.relu(self.b1(self.c1(h)))#輸出形狀[b 128 26 26]h = F.relu(self.b2(self.c2(features)))h = F.relu(self.b3(self.c3(h)))encoded = self.l1(h.view(x.shape[0], -1))# 輸出形狀[b 64]return encoded, featuresclass DeepInfoMaxLoss(nn.Module): # 實(shí)現(xiàn)全局、局部、先驗(yàn)判別器模型的結(jié)構(gòu)設(shè)計(jì),合并每個(gè)判別器的損失函數(shù),得到總的損失函數(shù)def __init__(self,alpha=0.5,beta=1.0,gamma=0.1):super().__init__()# 初始化損失函數(shù)的加權(quán)參數(shù)self.alpha = alphaself.beta = betaself.gamma = gamma# 定義局部判別模型self.local_d = nn.Sequential(nn.Conv2d(192,512,kernel_size=1),nn.ReLU(True),nn.Conv2d(512,512,kernel_size=1),nn.ReLU(True),nn.Conv2d(512,1,kernel_size=1))# 定義先驗(yàn)判別器模型self.prior_d = nn.Sequential(nn.Linear(64,1000),nn.ReLU(True),nn.Linear(1000,200),nn.ReLU(True),nn.Linear(200,1),nn.Sigmoid() # 在定義先驗(yàn)判別器模型的結(jié)構(gòu)時(shí),最后一層的激活函數(shù)用Sigmoid函數(shù)。這是原始GAN模型的標(biāo)準(zhǔn)用法(可以控制輸出值的范圍為0-1),是與損失函數(shù)配套使用的。)# 定義全局判別器模型self.global_d_M = nn.Sequential(nn.Conv2d(128,64,kernel_size=3), # 輸出形狀[b,64,24,24]nn.ReLU(True),nn.Conv2d(64,32,kernel_size=3), # 輸出形狀 [b,32,32,22]nn.Flatten(),)self.global_d_fc = nn.Sequential(nn.Linear(32*22*22+64,512),nn.ReLU(True),nn.Linear(512,512),nn.ReLU(True),nn.Linear(512,1))def GlobalD(self, y, M):h = self.global_d_M(M)h = torch.cat((y, h), dim=1)return self.global_d_fc(h)def forward(self,y,M,M_prime):# 復(fù)制特征向量y_exp = y.unsqueeze(-1).unsqueeze(-1)y_exp = y_exp.expand(-1,-1,26,26) # 輸出形狀[b,64,26,26]# 按照特征圖的像素連接特征向量y_M = torch.cat((M,y_exp),dim=1) # 輸出形狀[b,192,26,26]y_M_prime = torch.cat((M_prime,y_exp),dim=1)# 輸出形狀[b,192,26,26]# 計(jì)算局部互信息---互信息的計(jì)算Ej = -F.softplus(-self.local_d(y_M)).mean() # 聯(lián)合分布Em = F.softplus(self.local_d(y_M_prime)).mean() # 邊緣分布LOCAL = (Em - Ej) * self.beta # 最大化互信息---對(duì)互信息執(zhí)行了取反操作。將最大化問(wèn)題變?yōu)樽钚』瘑?wèn)題,在訓(xùn)練過(guò)程中,可以使用最小化損失的方法進(jìn)行處理。# 計(jì)算全局互信息---互信息的計(jì)算Ej = -F.softplus(-self.GlobalD(y, M)).mean() # 聯(lián)合分布Em = F.softplus(self.GlobalD(y, M_prime)).mean() # 邊緣分布GLOBAL = (Em - Ej) * self.alpha # 最大化互信息---對(duì)互信息執(zhí)行了取反操作。將最大化問(wèn)題變?yōu)樽钚』瘑?wèn)題,在訓(xùn)練過(guò)程中,可以使用最小化損失的方法進(jìn)行處理。# 計(jì)算先驗(yàn)損失prior = torch.rand_like(y) # 獲得隨機(jī)數(shù)term_a = torch.log(self.prior_d(prior)).mean() # GAN損失term_b = torch.log(1.0 - self.prior_d(y)).mean()PRIOR = -(term_a + term_b) * self.gamma # 最大化目標(biāo)分布---實(shí)現(xiàn)了判別器的損失函數(shù)。判別器的目標(biāo)是將真實(shí)數(shù)據(jù)和生成數(shù)據(jù)的分布最大化,因此,也需要取反,通過(guò)最小化損失的方法來(lái)實(shí)現(xiàn)。return LOCAL + GLOBAL + PRIOR# #### 在訓(xùn)練過(guò)程中,梯度可以通過(guò)損失函數(shù)直接傳播到編碼器模型,進(jìn)行聯(lián)合優(yōu)化,因此,不需要對(duì)編碼器額外進(jìn)行損失函數(shù)的定義!3.6?實(shí)例化DIM模型并訓(xùn)練------DIM_CIRFAR_train.py(第3部分)
# 1.3 實(shí)例化DIM模型并訓(xùn)練:實(shí)例化模型按照指定次數(shù)迭代訓(xùn)練。在制作邊緣分布樣本時(shí),將批次特征圖的第1條放到最后,以使特征圖與特征向量無(wú)法對(duì)應(yīng),實(shí)現(xiàn)與按批次打亂順序等同的效果。 totalepoch = 100 # 指定訓(xùn)練次數(shù) if __name__ == '__main__':encoder =Encoder().to(device)loss_fn = DeepInfoMaxLoss().to(device)optim = Adam(encoder.parameters(),lr=1e-4)loss_optim = Adam(loss_fn.parameters(),lr=1e-4)epoch_loss = []for epoch in range(totalepoch +1):batch = tqdm(train_loader,total=len(train_dataset)//batch_size)train_loss = []for x,target in batch: # 遍歷數(shù)據(jù)集x = x.to(device)optim.zero_grad()loss_optim.zero_grad()y,M = encoder(x) # 用編碼器生成特征圖和特征向量# 制作邊緣分布樣本M_prime = torch.cat((M[1:],M[0].unsqueeze(0)),dim=0)loss =loss_fn(y,M,M_prime) # 計(jì)算損失train_loss.append(loss.item())batch.set_description(str(epoch) + ' Loss:%.4f'% np.mean(train_loss[-20:]))loss.backward()optim.step() # 調(diào)用編碼器優(yōu)化器loss_optim.step() # 調(diào)用判別器優(yōu)化器if epoch % 10 == 0 : # 保存模型root = Path(r'./DIMmodel/')enc_file = root / Path('encoder' + str(epoch) + '.pth')loss_file = root / Path('loss' + str(epoch) + '.pth')enc_file.parent.mkdir(parents=True, exist_ok=True)torch.save(encoder.state_dict(), str(enc_file))torch.save(loss_fn.state_dict(), str(loss_file))epoch_loss.append(np.mean(train_loss[-20:])) # 收集訓(xùn)練損失plt.plot(np.arange(len(epoch_loss)), epoch_loss, 'r') # 損失可視化plt.show()結(jié)果:
?
3.7 加載模型并搜索圖片------DIM_CIRFAR_loadpath.py
import torch import torch.nn.functional as F from tqdm import tqdm import random# 功能介紹:載入編碼器模型,對(duì)樣本集中所有圖片進(jìn)行編碼,隨機(jī)取一張圖片,找出與該圖片最接近與最不接近的十張圖片 # # 引入本地庫(kù) #引入本地代碼庫(kù) from DIM_CIRFAR_train import ( train_loader,train_dataset,totalepoch,device,batch_size,imshowrow, Encoder)# 加載模型 model_path = r'./DIMmodel/encoder%d.pth'% (totalepoch) encoder = Encoder().to(device) encoder.load_state_dict(torch.load(model_path,map_location=device))# 加載模型樣本,并調(diào)用編碼器生成特征向量 batchesimg = [] batchesenc = [] batch = tqdm(train_loader,total=len(train_dataset)//batch_size) for images ,target in batch :images = images.to(device)with torch.no_grad():encoded,features = encoder(images) # 調(diào)用編碼器生成特征向量batchesimg.append(images)batchesenc.append(encoded) # 將樣本中的圖片與生成的向量沿第1維度展開 batchesenc = torch.cat(batchesenc,axis = 0) batchesimg = torch.cat(batchesimg,axis = 0) # 驗(yàn)證向量的搜索功能 index = random.randrange(0,len(batchesenc)) # 隨機(jī)獲取一個(gè)索引,作為目標(biāo)圖片 batchesenc[index].repeat(len(batchesenc),1) # 將目標(biāo)圖片的特征向量復(fù)制多份 # 使用F.mse_loss()函數(shù)進(jìn)行特征向量間的L2計(jì)算,傳入了參數(shù)reduction='none',這表明對(duì)計(jì)算后的結(jié)果不執(zhí)行任何操作。如果不傳入該參數(shù),那么函數(shù)默認(rèn)會(huì)對(duì)所有結(jié)果取平均值(常用在訓(xùn)練模型場(chǎng)景中) l2_dis = F.mse_loss(batchesenc[index].repeat(len(batchesenc),1),batchesenc,reduction='none').sum(1) # 計(jì)算目標(biāo)圖片與每個(gè)圖片的L2距離 findnum = 10 # 設(shè)置查找圖片的個(gè)數(shù) # 使用topk()方法獲取L2距離最近、最遠(yuǎn)的圖片。該方法會(huì)返回兩個(gè)值,第一個(gè)是真實(shí)的比較值,第二個(gè)是該值對(duì)應(yīng)的索引。 _,indices = l2_dis.topk(findnum,largest=False ) # 查找10個(gè)最相近的圖片 _,indices_far = l2_dis.topk(findnum,) # 查找10個(gè)最不相關(guān)的圖片 # 顯示結(jié)果 indices = torch.cat([torch.tensor([index]).to(device),indices]) indices_far = torch.cat([torch.tensor([index]).to(device),indices_far]) rel = torch.cat([batchesimg[indices],batchesimg[indices_far]],axis = 0) imshowrow(rel.cpu() ,nrow=len(indices)) # 結(jié)果顯示:結(jié)果有兩行,每行的第一列是目標(biāo)圖片,第一行是與目標(biāo)圖片距離最近的搜索結(jié)果,第二行是與目標(biāo)圖片距離最遠(yuǎn)的搜索結(jié)果。?4 代碼總覽
4.1 訓(xùn)練模型:DIM_CIRFAR_train.py
import torch from torch import nn import torch.nn.functional as F import torchvision from torch.optim import Adam from torchvision.transforms import ToTensor from torch.utils.data import DataLoader from torchvision.datasets.cifar import CIFAR10 from matplotlib import pyplot as plt import numpy as np from tqdm import tqdm from pathlib import Path from torchvision.transforms import ToPILImage import os os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"# 1.1 獲取數(shù)據(jù)集并顯示數(shù)據(jù)集 # 指定運(yùn)算設(shè)備 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(device) # 加載數(shù)據(jù)集 batch_size = 512 data_dir = r'./cifar10/' # 將CIFAR10數(shù)據(jù)集下載到本地:共有三份文件,標(biāo)簽說(shuō)明文件batches.meta,訓(xùn)練樣本集data_batch_x(一共五個(gè),包含10000條訓(xùn)練樣本),測(cè)試樣本test.batch train_dataset = CIFAR10(data_dir,download=True,transform=ToTensor()) train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True,drop_last=True,pin_memory=torch.cuda.is_available()) print("訓(xùn)練樣本個(gè)數(shù):",len(train_dataset)) # 定義函數(shù)用于顯示圖片 def imshowrow(imgs,nrow):plt.figure(dpi=200) # figsize=(9,4)# ToPILImage()調(diào)用PyTorch的內(nèi)部轉(zhuǎn)換接口,實(shí)現(xiàn)張量===>PLImage類型圖片的轉(zhuǎn)換。# 該接口主要實(shí)現(xiàn)。(1)將張量的每個(gè)元素乘以255。(2)將張量的數(shù)據(jù)類型由FloatTensor轉(zhuǎn)化成uint8。(3)將張量轉(zhuǎn)化成NumPy的ndarray類型。(4)對(duì)ndarray對(duì)象執(zhí)行transpose(1,2,0)的操作。(5)利用Image下的fromarray()函數(shù),將ndarray對(duì)象轉(zhuǎn)化成PILImage形式。(6)輸出PILImage。_img = ToPILImage()(torchvision.utils.make_grid(imgs,nrow=nrow)) # 傳入PLlmage()接口的是由torchvision.utis.make_grid接口返回的張量對(duì)象plt.axis('off')plt.imshow(_img)plt.show()# 定義標(biāo)簽與對(duì)應(yīng)的字符 classes = ('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') # 獲取一部分樣本用于顯示 sample = iter(train_loader) images,labels = sample.next() print("樣本形狀:",np.shape(images)) print('樣本標(biāo)簽:',','.join('%2d:%-5s' % (labels[j],classes[labels[j]]) for j in range(len(images[:10])))) imshowrow(images[:10],nrow=10)# 1.2 定義DIM模型 class Encoder(nn.Module): # 通過(guò)多個(gè)卷積層對(duì)輸入數(shù)據(jù)進(jìn)行編碼,生成64維特征向量def __init__(self):super().__init__()self.c0 = nn.Conv2d(3, 64, kernel_size=4, stride=1) # 輸出尺寸29self.c1 = nn.Conv2d(64, 128, kernel_size=4, stride=1) # 輸出尺寸26self.c2 = nn.Conv2d(128, 256, kernel_size=4, stride=1) # 輸出尺寸23self.c3 = nn.Conv2d(256, 512, kernel_size=4, stride=1) # 輸出尺寸20self.l1 = nn.Linear(512*20*20, 64)# 定義BN層self.b1 = nn.BatchNorm2d(128)self.b2 = nn.BatchNorm2d(256)self.b3 = nn.BatchNorm2d(512)def forward(self, x):h = F.relu(self.c0(x))features = F.relu(self.b1(self.c1(h)))#輸出形狀[b 128 26 26]h = F.relu(self.b2(self.c2(features)))h = F.relu(self.b3(self.c3(h)))encoded = self.l1(h.view(x.shape[0], -1))# 輸出形狀[b 64]return encoded, featuresclass DeepInfoMaxLoss(nn.Module): # 實(shí)現(xiàn)全局、局部、先驗(yàn)判別器模型的結(jié)構(gòu)設(shè)計(jì),合并每個(gè)判別器的損失函數(shù),得到總的損失函數(shù)def __init__(self,alpha=0.5,beta=1.0,gamma=0.1):super().__init__()# 初始化損失函數(shù)的加權(quán)參數(shù)self.alpha = alphaself.beta = betaself.gamma = gamma# 定義局部判別模型self.local_d = nn.Sequential(nn.Conv2d(192,512,kernel_size=1),nn.ReLU(True),nn.Conv2d(512,512,kernel_size=1),nn.ReLU(True),nn.Conv2d(512,1,kernel_size=1))# 定義先驗(yàn)判別器模型self.prior_d = nn.Sequential(nn.Linear(64,1000),nn.ReLU(True),nn.Linear(1000,200),nn.ReLU(True),nn.Linear(200,1),nn.Sigmoid() # 在定義先驗(yàn)判別器模型的結(jié)構(gòu)時(shí),最后一層的激活函數(shù)用Sigmoid函數(shù)。這是原始GAN模型的標(biāo)準(zhǔn)用法(可以控制輸出值的范圍為0-1),是與損失函數(shù)配套使用的。)# 定義全局判別器模型self.global_d_M = nn.Sequential(nn.Conv2d(128,64,kernel_size=3), # 輸出形狀[b,64,24,24]nn.ReLU(True),nn.Conv2d(64,32,kernel_size=3), # 輸出形狀 [b,32,32,22]nn.Flatten(),)self.global_d_fc = nn.Sequential(nn.Linear(32*22*22+64,512),nn.ReLU(True),nn.Linear(512,512),nn.ReLU(True),nn.Linear(512,1))def GlobalD(self, y, M):h = self.global_d_M(M)h = torch.cat((y, h), dim=1)return self.global_d_fc(h)def forward(self,y,M,M_prime):# 復(fù)制特征向量y_exp = y.unsqueeze(-1).unsqueeze(-1)y_exp = y_exp.expand(-1,-1,26,26) # 輸出形狀[b,64,26,26]# 按照特征圖的像素連接特征向量y_M = torch.cat((M,y_exp),dim=1) # 輸出形狀[b,192,26,26]y_M_prime = torch.cat((M_prime,y_exp),dim=1)# 輸出形狀[b,192,26,26]# 計(jì)算局部互信息---互信息的計(jì)算Ej = -F.softplus(-self.local_d(y_M)).mean() # 聯(lián)合分布Em = F.softplus(self.local_d(y_M_prime)).mean() # 邊緣分布LOCAL = (Em - Ej) * self.beta # 最大化互信息---對(duì)互信息執(zhí)行了取反操作。將最大化問(wèn)題變?yōu)樽钚』瘑?wèn)題,在訓(xùn)練過(guò)程中,可以使用最小化損失的方法進(jìn)行處理。# 計(jì)算全局互信息---互信息的計(jì)算Ej = -F.softplus(-self.GlobalD(y, M)).mean() # 聯(lián)合分布Em = F.softplus(self.GlobalD(y, M_prime)).mean() # 邊緣分布GLOBAL = (Em - Ej) * self.alpha # 最大化互信息---對(duì)互信息執(zhí)行了取反操作。將最大化問(wèn)題變?yōu)樽钚』瘑?wèn)題,在訓(xùn)練過(guò)程中,可以使用最小化損失的方法進(jìn)行處理。# 計(jì)算先驗(yàn)損失prior = torch.rand_like(y) # 獲得隨機(jī)數(shù)term_a = torch.log(self.prior_d(prior)).mean() # GAN損失term_b = torch.log(1.0 - self.prior_d(y)).mean()PRIOR = -(term_a + term_b) * self.gamma # 最大化目標(biāo)分布---實(shí)現(xiàn)了判別器的損失函數(shù)。判別器的目標(biāo)是將真實(shí)數(shù)據(jù)和生成數(shù)據(jù)的分布最大化,因此,也需要取反,通過(guò)最小化損失的方法來(lái)實(shí)現(xiàn)。return LOCAL + GLOBAL + PRIOR# #### 在訓(xùn)練過(guò)程中,梯度可以通過(guò)損失函數(shù)直接傳播到編碼器模型,進(jìn)行聯(lián)合優(yōu)化,因此,不需要對(duì)編碼器額外進(jìn)行損失函數(shù)的定義!# 1.3 實(shí)例化DIM模型并訓(xùn)練:實(shí)例化模型按照指定次數(shù)迭代訓(xùn)練。在制作邊緣分布樣本時(shí),將批次特征圖的第1條放到最后,以使特征圖與特征向量無(wú)法對(duì)應(yīng),實(shí)現(xiàn)與按批次打亂順序等同的效果。 totalepoch = 10 # 指定訓(xùn)練次數(shù) if __name__ == '__main__':encoder =Encoder().to(device)loss_fn = DeepInfoMaxLoss().to(device)optim = Adam(encoder.parameters(),lr=1e-4)loss_optim = Adam(loss_fn.parameters(),lr=1e-4)epoch_loss = []for epoch in range(totalepoch +1):batch = tqdm(train_loader,total=len(train_dataset)//batch_size)train_loss = []for x,target in batch: # 遍歷數(shù)據(jù)集x = x.to(device)optim.zero_grad()loss_optim.zero_grad()y,M = encoder(x) # 用編碼器生成特征圖和特征向量# 制作邊緣分布樣本M_prime = torch.cat((M[1:],M[0].unsqueeze(0)),dim=0)loss =loss_fn(y,M,M_prime) # 計(jì)算損失train_loss.append(loss.item())batch.set_description(str(epoch) + ' Loss:%.4f'% np.mean(train_loss[-20:]))loss.backward()optim.step() # 調(diào)用編碼器優(yōu)化器loss_optim.step() # 調(diào)用判別器優(yōu)化器if epoch % 10 == 0 : # 保存模型root = Path(r'./DIMmodel/')enc_file = root / Path('encoder' + str(epoch) + '.pth')loss_file = root / Path('loss' + str(epoch) + '.pth')enc_file.parent.mkdir(parents=True, exist_ok=True)torch.save(encoder.state_dict(), str(enc_file))torch.save(loss_fn.state_dict(), str(loss_file))epoch_loss.append(np.mean(train_loss[-20:])) # 收集訓(xùn)練損失plt.plot(np.arange(len(epoch_loss)), epoch_loss, 'r') # 損失可視化plt.show()4.2 加載模型:DIM_CIRFAR_loadpath.py
import torch import torch.nn.functional as F from tqdm import tqdm import random# 功能介紹:載入編碼器模型,對(duì)樣本集中所有圖片進(jìn)行編碼,隨機(jī)取一張圖片,找出與該圖片最接近與最不接近的十張圖片 # # 引入本地庫(kù) #引入本地代碼庫(kù) from DIM_CIRFAR_train import ( train_loader,train_dataset,totalepoch,device,batch_size,imshowrow, Encoder)# 加載模型 model_path = r'./DIMmodel/encoder%d.pth'% (totalepoch) encoder = Encoder().to(device) encoder.load_state_dict(torch.load(model_path,map_location=device))# 加載模型樣本,并調(diào)用編碼器生成特征向量 batchesimg = [] batchesenc = [] batch = tqdm(train_loader,total=len(train_dataset)//batch_size) for images ,target in batch :images = images.to(device)with torch.no_grad():encoded,features = encoder(images) # 調(diào)用編碼器生成特征向量batchesimg.append(images)batchesenc.append(encoded) # 將樣本中的圖片與生成的向量沿第1維度展開 batchesenc = torch.cat(batchesenc,axis = 0) batchesimg = torch.cat(batchesimg,axis = 0) # 驗(yàn)證向量的搜索功能 index = random.randrange(0,len(batchesenc)) # 隨機(jī)獲取一個(gè)索引,作為目標(biāo)圖片 batchesenc[index].repeat(len(batchesenc),1) # 將目標(biāo)圖片的特征向量復(fù)制多份 # 使用F.mse_loss()函數(shù)進(jìn)行特征向量間的L2計(jì)算,傳入了參數(shù)reduction='none',這表明對(duì)計(jì)算后的結(jié)果不執(zhí)行任何操作。如果不傳入該參數(shù),那么函數(shù)默認(rèn)會(huì)對(duì)所有結(jié)果取平均值(常用在訓(xùn)練模型場(chǎng)景中) l2_dis = F.mse_loss(batchesenc[index].repeat(len(batchesenc),1),batchesenc,reduction='none').sum(1) # 計(jì)算目標(biāo)圖片與每個(gè)圖片的L2距離 findnum = 10 # 設(shè)置查找圖片的個(gè)數(shù) # 使用topk()方法獲取L2距離最近、最遠(yuǎn)的圖片。該方法會(huì)返回兩個(gè)值,第一個(gè)是真實(shí)的比較值,第二個(gè)是該值對(duì)應(yīng)的索引。 _,indices = l2_dis.topk(findnum,largest=False ) # 查找10個(gè)最相近的圖片 _,indices_far = l2_dis.topk(findnum,) # 查找10個(gè)最不相關(guān)的圖片 # 顯示結(jié)果 indices = torch.cat([torch.tensor([index]).to(device),indices]) indices_far = torch.cat([torch.tensor([index]).to(device),indices_far]) rel = torch.cat([batchesimg[indices],batchesimg[indices_far]],axis = 0) imshowrow(rel.cpu() ,nrow=len(indices)) # 結(jié)果顯示:結(jié)果有兩行,每行的第一列是目標(biāo)圖片,第一行是與目標(biāo)圖片距離最近的搜索結(jié)果,第二行是與目標(biāo)圖片距離最遠(yuǎn)的搜索結(jié)果。總結(jié)
以上是生活随笔為你收集整理的【Pytorch神经网络实战案例】18 最大化深度互信信息模型DIM实现搜索最相关与最不相关的图片的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: mysql lib 5.5.28_mys
- 下一篇: 【Pytorch神经网络实战案例】28