TF之CNN:基于CIFAR-10数据集训练、检测CNN(2+2)模型(TensorBoard可视化)
TF之CNN:基于CIFAR-10數據集訓練、檢測CNN(2+2)模型(TensorBoard可視化)
?
?
目錄
1、基于CIFAR-10數據集訓練CNN(2+2)模型代碼
2、檢測CNN(2+2)模型
3、TensorBoard查看損失的變化曲線
?
?
?
1、基于CIFAR-10數據集訓練CNN(2+2)模型代碼
from datetime import datetime import time import tensorflow as tf import cifar10FLAGS = tf.app.flags.FLAGStf.app.flags.DEFINE_string('train_dir', '/tmp/cifar10_train',"""Directory where to write event logs """"""and checkpoint.""") #寫入事件日志和檢查點的目錄 tf.app.flags.DEFINE_integer('max_steps', 1000000,"""Number of batches to run.""") #要運行的批次數 tf.app.flags.DEFINE_boolean('log_device_placement', False,"""Whether to log device placement.""") #是否記錄設備放置 tf.app.flags.DEFINE_integer('log_frequency', 10,"""How often to log results to the console.""") #將結果記錄到控制臺的頻率def train():"""Train CIFAR-10 for a number of steps."""with tf.Graph().as_default():global_step = tf.train.get_or_create_global_step() #tf.contrib.framework.get_or_create_global_step()# Get images and labels for CIFAR-10.images, labels = cifar10.distorted_inputs()# Build a Graph that computes the logits predictions from the# inference model.logits = cifar10.inference(images)# Calculate loss.loss = cifar10.loss(logits, labels)# Build a Graph that trains the model with one batch of examples and# updates the model parameters.train_op = cifar10.train(loss, global_step)class _LoggerHook(tf.train.SessionRunHook):"""Logs loss and runtime."""def begin(self):self._step = -1self._start_time = time.time()def before_run(self, run_context):self._step += 1return tf.train.SessionRunArgs(loss) # Asks for loss value.def after_run(self, run_context, run_values):if self._step % FLAGS.log_frequency == 0:current_time = time.time()duration = current_time - self._start_timeself._start_time = current_timeloss_value = run_values.resultsexamples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / durationsec_per_batch = float(duration / FLAGS.log_frequency)format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ''sec/batch)')print(format_str % (datetime.now(), self._step, loss_value,examples_per_sec, sec_per_batch))with tf.train.MonitoredTrainingSession(checkpoint_dir=FLAGS.train_dir, #FLAGS.train_dir,寫入事件日志和檢查點的目錄hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps), #FLAGS.max_steps,要運行的批次數tf.train.NanTensorHook(loss),_LoggerHook()],config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)) as mon_sess: #Whether to log device placementwhile not mon_sess.should_stop():mon_sess.run(train_op)def main(argv=None): # pylint: disable=unused-argumentcifar10.maybe_download_and_extract()if tf.gfile.Exists(FLAGS.train_dir):tf.gfile.DeleteRecursively(FLAGS.train_dir)tf.gfile.MakeDirs(FLAGS.train_dir)train()if __name__ == '__main__':FLAGS.train_dir='cifarlO_train/'FLAGS.max_steps='1000000'FLAGS.log_device_placement='False'FLAGS.log_frequency='10'tf.app.run()控制臺輸出結果
Filling queue with 20000 CIFAR images before starting to train. This will take a few minutes.
2018-09-21 11:15:53.399945: step 0, loss = 4.67 (0.7 examples/sec; 177.888 sec/batch)
2018-09-21 11:17:13.770461: step 10, loss = 4.62 (15.9 examples/sec; 8.037 sec/batch)
2018-09-21 11:19:10.122213: step 20, loss = 4.36 (11.0 examples/sec; 11.635 sec/batch)
2018-09-21 11:21:01.145664: step 30, loss = 4.34 (11.5 examples/sec; 11.102 sec/batch)
2018-09-21 11:22:55.463296: step 40, loss = 4.37 (11.2 examples/sec; 11.432 sec/batch)
2018-09-21 11:24:43.938444: step 50, loss = 4.45 (11.8 examples/sec; 10.848 sec/batch)
2018-09-21 11:26:36.091383: step 60, loss = 4.29 (11.4 examples/sec; 11.215 sec/batch)
2018-09-21 11:28:27.229967: step 70, loss = 4.12 (11.5 examples/sec; 11.114 sec/batch)
2018-09-21 11:30:24.759522: step 80, loss = 4.04 (10.9 examples/sec; 11.753 sec/batch)
2018-09-21 11:32:04.392507: step 90, loss = 4.14 (12.8 examples/sec; 9.963 sec/batch)
2018-09-21 11:33:50.161788: step 100, loss = 4.08 (12.1 examples/sec; 10.577 sec/batch)
2018-09-21 11:35:27.867156: step 110, loss = 4.05 (13.1 examples/sec; 9.771 sec/batch)
2018-09-21 11:36:59.189017: step 120, loss = 3.99 (14.0 examples/sec; 9.132 sec/batch)
2018-09-21 11:38:44.246431: step 130, loss = 3.93 (12.2 examples/sec; 10.506 sec/batch)
2018-09-21 11:40:27.267226: step 140, loss = 4.12 (12.4 examples/sec; 10.302 sec/batch)
2018-09-21 11:42:20.492360: step 150, loss = 3.94 (11.3 examples/sec; 11.323 sec/batch)
2018-09-21 11:44:05.324174: step 160, loss = 3.93 (12.2 examples/sec; 10.483 sec/batch)
2018-09-21 11:45:45.123575: step 170, loss = 3.80 (12.8 examples/sec; 9.980 sec/batch)
2018-09-21 11:47:31.441841: step 180, loss = 3.95 (12.0 examples/sec; 10.632 sec/batch)
2018-09-21 11:49:19.129222: step 190, loss = 3.90 (11.9 examples/sec; 10.769 sec/batch)
2018-09-21 11:50:58.325049: step 200, loss = 4.15 (12.9 examples/sec; 9.920 sec/batch)
2018-09-21 11:52:34.784594: step 210, loss = 3.92 (13.3 examples/sec; 9.646 sec/batch)
2018-09-21 11:54:32.453522: step 220, loss = 3.81 (10.9 examples/sec; 11.767 sec/batch)
2018-09-21 11:56:33.002429: step 230, loss = 3.87 (10.6 examples/sec; 12.055 sec/batch)
2018-09-21 11:58:19.417427: step 240, loss = 3.67 (12.0 examples/sec; 10.641 sec/batch)
2、檢測CNN(2+2)模型
? ? ? ?檢測模型在CIFAR-10 測試數據集上的準確性,實際上到6萬步左右時, 模型就有了85.99%的準確率,到10萬步時的準確率為86.38%,到15萬步后的準確率基本穩定在86.66%左右。
?
3、TensorBoard查看損失的變化曲線
?
?
?
?
?
?
?
?
?
總結
以上是生活随笔為你收集整理的TF之CNN:基于CIFAR-10数据集训练、检测CNN(2+2)模型(TensorBoard可视化)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Python命令行解析:IDE内点击Ru
- 下一篇: TF之VGG系列:利用预先编制好的脚本d