【Pytorch神经网络实战案例】16 条件WGAN模型生成可控Fashon-MNST模拟数据
生活随笔
收集整理的這篇文章主要介紹了
【Pytorch神经网络实战案例】16 条件WGAN模型生成可控Fashon-MNST模拟数据
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
1 條件GAN前置知識
條件GAN也可以使GAN所生成的數據可控,使模型變得實用,
1.1 實驗描述
搭建條件GAN模型,實現向模型中輸入標簽,并使其生成與標簽類別對應的模擬數據的功能,基于WGAN-gp模型改造實現帶有條件的wGAN-gp模型。
2?實例代碼編寫
條件GAN與條件自編碼神經網絡的做法幾乎一樣,在GAN的基礎之上,為每個模型輸入都添加一個標簽向量。
2.1 代碼實戰:引入模塊并載入樣本----WGAN_cond_237.py(第1部分)
import torch import torchvision from torchvision import transforms from torch.utils.data import DataLoader from torch import nn import torch.autograd as autograd import matplotlib.pyplot as plt import numpy as np import matplotlib import os os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"# 1.1 引入模塊并載入樣本:定義基本函數,加載FashionMNIST數據集 def to_img(x):x = 0.5 * (x+1)x = x.clamp(0,1)x = x.view(x.size(0),1,28,28)return xdef imshow(img,filename = None):npimg = img.numpy()plt.axis('off')array = np.transpose(npimg,(1,2,0))if filename != None:matplotlib.image.imsave(filename,array)else:plt.imshow(array)# plt.savefig(filename) # 保存圖片 注釋掉,因為會報錯,暫時不知道什么原因 2022.3.26 15:20plt.show()img_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5],std=[0.5])] )data_dir = './fashion_mnist'train_dataset = torchvision.datasets.FashionMNIST(data_dir,train=True,transform=img_transform,download=True) train_loader = DataLoader(train_dataset,batch_size=1024,shuffle=True) # 測試數據集 val_dataset = torchvision.datasets.FashionMNIST(data_dir,train=False,transform=img_transform) test_loader = DataLoader(val_dataset,batch_size=10,shuffle=False) # 指定設備 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(device)2.2 代碼實戰:實現生成器和判別器----WGAN_cond_237.py(第2部分)
# 1.2 實現生成器和判別器 :因為復雜部分都放在loss值的計算方面了,所以生成器和判別器就會簡單一些。 # 生成器和判別器各自有兩個卷積和兩個全連接層。生成器最終輸出與輸入圖片相同維度的數據作為模擬樣本。 # 判別器的輸出不需要有激活函數,并且輸出維度為1的數值用來表示結果。 # 在GAN模型中,因判別器的輸入則是具體的樣本數據,要區分每個數據的分布特征,所以判別器使用實例歸一化, class WGAN_D(nn.Module): # 定義判別器類D :有兩個卷積和兩個全連接層def __init__(self,inputch=1):super(WGAN_D, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(inputch,64,4,2,1), # 輸出形狀為[batch,64,28,28]nn.LeakyReLU(0.2,True),nn.InstanceNorm2d(64,affine=True))self.conv2 = nn.Sequential(nn.Conv2d(64,128,4,2,1),# 輸出形狀為[batch,64,14,14]nn.LeakyReLU(0.2,True),nn.InstanceNorm2d(128,affine=True))self.fc = nn.Sequential(nn.Linear(128*7*7,1024),nn.LeakyReLU(0.2,True))self.fc2 = nn.Sequential(nn.InstanceNorm1d(1,affine=True),nn.Flatten(),nn.Linear(1024,1))def forward(self,x,*arg): # 正向傳播x = self.conv1(x)x = self.conv2(x)x = x.view(x.size(0),-1)x = self.fc(x)x = x.reshape(x.size(0),1,-1)x = self.fc2(x)return x.view(-1,1).squeeze(1)# 在GAN模型中,因生成器的初始輸入是隨機值,所以生成器使用批量歸一化。 class WGAN_G(nn.Module): # 定義生成器類G:有兩個卷積和兩個全連接層def __init__(self,input_size,input_n=1):super(WGAN_G, self).__init__()self.fc1 = nn.Sequential(nn.Linear(input_size * input_n,1024),nn.ReLU(True),nn.BatchNorm1d(1024))self.fc2 = nn.Sequential(nn.Linear(1024,7*7*128),nn.ReLU(True),nn.BatchNorm1d(7*7*128))self.upsample1 = nn.Sequential(nn.ConvTranspose2d(128,64,4,2,padding=1,bias=False), # 輸出形狀為[batch,64,14,14]nn.ReLU(True),nn.BatchNorm2d(64))self.upsample2 = nn.Sequential(nn.ConvTranspose2d(64,1,4,2,padding=1,bias=False), # 輸出形狀為[batch,64,28,28]nn.Tanh())def forward(self,x,*arg): # 正向傳播x = self.fc1(x)x = self.fc2(x)x = x.view(x.size(0),128,7,7)x = self.upsample1(x)img = self.upsample2(x)return img2.3?代碼實戰:定義函數完成梯度懲罰項----WGAN_cond_237.py(第3部分)
# 1.3 定義函數compute_gradient_penalty()完成梯度懲罰項 # 懲罰項的樣本X_inter由一部分Pg分布和一部分Pr分布組成,同時對D(X_inter)求梯度,并計算梯度與1的平方差,最終得到gradient_penalties lambda_gp = 10 # 計算梯度懲罰項 def compute_gradient_penalty(D,real_samples,fake_samples,y_one_hot):# 獲取一個隨機數,作為真假樣本的采樣比例eps = torch.FloatTensor(real_samples.size(0),1,1,1).uniform_(0,1).to(device)# 按照eps比例生成真假樣本采樣值X_interX_inter = (eps * real_samples + ((1-eps)*fake_samples)).requires_grad_(True)d_interpolates = D(X_inter,y_one_hot)fake = torch.full((real_samples.size(0),),1,device=device) # 計算梯度輸出的掩碼,在本例中需要對所有梯度進行計算,故需要按照樣本個數生成全為1的張量。# 求梯度gradients = autograd.grad(outputs=d_interpolates, # 輸出值outputs,傳入計算過的張量結果inputs=X_inter,# 待求梯度的輸入值inputs,傳入可導的張量,即requires_grad=Truegrad_outputs=fake, # 傳出梯度的掩碼grad_outputs,使用1和0組成的掩碼,在計算梯度之后,會將求導結果與該掩碼進行相乘得到最終結果。create_graph=True,retain_graph=True,only_inputs=True)[0]gradients = gradients.view(gradients.size(0),-1)gradient_penaltys = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * lambda_gpreturn gradient_penaltys2.4?代碼實戰:定義模型的訓練函數----WGAN_cond_237.py(第4部分)
# 1.4 定義模型的訓練函數 # 定義函數train(),實現模型的訓練過程。 # 在函數train()中,按照對抗神經網絡專題(一)中的式(8-24)實現模型的損失函數。 # 判別器的loss為D(fake_samples)-D(real_samples)再加上聯合分布樣本的梯度懲罰項gradient_penalties,其中fake_samples為生成的模擬數據,real_Samples為真實數據, # 生成器的loss為-D(fake_samples)。 def train(D,G,outdir,z_dimension,num_epochs=30):d_optimizer = torch.optim.Adam(D.parameters(),lr=0.001) # 定義優化器g_optimizer = torch.optim.Adam(G.parameters(),lr=0.001)os.makedirs(outdir,exist_ok=True) # 創建輸出文件夾# 在函數train()中,判別器和生成器是分開訓練的。讓判別器學習的次數多一些,判別器每訓練5次,生成器優化1次。# WGAN_gp不會因為判別器準確率太高而引起生成器梯度消失的問題,所以好的判別器會讓生成器有更好的模擬效果。for epoch in range(num_epochs):for i,(img,lab) in enumerate(train_loader):num_img = img.size(0)# 訓練判別器real_img = img.to(device)y_one_hot = torch.zeros(lab.shape[0],10).scatter_(1,lab.view(lab.shape[0],1),1).to(device)for ii in range(5): # 循環訓練5次d_optimizer.zero_grad() # 梯度清零# 對real_img進行判別real_out = D(real_img,y_one_hot)# 生成隨機值z = torch.randn(num_img,z_dimension).to(device)fake_img = G(z,y_one_hot) # 生成fake_imgfake_out = D(fake_img,y_one_hot) # 對fake_img進行判別# 計算梯度懲罰項gradient_penalty = compute_gradient_penalty(D,real_img.data,fake_img.data,y_one_hot)# 計算判別器的lossd_loss = -torch.mean(real_out)+torch.mean(fake_out)+gradient_penaltyd_loss.backward()d_optimizer.step()# 訓練生成器for ii in range(1): # 訓練一次g_optimizer.zero_grad() # 梯度清0z = torch.randn(num_img,z_dimension).to(device)fake_img = G(z,y_one_hot)fake_out = D(fake_img,y_one_hot)g_loss = -torch.mean(fake_out)g_loss.backward()g_optimizer.step()# 輸出可視化結果,并將生成的結果以圖片的形式存儲在硬盤中fake_images = to_img(fake_img.cpu().data)real_images = to_img(real_img.cpu().data)rel = torch.cat([to_img(real_images[:10]), fake_images[:10]], axis=0)imshow(torchvision.utils.make_grid(rel, nrow=10),os.path.join(outdir, 'fake_images-{}.png'.format(epoch + 1)))# 輸出訓練結果print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} ''D real: {:.6f}, D fake: {:.6f}'.format(epoch, num_epochs, d_loss.data, g_loss.data,real_out.data.mean(), fake_out.data.mean()))# 保存訓練模型torch.save(G.state_dict(), os.path.join(outdir, 'generator.pth'))torch.save(D.state_dict(), os.path.join(outdir, 'discriminator.pth'))2.5?代碼實戰:現可視化模型結果----WGAN_cond_237.py(第5部分)
# 1.5 定義函數,實現可視化模型結果:獲取一部分測試數據,顯示由模型生成的模擬數據。 def displayAndTest(D,G,z_dimension): # 可視化結果sample = iter(test_loader)images, labels = sample.next()y_one_hot = torch.zeros(labels.shape[0], 10).scatter_(1,labels.view(labels.shape[0], 1), 1).to(device)num_img = images.size(0) # 獲取樣本個數with torch.no_grad():z = torch.randn(num_img, z_dimension).to(device) # 生成隨機數fake_img = G(z, y_one_hot)fake_images = to_img(fake_img.cpu().data) # 生成模擬樣本rel = torch.cat([to_img(images[:10]), fake_images[:10]], axis=0)imshow(torchvision.utils.make_grid(rel, nrow=10))print(labels[:10])2.6?定義判別器類CondWGAN_D----WGAN_cond_237.py
(第6部分)
# 1.6 定義判別器類CondWGAN_D # 在判別器和生成器類的正向結構中,增加標簽向量的輸入,并使用全連接網絡對標簽向量的維度進行擴展,同時將其連接到輸入數據。 class CondWGAN_D(WGAN_D): # 定義判別器類CondWGAN_D,使其繼承自WGAN_D類。def __init__(self, inputch=2):super(CondWGAN_D, self).__init__(inputch)self.labfc1 = nn.Linear(10, 28 * 28)def forward(self, x, lab): # 添加輸入標簽,batch, width, height, channel=1d_in = torch.cat((x.view(x.size(0), -1), self.labfc1(lab)), -1)x = d_in.view(d_in.size(0), 2, 28, 28)return super(CondWGAN_D, self).forward(x, lab)2.7?定義生成器類CondWGAN_G----WGAN_cond_237.py(第7部分)
# 1.7 定義生成器類CondWGAN_G # 在判別器和生成器類的正向結構中,增加標簽向量的輸入,并使用全連接網絡對標簽向量的維度進行擴展,同時將其連接到輸入數據。 class CondWGAN_G(WGAN_G): # 定義生成器類CondWGAN_G,使其繼承自WGAN_G類。def __init__(self, input_size, input_n=2):super(CondWGAN_G, self).__init__(input_size, input_n)self.labfc1 = nn.Linear(10, input_size)def forward(self, x, lab): # 添加輸入標簽,batch, width, height, channel=1d_in = torch.cat((x, self.labfc1(lab)), -1)return super(CondWGAN_G, self).forward(d_in, lab)2.8?調用函數并訓練模型----WGAN_cond_237.py(第6部分)
# 1.8 調用函數并訓練模型:實例化判別器和生成器模型,并調用函數進行訓練 if __name__ == '__main__':z_dimension = 40 # 設置輸入隨機數的維度D = CondWGAN_D().to(device) # 實例化判別器G = CondWGAN_G(z_dimension).to(device) # 實例化生成器train(D, G, './condw_img', z_dimension) # 訓練模型displayAndTest(D, G, z_dimension) # 輸出可視化在訓練之后,模型輸出了可視化結果,如圖所示,第1行是原始樣本,第2行是輸出的模擬樣本。
同時,程序也輸出了圖8-20中樣本對應的類標簽,如下:
? ? tensor([9,2,1,1,6,1,4,6,5,7])
從輸出的樣本中可以看到,輸出的模擬樣本與原始樣本的類別一致,這表明生成器可以按照指定的標簽生成模擬數據。
?3??代碼匯總(WGAN_cond_237.py)
import torch import torchvision from torchvision import transforms from torch.utils.data import DataLoader from torch import nn import torch.autograd as autograd import matplotlib.pyplot as plt import numpy as np import matplotlib import os os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"# 1.1 引入模塊并載入樣本:定義基本函數,加載FashionMNIST數據集 def to_img(x):x = 0.5 * (x+1)x = x.clamp(0,1)x = x.view(x.size(0),1,28,28)return xdef imshow(img,filename = None):npimg = img.numpy()plt.axis('off')array = np.transpose(npimg,(1,2,0))if filename != None:matplotlib.image.imsave(filename,array)else:plt.imshow(array)# plt.savefig(filename) # 保存圖片 注釋掉,因為會報錯,暫時不知道什么原因 2022.3.26 15:20plt.show()img_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5],std=[0.5])] )data_dir = './fashion_mnist'train_dataset = torchvision.datasets.FashionMNIST(data_dir,train=True,transform=img_transform,download=True) train_loader = DataLoader(train_dataset,batch_size=1024,shuffle=True) # 測試數據集 val_dataset = torchvision.datasets.FashionMNIST(data_dir,train=False,transform=img_transform) test_loader = DataLoader(val_dataset,batch_size=10,shuffle=False) # 指定設備 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(device)# 1.2 實現生成器和判別器 :因為復雜部分都放在loss值的計算方面了,所以生成器和判別器就會簡單一些。 # 生成器和判別器各自有兩個卷積和兩個全連接層。生成器最終輸出與輸入圖片相同維度的數據作為模擬樣本。 # 判別器的輸出不需要有激活函數,并且輸出維度為1的數值用來表示結果。 # 在GAN模型中,因判別器的輸入則是具體的樣本數據,要區分每個數據的分布特征,所以判別器使用實例歸一化, class WGAN_D(nn.Module): # 定義判別器類D :有兩個卷積和兩個全連接層def __init__(self,inputch=1):super(WGAN_D, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(inputch,64,4,2,1), # 輸出形狀為[batch,64,28,28]nn.LeakyReLU(0.2,True),nn.InstanceNorm2d(64,affine=True))self.conv2 = nn.Sequential(nn.Conv2d(64,128,4,2,1),# 輸出形狀為[batch,64,14,14]nn.LeakyReLU(0.2,True),nn.InstanceNorm2d(128,affine=True))self.fc = nn.Sequential(nn.Linear(128*7*7,1024),nn.LeakyReLU(0.2,True))self.fc2 = nn.Sequential(nn.InstanceNorm1d(1,affine=True),nn.Flatten(),nn.Linear(1024,1))def forward(self,x,*arg): # 正向傳播x = self.conv1(x)x = self.conv2(x)x = x.view(x.size(0),-1)x = self.fc(x)x = x.reshape(x.size(0),1,-1)x = self.fc2(x)return x.view(-1,1).squeeze(1)# 在GAN模型中,因生成器的初始輸入是隨機值,所以生成器使用批量歸一化。 class WGAN_G(nn.Module): # 定義生成器類G:有兩個卷積和兩個全連接層def __init__(self,input_size,input_n=1):super(WGAN_G, self).__init__()self.fc1 = nn.Sequential(nn.Linear(input_size * input_n,1024),nn.ReLU(True),nn.BatchNorm1d(1024))self.fc2 = nn.Sequential(nn.Linear(1024,7*7*128),nn.ReLU(True),nn.BatchNorm1d(7*7*128))self.upsample1 = nn.Sequential(nn.ConvTranspose2d(128,64,4,2,padding=1,bias=False), # 輸出形狀為[batch,64,14,14]nn.ReLU(True),nn.BatchNorm2d(64))self.upsample2 = nn.Sequential(nn.ConvTranspose2d(64,1,4,2,padding=1,bias=False), # 輸出形狀為[batch,64,28,28]nn.Tanh())def forward(self,x,*arg): # 正向傳播x = self.fc1(x)x = self.fc2(x)x = x.view(x.size(0),128,7,7)x = self.upsample1(x)img = self.upsample2(x)return img# 1.3 定義函數compute_gradient_penalty()完成梯度懲罰項 # 懲罰項的樣本X_inter由一部分Pg分布和一部分Pr分布組成,同時對D(X_inter)求梯度,并計算梯度與1的平方差,最終得到gradient_penalties lambda_gp = 10 # 計算梯度懲罰項 def compute_gradient_penalty(D,real_samples,fake_samples,y_one_hot):# 獲取一個隨機數,作為真假樣本的采樣比例eps = torch.FloatTensor(real_samples.size(0),1,1,1).uniform_(0,1).to(device)# 按照eps比例生成真假樣本采樣值X_interX_inter = (eps * real_samples + ((1-eps)*fake_samples)).requires_grad_(True)d_interpolates = D(X_inter,y_one_hot)fake = torch.full((real_samples.size(0),),1,device=device) # 計算梯度輸出的掩碼,在本例中需要對所有梯度進行計算,故需要按照樣本個數生成全為1的張量。# 求梯度gradients = autograd.grad(outputs=d_interpolates, # 輸出值outputs,傳入計算過的張量結果inputs=X_inter,# 待求梯度的輸入值inputs,傳入可導的張量,即requires_grad=Truegrad_outputs=fake, # 傳出梯度的掩碼grad_outputs,使用1和0組成的掩碼,在計算梯度之后,會將求導結果與該掩碼進行相乘得到最終結果。create_graph=True,retain_graph=True,only_inputs=True)[0]gradients = gradients.view(gradients.size(0),-1)gradient_penaltys = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * lambda_gpreturn gradient_penaltys# 1.4 定義模型的訓練函數 # 定義函數train(),實現模型的訓練過程。 # 在函數train()中,按照對抗神經網絡專題(一)中的式(8-24)實現模型的損失函數。 # 判別器的loss為D(fake_samples)-D(real_samples)再加上聯合分布樣本的梯度懲罰項gradient_penalties,其中fake_samples為生成的模擬數據,real_Samples為真實數據, # 生成器的loss為-D(fake_samples)。 def train(D,G,outdir,z_dimension,num_epochs=30):d_optimizer = torch.optim.Adam(D.parameters(),lr=0.001) # 定義優化器g_optimizer = torch.optim.Adam(G.parameters(),lr=0.001)os.makedirs(outdir,exist_ok=True) # 創建輸出文件夾# 在函數train()中,判別器和生成器是分開訓練的。讓判別器學習的次數多一些,判別器每訓練5次,生成器優化1次。# WGAN_gp不會因為判別器準確率太高而引起生成器梯度消失的問題,所以好的判別器會讓生成器有更好的模擬效果。for epoch in range(num_epochs):for i,(img,lab) in enumerate(train_loader):num_img = img.size(0)# 訓練判別器real_img = img.to(device)y_one_hot = torch.zeros(lab.shape[0],10).scatter_(1,lab.view(lab.shape[0],1),1).to(device)for ii in range(5): # 循環訓練5次d_optimizer.zero_grad() # 梯度清零# 對real_img進行判別real_out = D(real_img,y_one_hot)# 生成隨機值z = torch.randn(num_img,z_dimension).to(device)fake_img = G(z,y_one_hot) # 生成fake_imgfake_out = D(fake_img,y_one_hot) # 對fake_img進行判別# 計算梯度懲罰項gradient_penalty = compute_gradient_penalty(D,real_img.data,fake_img.data,y_one_hot)# 計算判別器的lossd_loss = -torch.mean(real_out)+torch.mean(fake_out)+gradient_penaltyd_loss.backward()d_optimizer.step()# 訓練生成器for ii in range(1): # 訓練一次g_optimizer.zero_grad() # 梯度清0z = torch.randn(num_img,z_dimension).to(device)fake_img = G(z,y_one_hot)fake_out = D(fake_img,y_one_hot)g_loss = -torch.mean(fake_out)g_loss.backward()g_optimizer.step()# 輸出可視化結果,并將生成的結果以圖片的形式存儲在硬盤中fake_images = to_img(fake_img.cpu().data)real_images = to_img(real_img.cpu().data)rel = torch.cat([to_img(real_images[:10]), fake_images[:10]], axis=0)imshow(torchvision.utils.make_grid(rel, nrow=10),os.path.join(outdir, 'fake_images-{}.png'.format(epoch + 1)))# 輸出訓練結果print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} ''D real: {:.6f}, D fake: {:.6f}'.format(epoch, num_epochs, d_loss.data, g_loss.data,real_out.data.mean(), fake_out.data.mean()))# 保存訓練模型torch.save(G.state_dict(), os.path.join(outdir, 'cond_generator.pth'))torch.save(D.state_dict(), os.path.join(outdir, 'cond_discriminator.pth'))# 1.5 定義函數,實現可視化模型結果:獲取一部分測試數據,顯示由模型生成的模擬數據。 def displayAndTest(D,G,z_dimension): # 可視化結果sample = iter(test_loader)images, labels = sample.next()y_one_hot = torch.zeros(labels.shape[0], 10).scatter_(1,labels.view(labels.shape[0], 1), 1).to(device)num_img = images.size(0) # 獲取樣本個數with torch.no_grad():z = torch.randn(num_img, z_dimension).to(device) # 生成隨機數fake_img = G(z, y_one_hot)fake_images = to_img(fake_img.cpu().data) # 生成模擬樣本rel = torch.cat([to_img(images[:10]), fake_images[:10]], axis=0)imshow(torchvision.utils.make_grid(rel, nrow=10))print(labels[:10])# 1.6 定義判別器類CondWGAN_D # 在判別器和生成器類的正向結構中,增加標簽向量的輸入,并使用全連接網絡對標簽向量的維度進行擴展,同時將其連接到輸入數據。 class CondWGAN_D(WGAN_D): # 定義判別器類CondWGAN_D,使其繼承自WGAN_D類。def __init__(self, inputch=2):super(CondWGAN_D, self).__init__(inputch)self.labfc1 = nn.Linear(10, 28 * 28)def forward(self, x, lab): # 添加輸入標簽,batch, width, height, channel=1d_in = torch.cat((x.view(x.size(0), -1), self.labfc1(lab)), -1)x = d_in.view(d_in.size(0), 2, 28, 28)return super(CondWGAN_D, self).forward(x, lab)# 1.7 定義生成器類CondWGAN_G # 在判別器和生成器類的正向結構中,增加標簽向量的輸入,并使用全連接網絡對標簽向量的維度進行擴展,同時將其連接到輸入數據。 class CondWGAN_G(WGAN_G): # 定義生成器類CondWGAN_G,使其繼承自WGAN_G類。def __init__(self, input_size, input_n=2):super(CondWGAN_G, self).__init__(input_size, input_n)self.labfc1 = nn.Linear(10, input_size)def forward(self, x, lab): # 添加輸入標簽,batch, width, height, channel=1d_in = torch.cat((x, self.labfc1(lab)), -1)return super(CondWGAN_G, self).forward(d_in, lab)# 1.8 調用函數并訓練模型:實例化判別器和生成器模型,并調用函數進行訓練 if __name__ == '__main__':z_dimension = 40 # 設置輸入隨機數的維度D = CondWGAN_D().to(device) # 實例化判別器G = CondWGAN_G(z_dimension).to(device) # 實例化生成器train(D, G, './condw_img', z_dimension) # 訓練模型displayAndTest(D, G, z_dimension) # 輸出可視化總結
以上是生活随笔為你收集整理的【Pytorch神经网络实战案例】16 条件WGAN模型生成可控Fashon-MNST模拟数据的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 摩尔斯电码转换python编码_Mors
- 下一篇: 【Pytorch神经网络实战案例】29