PyTorch 实现 VAE 变分自编码器 含代码
生活随笔
收集整理的這篇文章主要介紹了
PyTorch 实现 VAE 变分自编码器 含代码
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
編碼器
- 自編碼器
- 自編碼器網絡結構圖
- 線性自編碼器代碼如下:
- 卷積自編碼器代碼如下:
- 變分自編碼器
- 變分自編碼器網絡結構圖
- 變分自編碼器代碼如下:
- Ref
自編碼器
自編碼器網絡結構圖
線性自編碼器代碼如下:
import torch import torchvision from torch import nn from torch import optim import torch.nn.functional as F from torch.autograd import Variable from torch.utils.data import DataLoader from torchvision import transforms from torchvision.utils import save_image from torchvision.datasets import MNIST import osif not os.path.exists('./vae_img'):os.mkdir('./vae_img')def to_img(x):x = x.clamp(0, 1)x = x.view(x.size(0), 1, 28, 28)return xnum_epochs = 100 batch_size = 128 learning_rate = 1e-3img_transform = transforms.Compose([transforms.ToTensor()# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ])dataset = MNIST('../data', transform=img_transform, download=True) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)class VAE(nn.Module):def __init__(self):super(VAE, self).__init__()self.fc1 = nn.Linear(784, 400)self.fc21 = nn.Linear(400, 20)self.fc22 = nn.Linear(400, 20)self.fc3 = nn.Linear(20, 400)self.fc4 = nn.Linear(400, 784)def encode(self, x):h1 = F.relu(self.fc1(x))return self.fc21(h1), self.fc22(h1)def reparametrize(self, mu, logvar):std = logvar.mul(0.5).exp_()if torch.cuda.is_available():eps = torch.cuda.FloatTensor(std.size()).normal_()else:eps = torch.FloatTensor(std.size()).normal_()eps = Variable(eps)return eps.mul(std).add_(mu)def decode(self, z):h3 = F.relu(self.fc3(z))# return F.sigmoid(self.fc4(h3))return torch.sigmoid(self.fc4(h3))def forward(self, x):mu, logvar = self.encode(x)z = self.reparametrize(mu, logvar)return self.decode(z), mu, logvarmodel = VAE() if torch.cuda.is_available():# model.cuda()print('cuda is OK!')model = model.to('cuda') else:print('cuda is NO!')reconstruction_function = nn.MSELoss(size_average=False) # reconstruction_function = nn.MSELoss(reduction=sum)def loss_function(recon_x, x, mu, logvar):"""recon_x: generating imagesx: origin imagesmu: latent meanlogvar: latent log variance"""BCE = reconstruction_function(recon_x, x) # mse loss# loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)KLD = torch.sum(KLD_element).mul_(-0.5)# KL divergencereturn BCE + KLDoptimizer = optim.Adam(model.parameters(), lr=1e-3)for epoch in range(num_epochs):model.train()train_loss = 0for batch_idx, data in enumerate(dataloader):img, _ = dataimg = img.view(img.size(0), -1)img = Variable(img)if torch.cuda.is_available():img = img.cuda()optimizer.zero_grad()recon_batch, mu, logvar = model(img)loss = loss_function(recon_batch, img, mu, logvar)loss.backward()# train_loss += loss.data[0]train_loss += loss.item()optimizer.step()if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch,batch_idx * len(img),len(dataloader.dataset), 100. * batch_idx / len(dataloader),# loss.data[0] / len(img)))loss.item() / len(img)))print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(dataloader.dataset)))if epoch % 10 == 0:save = to_img(recon_batch.cpu().data)save_image(save, './vae_img/image_{}.png'.format(epoch))torch.save(model.state_dict(), './vae.pth')卷積自編碼器代碼如下:
import os import datetimeimport torch import torchvision from torch import nn from torch.autograd import Variable from torch.utils.data import DataLoader from torchvision import transforms from torchvision.utils import save_image from torchvision.datasets import MNISTif not os.path.exists('./dc_img'):os.mkdir('./dc_img')def to_img(x):x = 0.5 * (x + 1)x = x.clamp(0, 1)x = x.view(x.size(0), 1, 28, 28)return xnum_epochs = 100 batch_size = 128 learning_rate = 1e-3img_transform = transforms.Compose([transforms.ToTensor(),# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))transforms.Normalize([0.5], [0.5]) ])dataset = MNIST('./data', transform=img_transform) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)class autoencoder(nn.Module):def __init__(self):super(autoencoder, self).__init__()self.encoder = nn.Sequential(nn.Conv2d(1, 16, 3, stride=3, padding=1), # b, 16, 10, 10nn.ReLU(True),nn.MaxPool2d(2, stride=2), # b, 16, 5, 5nn.Conv2d(16, 8, 3, stride=2, padding=1), # b, 8, 3, 3nn.ReLU(True),nn.MaxPool2d(2, stride=1) # b, 8, 2, 2)self.decoder = nn.Sequential(nn.ConvTranspose2d(8, 16, 3, stride=2), # b, 16, 5, 5nn.ReLU(True),nn.ConvTranspose2d(16, 8, 5, stride=3, padding=1), # b, 8, 15, 15nn.ReLU(True),nn.ConvTranspose2d(8, 1, 2, stride=2, padding=1), # b, 1, 28, 28nn.Tanh())def forward(self, x):x = self.encoder(x)x = self.decoder(x)return xmodel = autoencoder().cuda() criterion = nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,weight_decay=1e-5) starttime = datetime.datetime.now()for epoch in range(num_epochs):for data in dataloader:img, label = dataimg = Variable(img).cuda()# ===================forward=====================output = model(img)loss = criterion(output, img)# ===================backward====================optimizer.zero_grad()loss.backward()optimizer.step()# ===================log========================endtime = datetime.datetime.now()print('epoch [{}/{}], loss:{:.4f}, time:{:.2f}s'.format(epoch+1, num_epochs, loss.item(), (endtime-starttime).seconds))# if epoch % 10 == 0:pic = to_img(output.cpu().data)save_image(pic, './dc_img/image_{}.png'.format(epoch))torch.save(model.state_dict(), './conv_autoencoder.pth')變分自編碼器
變分自編碼器網絡結構圖
變分自編碼器代碼如下:
import torch import torchvision from torch import nn from torch import optim import torch.nn.functional as F from torch.autograd import Variable from torch.utils.data import DataLoader from torchvision import transforms from torchvision.utils import save_image from torchvision.datasets import MNIST import os import datetimeif not os.path.exists('./vae_img'):os.mkdir('./vae_img')def to_img(x):x = x.clamp(0, 1)x = x.view(x.size(0), 1, 28, 28)return xnum_epochs = 100 batch_size = 128 learning_rate = 1e-3img_transform = transforms.Compose([transforms.ToTensor()# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ])dataset = MNIST('./data', transform=img_transform, download=True) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)class VAE(nn.Module):def __init__(self):super(VAE, self).__init__()self.fc1 = nn.Linear(784, 400)self.fc21 = nn.Linear(400, 20)self.fc22 = nn.Linear(400, 20)self.fc3 = nn.Linear(20, 400)self.fc4 = nn.Linear(400, 784)def encode(self, x):h1 = F.relu(self.fc1(x))return self.fc21(h1), self.fc22(h1)def reparametrize(self, mu, logvar):std = logvar.mul(0.5).exp_()if torch.cuda.is_available():eps = torch.cuda.FloatTensor(std.size()).normal_()else:eps = torch.FloatTensor(std.size()).normal_()eps = Variable(eps)return eps.mul(std).add_(mu)def decode(self, z):h3 = F.relu(self.fc3(z))# return F.sigmoid(self.fc4(h3))return torch.sigmoid(self.fc4(h3))def forward(self, x):mu, logvar = self.encode(x)z = self.reparametrize(mu, logvar)return self.decode(z), mu, logvarstrattime = datetime.datetime.now() model = VAE() if torch.cuda.is_available():# model.cuda()print('cuda is OK!')model = model.to('cuda') else:print('cuda is NO!')reconstruction_function = nn.MSELoss(size_average=False) # reconstruction_function = nn.MSELoss(reduction=sum)def loss_function(recon_x, x, mu, logvar):"""recon_x: generating imagesx: origin imagesmu: latent meanlogvar: latent log variance"""BCE = reconstruction_function(recon_x, x) # mse loss# loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)KLD = torch.sum(KLD_element).mul_(-0.5)# KL divergencereturn BCE + KLDoptimizer = optim.Adam(model.parameters(), lr=1e-3)for epoch in range(num_epochs):model.train()train_loss = 0for batch_idx, data in enumerate(dataloader):img, _ = dataimg = img.view(img.size(0), -1)img = Variable(img)img = (img.cuda() if torch.cuda.is_available() else img)optimizer.zero_grad()recon_batch, mu, logvar = model(img)loss = loss_function(recon_batch, img, mu, logvar)loss.backward()# train_loss += loss.data[0]train_loss += loss.item()optimizer.step()if batch_idx % 100 == 0:endtime = datetime.datetime.now()print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} time:{:.2f}s'.format(epoch,batch_idx * len(img),len(dataloader.dataset), 100. * batch_idx / len(dataloader),loss.item() / len(img), (endtime-strattime).seconds))print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(dataloader.dataset)))if epoch % 10 == 0:save = to_img(recon_batch.cpu().data)save_image(save, './vae_img/image_{}.png'.format(epoch))torch.save(model.state_dict(), './vae.pth')Ref
總結
以上是生活随笔為你收集整理的PyTorch 实现 VAE 变分自编码器 含代码的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 【Arduino】四位数码管显示
- 下一篇: PyTorch 实现 GAN 生成式对抗