5.2 使用pytorch搭建GoogLeNet网络 笔记
生活随笔
收集整理的這篇文章主要介紹了
5.2 使用pytorch搭建GoogLeNet网络 笔记
小編覺得挺不錯的,現(xiàn)在分享給大家,幫大家做個參考.
B站資源
csdn本家
文章目錄
- model
- train
- predict
model
1.BasicConv2d類
2.Inception
3.InceptionAux(nn.Module):#輔助分類器
4.GoogLeNet
其中上表格查的第一個參數(shù),輸入:
train
print(1) import torch print(2) import torch.nn as nn print(3) from torchvision import transforms, datasets import torchvision import json import matplotlib.pyplot as plt import os import torch.optim as optim from model import GoogLeNetprint(4)device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(5) #device = torch.device("cpu")if torch.cuda.is_available():print('yes') print(device)data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),"val": transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # get data root path image_path = data_root + "/data_set/flower_data/" # flower data set pathtrain_dataset = datasets.ImageFolder(root=image_path + "train",transform=data_transform["train"]) train_num = len(train_dataset)# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4} flower_list = train_dataset.class_to_idx cla_dict = dict((val, key) for key, val in flower_list.items()) # write dict into json file json_str = json.dumps(cla_dict, indent=4) with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 32 train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=0)validate_dataset = datasets.ImageFolder(root=image_path + "val",transform=data_transform["val"]) val_num = len(validate_dataset) validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=False,num_workers=0) #################################################### test_data_iter = iter(validate_loader) # test_image, test_label = test_data_iter.next()# net = torchvision.models.googlenet(num_classes=5) # model_dict = net.state_dict() # pretrain_model = torch.load("googlenet.pth") # del_list = ["aux1.fc2.weight", "aux1.fc2.bias", # "aux2.fc2.weight", "aux2.fc2.bias", # "fc.weight", "fc.bias"] # pretrain_dict = {k: v for k, v in pretrain_model.items() if k not in del_list} # model_dict.update(pretrain_dict) # net.load_state_dict(model_dict)######################################################net = GoogLeNet(num_classes=5, aux_logits=True, init_weights=True) net.to(device) loss_function = nn.CrossEntropyLoss() optimizer = optim.Adam(net.parameters(), lr=0.0003)best_acc = 0.0 save_path = './googleNet.pth' # import os ##導(dǎo)入預(yù)訓(xùn)練參數(shù)if os.path.exists(save_path):net.load_state_dict(torch.load(save_path))#高晗的預(yù)訓(xùn)練加載代碼 # model = MyModel().to(lib.device) # optimizer = Adam(model.parameters(), 0.001) # if os.path.exists(r"F:\code\NLP學(xué)習(xí) 2020\文本情感分類\model\model_ch2.pkl"): # model.load_state_dict(torch.load( # r"F:\code\NLP學(xué)習(xí) 2020\文本情感分類\model\model_ch2.pkl")) # optimizer.load_state_dict(torch.load( # r"F:\code\NLP學(xué)習(xí) 2020\文本情感分類\model\optimizer_ch2.pkl"))for epoch in range(30):# train## 這里可以加一條輸出上次訓(xùn)練的殘差與準(zhǔn)確度net.train()running_loss = 0.0for step, data in enumerate(train_loader, start=0):images, labels = dataoptimizer.zero_grad()#三個輸出logits, aux_logits2, aux_logits1 = net(images.to(device))loss0 = loss_function(logits, labels.to(device))loss1 = loss_function(aux_logits1, labels.to(device))loss2 = loss_function(aux_logits2, labels.to(device))loss = loss0 + loss1 * 0.3 + loss2 * 0.3#加權(quán)加入,0.3時論文給的loss.backward()#反向傳播optimizer.step()#優(yōu)化器更新參數(shù)# print statisticsrunning_loss += loss.item()# print train processrate = (step + 1) / len(train_loader)a = "*" * int(rate * 50)b = "." * int((1 - rate) * 50)print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")print()# validatenet.eval()acc = 0.0 # accumulate accurate number / epochwith torch.no_grad():for val_data in validate_loader:val_images, val_labels = val_dataoutputs = net(val_images.to(device)) # eval model only have last output layerpredict_y = torch.max(outputs, dim=1)[1]acc += (predict_y == val_labels.to(device)).sum().item()val_accurate = acc / val_numif val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('[epoch %d] train_loss: %.3f test_accuracy: %.3f' %(epoch + 1, running_loss / step, val_accurate))print('Finished Training')predict
import torch from model import GoogLeNet from PIL import Image from torchvision import transforms import matplotlib.pyplot as plt import jsondata_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# load image img = Image.open("../tulip.jpg") plt.imshow(img) # [N, C, H, W] img = data_transform(img) # expand batch dimension img = torch.unsqueeze(img, dim=0)# read class_indict try:json_file = open('./class_indices.json', 'r')class_indict = json.load(json_file) except Exception as e:print(e)exit(-1)# create model model = GoogLeNet(num_classes=5, aux_logits=False) # load model weights model_weight_path = "./googleNet.pth" missing_keys, unexpected_keys = model.load_state_dict(torch.load(model_weight_path), strict=False)#strict=False就不精確匹配參數(shù)(因為有兩個輔助分類器沒有) model.eval() with torch.no_grad():# predict classoutput = torch.squeeze(model(img))predict = torch.softmax(output, dim=0)predict_cla = torch.argmax(predict).numpy() print(class_indict[str(predict_cla)]) plt.show()總結(jié)
以上是生活随笔為你收集整理的5.2 使用pytorch搭建GoogLeNet网络 笔记的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: ACM取余
- 下一篇: 5.3 使用tensorflow搭建Go