机器学习笔记 - 使用Keras + Unet 进行图像分割
一、U-Net簡介
?????????U-Net 是最初為醫學影像分割而提出的一種語義分割技術。 它是較早的深度學習分割模型之一,U-Net 架構也用于許多 GAN 變體,例如 Pix2Pix 生成器。
????????U-Net 在論文 U-Net: Convolutional Networks for Biomedical Image Segmentation 中進行了介紹。 模型架構相當簡單:一個編碼器(用于下采樣)和一個解碼器(用于上采樣),帶有跳躍連接。 如圖 1 所示,它的形狀像字母 U,因此得名 U-Net。
二、數據集說明
????????我們將使用作為 TensorFlow 數據集 (TFDS) 的一部分提供的 Oxford-IIIT 寵物數據集。 它可以很容易地用 TFDS 加載,然后進行一些數據預處理,為訓練分割模型做好準備。
????????可以使用 tfds 通過指定數據集的名稱來加載數據集,并通過設置 with_info=True 來獲取數據集信息:
? ? ? ? 代碼如下,如果多次運行程序,第一次下載完之后可以添加download=False參數,會自動從已經下載好的文件夾下讀取數據。
dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)????????使用 print(info) 打印數據集信息,我們將看到牛津寵物數據集的各種詳細信息。 例如,在圖 2 中,我們可以看到共有 7349 張圖像,其中包含內置的測試/訓練拆分。
三、相關代碼
1、unet模型
????????U-Net 的架構相當簡單; 然而,為了在編碼器和解碼器之間創建跳躍連接,我們需要連接一些層。 所以 Keras 函數式 API 最適合這個目的。
????????首先,我們創建一個 build_unet_model 函數,指定輸入、編碼器層、瓶頸、解碼器層,最后是帶有激活 softmax 的 Conv2D 的輸出層。 注意輸入圖像的形狀是 128x128x3。 輸出具有三個通道,對應于模型將為每個像素分類的三個類:背景、前景對象和對象輪廓。
import tensorflow as tf from tensorflow import keras from tensorflow.keras import layers import tensorflow_datasets as tfds import matplotlib.pyplot as plt import numpy as np# 在編碼器和U-Net的瓶頸中使用 def double_conv_block(x, n_filters):# Conv2D then ReLU activationx = layers.Conv2D(n_filters, 3, padding = "same", activation = "relu", kernel_initializer = "he_normal")(x)# Conv2D then ReLU activationx = layers.Conv2D(n_filters, 3, padding = "same", activation = "relu", kernel_initializer = "he_normal")(x)return x# 用于在編碼器中進行下采樣或特征提取 def downsample_block(x, n_filters):f = double_conv_block(x, n_filters)p = layers.MaxPool2D(2)(f)p = layers.Dropout(0.3)(p)return f, p# 上采樣函數 upsample_block def upsample_block(x, conv_features, n_filters):# upsamplex = layers.Conv2DTranspose(n_filters, 3, 2, padding="same")(x)# concatenatex = layers.concatenate([x, conv_features])# dropoutx = layers.Dropout(0.3)(x)# Conv2D twice with ReLU activationx = double_conv_block(x, n_filters)return x# 創建模型 def build_unet_model():# inputsinputs = layers.Input(shape=(128, 128, 3))# encoder: contracting path - downsample# 1 - downsamplef1, p1 = downsample_block(inputs, 64)# 2 - downsamplef2, p2 = downsample_block(p1, 128)# 3 - downsamplef3, p3 = downsample_block(p2, 256)# 4 - downsamplef4, p4 = downsample_block(p3, 512)# 5 - bottleneckbottleneck = double_conv_block(p4, 1024)# decoder: expanding path - upsample# 6 - upsampleu6 = upsample_block(bottleneck, f4, 512)# 7 - upsampleu7 = upsample_block(u6, f3, 256)# 8 - upsampleu8 = upsample_block(u7, f2, 128)# 9 - upsampleu9 = upsample_block(u8, f1, 64)# outputsoutputs = layers.Conv2D(3, 1, padding="same", activation="softmax")(u9)# unet model with Keras Functional APIunet_model = tf.keras.Model(inputs, outputs, name="U-Net")return unet_model2、訓練代碼?
? ? ? ? 運行train函數進行訓練。
dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True, download=False) train_dataset = dataset["train"].map(load_image_train, num_parallel_calls=tf.data.AUTOTUNE) test_dataset = dataset["test"].map(load_image_test, num_parallel_calls=tf.data.AUTOTUNE)BATCH_SIZE = 32 BUFFER_SIZE = 1000 train_batches = train_dataset.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat() train_batches = train_batches.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) validation_batches = test_dataset.take(3000).batch(BATCH_SIZE) test_batches = test_dataset.skip(3000).take(669).batch(BATCH_SIZE)def train():unet_model = build_unet_model()unet_model.compile(optimizer=tf.keras.optimizers.Adam(),loss="sparse_categorical_crossentropy",metrics="accuracy")NUM_EPOCHS = 40TRAIN_LENGTH = info.splits["train"].num_examplesSTEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZEVAL_SUBSPLITS = 5TEST_LENTH = info.splits["test"].num_examplesVALIDATION_STEPS = TEST_LENTH // BATCH_SIZE // VAL_SUBSPLITSmodel_history = unet_model.fit(train_batches,epochs=NUM_EPOCHS,steps_per_epoch=STEPS_PER_EPOCH,validation_steps=VALIDATION_STEPS,validation_data=test_batches)unet_model.save('unet.h5')3、其它函數
# 修改大小 def resize(input_image, input_mask):input_image = tf.image.resize(input_image, (128, 128), method="nearest")input_mask = tf.image.resize(input_mask, (128, 128), method="nearest")return input_image, input_mask# 水平翻轉 def augment(input_image, input_mask):if tf.random.uniform(()) > 0.5:# Random flipping of the image and maskinput_image = tf.image.flip_left_right(input_image)input_mask = tf.image.flip_left_right(input_mask)return input_image, input_mask# 規范化數據集 def normalize(input_image, input_mask):input_image = tf.cast(input_image, tf.float32) / 255.0input_mask -= 1return input_image, input_mask# 加載訓練集 def load_image_train(datapoint):input_image = datapoint["image"]input_mask = datapoint["segmentation_mask"]input_image, input_mask = resize(input_image, input_mask)input_image, input_mask = augment(input_image, input_mask)input_image, input_mask = normalize(input_image, input_mask)return input_image, input_mask# 加載測試集 def load_image_test(datapoint):input_image = datapoint["image"]input_mask = datapoint["segmentation_mask"]input_image, input_mask = resize(input_image, input_mask)input_image, input_mask = normalize(input_image, input_mask)return input_image, input_mask# 創建mask def create_mask(pred_mask):pred_mask = tf.argmax(pred_mask, axis=-1)pred_mask = pred_mask[..., tf.newaxis]return pred_mask[0]# 顯示預測結果 def show_predictions(dataset=None, num=1, unet_model=None):if dataset:for image, mask in dataset.take(num):pred_mask = unet_model.predict(image)display([image[0], mask[0], create_mask(pred_mask)])else:display([sample_image, sample_mask, create_mask(unet_model.predict(sample_image[tf.newaxis, ...]))])# 可視化 def display(display_list):plt.figure(figsize=(15, 15))title = ["Input Image", "True Mask", "Predicted Mask"]for i in range(len(display_list)):plt.subplot(1, len(display_list), i+1)plt.title(title[i])plt.imshow(tf.keras.utils.array_to_img(display_list[i]))plt.axis("off")plt.show()# sample_batch = next(iter(train_batches)) # random_index = np.random.choice(sample_batch[0].shape[0]) # sample_image, sample_mask = sample_batch[0][random_index], sample_batch[1][random_index] # display([sample_image, sample_mask])4、調用模型進行測試
? ? ? ? 加載訓練好的模型,調用上面的函數,可以進行測試,測試結果如下圖
model = load_model('unet.h5') show_predictions(test_batches, 1, model)四、其他參考
機器學習筆記 - Keras + TensorFlow2.0 + Unet進行語義分割_bashendixie5的博客-CSDN博客https://blog.csdn.net/bashendixie5/article/details/115795171
總結
以上是生活随笔為你收集整理的机器学习笔记 - 使用Keras + Unet 进行图像分割的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: qtreewidget点击空白处时取消以
- 下一篇: java弹出提示窗口_Java实现弹窗效