TensorFlow——基于Keras子类API的fashion-mnist数据集图像分类
生活随笔
收集整理的這篇文章主要介紹了
TensorFlow——基于Keras子类API的fashion-mnist数据集图像分类
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
https://tensorflow.google.cn/tutorials/keras/classification??
解決方案?
#!usr/bin/env python # -*- coding:utf-8 _*- """ @version: 0.0.1 author: ShenTuZhiGang @time: 2021/01/25 16:33 @file: 12.py @function: @modify: """from tensorflow import keras import tensorflow as tf import mnist_reader import numpy as np import matplotlib.pyplot as plt from tensorflow import summary import datetimecurrent_time = str(datetime.datetime.now().timestamp()) train_log_dir = '/content/drive/My Drive/colab notebooks/output/tsboardx/train/' + current_time test_log_dir = '/content/drive/My Drive/colab notebooks/output/tsboardx/test/' + current_time val_log_dir = '/content/drive/My Drive/colab notebooks/output/tsboardx/val/' + current_time train_summary_writer = summary.create_file_writer(train_log_dir) val_summary_writer = summary.create_file_writer(val_log_dir) test_summary_writer = summary.create_file_writer(test_log_dir) (train_images, train_labels), (test_images, test_labels) = mnist_reader.load_data('../data/fashion') class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat','Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'] train_images = train_images / 255.0test_images = test_images / 255.0plt.figure(figsize=(10,10)) for i in range(25):plt.subplot(5,5,i+1)plt.xticks([])plt.yticks([])plt.grid(False)plt.imshow(train_images[i], cmap=plt.cm.binary)plt.xlabel(class_names[train_labels[i]]) plt.show()class FashionMnistModel(keras.Model):def __init__(self, **kwargs):super().__init__(**kwargs)self.input_ = keras.layers.Flatten(input_shape=[28, 28])self.hidden1 = keras.layers.Dense(128, activation="relu")self.main_output = keras.layers.Dense(10)def call(self, inputs, **kwargs):input_a = self.input_(inputs)hidden1 = self.hidden1(input_a)output = self.main_output(hidden1)return outputmodel = FashionMnistModel() model.build(input_shape=(0, 28, 28)) model.summary() model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy']) history = model.fit(train_images, train_labels, epochs=10) test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2) with test_summary_writer.as_default():summary.scalar('loss', test_loss, 10)summary.scalar('accuracy', test_acc, 10) print('\nTest accuracy:', test_acc) probability_model = tf.keras.Sequential([model,tf.keras.layers.Softmax()]) predictions = probability_model.predict(test_images) print(predictions[0]) print(np.argmax(predictions[0])) print(test_labels[0])def plot_image(i, predictions_array, true_label, img):predictions_array, true_label, img = predictions_array, true_label[i], img[i]plt.grid(False)plt.xticks([])plt.yticks([])plt.imshow(img, cmap=plt.cm.binary)predicted_label = np.argmax(predictions_array)if predicted_label == true_label:color = 'blue'else:color = 'red'plt.xlabel("{} {:2.0f}% ({})".format(class_names[predicted_label],100*np.max(predictions_array),class_names[true_label]),color=color)def plot_value_array(i, predictions_array, true_label):predictions_array, true_label = predictions_array, true_label[i]plt.grid(False)plt.xticks(range(10))plt.yticks([])thisplot = plt.bar(range(10), predictions_array, color="#777777")plt.ylim([0, 1])predicted_label = np.argmax(predictions_array)thisplot[predicted_label].set_color('red')thisplot[true_label].set_color('blue')i = 0 plt.figure(figsize=(6, 3)) plt.subplot(1, 2, 1) plot_image(i, predictions[i], test_labels, test_images) plt.subplot(1, 2, 2) plot_value_array(i, predictions[i], test_labels) plt.show()i = 12 plt.figure(figsize=(6,3)) plt.subplot(1,2,1) plot_image(i, predictions[i], test_labels, test_images) plt.subplot(1,2,2) plot_value_array(i, predictions[i], test_labels) plt.show()# Plot the first X test images, their predicted labels, and the true labels. # Color correct predictions in blue and incorrect predictions in red. num_rows = 5 num_cols = 3 num_images = num_rows*num_cols plt.figure(figsize=(2*2*num_cols, 2*num_rows)) for i in range(num_images):plt.subplot(num_rows, 2*num_cols, 2*i+1)plot_image(i, predictions[i], test_labels, test_images)plt.subplot(num_rows, 2*num_cols, 2*i+2)plot_value_array(i, predictions[i], test_labels) plt.tight_layout() plt.show()# Grab an image from the test dataset. img = test_images[1]print(img.shape)# Add the image to a batch where it's the only member. img = (np.expand_dims(img,0))print(img.shape)predictions_single = probability_model.predict(img)print(predictions_single)plot_value_array(1, predictions_single[0], test_labels) _ = plt.xticks(range(10), class_names, rotation=45)print(np.argmax(predictions_single[0]))參考文章
TensorFlow——本地加載fashion-mnist數據集
TensorFlow 教程——基本分類:對服裝圖像進行分類
總結
以上是生活随笔為你收集整理的TensorFlow——基于Keras子类API的fashion-mnist数据集图像分类的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Python OpenCV——函数 cv
- 下一篇: Python——基于OpenCV获取倾斜