TensorFlow模型实现:UNet模型
生活随笔
收集整理的這篇文章主要介紹了
TensorFlow模型实现:UNet模型
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
TensorFlow模型實現:UNet模型
1.UNet模型
# -*-coding: utf-8 -*- """@Project: triple_path_networks@File : UNet.py@Author : panjq@E-mail : pan_jinquan@163.com@Date : 2019-01-24 11:18:15 """ import tensorflow as tf import tensorflow.contrib.slim as slimdef lrelu(x):return tf.maximum(x * 0.2, x)activation_fn=lreludef UNet(inputs, reg): # Unetconv1 = slim.conv2d(inputs, 32, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv1_1', weights_regularizer=reg)conv1 = slim.conv2d(conv1, 32, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv1_2',weights_regularizer=reg)pool1 = slim.max_pool2d(conv1, [2, 2], padding='SAME')conv2 = slim.conv2d(pool1, 64, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv2_1',weights_regularizer=reg)conv2 = slim.conv2d(conv2, 64, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv2_2',weights_regularizer=reg)pool2 = slim.max_pool2d(conv2, [2, 2], padding='SAME')conv3 = slim.conv2d(pool2, 128, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv3_1',weights_regularizer=reg)conv3 = slim.conv2d(conv3, 128, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv3_2',weights_regularizer=reg)pool3 = slim.max_pool2d(conv3, [2, 2], padding='SAME')conv4 = slim.conv2d(pool3, 256, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv4_1',weights_regularizer=reg)conv4 = slim.conv2d(conv4, 256, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv4_2',weights_regularizer=reg)pool4 = slim.max_pool2d(conv4, [2, 2], padding='SAME')conv5 = slim.conv2d(pool4, 512, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv5_1',weights_regularizer=reg)conv5 = slim.conv2d(conv5, 512, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv5_2',weights_regularizer=reg)up6 = upsample_and_concat(conv5, conv4, 256, 512)conv6 = slim.conv2d(up6, 256, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv6_1',weights_regularizer=reg)conv6 = slim.conv2d(conv6, 256, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv6_2',weights_regularizer=reg)up7 = upsample_and_concat(conv6, conv3, 128, 256)conv7 = slim.conv2d(up7, 128, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv7_1',weights_regularizer=reg)conv7 = slim.conv2d(conv7, 128, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv7_2',weights_regularizer=reg)up8 = upsample_and_concat(conv7, conv2, 64, 128)conv8 = slim.conv2d(up8, 64, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv8_1',weights_regularizer=reg)conv8 = slim.conv2d(conv8, 64, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv8_2',weights_regularizer=reg)up9 = upsample_and_concat(conv8, conv1, 32, 64)conv9 = slim.conv2d(up9, 32, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv9_1', weights_regularizer=reg)conv9 = slim.conv2d(conv9, 32, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv9_2',weights_regularizer=reg)print("conv9.shape:{}".format(conv9.get_shape()))type='UNet_1X'with tf.variable_scope(name_or_scope="output"):if type=='UNet_3X':#UNet放大三倍conv10 = slim.conv2d(conv9, 27, [1, 1], rate=1, activation_fn=None, scope='g_conv10',weights_regularizer=reg)out = tf.depth_to_space(conv10, 3)if type=='UNet_1X':#輸入輸出維度相同out = slim.conv2d(conv9, 6, [1, 1], rate=1, activation_fn=None, scope='g_conv10',weights_regularizer=reg)return outdef upsample_and_concat(x1, x2, output_channels, in_channels):pool_size = 2deconv_filter = tf.Variable(tf.truncated_normal([pool_size, pool_size, output_channels, in_channels], stddev=0.02))deconv = tf.nn.conv2d_transpose(x1, deconv_filter, tf.shape(x2), strides=[1, pool_size, pool_size, 1])deconv_output = tf.concat([deconv, x2], 3)deconv_output.set_shape([None, None, None, output_channels * 2])return deconv_outputif __name__=="__main__":weight_decay=0.001reg = slim.l2_regularizer(scale=weight_decay)inputs = tf.ones(shape=[4, 100, 200, 3])out=UNet(inputs,reg)print("net1.shape:{}".format(inputs.get_shape()))print("out.shape:{}".format(out.get_shape()))with tf.Session() as sess:sess.run(tf.global_variables_initializer())?
總結
以上是生活随笔為你收集整理的TensorFlow模型实现:UNet模型的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Dilated/Atrous conv
- 下一篇: python实现交并比IOU