基于SegNet和UNet的遥感图像分割代码解读
生活随笔
收集整理的這篇文章主要介紹了
基于SegNet和UNet的遥感图像分割代码解读
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
基于SegNet和UNet的遙感圖像分割代碼解讀
目錄
- 基于SegNet和UNet的遙感圖像分割代碼解讀
- 前言
- 概述
- 代碼框架
- 代碼細節分析
- 劃分數據集gen_dataset.py
- UNet模型訓練unet_train.py
- 模型融合combind.py
- UNet模型預測unet_predict.py
- 分類結果集成ensemble.py
- SegNet模型訓練segnet_train.py
前言
上了一學期的課,趁著寒假有時間,看了往年論文和部分比賽的代碼,現在整理出來。整理的這部分內容以實際操作為主,主要講解代碼部分的分析。
概述
首先來分享一個小項目,基于SegNet和UNet的遙感圖像比賽。代碼來自github,這是對項目的簡要介紹。
代碼框架
以下是項目的代碼結構:總共有4個子目錄,分別是deprecated、ensemble、segnet、unet,其中deprecated是作者的一些代碼草稿,ensemble是對不同分類結果的集成,segnet和unet分別是兩個典型網絡的網絡架構、訓練代碼、預測代碼、劃分訓練集和測試集的代碼。
代碼細節分析
劃分數據集gen_dataset.py
import cv2 import random import os import numpy as np from tqdm import tqdmimg_w = 256 img_h = 256 # 數據集一共5張圖片 image_sets = ['1.png','2.png','3.png','4.png','5.png']def gamma_transform(img, gamma):gamma_table = [np.power(x / 255.0, gamma) * 255.0 for x in range(256)]gamma_table = np.round(np.array(gamma_table)).astype(np.uint8)# LUT: Look Up Table查找表,通過LUT變換可以改變圖像的曝光和色彩return cv2.LUT(img, gamma_table)def random_gamma_transform(img, gamma_vari):log_gamma_vari = np.log(gamma_vari)alpha = np.random.uniform(-log_gamma_vari, log_gamma_vari)gamma = np.exp(alpha)return gamma_transform(img, gamma)# 旋轉image def rotate(xb,yb,angle):M_rotate = cv2.getRotationMatrix2D((img_w/2, img_h/2), angle, 1)xb = cv2.warpAffine(xb, M_rotate, (img_w, img_h))yb = cv2.warpAffine(yb, M_rotate, (img_w, img_h))return xb,ybdef blur(img):# cv2.blur(img,(size,size))表示對img使用尺寸為size x size的均值濾波器進行平滑img = cv2.blur(img, (3, 3));return img # 加噪聲 def add_noise(img):for i in range(200): #添加點噪聲temp_x = np.random.randint(0,img.shape[0])temp_y = np.random.randint(0,img.shape[1])img[temp_x][temp_y] = 255return img# 數據增強:圖像旋轉、gamma變換、模糊變換、加噪聲 def data_augment(xb,yb):if np.random.random() < 0.25:xb,yb = rotate(xb,yb,90)if np.random.random() < 0.25:xb,yb = rotate(xb,yb,180)if np.random.random() < 0.25:xb,yb = rotate(xb,yb,270)if np.random.random() < 0.25:xb = cv2.flip(xb, 1) # flipcode > 0:沿y軸翻轉yb = cv2.flip(yb, 1)if np.random.random() < 0.25:xb = random_gamma_transform(xb,1.0)if np.random.random() < 0.25:xb = blur(xb)if np.random.random() < 0.2:xb = add_noise(xb)return xb,yb # 構建數據集 def creat_dataset(image_num = 50000, mode = 'original'):print('creating dataset...')# len(image_sets) = 5image_each = image_num / len(image_sets)g_count = 0for i in tqdm(range(len(image_sets))):count = 0# 讀取源圖像和標記圖像src_img = cv2.imread('./data/src/' + image_sets[i]) # 3 channelslabel_img = cv2.imread('./data/road_label/' + image_sets[i],cv2.IMREAD_GRAYSCALE) # single channelX_height,X_width,_ = src_img.shapewhile count < image_each:# img_w = img_h = 256random_width = random.randint(0, X_width - img_w - 1)random_height = random.randint(0, X_height - img_h - 1)# 隨機截取img_h x img_w大小的圖像src_roi = src_img[random_height: random_height + img_h, random_width: random_width + img_w,:]label_roi = label_img[random_height: random_height + img_h, random_width: random_width + img_w]# 如果是增強模式,那么對源圖像和標記圖像使用數據增強if mode == 'augment':src_roi,label_roi = data_augment(src_roi,label_roi)visualize = np.zeros((256,256)).astype(np.uint8)visualize = label_roi *50# 劃分數據集cv2.imwrite(('./unet_train/visualize/%d.png' % g_count),visualize)cv2.imwrite(('./unet_train/road/src/%d.png' % g_count),src_roi)cv2.imwrite(('./unet_train/road/label/%d.png' % g_count),label_roi)count += 1 g_count += 1if __name__=='__main__': creat_dataset(mode='augment')UNet模型訓練unet_train.py
#coding=utf-8 import matplotlib # matplotlib.use('Agg')必須放在import matplotlib.pyplot as plt前面,這個語句的意思是不使用交互式頁面,僅僅保存圖像而是不把圖像shhow出來 matplotlib.use("Agg") import matplotlib.pyplot as plt import argparse import numpy as np from keras.models import Sequential from keras.layers import Conv2D,MaxPooling2D,UpSampling2D,BatchNormalization,Reshape,Permute,Activation,Input from keras.utils.np_utils import to_categorical from keras.preprocessing.image import img_to_array from keras.callbacks import ModelCheckpoint from sklearn.preprocessing import LabelEncoder from keras.models import Model from keras.layers.merge import concatenate from PIL import Image import matplotlib.pyplot as plt import cv2 import random import os from tqdm import tqdm os.environ["CUDA_VISIBLE_DEVICES"] = "4" # 設置隨機數種子,以便每次產生的隨機數一樣,方便比較在同一批數據上比較實驗結果 seed = 7 np.random.seed(seed) #data_shape = 360*480 img_w = 256 img_h = 256 #有一個為背景 #n_label = 4+1 n_label = 1 # 總共5個類別 classes = [0. , 1., 2., 3. , 4.] labelencoder = LabelEncoder() labelencoder.fit(classes) image_sets = ['1.png','2.png','3.png']def load_img(path, grayscale=False):if grayscale:# cv2.IMREAD_GRAYSCALE將灰度圖讀取成灰度圖,否則cv2.imread默認將圖像讀取為RGBimg = cv2.imread(path,cv2.IMREAD_GRAYSCALE)else:img = cv2.imread(path)# 歸一化img = np.array(img,dtype="float") / 255.0return img # 訓練數據路徑 filepath ='./unet_train/' # 劃分訓練集和驗證集,其中用25%的數據來做驗證集 def get_train_val(val_rate = 0.25):train_url = [] train_set = []val_set = []for pic in os.listdir(filepath + 'src'):train_url.append(pic)random.shuffle(train_url)total_num = len(train_url)val_num = int(val_rate * total_num)# 打亂順序之后的前25%作為驗證集,剩余75%作為訓練集for i in range(len(train_url)):if i < val_num:val_set.append(train_url[i]) else:train_set.append(train_url[i])return train_set,val_set # 產生訓練數據 # data for training def generateData(batch_size,data=[]): #print 'generateData...'while True: train_data = [] train_label = [] batch = 0 for i in (range(len(data))): url = data[i]batch += 1 img = load_img(filepath + 'src/' + url)img = img_to_array(img) train_data.append(img)label = load_img(filepath + 'label/' + url, grayscale=True) label = img_to_array(label)train_label.append(label) if batch % batch_size==0: #print 'get enough batch!\n'train_data = np.array(train_data) train_label = np.array(train_label) yield (train_data,train_label) train_data = [] train_label = [] batch = 0 # 產生驗證數據 # data for validation def generateValidData(batch_size,data=[]): #print 'generateValidData...'while True: valid_data = [] valid_label = [] batch = 0 for i in (range(len(data))): url = data[i]batch += 1 img = load_img(filepath + 'src/' + url)img = img_to_array(img) valid_data.append(img) label = load_img(filepath + 'label/' + url, grayscale=True)label = img_to_array(label)valid_label.append(label) if batch % batch_size==0: valid_data = np.array(valid_data) valid_label = np.array(valid_label) yield (valid_data,valid_label) valid_data = [] valid_label = [] batch = 0 # 定義unet,整體上來看是一個對稱的U型結構 def unet():inputs = Input((3, img_w, img_h))conv1 = Conv2D(32, (3, 3), activation="relu", padding="same")(inputs)conv1 = Conv2D(32, (3, 3), activation="relu", padding="same")(conv1)pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)conv2 = Conv2D(64, (3, 3), activation="relu", padding="same")(pool1)conv2 = Conv2D(64, (3, 3), activation="relu", padding="same")(conv2)pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)conv3 = Conv2D(128, (3, 3), activation="relu", padding="same")(pool2)conv3 = Conv2D(128, (3, 3), activation="relu", padding="same")(conv3)pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)conv4 = Conv2D(256, (3, 3), activation="relu", padding="same")(pool3)conv4 = Conv2D(256, (3, 3), activation="relu", padding="same")(conv4)pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)conv5 = Conv2D(512, (3, 3), activation="relu", padding="same")(pool4)conv5 = Conv2D(512, (3, 3), activation="relu", padding="same")(conv5)conv5 = MaxPooling2D(pool_size=(2,2))(conv5)# 引入上采樣將特征圖方法,就是簡單的插值。其中,UpSampling2D(size = size)(x),執行的操作是分別將x的行和列重復size[0]和size[1]次# 例如令size = [2,2], 從[[1,2],[3,4]]變成[[1,1,2,2],[1,1,2,2],[3,3,4,4],[3,3,4,4]]up6 = concatenate([UpSampling2D(size=(2, 2))(conv5), conv4], axis=1)conv6 = Conv2D(256, (3, 3), activation="relu", padding="same")(up6)conv6 = Conv2D(256, (3, 3), activation="relu", padding="same")(conv6)up7 = concatenate([UpSampling2D(size=(2, 2))(conv6), conv3], axis=1)conv7 = Conv2D(128, (3, 3), activation="relu", padding="same")(up7)conv7 = Conv2D(128, (3, 3), activation="relu", padding="same")(conv7)up8 = concatenate([UpSampling2D(size=(2, 2))(conv7), conv2], axis=1)conv8 = Conv2D(64, (3, 3), activation="relu", padding="same")(up8)conv8 = Conv2D(64, (3, 3), activation="relu", padding="same")(conv8)up9 = concatenate([UpSampling2D(size=(2, 2))(conv8), conv1], axis=1)conv9 = Conv2D(32, (3, 3), activation="relu", padding="same")(up9)conv9 = Conv2D(32, (3, 3), activation="relu", padding="same")(conv9)conv10 = Conv2D(n_label, (1, 1), activation="sigmoid")(conv9)#conv10 = Conv2D(n_label, (1, 1), activation="softmax")(conv9)model = Model(inputs=inputs, outputs=conv10)# 使用二元分類的cross_entropy,直接用cross_entropy也可以,多分類問題也適用于二分類問題model.compile(optimizer='Adam', loss='binary_crossentropy', metrics=['accuracy'])return modeldef train(args): EPOCHS = 10# batch_sizeBS = 16#model = SegNet() model = unet()modelcheck = ModelCheckpoint(args['model'],monitor='val_accuracy',save_best_only=True,mode='max') callable = [modelcheck] train_set,val_set = get_train_val()train_numb = len(train_set) valid_numb = len(val_set) print ("the number of train data is",train_numb) print ("the number of val data is",valid_numb)# max_q_size定義了內部訓練隊列(queue)的最大大小H = model.fit_generator(generator=generateData(BS,train_set),steps_per_epoch=train_numb//BS,epochs=EPOCHS,verbose=1, validation_data=generateValidData(BS,val_set),validation_steps=valid_numb//BS,callbacks=callable,max_q_size=1) # plot the training loss and accuracy# plt.style.use('ggplot')用ggplot樣式美化畫圖效果# 可選的plt.style(plt.style.available)如下:# ['bmh', 'classic', 'dark_background', 'fast', 'fivethirtyeight', 'ggplot', 'grayscale', 'seaborn-bright', 'seaborn-colorblind', # 'seaborn-dark-palette', 'seaborn-dark', 'seaborn-darkgrid', 'seaborn-deep', 'seaborn-muted', 'seaborn-notebook', 'seaborn-paper',# 'seaborn-pastel', 'seaborn-poster', 'seaborn-talk', 'seaborn-ticks', 'seaborn-white', 'seaborn-whitegrid', 'seaborn', # 'Solarize_Light2', 'tableau-colorblind10', '_classic_test']plt.style.use("ggplot")plt.figure()N = EPOCHSplt.plot(np.arange(0, N), H.history["loss"], label="train_loss")plt.plot(np.arange(0, N), H.history["val_loss"], label="val_loss")plt.plot(np.arange(0, N), H.history["acc"], label="train_acc")plt.plot(np.arange(0, N), H.history["val_acc"], label="val_acc")plt.title("Training Loss and Accuracy on U-Net Satellite Seg")plt.xlabel("Epoch #")plt.ylabel("Loss/Accuracy")# 在右下角畫圖plt.legend(loc="lower left")plt.savefig(args["plot"])# 命令行輸入參數的提示以及默認參數 def args_parse():# construct the argument parse and parse the argumentsap = argparse.ArgumentParser()ap.add_argument("-d", "--data", help="training data's path",default=True)ap.add_argument("-m", "--model", required=True,help="path to output model")ap.add_argument("-p", "--plot", type=str, default="plot.png",help="path to output accuracy/loss plot")args = vars(ap.parse_args()) return argsif __name__=='__main__': args = args_parse()filepath = args['data']train(args) #predict()為了看清楚unet的每一層的輸入輸出的tensor是怎么樣的形狀,我們將其打印出來如下:
__________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_7 (InputLayer) (None, 3, 256, 256) 0 __________________________________________________________________________________________________ conv2d_79 (Conv2D) (None, 32, 256, 256) 896 input_7[0][0] __________________________________________________________________________________________________ conv2d_80 (Conv2D) (None, 32, 256, 256) 9248 conv2d_79[0][0] __________________________________________________________________________________________________ max_pooling2d_29 (MaxPooling2D) (None, 32, 128, 128) 0 conv2d_80[0][0] __________________________________________________________________________________________________ conv2d_81 (Conv2D) (None, 64, 128, 128) 18496 max_pooling2d_29[0][0] __________________________________________________________________________________________________ conv2d_82 (Conv2D) (None, 64, 128, 128) 36928 conv2d_81[0][0] __________________________________________________________________________________________________ max_pooling2d_30 (MaxPooling2D) (None, 64, 64, 64) 0 conv2d_82[0][0] __________________________________________________________________________________________________ conv2d_83 (Conv2D) (None, 128, 64, 64) 73856 max_pooling2d_30[0][0] __________________________________________________________________________________________________ conv2d_84 (Conv2D) (None, 128, 64, 64) 147584 conv2d_83[0][0] __________________________________________________________________________________________________ max_pooling2d_31 (MaxPooling2D) (None, 128, 32, 32) 0 conv2d_84[0][0] __________________________________________________________________________________________________ conv2d_85 (Conv2D) (None, 256, 32, 32) 295168 max_pooling2d_31[0][0] __________________________________________________________________________________________________ conv2d_86 (Conv2D) (None, 256, 32, 32) 590080 conv2d_85[0][0] __________________________________________________________________________________________________ max_pooling2d_32 (MaxPooling2D) (None, 256, 16, 16) 0 conv2d_86[0][0] __________________________________________________________________________________________________ conv2d_87 (Conv2D) (None, 512, 16, 16) 1180160 max_pooling2d_32[0][0] __________________________________________________________________________________________________ conv2d_88 (Conv2D) (None, 512, 16, 16) 2359808 conv2d_87[0][0] __________________________________________________________________________________________________ up_sampling2d_13 (UpSampling2D) (None, 512, 32, 32) 0 conv2d_88[0][0] __________________________________________________________________________________________________ concatenate_13 (Concatenate) (None, 768, 32, 32) 0 up_sampling2d_13[0][0]conv2d_86[0][0] __________________________________________________________________________________________________ conv2d_89 (Conv2D) (None, 256, 32, 32) 1769728 concatenate_13[0][0] __________________________________________________________________________________________________ conv2d_90 (Conv2D) (None, 256, 32, 32) 590080 conv2d_89[0][0] __________________________________________________________________________________________________ up_sampling2d_14 (UpSampling2D) (None, 256, 64, 64) 0 conv2d_90[0][0] __________________________________________________________________________________________________ concatenate_14 (Concatenate) (None, 384, 64, 64) 0 up_sampling2d_14[0][0]conv2d_84[0][0] __________________________________________________________________________________________________ conv2d_91 (Conv2D) (None, 128, 64, 64) 442496 concatenate_14[0][0] __________________________________________________________________________________________________ conv2d_92 (Conv2D) (None, 128, 64, 64) 147584 conv2d_91[0][0] __________________________________________________________________________________________________ up_sampling2d_15 (UpSampling2D) (None, 128, 128, 128 0 conv2d_92[0][0] __________________________________________________________________________________________________ concatenate_15 (Concatenate) (None, 192, 128, 128 0 up_sampling2d_15[0][0]conv2d_82[0][0] __________________________________________________________________________________________________ conv2d_93 (Conv2D) (None, 64, 128, 128) 110656 concatenate_15[0][0] __________________________________________________________________________________________________ conv2d_94 (Conv2D) (None, 64, 128, 128) 36928 conv2d_93[0][0] __________________________________________________________________________________________________ up_sampling2d_16 (UpSampling2D) (None, 64, 256, 256) 0 conv2d_94[0][0] __________________________________________________________________________________________________ concatenate_16 (Concatenate) (None, 96, 256, 256) 0 up_sampling2d_16[0][0]conv2d_80[0][0] __________________________________________________________________________________________________ conv2d_95 (Conv2D) (None, 32, 256, 256) 27680 concatenate_16[0][0] __________________________________________________________________________________________________ conv2d_96 (Conv2D) (None, 32, 256, 256) 9248 conv2d_95[0][0] __________________________________________________________________________________________________ conv2d_97 (Conv2D) (None, 1, 256, 256) 33 conv2d_96[0][0] ================================================================================================== Total params: 7,846,657 Trainable params: 7,846,657 Non-trainable params: 0 __________________________________________________________________________________________________模型融合combind.py
#coding=utf-8import numpy as np import cv2 import csv from tqdm import tqdm # 定義三個mask mask1_pool = ['testing1_vegetation_predict.png','testing1_building_predict.png','testing1_water_predict.png','testing1_road_predict.png']mask2_pool = ['testing2_vegetation_predict.png','testing2_building_predict.png','testing2_water_predict.png','testing2_road_predict.png']mask3_pool = ['testing3_vegetation_predict.png','testing3_building_predict.png','testing3_water_predict.png','testing3_road_predict.png'] ## 0:none 1:vegetation 2:building 3:water 4:road#after mask combind img_sets = ['pre1.png','pre2.png','pre3.png']def combind_all_mask():for mask_num in tqdm(range(3)):if mask_num == 0:final_mask = np.zeros((5142,5664),np.uint8)#生成一個全黑全0圖像,圖片尺寸與原圖相同elif mask_num == 1:final_mask = np.zeros((2470,4011),np.uint8)elif mask_num == 2:final_mask = np.zeros((6116,3356),np.uint8)#final_mask = cv2.imread('final_1_8bits_predict.png',0)if mask_num == 0:mask_pool = mask1_poolelif mask_num == 1:mask_pool = mask2_poolelif mask_num == 2:mask_pool = mask3_poolfinal_name = img_sets[mask_num]for idx,name in enumerate(mask_pool):img = cv2.imread('./predict_mask/'+name,0)height,width = img.shapelabel_value = idx+1 #coressponding labels valuefor i in tqdm(range(height)): #priority:building>water>road>vegetationfor j in range(width):# 模型融合if img[i,j] == 255:# 如果當前像素為全部為全白,那么到底這個區域屬于哪個類別呢?按照優先級的順序來定:building>water>road>vegetationif label_value == 2:final_mask[i,j] = label_valueelif label_value == 3 and final_mask[i,j] != 2:final_mask[i,j] = label_valueelif label_value == 4 and final_mask[i,j] != 2 and final_mask[i,j] != 3:final_mask[i,j] = label_valueelif label_value == 1 and final_mask[i,j] == 0:final_mask[i,j] = label_value cv2.imwrite('./final_result/'+final_name,final_mask) print 'combinding mask...' combind_all_mask()UNet模型預測unet_predict.py
import cv2 import random import numpy as np import os import argparse from keras.preprocessing.image import img_to_array from keras.models import load_model from sklearn.preprocessing import LabelEncoder # 設置用編號為1的GPU來訓練 os.environ["CUDA_VISIBLE_DEVICES"] = "1"TEST_SET = ['1.png','2.png','3.png']image_size = 256classes = [0. , 1., 2., 3. , 4.] labelencoder = LabelEncoder() labelencoder.fit(classes) def args_parse(): # construct the argument parse and parse the argumentsap = argparse.ArgumentParser()ap.add_argument("-m", "--model", required=True,help="path to trained model model")ap.add_argument("-s", "--stride", required=False,help="crop slide stride", type=int, default=image_size)args = vars(ap.parse_args()) return argsdef predict(args):# load the trained convolutional neural networkprint("[INFO] loading network...")# 加載訓練好的模型model = load_model(args["model"])stride = args['stride']for n in range(len(TEST_SET)):path = TEST_SET[n]#load the image讀取測試圖片image = cv2.imread('./test/' + path)h,w,_ = image.shape# 要怎么樣進行預測呢?由于在訓練的時候輸入的圖像大小是256x256,在測試的時候喂給model的size也是256,# 可以先對原圖補零,確保padding之后的size剛好可以被256整除padding_h = (h//stride + 1) * stride padding_w = (w//stride + 1) * stridepadding_img = np.zeros((padding_h,padding_w,3),dtype=np.uint8)# 不足的部分補零padding_img[0:h,0:w,:] = image[:,:,:]#padding_img = padding_img.astype("float") / 255.0padding_img = img_to_array(padding_img)print ('src:',padding_img.shape)mask_whole = np.zeros((padding_h,padding_w),dtype=np.uint8)for i in range(padding_h//stride):for j in range(padding_w//stride):# 放到padding之后的圖像對應的位置crop = padding_img[:3,i*stride:i*stride+image_size,j*stride:j*stride+image_size]_,ch,cw = crop.shapeif ch != 256 or cw != 256:print ('invalid size!')continuecrop = np.expand_dims(crop, axis=0) # fit當中的verbose = 0 為不在標準輸出流輸出日志信息# verbose = 1 為輸出進度條記錄# verbose = 2 為每個epoch輸出一行記錄# evaluate當中的verbose = 0 為不在標準輸出流輸出日志信息# verbose = 1 為輸出進度條記錄pred = model.predict(crop,verbose=2)#print (np.unique(pred)) pred = pred.reshape((256,256)).astype(np.uint8)#print ('pred:',pred.shape)mask_whole[i*stride:i*stride+image_size,j*stride:j*stride+image_size] = pred[:,:]# 再把圖像切割成跟原來一樣大小的圖像cv2.imwrite('./predict/pre'+str(n+1)+'.png',mask_whole[0:h,0:w])if __name__ == '__main__':args = args_parse()predict(args)分類結果集成ensemble.py
import numpy as np import cv2 import argparseRESULT_PREFIXX = ['./result1/','./result2/','./result3/']# each mask has 5 classes: 0~4def vote_per_image(image_id):result_list = []for j in range(len(RESULT_PREFIXX)):im = cv2.imread(RESULT_PREFIXX[j]+str(image_id)+'.png',0)result_list.append(im)# each pixelheight,width = result_list[0].shapevote_mask = np.zeros((height,width))for h in range(height):for w in range(width):# 像素級別# 每個像素的所屬的類別,總共5類,因此類別list是一個1x5的recordrecord = np.zeros((1,5))# 下面這個for循環是每個像素的類別級別for n in range(len(result_list)):#對于每一類結果中的每一張圖片的每一個像素,統計這個位置的類別票數mask = result_list[n]pixel = mask[h,w]#print('pix:',pixel)record[0,pixel]+=1# 集成學習,取票數最多的為最終類別label = record.argmax()#print(label)vote_mask[h,w] = labelcv2.imwrite('vote_mask'+str(image_id)+'.png',vote_mask) # 總共3類結果 vote_per_image(3)SegNet模型訓練segnet_train.py
#coding=utf-8 import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import argparse import numpy as np from keras.models import Sequential from keras.layers import Conv2D,MaxPooling2D,UpSampling2D,BatchNormalization,Reshape,Permute,Activation from keras.utils.np_utils import to_categorical from keras.preprocessing.image import img_to_array from keras.callbacks import ModelCheckpoint from sklearn.preprocessing import LabelEncoder from PIL import Image import matplotlib.pyplot as plt import cv2 import random import os from tqdm import tqdm os.environ["CUDA_VISIBLE_DEVICES"] = "1" seed = 7 np.random.seed(seed) #data_shape = 360*480 img_w = 256 img_h = 256 #有一個為背景 n_label = 4+1 classes = [0. , 1., 2., 3. , 4.] labelencoder = LabelEncoder() labelencoder.fit(classes) image_sets = ['1.png','2.png','3.png']def load_img(path, grayscale=False):if grayscale:img = cv2.imread(path,cv2.IMREAD_GRAYSCALE)else:img = cv2.imread(path)img = np.array(img,dtype="float") / 255.0return imgfilepath ='./train/' def get_train_val(val_rate = 0.25):train_url = [] train_set = []val_set = []for pic in os.listdir(filepath + 'src'):train_url.append(pic)random.shuffle(train_url)total_num = len(train_url)val_num = int(val_rate * total_num)for i in range(len(train_url)):if i < val_num:val_set.append(train_url[i]) else:train_set.append(train_url[i])return train_set,val_set# data for training def generateData(batch_size,data=[]): #print 'generateData...'while True: train_data = [] train_label = [] batch = 0 for i in (range(len(data))): url = data[i]batch += 1 img = load_img(filepath + 'src/' + url)img = img_to_array(img) train_data.append(img) label = load_img(filepath + 'label/' + url, grayscale=True)label = img_to_array(label).reshape((img_w * img_h,)) # print label.shape train_label.append(label) if batch % batch_size==0: #print 'get enough bacth!\n'train_data = np.array(train_data) train_label = np.array(train_label).flatten() train_label = labelencoder.transform(train_label) train_label = to_categorical(train_label, num_classes=n_label) train_label = train_label.reshape((batch_size,img_w * img_h,n_label)) yield (train_data,train_label) train_data = [] train_label = [] batch = 0 # data for validation def generateValidData(batch_size,data=[]): #print 'generateValidData...'while True: valid_data = [] valid_label = [] batch = 0 for i in (range(len(data))): url = data[i]batch += 1 img = load_img(filepath + 'src/' + url)img = img_to_array(img) valid_data.append(img) label = load_img(filepath + 'label/' + url, grayscale=True)label = img_to_array(label).reshape((img_w * img_h,)) # print label.shape valid_label.append(label) if batch % batch_size==0: valid_data = np.array(valid_data) valid_label = np.array(valid_label).flatten() valid_label = labelencoder.transform(valid_label) valid_label = to_categorical(valid_label, num_classes=n_label) valid_label = valid_label.reshape((batch_size,img_w * img_h,n_label)) yield (valid_data,valid_label) valid_data = [] valid_label = [] batch = 0 def SegNet(): model = Sequential() #encoder model.add(Conv2D(64,(3,3),strides=(1,1),input_shape=(3,img_w,img_h),padding='same',activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(64,(3,3),strides=(1,1),padding='same',activation='relu')) model.add(BatchNormalization()) model.add(MaxPooling2D(pool_size=(2,2),dim_ordering = 'th')) #(128,128) model.add(Conv2D(128, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(128, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(MaxPooling2D(pool_size=(2, 2),dim_ordering = 'th')) #(64,64) model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(MaxPooling2D(pool_size=(2, 2),dim_ordering = 'th')) #(32,32) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(MaxPooling2D(pool_size=(2, 2),dim_ordering = 'th')) #(16,16) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(MaxPooling2D(pool_size=(2, 2),dim_ordering = 'th')) #(8,8) #decoder model.add(UpSampling2D(size=(2,2))) #(16,16) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(UpSampling2D(size=(2, 2))) #(32,32) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(UpSampling2D(size=(2, 2))) #(64,64) model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(UpSampling2D(size=(2, 2))) #(128,128) model.add(Conv2D(128, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(128, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(UpSampling2D(size=(2, 2))) #(256,256) model.add(Conv2D(64, (3, 3), strides=(1, 1), input_shape=(3,img_w, img_h), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(64, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(n_label, (1, 1), strides=(1, 1), padding='same')) model.add(Reshape((n_label,img_w*img_h))) #axis=1和axis=2互換位置,等同于np.swapaxes(layer,1,2) model.add(Permute((2,1))) model.add(Activation('softmax')) model.compile(loss='categorical_crossentropy',optimizer='sgd',metrics=['accuracy']) return model def train(args): EPOCHS = 30BS = 16model = SegNet() modelcheck = ModelCheckpoint(args['model'],monitor='val_acc',save_best_only=True,mode='max') callable = [modelcheck] train_set,val_set = get_train_val()train_numb = len(train_set) valid_numb = len(val_set) print ("the number of train data is",train_numb) print ("the number of val data is",valid_numb)H = model.fit_generator(generator=generateData(BS,train_set),steps_per_epoch=train_numb//BS,epochs=EPOCHS,verbose=1, validation_data=generateValidData(BS,val_set),validation_steps=valid_numb//BS,callbacks=callable,max_q_size=1) # plot the training loss and accuracyplt.style.use("ggplot")plt.figure()N = EPOCHSplt.plot(np.arange(0, N), H.history["loss"], label="train_loss")plt.plot(np.arange(0, N), H.history["val_loss"], label="val_loss")plt.plot(np.arange(0, N), H.history["acc"], label="train_acc")plt.plot(np.arange(0, N), H.history["val_acc"], label="val_acc")plt.title("Training Loss and Accuracy on SegNet Satellite Seg")plt.xlabel("Epoch #")plt.ylabel("Loss/Accuracy")plt.legend(loc="lower left")plt.savefig(args["plot"])def args_parse():# construct the argument parse and parse the argumentsap = argparse.ArgumentParser()ap.add_argument("-a", "--augment", help="using data augment or not",action="store_true", default=False)ap.add_argument("-m", "--model", required=True,help="path to output model")ap.add_argument("-p", "--plot", type=str, default="plot.png",help="path to output accuracy/loss plot")args = vars(ap.parse_args()) return argsif __name__=='__main__': args = args_parse()if args['augment'] == True:filepath ='./aug/train/'train(args) #predict()同理,為了搞清楚segnet每一層的輸入輸出的tensor分別是什么樣的,我們將shape打印出來如下:
_________________________________________________________________ Layer (type) Output Shape Param # ================================================================= conv2d_98 (Conv2D) (None, 64, 256, 256) 1792 _________________________________________________________________ batch_normalization_1 (Batch (None, 64, 256, 256) 1024 _________________________________________________________________ conv2d_99 (Conv2D) (None, 64, 256, 256) 36928 _________________________________________________________________ batch_normalization_2 (Batch (None, 64, 256, 256) 1024 _________________________________________________________________ max_pooling2d_33 (MaxPooling (None, 64, 128, 128) 0 _________________________________________________________________ conv2d_100 (Conv2D) (None, 128, 128, 128) 73856 _________________________________________________________________ batch_normalization_3 (Batch (None, 128, 128, 128) 512 _________________________________________________________________ conv2d_101 (Conv2D) (None, 128, 128, 128) 147584 _________________________________________________________________ batch_normalization_4 (Batch (None, 128, 128, 128) 512 _________________________________________________________________ max_pooling2d_34 (MaxPooling (None, 128, 64, 64) 0 _________________________________________________________________ conv2d_102 (Conv2D) (None, 256, 64, 64) 295168 _________________________________________________________________ batch_normalization_5 (Batch (None, 256, 64, 64) 256 _________________________________________________________________ conv2d_103 (Conv2D) (None, 256, 64, 64) 590080 _________________________________________________________________ batch_normalization_6 (Batch (None, 256, 64, 64) 256 _________________________________________________________________ conv2d_104 (Conv2D) (None, 256, 64, 64) 590080 _________________________________________________________________ batch_normalization_7 (Batch (None, 256, 64, 64) 256 _________________________________________________________________ max_pooling2d_35 (MaxPooling (None, 256, 32, 32) 0 _________________________________________________________________ conv2d_105 (Conv2D) (None, 512, 32, 32) 1180160 _________________________________________________________________ batch_normalization_8 (Batch (None, 512, 32, 32) 128 _________________________________________________________________ conv2d_106 (Conv2D) (None, 512, 32, 32) 2359808 _________________________________________________________________ batch_normalization_9 (Batch (None, 512, 32, 32) 128 _________________________________________________________________ conv2d_107 (Conv2D) (None, 512, 32, 32) 2359808 _________________________________________________________________ batch_normalization_10 (Batc (None, 512, 32, 32) 128 _________________________________________________________________ max_pooling2d_36 (MaxPooling (None, 512, 16, 16) 0 _________________________________________________________________ conv2d_108 (Conv2D) (None, 512, 16, 16) 2359808 _________________________________________________________________ batch_normalization_11 (Batc (None, 512, 16, 16) 64 _________________________________________________________________ conv2d_109 (Conv2D) (None, 512, 16, 16) 2359808 _________________________________________________________________ batch_normalization_12 (Batc (None, 512, 16, 16) 64 _________________________________________________________________ conv2d_110 (Conv2D) (None, 512, 16, 16) 2359808 _________________________________________________________________ batch_normalization_13 (Batc (None, 512, 16, 16) 64 _________________________________________________________________ max_pooling2d_37 (MaxPooling (None, 512, 8, 8) 0 _________________________________________________________________ up_sampling2d_17 (UpSampling (None, 512, 16, 16) 0 _________________________________________________________________ conv2d_111 (Conv2D) (None, 512, 16, 16) 2359808 _________________________________________________________________ batch_normalization_14 (Batc (None, 512, 16, 16) 64 _________________________________________________________________ conv2d_112 (Conv2D) (None, 512, 16, 16) 2359808 _________________________________________________________________ batch_normalization_15 (Batc (None, 512, 16, 16) 64 _________________________________________________________________ conv2d_113 (Conv2D) (None, 512, 16, 16) 2359808 _________________________________________________________________ batch_normalization_16 (Batc (None, 512, 16, 16) 64 _________________________________________________________________ up_sampling2d_18 (UpSampling (None, 512, 32, 32) 0 _________________________________________________________________ conv2d_114 (Conv2D) (None, 512, 32, 32) 2359808 _________________________________________________________________ batch_normalization_17 (Batc (None, 512, 32, 32) 128 _________________________________________________________________ conv2d_115 (Conv2D) (None, 512, 32, 32) 2359808 _________________________________________________________________ batch_normalization_18 (Batc (None, 512, 32, 32) 128 _________________________________________________________________ conv2d_116 (Conv2D) (None, 512, 32, 32) 2359808 _________________________________________________________________ batch_normalization_19 (Batc (None, 512, 32, 32) 128 _________________________________________________________________ up_sampling2d_19 (UpSampling (None, 512, 64, 64) 0 _________________________________________________________________ conv2d_117 (Conv2D) (None, 256, 64, 64) 1179904 _________________________________________________________________ batch_normalization_20 (Batc (None, 256, 64, 64) 256 _________________________________________________________________ conv2d_118 (Conv2D) (None, 256, 64, 64) 590080 _________________________________________________________________ batch_normalization_21 (Batc (None, 256, 64, 64) 256 _________________________________________________________________ conv2d_119 (Conv2D) (None, 256, 64, 64) 590080 _________________________________________________________________ batch_normalization_22 (Batc (None, 256, 64, 64) 256 _________________________________________________________________ up_sampling2d_20 (UpSampling (None, 256, 128, 128) 0 _________________________________________________________________ conv2d_120 (Conv2D) (None, 128, 128, 128) 295040 _________________________________________________________________ batch_normalization_23 (Batc (None, 128, 128, 128) 512 _________________________________________________________________ conv2d_121 (Conv2D) (None, 128, 128, 128) 147584 _________________________________________________________________ batch_normalization_24 (Batc (None, 128, 128, 128) 512 _________________________________________________________________ up_sampling2d_21 (UpSampling (None, 128, 256, 256) 0 _________________________________________________________________ conv2d_122 (Conv2D) (None, 64, 256, 256) 73792 _________________________________________________________________ batch_normalization_25 (Batc (None, 64, 256, 256) 1024 _________________________________________________________________ conv2d_123 (Conv2D) (None, 64, 256, 256) 36928 _________________________________________________________________ batch_normalization_26 (Batc (None, 64, 256, 256) 1024 _________________________________________________________________ conv2d_124 (Conv2D) (None, 1, 256, 256) 65 _________________________________________________________________ reshape_1 (Reshape) (None, 1, 65536) 0 _________________________________________________________________ permute_1 (Permute) (None, 65536, 1) 0 _________________________________________________________________ activation_1 (Activation) (None, 65536, 1) 0 ================================================================= Total params: 31,795,841 Trainable params: 31,791,425 Non-trainable params: 4,416 _________________________________________________________________總結
以上是生活随笔為你收集整理的基于SegNet和UNet的遥感图像分割代码解读的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: pip升级之后出现no module n
- 下一篇: 人群频率 | gnomAD数据库简介 (