Pytorch基础-07-自动编码器
自動編碼器(AutoEncoder)是一種可以進行無監督學習的神經網絡模型。一般而言,一個完整的自動編碼器主要由兩部分組成,分別是用于核心特征提取的編碼部分和可以實現數據重構的解碼部分。
1 自動編碼器入門
在自動編碼器中負責編碼的部分也叫作編碼器(Encoder),而負 責解碼的部分也叫作解碼器(Decoder)。編碼器主要負責對原始的輸 入數據進行壓縮并提取數據中的核心特征,而解碼器主要是對在編碼器 中提取的核心特征進行展開并重新構造出之前的輸入數據。
如上圖就是一個簡化的自動編碼器模型,它的主要結構是神 經網絡,該模型的最左邊是用于數據輸入的輸入層,在輸入數據通過神 經網絡的層層傳遞之后得到了中間輸入數據的核心特征,這就完成了在 自編碼器中輸入數據的編碼過程。然后,將輸入數據的核心特征再傳遞 到一個逆向的神經網絡中,核心特征會被解壓并重構,最后得到了一個 和輸入數據相近的輸出數據,這就是自動編碼器中的解碼過程。輸入數 據通過自動編碼器模型的處理后又被重新還原了。
我們會好奇自動編碼器模型這種先編碼后解碼的神經網絡模型到底 有什么作用,下面進行講解。自動編碼器模型的最大用途就是實現輸入 數據的清洗,比如去除輸入數據中的噪聲數據、對輸入數據的某些關鍵特征進行增強和放大,等等。舉一個比較簡單的例子,假設我們現在有 一些被打上了馬賽克的圖片需要進行除碼處理,這時就可以通過自動編碼器模型來解決這個問題。其實可以將這個除碼的過程看作對數據進行除噪的過程,這也是我們接下來會實現的實踐案例。下面看看具體如何實現基于PyTorch的自動編碼器。
2 PyTorch之自動編碼實戰
2.1 通過線性變換實現自動編碼器模型
完成代碼:
import torch import torchvision from torchvision import datasets,transforms from torch.autograd import Variable import numpy as np import matplotlib.pyplot as plttransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5],std=[0.5])]) dataset_train = datasets.MNIST(root='./data',transform = transform,train = True,download= False) dataset_test = datasets.MNIST(root= './data',transform=transform,train=False) train_load = torch.utils.data.DataLoader(dataset = dataset_train,batch_size=4,shuffle=True) test_load = torch.utils.data.DataLoader(dataset = dataset_test,batch_size=4,shuffle=True) images,label = next(iter(train_load)) print (images.shape) images_example = torchvision.utils.make_grid(images) images_example = images_example.numpy().transpose(1,2,0) mean = [0.5] std = [0.5] images_example = images_example*std + mean plt.imshow(images_example) plt.show() noisy_images = images_example + 0.5*np.random.randn(*images_example.shape) noisy_images = np.clip(noisy_images,0.,1.) plt.imshow(noisy_images) plt.show()
在以上代碼中損失函數使用的是torch.nn.MSELoss,即計算的是均方誤差,我們在之前處理的都是圖片分類相關的問題,所以在這里使用交叉熵來計算損失值。而在這個問題中我們需要衡量的是圖片在去碼后和原始圖片之間的誤差,所以選擇均方誤差這類損失函數作為度量。總體的訓練流程是我們首先獲取一個批次的圖片,然后對這個批次的圖片進行打碼處理并裁剪到指定的像素值范圍內,因為之前說過,在MNIST數據集使用的圖片中每個像素點的數字值在0到1之間。在得到了經過打碼處理的圖片后,將其輸入搭建好的自動編碼器模型中,經過模型處理后輸出一個預測圖片,用這個預測圖片和原始圖片進行損失值計算,通過這個損失值對模型進行后向傳播,最后就能得到去除圖片馬賽克效果的模型了。
在每輪訓練中,我們都對預測圖片和原始圖片計算得到的損失值進 行輸出,在訓練10輪之后,輸出的結果如下:
從以上結果可以看出,我們得到的損失值在逐漸減小,而且損失值 已經在一個足夠小的范圍內了。最后,我們通過使用一部分測試數據集中的圖片來驗證我們的模型能否正常工作,代碼如下:
data_loader_test = torch.utils.data.DataLoader(dataset=dataset_test,batch_size = 4,shuffle = True) x_test,_ = next(iter(data_loader_test))img1 = torchvision.utils.make_grid(x_test) img1 = img1.numpy().transpose(1,2,0) std = [0.5] mean = [0.5] img1 =img1*std + meannoisy_x_test = img1 +0.5*np.random.randn(*img1.shape) noisy_x_test = np.clip(noisy_x_test,0.,1.)plt.figure() plt.imshow(noisy_x_test)img2 = x_test + 0.5*torch.randn(*x_test.shape) img2 = torch.clamp(img2,0.,1.)img2 = Variable(img2.view(-1,28*28))test_pred = model(img2)img_test = test_pred.data.view(-1,1,28,28) img2 = torchvision.utils.make_grid(img_test) img2 = img2.numpy().transpose(1,2,0) img2 = img2*std +mean img2 = np.clip(img2,0.,1.) plt.figure() plt.imshow(img2)
下面是使用普通的濾波器的效果:
自動編碼器的去噪效果,還是可圈可點的。
2.2 通過卷積變換實現自動編碼器模型
以卷積變換的方式和以線性變換方式構建的自動編碼器模型會有較大的區別,而且相對復雜一些,卷積變換的方式僅使用卷積層、最大池化層、上采樣層和激活函數作為神經網絡結構的主要組成部分,代碼如下:
class AutoEncoder2(torch.nn.Module):def __init__(self):super(AutoEncoder2, self).__init__()self.encoder = torch.nn.Sequential(torch.nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),torch.nn.ReLU(),torch.nn.MaxPool2d(kernel_size=2, stride=2),torch.nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),torch.nn.ReLU(),torch.nn.MaxPool2d(kernel_size=2, stride=2))self.decoder = torch.nn.Sequential(torch.nn.Upsample(scale_factor=2, mode='nearest'),torch.nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),torch.nn.ReLU(),torch.nn.Upsample(scale_factor=2, mode='nearest'),torch.nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1))def forward(self, input):output = self.encoder(input)output = self.decoder(output)return output在以上代碼中出現了一個我們之前從來沒有接觸過的上采樣層,即 torch.nn.Upsample類。這個類的作用就是對我們提取到的核心特征進行 解壓,實現圖片的重寫構建,傳遞給它的參數一共有兩個,分別是 scale_factor和mode:前者用于確定解壓的倍數;后者用于定義圖片重構 的模式,可選擇的模式有nearest、linear、bilinear和trilinear,其中nearest 是最鄰近法,linear是線性插值法,bilinear是雙線性插值法,trilinear是 三線性插值法。因為在我們的代碼中使用的是最鄰近法,所以這里通過 一張圖片來看一下最鄰近法的具體工作方式。
訓練代碼如下:
model = AutoEncoder2() print(model)optimizer = torch.optim.Adam(model.parameters()) loss_f = torch.nn.MSELoss()epoch_n = 5 for epoch in range(epoch_n):running_loss = 0print('Epoch {}/{}'.format(epoch, epoch_n))print('===' * 10)for data in train_load:X_train, _ = datanoisy_X_train = X_train + 0.5 * torch.rand(X_train.shape)noisy_X_train = torch.clamp(noisy_X_train, 0., 1.)X_train, noisy_X_train = Variable(X_train), Variable(noisy_X_train)train_pre = model(noisy_X_train)loss = loss_f(train_pre, X_train)optimizer.zero_grad()loss.backward()optimizer.step()running_loss += loss.dataprint('Loss is:{}'.format(running_loss / len(dataset_train)))我們在每輪訓練中都對預測圖片和原始圖片計算得到的損失值進行 輸出,在訓練5輪之后,輸出結果如下:
看一下測試效果:
總結
以上是生活随笔為你收集整理的Pytorch基础-07-自动编码器的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 整理常用的PHP函数
- 下一篇: Ubuntu解压安装包及make命令相关