TensorFlow实现Unet遥感图像分割
生活随笔
收集整理的這篇文章主要介紹了
TensorFlow实现Unet遥感图像分割
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
Unet是一種U型網絡,分為左右兩部分卷積,左邊為下采樣提取高維特征,右邊為上采樣并與左側融合實現圖像分割。這里使用TensorFlow實現Unet網絡,實現對遙感影像的道路分割。
訓練數據:
標簽圖像:
?
Unet實現:
import tensorflow as tf import numpy as np import cv2 import glob import itertoolsclass UNet:def __init__(self,input_width,input_height,num_classes,train_images,train_instances,val_images,val_instances,epochs,lr,lr_decay,batch_size,save_path):self.input_width = input_widthself.input_height = input_heightself.num_classes = num_classesself.train_images = train_imagesself.train_instances = train_instancesself.val_images = val_imagesself.val_instances = val_instancesself.epochs = epochsself.lr = lrself.lr_decay = lr_decayself.batch_size = batch_sizeself.save_path = save_pathdef leftNetwork(self, inputs):x = tf.keras.layers.Conv2D(64, (3, 3), padding='valid', activation='relu')(inputs)o_1 = tf.keras.layers.Conv2D(64, (3, 3), padding='valid', activation='relu')(x)x = tf.keras.layers.MaxPooling2D(pool_size=(2,2), strides=(2, 2))(o_1)x = tf.keras.layers.Conv2D(128, (3, 3), padding='valid', activation='relu')(x)o_2 = tf.keras.layers.Conv2D(128, (3, 3), padding='valid', activation='relu')(x)x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(o_2)x = tf.keras.layers.Conv2D(256, (3, 3), padding='valid', activation='relu')(x)o_3 = tf.keras.layers.Conv2D(256, (3, 3), padding='valid', activation='relu')(x)x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(o_3)x = tf.keras.layers.Conv2D(512, (3, 3), padding='valid', activation='relu')(x)o_4 = tf.keras.layers.Conv2D(512, (3, 3), padding='valid', activation='relu')(x)x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(o_4)x = tf.keras.layers.Conv2D(1024, (3, 3), padding='valid', activation='relu')(x)o_5 = tf.keras.layers.Conv2D(1024, (3, 3), padding='valid', activation='relu')(x)return [o_1, o_2, o_3, o_4, o_5]def rightNetwork(self, inputs):c_1, c_2, c_3, c_4, o_5 = inputso_5 = tf.keras.layers.UpSampling2D((2, 2))(o_5)x = tf.keras.layers.concatenate([tf.keras.layers.Cropping2D(4)(c_4), o_5], axis=3)x = tf.keras.layers.Conv2D(512, (3, 3), padding='valid', activation='relu')(x)x = tf.keras.layers.Conv2D(512, (3, 3), padding='valid', activation='relu')(x)x = tf.keras.layers.UpSampling2D((2, 2))(x)x = tf.keras.layers.concatenate([tf.keras.layers.Cropping2D(16)(c_3), x], axis=3)x = tf.keras.layers.Conv2D(256, (3, 3), padding='valid', activation='relu')(x)x = tf.keras.layers.Conv2D(256, (3, 3), padding='valid', activation='relu')(x)x = tf.keras.layers.UpSampling2D((2, 2))(x)x = tf.keras.layers.concatenate([tf.keras.layers.Cropping2D(40)(c_2), x], axis=3)x = tf.keras.layers.Conv2D(128, (3, 3), padding='valid', activation='relu')(x)x = tf.keras.layers.Conv2D(128, (3, 3), padding='valid', activation='relu')(x)x = tf.keras.layers.UpSampling2D((2, 2))(x)x = tf.keras.layers.concatenate([tf.keras.layers.Cropping2D(88)(c_1), x], axis=3)x = tf.keras.layers.Conv2D(64, (3, 3), padding='valid', activation='relu')(x)x = tf.keras.layers.Conv2D(64, (3, 3), padding='valid', activation='relu')(x)x = tf.keras.layers.Conv2D(self.num_classes, (1, 1), padding='valid')(x)x = tf.keras.layers.Activation('softmax')(x)return xdef build_model(self):inputs = tf.keras.Input(shape=[self.input_height, self.input_width, 3])left_output = self.leftNetwork(inputs)right_output = self.rightNetwork(left_output)model = tf.keras.Model(inputs=inputs, outputs=right_output)return modeldef train(self):G_train = self.dataGenerator(model='training')G_eval = self.dataGenerator(model='validation')#model = self.build_model()model = tf.keras.models.load_model('model.h5')model.compile(optimizer=tf.keras.optimizers.Adam(self.lr, self.lr_decay),loss='categorical_crossentropy',metrics=['accuracy'])model.fit_generator(G_train, 5, validation_data=G_eval, validation_steps=5, epochs=self.epochs)model.save(self.save_path)def dataGenerator(self, model):if model == 'training':images = glob.glob(self.train_images + '*.jpg')images.sort()instances = glob.glob(self.train_instances + '*.png')instances.sort()zipped = itertools.cycle(zip(images, instances))while True:x_train = []y_train = []for _ in range(self.batch_size):img, seg = next(zipped)img = cv2.resize(cv2.imread(img, 1), (self.input_width, self.input_height)) / 255.0seg = tf.keras.utils.to_categorical(cv2.imread(seg, 0), self.num_classes)x_train.append(img)y_train.append(seg)yield np.array(x_train), np.array(y_train)if model == 'validation':images = glob.glob(self.val_images + '*.jpg')images.sort()instances = glob.glob(self.val_instances + '*.png')instances.sort()zipped = itertools.cycle(zip(images, instances))while True:x_eval = []y_eval = []for _ in range(self.batch_size):img, seg = next(zipped)img = cv2.resize(cv2.imread(img, 1), (self.input_width, self.input_height)) / 255.0seg = tf.keras.utils.to_categorical(cv2.imread(seg, 0), self.num_classes)x_eval.append(img)y_eval.append(seg)yield np.array(x_eval), np.array(y_eval)?訓練腳本:
unet = UNet(input_width=572,input_height=572,num_classes=2,train_images='./datasets/train/images/',train_instances='./datasets/train/instances/',val_images='./datasets/validation/images/',val_instances='./datasets/validation/instances/',epochs=100,lr=0.0001,lr_decay=0.00001,batch_size=100,save_path='model.h5' )unet.train()這里僅分割道路和背景,屬于二分類,輸出矩陣形狀為2*388*388,進行100輪訓練后保存模型進行推理驗證。
推理腳本:
import tensorflow as tf import numpy as np import matplotlib.pyplot as plt import cv2model = tf.keras.models.load_model('model.h5')img = '17.jpg' img = cv2.resize(cv2.imread(img), (572, 572)) / 255. img = np.expand_dims(img, 0) pred = model.predict(img) pred = np.argmax(pred[0], axis=-1) pred[pred == 1] = 255 cv2.imwrite('result.jpg', pred) plt.imshow(pred) plt.show()測試圖像:
推理結果:
將推理結果與原始圖像疊加顯示:?
import cv2img_path = '17.jpg' result_path = 'result.jpg' img = cv2.imread(img_path) height, width = img.shape[:2] result = cv2.imread(result_path) result = cv2.resize(result, (height, width), cv2.INTER_LINEAR) result = cv2.Canny(result, 0, 255) for i in range(height):for j in range(width):if result[i][j] == 255:img[i][j] = [0, 0, 255] cv2.imwrite('temp.jpg', result) cv2.imwrite('out.jpg', img)總結
以上是生活随笔為你收集整理的TensorFlow实现Unet遥感图像分割的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: TensorFlow模型转换h5转pb
- 下一篇: OpenCV图像处理基础操作汇总