PYTORCH:DenseNet做新冠肺炎CT照片是否确诊分类
生活随笔
收集整理的這篇文章主要介紹了
PYTORCH:DenseNet做新冠肺炎CT照片是否确诊分类
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
完整項目代碼:https://github.com/SPECTRELWF/pytorch-cnn-study
DenseNet網絡結構
DenseNet是清華大學的黃高教授在CVPR的工作,在resnet提出的第二年提出,也拿到了當年的最佳論文。。
數據集描述
數據集使用的是來自格物鈦的一個公開數據集,數據集下載地址:https://gas.graviti.cn/dataset/data-decorators/COVID_CT
里面包含715張圖片,包含確診和未確診的,比例大概一比一,圖像是處理過的CT圖像。
網絡結構
使用pytorch的torchvision里面提供的densenet(),未使用預訓練模型。在后面再加上一層全連接層:
# !/usr/bin/python3 # -*- coding:utf-8 -*- # Author:WeiFeng Liu # @Time: 2021/11/9 下午4:57import torchvision import torch.nn as nnclass my_densenet(nn.Module):def __init__(self):super(my_densenet, self).__init__()self.backbone = torchvision.models.densenet121(pretrained=False)self.fc2 = nn.Linear(1000,512)self.fc3 = nn.Linear(512,2)def forward(self,x):x = self.backbone(x)x = self.fc2(x)x = self.fc3(x)return xtrain:
# !/usr/bin/python3 # -*- coding:utf-8 -*- # Author:WeiFeng Liu # @Time: 2021/11/9 下午4:48import torch import torchvision import torchvision.transforms as transforms import torchvision.models as models import torch.utils.data as data from torch.utils.data import DataLoader from dataload.COVID_Dataload import COVID from densenet import my_densenet from torch import nn,optimtransforms = transforms.Compose([transforms.Resize([224,224]),transforms.RandomHorizontalFlip(),# transforms.RandomCrop(224),transforms.ToTensor(),])batch_size = 32 train_set = COVID(transformer=transforms,train=True) train_loader = DataLoader(train_set,batch_size = batch_size,shuffle = True,)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")#設置超參數 epochs = 200 lr = 1e-4net = my_densenet().cuda(device) loss_func = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(),lr=lr,momentum=0.9) train_loss = []for epoch in range(epochs):sum_loss = 0for batch_idx,(x,y) in enumerate(train_loader):x = x.to(device)y = y.to(device)pred = net(x)optimizer.zero_grad()loss = loss_func(pred, y)loss.backward()optimizer.step()sum_loss += loss.item()train_loss.append(loss.item())print(["epoch:%d , batch:%d , loss:%.3f" % (epoch, batch_idx,loss.item())])torch.save(net.state_dict(),'model/no_pretrain/epoch' + str(epoch+1) + '.pth') from utils import plot_curve plot_curve(train_loss)訓練loss
test:
# !/usr/bin/python3 # -*- coding:utf-8 -*- # Author:WeiFeng Liu # @Time: 2021/11/4 下午1:29import torch import torchvision from dataload.COVID_Dataload import COVID # 定義使用GPU from torch.utils.data import DataLoader device = torch.device("cuda" if torch.cuda.is_available() else "cpu") import torchvision.transforms as transforms from densenet import my_densenet transform = transforms.Compose([transforms.Resize([224,224]),transforms.RandomHorizontalFlip(),transforms.ToTensor(),# transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5]),])test_dataset = COVID(train=False,transformer=transform) test_loader = DataLoader(test_dataset,batch_size = 32,shuffle = False,)def predict():net = my_densenet().to(device)net.load_state_dict(torch.load('model/pretrain/epoch200.pth'))print(net)total_correct = 0for batch_idx, (x, y) in enumerate(test_loader):# x = x.view(x.size(0),28*28)# x = x.view(256,28,28)x = x.to(device)print(x.shape)y = y.to(device)print('y',y)out = net(x)# print(out)pred = out.argmax(dim=1)print('pred',pred)correct = pred.eq(y).sum().float().item()total_correct += correcttotal_num = len(test_loader.dataset)acc = total_correct / total_numprint("test acc:", acc)predict()predict
# !/usr/bin/python3 # -*- coding:utf-8 -*- # Author:WeiFeng Liu # @Time: 2021/11/4 下午2:38##讀入文件,顯示正確分類和預測分類 import matplotlib.pyplot as plt import torch import torchvision import torchvision.transforms as transforms from PIL import Image from densenet import my_densenettransform = transforms.Compose([transforms.Resize([224,224]),transforms.RandomHorizontalFlip(),transforms.ToTensor(),# transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5]),])file_name = input("輸入要預測的文件名:") img = Image.open(file_name).convert("RGB") show_img = img img = transform(img) # # print(img) # print(img.shape) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') img = img.to(device) img = img.unsqueeze(0) net = my_densenet().to(device) net.load_state_dict(torch.load(r'model/no_pretrain/epoch200.pth'))pred = net(img) print(pred) print(pred.argmax(dim = 1).cpu().numpy()[0]) res = '' if pred.argmax(dim = 1) == 0:res += 'pred:no_covid' else:res += 'pred:covid'plt.figure("Predict") plt.imshow(show_img) plt.axis("off") plt.title(res) plt.show()總結
以上是生活随笔為你收集整理的PYTORCH:DenseNet做新冠肺炎CT照片是否确诊分类的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: pytorch:ResNet50做新冠肺
- 下一篇: python将灰度图转换为RGB彩色图