pytorch实现图像分类代码实例
生活随笔
收集整理的這篇文章主要介紹了
pytorch实现图像分类代码实例
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
圖像多標簽分類例子
import os import torch import torch.nn as nn import torchvision.transforms as transforms import torchvision.datasets as datasets import torchvision.models as models import matplotlib.pyplot as plt from matplotlib.ticker import MultipleLocator from tensorboardX import SummaryWriterimport seaborn as sns from sklearn.metrics import confusion_matrix'''數據加載''' #選擇設備 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") #對三種數據集進行不同預處理,對訓練數據進行加強 data_transforms = {'train': transforms.Compose([transforms.RandomRotation(30),transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])]),'valid': transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])]) }#數據目錄 data_dir = "/DATA/wanghongzhi/17flowers"#獲取兩個數據集 image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), #要習慣python這種語法data_transforms[x]) for x in ['train', 'valid']} traindataset = image_datasets['train'] validdataset = image_datasets['valid']batch_size = 8 dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size,shuffle=True, num_workers=4) for x in ['train', 'valid']} print(dataloaders) traindataloader = dataloaders['train'] validdataloader = dataloaders['valid']dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}'''定義網絡結構''' class Net(nn.Module):def __init__(self,model):super(Net,self).__init__()self.features = model.features# for p in self.parameters():# p.requires_grad = Falseself.classifier = nn.Sequential(nn.Linear(25088, 4096,bias=True),nn.ReLU(inplace=True),nn.Dropout(p=0.5,inplace=False),nn.Linear(4096, 4096,bias=True),nn.ReLU(inplace=True),nn.Dropout(p=0.5,inplace=False),nn.Linear(4096, 102,bias=True))def forward(self,x):x = self.features(x)x = x.view(x.shape[0], -1)x = self.classifier(x)return xnet = models.resnet50().to(device)net.load_state_dict(torch.load('/home/wanghongzhi/zuoye/resnet50.pth'))'''參數設定''' criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, net.parameters()),lr=0.0001,momentum=0.9)'''定義根據loss列表繪制loss曲線函數''' def hua_loss(loss):l=len(loss)#x=list(range(1,l+1))x=range(1,l+1)# 設置圖片大小plt.figure(figsize=(20,8),dpi=80) # figsize設置圖片大小,dpi設置清晰度plt.title("Train-Epoch-Loss",fontsize=25)plt.xlabel("Epoch",fontsize=20)plt.ylabel("Loss",fontsize=20)plt.plot(x,loss)x_major_locator=MultipleLocator(2) #x軸刻度為1的倍數y_major_locator=MultipleLocator(0.15) #y軸刻度為0.01的倍數ax=plt.gca() #ax為兩條坐標軸的實例ax.xaxis.set_major_locator(x_major_locator)ax.yaxis.set_major_locator(y_major_locator)#保存#plt.savefig("./t1.png")plt.show()'''先定義驗證集檢驗''' #測試集和驗證集代碼一模一樣 def valid_model(model, criterion):best_acc = 0.0print('-' * 10)running_loss = 0.0running_corrects = 0model = model.to(device)for inputs, labels in validdataloader:inputs = inputs.to(device)labels = labels.to(device)model.eval()with torch.no_grad():outputs = model(inputs)loss = criterion(outputs, labels)print('outputs:',outputs)print('labels:',labels)_, preds = torch.max(outputs, 1)running_loss += loss.item()running_corrects += torch.sum(preds == labels).item()epoch_loss = running_loss / dataset_sizes['valid']print(running_corrects)epoch_acc = running_corrects / dataset_sizes['valid']print('{} Loss: {:.4f} Acc: {:.4f}'.format('valid', epoch_loss, epoch_acc))print('-' * 10)print()#val_loss.append(epoch_loss)'''訓練模型''' def train_model(model, criterion, optimizer, num_epochs=5):#since = time.time()best_acc = 0.0train_loss=[]#val_loss=[]for epoch in range(num_epochs):if (epoch+1)%5==0: #每五個epoch就用該模型驗證一次結果 valid_model(model, criterion)print('-' * 10)print('Epoch {}/{}'.format(epoch+1, num_epochs))running_loss = 0.0running_corrects = 0model = model.to(device)for inputs, labels in traindataloader:inputs = inputs.to(device)labels = labels.to(device)model.train()optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()_, preds = torch.max(outputs, 1)running_loss += loss.item() #加起來用來計算每個epoch的lossrunning_corrects += torch.sum(preds == labels).item() #item()取出張量中的值,或者(predicted==labels).sum().item()epoch_loss = running_loss / dataset_sizes['train']print(dataset_sizes['train']) #訓練集總數print(running_corrects) #正確預測個數epoch_acc = running_corrects / dataset_sizes['train']best_acc = max(best_acc,epoch_acc)print('{} Loss: {:.4f} Acc: {:.4f}'.format('train', epoch_loss, epoch_acc)) print()train_loss.append(epoch_loss)hua_loss(train_loss)print('Best val Acc: {:4f}'.format(best_acc)) return model'''開始訓練''' epochs = 5 model = train_model(net, criterion, optimizer, epochs)參考了博客:Pytorch實現鮮花分類(102 Category Flower Dataset)
總結
以上是生活随笔為你收集整理的pytorch实现图像分类代码实例的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: linux的挂载点是什么意思(linux
- 下一篇: ddos攻击内网(ddos攻击内网ip)