【Keras】基于SegNet和U-Net的遥感图像语义分割
from:【Keras】基于SegNet和U-Net的遙感圖像語義分割
上兩個月參加了個比賽,做的是對遙感高清圖像做語義分割,美其名曰“天空之眼”。這兩周數(shù)據(jù)挖掘課期末project我們組選的課題也是遙感圖像的語義分割,所以剛好又把前段時間做的成果重新整理和加強(qiáng)了一下,故寫了這篇文章,記錄一下用深度學(xué)習(xí)做遙感圖像語義分割的完整流程以及一些好的思路和技巧。
?
數(shù)據(jù)集
首先介紹一下數(shù)據(jù),我們這次采用的數(shù)據(jù)集是CCF大數(shù)據(jù)比賽提供的數(shù)據(jù)(2015年中國南方某城市的高清遙感圖像),這是一個小數(shù)據(jù)集,里面包含了5張帶標(biāo)注的大尺寸RGB遙感圖像(尺寸范圍從3000×3000到6000×6000),里面一共標(biāo)注了4類物體,植被(標(biāo)記1)、建筑(標(biāo)記2)、水體(標(biāo)記3)、道路(標(biāo)記4)以及其他(標(biāo)記0)。其中,耕地、林地、草地均歸為植被類,為了更好地觀察標(biāo)注情況,我們將其中三幅訓(xùn)練圖片可視化如下:藍(lán)色-水體,黃色-房屋,綠色-植被,棕色-馬路。更多數(shù)據(jù)介紹可以參看這里。
現(xiàn)在說一說我們的數(shù)據(jù)處理的步驟。我們現(xiàn)在擁有的是5張大尺寸的遙感圖像,我們不能直接把這些圖像送入網(wǎng)絡(luò)進(jìn)行訓(xùn)練,因?yàn)閮?nèi)存承受不了而且他們的尺寸也各不相同。因此,我們首先將他們做隨機(jī)切割,即隨機(jī)生成x,y坐標(biāo),然后摳出該坐標(biāo)下256*256的小圖,并做以下數(shù)據(jù)增強(qiáng)操作:
這里我沒有采用Keras自帶的數(shù)據(jù)增廣函數(shù),而是自己使用opencv編寫了相應(yīng)的增強(qiáng)函數(shù)。
img_w = 256 img_h = 256 image_sets = ['1.png','2.png','3.png','4.png','5.png']def gamma_transform(img, gamma):gamma_table = [np.power(x / 255.0, gamma) * 255.0 for x in range(256)]gamma_table = np.round(np.array(gamma_table)).astype(np.uint8)return cv2.LUT(img, gamma_table)def random_gamma_transform(img, gamma_vari):log_gamma_vari = np.log(gamma_vari)alpha = np.random.uniform(-log_gamma_vari, log_gamma_vari)gamma = np.exp(alpha)return gamma_transform(img, gamma)def rotate(xb,yb,angle):M_rotate = cv2.getRotationMatrix2D((img_w/2, img_h/2), angle, 1)xb = cv2.warpAffine(xb, M_rotate, (img_w, img_h))yb = cv2.warpAffine(yb, M_rotate, (img_w, img_h))return xb,ybdef blur(img):img = cv2.blur(img, (3, 3));return imgdef add_noise(img):for i in range(200): #添加點(diǎn)噪聲temp_x = np.random.randint(0,img.shape[0])temp_y = np.random.randint(0,img.shape[1])img[temp_x][temp_y] = 255return imgdef data_augment(xb,yb):if np.random.random() < 0.25:xb,yb = rotate(xb,yb,90)if np.random.random() < 0.25:xb,yb = rotate(xb,yb,180)if np.random.random() < 0.25:xb,yb = rotate(xb,yb,270)if np.random.random() < 0.25:xb = cv2.flip(xb, 1) # flipcode > 0:沿y軸翻轉(zhuǎn)yb = cv2.flip(yb, 1)if np.random.random() < 0.25:xb = random_gamma_transform(xb,1.0)if np.random.random() < 0.25:xb = blur(xb)if np.random.random() < 0.2:xb = add_noise(xb)return xb,ybdef creat_dataset(image_num = 100000, mode = 'original'):print('creating dataset...')image_each = image_num / len(image_sets)g_count = 0for i in tqdm(range(len(image_sets))):count = 0src_img = cv2.imread('./data/src/' + image_sets[i]) # 3 channelslabel_img = cv2.imread('./data/label/' + image_sets[i],cv2.IMREAD_GRAYSCALE) # single channelX_height,X_width,_ = src_img.shapewhile count < image_each:random_width = random.randint(0, X_width - img_w - 1)random_height = random.randint(0, X_height - img_h - 1)src_roi = src_img[random_height: random_height + img_h, random_width: random_width + img_w,:]label_roi = label_img[random_height: random_height + img_h, random_width: random_width + img_w]if mode == 'augment':src_roi,label_roi = data_augment(src_roi,label_roi)visualize = np.zeros((256,256)).astype(np.uint8)visualize = label_roi *50cv2.imwrite(('./aug/train/visualize/%d.png' % g_count),visualize)cv2.imwrite(('./aug/train/src/%d.png' % g_count),src_roi)cv2.imwrite(('./aug/train/label/%d.png' % g_count),label_roi)count += 1 g_count += 1經(jīng)過上面數(shù)據(jù)增強(qiáng)操作后,我們得到了較大的訓(xùn)練集:100000張256*256的圖片。
卷積神經(jīng)網(wǎng)絡(luò)
面對這類圖像語義分割的任務(wù),我們可以選取的經(jīng)典網(wǎng)絡(luò)有很多,比如FCN,U-Net,SegNet,DeepLab,RefineNet,Mask Rcnn,Hed Net這些都是非常經(jīng)典而且在很多比賽都廣泛采用的網(wǎng)絡(luò)架構(gòu)。所以我們就可以從中選取一兩個經(jīng)典網(wǎng)絡(luò)作為我們這個分割任務(wù)的解決方案。我們根據(jù)我們小組的情況,選取了U-Net和SegNet作為我們的主體網(wǎng)絡(luò)進(jìn)行實(shí)驗(yàn)。
SegNet
SegNet已經(jīng)出來好幾年了,這不是一個最新、效果最好的語義分割網(wǎng)絡(luò),但是它勝在網(wǎng)絡(luò)結(jié)構(gòu)清晰易懂,訓(xùn)練快速坑少,所以我們也采取它來做同樣的任務(wù)。SegNet網(wǎng)絡(luò)結(jié)構(gòu)是編碼器-解碼器的結(jié)構(gòu),非常優(yōu)雅,值得注意的是,SegNet做語義分割時通常在末端加入CRF模塊做后處理,旨在進(jìn)一步精修邊緣的分割結(jié)果。有興趣深究的可以看看這里
現(xiàn)在講解代碼部分,首先我們先定義好SegNet的網(wǎng)絡(luò)結(jié)構(gòu)。
def SegNet(): model = Sequential() #encoder model.add(Conv2D(64,(3,3),strides=(1,1),input_shape=(3,img_w,img_h),padding='same',activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(64,(3,3),strides=(1,1),padding='same',activation='relu')) model.add(BatchNormalization()) model.add(MaxPooling2D(pool_size=(2,2))) #(128,128) model.add(Conv2D(128, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(128, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(MaxPooling2D(pool_size=(2, 2))) #(64,64) model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(MaxPooling2D(pool_size=(2, 2))) #(32,32) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(MaxPooling2D(pool_size=(2, 2))) #(16,16) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(MaxPooling2D(pool_size=(2, 2))) #(8,8) #decoder model.add(UpSampling2D(size=(2,2))) #(16,16) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(UpSampling2D(size=(2, 2))) #(32,32) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(UpSampling2D(size=(2, 2))) #(64,64) model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(UpSampling2D(size=(2, 2))) #(128,128) model.add(Conv2D(128, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(128, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(UpSampling2D(size=(2, 2))) #(256,256) model.add(Conv2D(64, (3, 3), strides=(1, 1), input_shape=(3,img_w, img_h), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(64, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(n_label, (1, 1), strides=(1, 1), padding='same')) model.add(Reshape((n_label,img_w*img_h))) #axis=1和axis=2互換位置,等同于np.swapaxes(layer,1,2) model.add(Permute((2,1))) model.add(Activation('softmax')) model.compile(loss='categorical_crossentropy',optimizer='sgd',metrics=['accuracy']) model.summary() return model然后需要讀入數(shù)據(jù)集。這里我們選擇的驗(yàn)證集大小是訓(xùn)練集的0.25。
def get_train_val(val_rate = 0.25):train_url = [] train_set = []val_set = []for pic in os.listdir(filepath + 'src'):train_url.append(pic)random.shuffle(train_url)total_num = len(train_url)val_num = int(val_rate * total_num)for i in range(len(train_url)):if i < val_num:val_set.append(train_url[i]) else:train_set.append(train_url[i])return train_set,val_set# data for training def generateData(batch_size,data=[]): #print 'generateData...'while True: train_data = [] train_label = [] batch = 0 for i in (range(len(data))): url = data[i]batch += 1 #print (filepath + 'src/' + url)#img = load_img(filepath + 'src/' + url, target_size=(img_w, img_h)) img = load_img(filepath + 'src/' + url)img = img_to_array(img) # print img# print img.shape train_data.append(img) #label = load_img(filepath + 'label/' + url, target_size=(img_w, img_h),grayscale=True)label = load_img(filepath + 'label/' + url, grayscale=True)label = img_to_array(label).reshape((img_w * img_h,)) # print label.shape train_label.append(label) if batch % batch_size==0: #print 'get enough bacth!\n'train_data = np.array(train_data) train_label = np.array(train_label).flatten() train_label = labelencoder.transform(train_label) train_label = to_categorical(train_label, num_classes=n_label) train_label = train_label.reshape((batch_size,img_w * img_h,n_label)) yield (train_data,train_label) train_data = [] train_label = [] batch = 0 # data for validation def generateValidData(batch_size,data=[]): #print 'generateValidData...'while True: valid_data = [] valid_label = [] batch = 0 for i in (range(len(data))): url = data[i]batch += 1 #img = load_img(filepath + 'src/' + url, target_size=(img_w, img_h))img = load_img(filepath + 'src/' + url)#print img#print (filepath + 'src/' + url)img = img_to_array(img) # print img.shape valid_data.append(img) #label = load_img(filepath + 'label/' + url, target_size=(img_w, img_h),grayscale=True)label = load_img(filepath + 'label/' + url, grayscale=True)label = img_to_array(label).reshape((img_w * img_h,)) # print label.shape valid_label.append(label) if batch % batch_size==0: valid_data = np.array(valid_data) valid_label = np.array(valid_label).flatten() valid_label = labelencoder.transform(valid_label) valid_label = to_categorical(valid_label, num_classes=n_label) valid_label = valid_label.reshape((batch_size,img_w * img_h,n_label)) yield (valid_data,valid_label) valid_data = [] valid_label = [] batch = 0然后定義一下我們訓(xùn)練的過程,在這個任務(wù)上,我們把batch size定為16,epoch定為30,每次都存儲最佳model(save_best_only=True),并且在訓(xùn)練結(jié)束時繪制loss/acc曲線,并存儲起來。
def train(args): EPOCHS = 30BS = 16model = SegNet() modelcheck = ModelCheckpoint(args['model'],monitor='val_acc',save_best_only=True,mode='max') callable = [modelcheck] train_set,val_set = get_train_val()train_numb = len(train_set) valid_numb = len(val_set) print ("the number of train data is",train_numb) print ("the number of val data is",valid_numb)H = model.fit_generator(generator=generateData(BS,train_set),steps_per_epoch=train_numb//BS,epochs=EPOCHS,verbose=1, validation_data=generateValidData(BS,val_set),validation_steps=valid_numb//BS,callbacks=callable,max_q_size=1) # plot the training loss and accuracyplt.style.use("ggplot")plt.figure()N = EPOCHSplt.plot(np.arange(0, N), H.history["loss"], label="train_loss")plt.plot(np.arange(0, N), H.history["val_loss"], label="val_loss")plt.plot(np.arange(0, N), H.history["acc"], label="train_acc")plt.plot(np.arange(0, N), H.history["val_acc"], label="val_acc")plt.title("Training Loss and Accuracy on SegNet Satellite Seg")plt.xlabel("Epoch #")plt.ylabel("Loss/Accuracy")plt.legend(loc="lower left")plt.savefig(args["plot"])然后開始漫長的訓(xùn)練,訓(xùn)練時間接近3天,繪制出的loss/acc圖如下:
訓(xùn)練loss降到0.1左右,acc可以去到0.9,但是驗(yàn)證集的loss和acc都沒那么好,貌似存在點(diǎn)問題。
先不管了,先看看預(yù)測結(jié)果吧。
這里需要思考一下怎么預(yù)測整張遙感圖像。我們知道,我們訓(xùn)練模型時選擇的圖片輸入是256×256,所以我們預(yù)測時也要采用256×256的圖片尺寸送進(jìn)模型預(yù)測。現(xiàn)在我們要考慮一個問題,我們該怎么將這些預(yù)測好的小圖重新拼接成一個大圖呢?這里給出一個最基礎(chǔ)的方案:先給大圖做padding 0操作,得到一副padding過的大圖,同時我們也生成一個與該圖一樣大的全0圖A,把圖像的尺寸補(bǔ)齊為256的倍數(shù),然后以256為步長切割大圖,依次將小圖送進(jìn)模型預(yù)測,預(yù)測好的小圖則放在A的相應(yīng)位置上,依次進(jìn)行,最終得到預(yù)測好的整張大圖(即A),再做圖像切割,切割成原先圖片的尺寸,完成整個預(yù)測流程。
def predict(args):# load the trained convolutional neural networkprint("[INFO] loading network...")model = load_model(args["model"])stride = args['stride']for n in range(len(TEST_SET)):path = TEST_SET[n]#load the imageimage = cv2.imread('./test/' + path)# pre-process the image for classification#image = image.astype("float") / 255.0#image = img_to_array(image)h,w,_ = image.shapepadding_h = (h//stride + 1) * stride padding_w = (w//stride + 1) * stridepadding_img = np.zeros((padding_h,padding_w,3),dtype=np.uint8)padding_img[0:h,0:w,:] = image[:,:,:]padding_img = padding_img.astype("float") / 255.0padding_img = img_to_array(padding_img)print 'src:',padding_img.shapemask_whole = np.zeros((padding_h,padding_w),dtype=np.uint8)for i in range(padding_h//stride):for j in range(padding_w//stride):crop = padding_img[:3,i*stride:i*stride+image_size,j*stride:j*stride+image_size]_,ch,cw = crop.shapeif ch != 256 or cw != 256:print 'invalid size!'continuecrop = np.expand_dims(crop, axis=0)#print 'crop:',crop.shapepred = model.predict_classes(crop,verbose=2) pred = labelencoder.inverse_transform(pred[0]) #print (np.unique(pred)) pred = pred.reshape((256,256)).astype(np.uint8)#print 'pred:',pred.shapemask_whole[i*stride:i*stride+image_size,j*stride:j*stride+image_size] = pred[:,:]cv2.imwrite('./predict/pre'+str(n+1)+'.png',mask_whole[0:h,0:w])預(yù)測的效果圖如下:
一眼看去,效果真的不錯,但是仔細(xì)看一下,就會發(fā)現(xiàn)有個很大的問題:拼接痕跡過于明顯了!那怎么解決這類邊緣問題呢?很直接的想法就是縮小切割時的滑動步伐,比如我們把切割步伐改為128,那么拼接時就會有一般的圖像發(fā)生重疊,這樣做可以盡可能地減少拼接痕跡。
U-Net
對于這個語義分割任務(wù),我們毫不猶豫地選擇了U-Net作為我們的方案,原因很簡單,我們參考很多類似的遙感圖像分割比賽的資料,絕大多數(shù)獲獎的選手使用的都是U-Net模型。在這么多的好評下,我們選擇U-Net也就毫無疑問了。
U-Net有很多優(yōu)點(diǎn),最大賣點(diǎn)就是它可以在小數(shù)據(jù)集上也能train出一個好的模型,這個優(yōu)點(diǎn)對于我們這個任務(wù)來說真的非常適合。而且,U-Net在訓(xùn)練速度上也是非常快的,這對于需要短時間就得出結(jié)果的期末project來說也是非常合適。U-Net在網(wǎng)絡(luò)架構(gòu)上還是非常優(yōu)雅的,整個呈現(xiàn)U形,故起名U-Net。這里不打算詳細(xì)介紹U-Net結(jié)構(gòu),有興趣的深究的可以看看論文。
現(xiàn)在開始談?wù)劥a細(xì)節(jié)。首先我們定義一下U-Net的網(wǎng)絡(luò)結(jié)構(gòu),這里用的deep learning框架還是Keras。
注意到,我們這里訓(xùn)練的模型是一個多分類模型,其實(shí)更好的做法是,訓(xùn)練一個二分類模型(使用二分類的標(biāo)簽),對每一類物體進(jìn)行預(yù)測,得到4張預(yù)測圖,再做預(yù)測圖疊加,合并成一張完整的包含4類的預(yù)測圖,這個策略在效果上肯定好于一個直接4分類的模型。所以,U-Net這邊我們采取的思路就是對于每一類的分類都訓(xùn)練一個二分類模型,最后再將每一類的預(yù)測結(jié)果組合成一個四分類的結(jié)果。
定義U-Net結(jié)構(gòu),注意了,這里的loss function我們選了binary_crossentropy,因?yàn)槲覀円?xùn)練的是二分類模型。
def unet():inputs = Input((3, img_w, img_h))conv1 = Conv2D(32, (3, 3), activation="relu", padding="same")(inputs)conv1 = Conv2D(32, (3, 3), activation="relu", padding="same")(conv1)pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)conv2 = Conv2D(64, (3, 3), activation="relu", padding="same")(pool1)conv2 = Conv2D(64, (3, 3), activation="relu", padding="same")(conv2)pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)conv3 = Conv2D(128, (3, 3), activation="relu", padding="same")(pool2)conv3 = Conv2D(128, (3, 3), activation="relu", padding="same")(conv3)pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)conv4 = Conv2D(256, (3, 3), activation="relu", padding="same")(pool3)conv4 = Conv2D(256, (3, 3), activation="relu", padding="same")(conv4)pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)conv5 = Conv2D(512, (3, 3), activation="relu", padding="same")(pool4)conv5 = Conv2D(512, (3, 3), activation="relu", padding="same")(conv5)up6 = concatenate([UpSampling2D(size=(2, 2))(conv5), conv4], axis=1)conv6 = Conv2D(256, (3, 3), activation="relu", padding="same")(up6)conv6 = Conv2D(256, (3, 3), activation="relu", padding="same")(conv6)up7 = concatenate([UpSampling2D(size=(2, 2))(conv6), conv3], axis=1)conv7 = Conv2D(128, (3, 3), activation="relu", padding="same")(up7)conv7 = Conv2D(128, (3, 3), activation="relu", padding="same")(conv7)up8 = concatenate([UpSampling2D(size=(2, 2))(conv7), conv2], axis=1)conv8 = Conv2D(64, (3, 3), activation="relu", padding="same")(up8)conv8 = Conv2D(64, (3, 3), activation="relu", padding="same")(conv8)up9 = concatenate([UpSampling2D(size=(2, 2))(conv8), conv1], axis=1)conv9 = Conv2D(32, (3, 3), activation="relu", padding="same")(up9)conv9 = Conv2D(32, (3, 3), activation="relu", padding="same")(conv9)conv10 = Conv2D(n_label, (1, 1), activation="sigmoid")(conv9)#conv10 = Conv2D(n_label, (1, 1), activation="softmax")(conv9)model = Model(inputs=inputs, outputs=conv10)model.compile(optimizer='Adam', loss='binary_crossentropy', metrics=['accuracy'])return model讀取數(shù)據(jù)的組織方式有一些改動。
# data for training def generateData(batch_size,data=[]): #print 'generateData...'while True: train_data = [] train_label = [] batch = 0 for i in (range(len(data))): url = data[i]batch += 1 img = load_img(filepath + 'src/' + url)img = img_to_array(img) train_data.append(img) label = load_img(filepath + 'label/' + url, grayscale=True) label = img_to_array(label)#print label.shape train_label.append(label) if batch % batch_size==0: #print 'get enough bacth!\n'train_data = np.array(train_data) train_label = np.array(train_label) yield (train_data,train_label) train_data = [] train_label = [] batch = 0 # data for validation def generateValidData(batch_size,data=[]): #print 'generateValidData...'while True: valid_data = [] valid_label = [] batch = 0 for i in (range(len(data))): url = data[i]batch += 1 img = load_img(filepath + 'src/' + url)#print imgimg = img_to_array(img) # print img.shape valid_data.append(img) label = load_img(filepath + 'label/' + url, grayscale=True)valid_label.append(label) if batch % batch_size==0: valid_data = np.array(valid_data) valid_label = np.array(valid_label) yield (valid_data,valid_label) valid_data = [] valid_label = [] batch = 0訓(xùn)練:指定輸出model名字和訓(xùn)練集位置
python unet.py --model unet_buildings20.h5 --data ./unet_train/buildings/預(yù)測單張遙感圖像時我們分別使用4個模型做預(yù)測,那我們就會得到4張mask(比如下圖就是我們用訓(xùn)練好的buildings模型預(yù)測的結(jié)果),我們現(xiàn)在要將這4張mask合并成1張,那么怎么合并會比較好呢?我思路是,通過觀察每一類的預(yù)測結(jié)果,我們可以從直觀上知道哪些類的預(yù)測比較準(zhǔn)確,那么我們就可以給這些mask圖排優(yōu)先級了,比如:priority:building>water>road>vegetation,那么當(dāng)遇到一個像素點(diǎn),4個mask圖都說是屬于自己類別的標(biāo)簽時,我們就可以根據(jù)先前定義好的優(yōu)先級,把該像素的標(biāo)簽定為優(yōu)先級最高的標(biāo)簽。代碼思路可以參照下面的代碼:
def combind_all_mask():for mask_num in tqdm(range(3)):if mask_num == 0:final_mask = np.zeros((5142,5664),np.uint8)#生成一個全黑全0圖像,圖片尺寸與原圖相同elif mask_num == 1:final_mask = np.zeros((2470,4011),np.uint8)elif mask_num == 2:final_mask = np.zeros((6116,3356),np.uint8)#final_mask = cv2.imread('final_1_8bits_predict.png',0)if mask_num == 0:mask_pool = mask1_poolelif mask_num == 1:mask_pool = mask2_poolelif mask_num == 2:mask_pool = mask3_poolfinal_name = img_sets[mask_num]for idx,name in enumerate(mask_pool):img = cv2.imread('./predict_mask/'+name,0)height,width = img.shapelabel_value = idx+1 #coressponding labels valuefor i in tqdm(range(height)): #priority:building>water>road>vegetationfor j in range(width):if img[i,j] == 255:if label_value == 2:final_mask[i,j] = label_valueelif label_value == 3 and final_mask[i,j] != 2:final_mask[i,j] = label_valueelif label_value == 4 and final_mask[i,j] != 2 and final_mask[i,j] != 3:final_mask[i,j] = label_valueelif label_value == 1 and final_mask[i,j] == 0:final_mask[i,j] = label_valuecv2.imwrite('./final_result/'+final_name,final_mask) print 'combinding mask...' combind_all_mask()模型融合
集成學(xué)習(xí)的方法在這類比賽中經(jīng)常使用,要想獲得好成績集成學(xué)習(xí)必須做得好。在這里簡單談?wù)勊悸?#xff0c;我們使用了兩個模型,我們模型也會采取不同參數(shù)去訓(xùn)練和預(yù)測,那么我們就會得到很多預(yù)測MASK圖,此時 我們可以采取模型融合的思路,對每張結(jié)果圖的每個像素點(diǎn)采取投票表決的思路,對每張圖相應(yīng)位置的像素點(diǎn)的類別進(jìn)行預(yù)測,票數(shù)最多的類別即為該像素點(diǎn)的類別。正所謂“三個臭皮匠,勝過諸葛亮”,我們這種ensemble的思路,可以很好地去掉一些明顯分類錯誤的像素點(diǎn),很大程度上改善模型的預(yù)測能力。
少數(shù)服從多數(shù)的投票表決策略代碼:
import numpy as np import cv2 import argparseRESULT_PREFIXX = ['./result1/','./result2/','./result3/']# each mask has 5 classes: 0~4def vote_per_image(image_id):result_list = []for j in range(len(RESULT_PREFIXX)):im = cv2.imread(RESULT_PREFIXX[j]+str(image_id)+'.png',0)result_list.append(im)# each pixelheight,width = result_list[0].shapevote_mask = np.zeros((height,width))for h in range(height):for w in range(width):record = np.zeros((1,5))for n in range(len(result_list)):mask = result_list[n]pixel = mask[h,w]#print('pix:',pixel)record[0,pixel]+=1label = record.argmax()#print(label)vote_mask[h,w] = labelcv2.imwrite('vote_mask'+str(image_id)+'.png',vote_mask)vote_per_image(3)模型融合后的預(yù)測結(jié)果:
可以看出,模型融合后的預(yù)測效果確實(shí)有較大提升,明顯錯誤分類的像素點(diǎn)消失了。
額外的思路:GAN
我們對數(shù)據(jù)方面思考得更多一些,我們針對數(shù)據(jù)集小的問題,我們有個想法:使用生成對抗網(wǎng)絡(luò)去生成虛假的衛(wèi)星地圖,旨在進(jìn)一步擴(kuò)大數(shù)據(jù)集。我們的想法就是,使用這些虛假+真實(shí)的數(shù)據(jù)集去訓(xùn)練網(wǎng)絡(luò),網(wǎng)絡(luò)的泛化能力肯定有更大的提升。我們的想法是根據(jù)這篇論文(pix2pix)來展開的,這是一篇很有意思的論文,它主要講的是用圖像生成圖像的方法。里面提到了用標(biāo)注好的衛(wèi)星地圖生成虛假的衛(wèi)星地圖的想法,真的讓人耳目一新,我們也想根據(jù)該思路,生成屬于我們的虛假衛(wèi)星地圖數(shù)據(jù)集。 Map to Aerial的效果是多么的震撼。
但是我們自己實(shí)現(xiàn)起來的效果卻不容樂觀(如下圖所示,右面那幅就是我們生成的假圖),效果不好的原因有很多,標(biāo)注的問題最大,因?yàn)樯傻奶摷傩l(wèi)星地圖質(zhì)量不好,所以該想法以失敗告終,生成的假圖也沒有拿去做訓(xùn)練。但感覺思路還是可行的,如果給的標(biāo)注合適的話,還是可以生成非常像的虛假地圖。
總結(jié)
對于這類遙感圖像的語義分割,思路還有很多,最容易想到的思路就是,將各種語義分割經(jīng)典網(wǎng)絡(luò)都實(shí)現(xiàn)以下,看看哪個效果最好,再做模型融合,只要集成學(xué)習(xí)做得好,效果一般都會很不錯的。我們僅靠上面那個簡單思路(數(shù)據(jù)增強(qiáng),經(jīng)典模型搭建,集成學(xué)習(xí)),就已經(jīng)可以獲得比賽的TOP 5%了,當(dāng)然還有一些tricks可以使效果更進(jìn)一步提升,這里就不細(xì)說了,總的建模思路掌握就行。完整的代碼可以在我的github獲取。
?
數(shù)據(jù)下載:
鏈接:https://pan.baidu.com/s/1i6oMukH
密碼:yqj2
總結(jié)
以上是生活随笔為你收集整理的【Keras】基于SegNet和U-Net的遥感图像语义分割的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Python中浮点数精度处理
- 下一篇: Shell中判断字符串是否为数字的6种方