TensorLayer MNIST
生活随笔
收集整理的這篇文章主要介紹了
TensorLayer MNIST
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
代碼
import tensorflow as tf import tensorlayer as tlsess = tf.InteractiveSession()# Prepare data X_train, y_train, X_val, y_val, X_test, y_test = tl.files.load_mnist_dataset(shape=(-1, 784))# Define placeholder x = tf.placeholder(tf.float32, shape=[None, 784], name='x') y_ = tf.placeholder(tf.int64, shape=[None, ], name='y_')# Define the neural network structure network = tl.layers.InputLayer(x, name='input_layer') network = tl.layers.DropoutLayer(network, keep=0.8, name='drop1') network = tl.layers.DenseLayer(network, n_units=800, act=tf.nn.relu, name='relu1') network = tl.layers.DropoutLayer(network, keep=0.5, name='drop2') network = tl.layers.DenseLayer(network, n_units=800, act=tf.nn.relu, name='relu2') network = tl.layers.DropoutLayer(network, keep=0.5, name='drop3')# The softmax is implemented internally in tl.cost.cross_entropy(y, y_) to # speed up computation, so we use identity here. # see tf.nn.sparse_softmax_cross_entropy_with_logits() network = tl.layers.DenseLayer(network, n_units=10, act=tf.identity, name='output_layer')# Define cost function and metric. y = network.outputs cost = tl.cost.cross_entropy(y, y_, 'cost') correct_prediction = tf.equal(tf.argmax(y, 1), y_) acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) y_op = tf.argmax(tf.nn.softmax(y), 1)# Define the optimizer train_params = network.all_params train_op = tf.train.AdamOptimizer(learning_rate=0.0001, beta1=0.9, beta2=0.999,epsilon=1e-08, use_locking=False).minimize(cost, var_list=train_params)# Initialize all variables in the session tl.layers.initialize_global_variables(sess)# Print network information network.print_params() network.print_layers()# Train the network, we recommend to use tl.iterate.minibatches() tl.utils.fit(sess, network, train_op, cost, X_train, y_train, x, y_,acc=acc, batch_size=500, n_epoch=500, print_freq=5,X_val=X_val, y_val=y_val, eval_train=False)# Evaluation tl.utils.test(sess, network, acc, X_test, y_test, x, y_, batch_size=None, cost=cost)# Save the network to .npz file tl.files.save_npz(network.all_params, name='model.npz')sess.close()個人還是喜歡直接基于TensorFlow的API來開發程序,不太喜歡TensorLayer, TFLearn。
總結
以上是生活随笔為你收集整理的TensorLayer MNIST的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: TensorLayer的安装
- 下一篇: TensorFlow CIFAR-10数