深度学习之基于卷积神经网络实现超大Mnist数据集识别
生活随笔
收集整理的這篇文章主要介紹了
深度学习之基于卷积神经网络实现超大Mnist数据集识别
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
在以往的手寫數字識別中,數據集一共是70000張圖片,模型準確率可以達到99%以上的準確率。而本次實驗的手寫數字數據集中有120000張圖片,而且數據集的預處理方式也是之前沒有遇到過的。最終在驗證集上的模型準確率達到了99.1%。在模型訓練過程中,加入了上一篇文章中提到的早停策略以及模型保存策略。
1.導入庫
import numpy as np import tensorflow as tf import matplotlib.pyplot as plt import os,PIL,pathlib,warnings,pickle,pngwarnings.filterwarnings("ignore")#忽略警告信息# 支持中文 plt.rcParams['font.sans-serif'] = ['SimHei'] # 用來正常顯示中文標簽 plt.rcParams['axes.unicode_minus'] = False # 用來正常顯示負號 os.environ['TF_CPP_MIN_LOG_LEVEL']='2'2.數據處理
原始數據如下所示:
這是經過序列化的圖片數據,因此需要我們自己反序列化,讀入內存中
讀入內存中的數據,需要轉化為圖片格式,按照它所屬的標簽,存放到不同的文件夾中。
num = data.shape[0] #如果不存在文件夾,就新建文件夾 if not os.path.exists('E:/tmp/.keras/datasets/QMnist/dataset'):os.mkdir('E:/tmp/.keras/datasets/QMnist/dataset') for i in range(0,num):x = data[i]y = str(labels[i])name = str(i)#二級文件夾,存放0-9不同種類的圖片if not os.path.exists('E:/tmp/.keras/datasets/QMnist/dataset/{}'.format(y)):os.mkdir('E:/tmp/.keras/datasets/QMnist/dataset/{}'.format(y)) #存放圖片 png.from_array(x,mode="L").save("E:/tmp/.keras/datasets/QMnist/dataset/{}/{}.png".format(y,name))最終處理出來的圖片數據如下所示:
其中[4]中的部分圖片如下所示:
3.劃分訓練集、測試集、驗證集
這一部分屬于老生常談的問題了~
data_dir = "E:/tmp/.keras/datasets/QMnist/dataset" data_dir = pathlib.Path(data_dir)image_count = len(list(data_dir.glob('*/*.png'))) # print(image_count)#120000all_images_paths = list(data_dir.glob('*')) all_images_paths = [str(path) for path in all_images_paths] all_label_names = [path.split("\\")[5].split(".")[0] for path in all_images_paths] # print(all_label_names) height = 75 width = 75 batch_size = 8 epochs = 50train_data_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255,validation_split=0.2 ) train_ds = train_data_gen.flow_from_directory(directory=data_dir,target_size=(height,width),batch_size=batch_size,shuffle=True,class_mode='categorical',subset='training',seed=42 )validation_data_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255,validation_split=0.2 ) val_ds = validation_data_gen.flow_from_directory(directory=data_dir,target_size=(height,width),batch_size=batch_size,shuffle=True,class_mode='categorical',subset='validation' )test_data_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255,validation_split=0.1 ) test_ds = test_data_gen.flow_from_directory(directory=data_dir,target_size=(height,width),batch_size=batch_size,shuffle=True,class_mode='categorical',subset='validation' )經過處理之后,查看圖片:
plt.figure(figsize=(15, 10)) # 圖形的寬為15高為10for images, labels in train_ds:for i in range(40):ax = plt.subplot(5, 8, i + 1)plt.imshow(images[i])plt.title(all_label_names[np.argmax(labels[i])])plt.axis("off")break plt.show()4.網絡搭建
一開始采用的是VGG16模型,但是跑的實在是太慢了,而且不知道哪方面出了問題,準確率很低。
model = tf.keras.Sequential([tf.keras.layers.Conv2D(filters=32,kernel_size=(3,3),padding="same",activation="relu",input_shape=[64, 64, 3]),tf.keras.layers.MaxPooling2D((2,2)),tf.keras.layers.Conv2D(filters=64,kernel_size=(3,3),padding="same",activation="relu"),tf.keras.layers.MaxPooling2D((2,2)),tf.keras.layers.Conv2D(filters=64,kernel_size=(3,3),padding="same",activation="relu"),tf.keras.layers.MaxPooling2D((2,2)),tf.keras.layers.Flatten(),tf.keras.layers.Dense(64, activation="relu"),tf.keras.layers.Dense(10, activation="softmax") ])早停策略以及模型保存
Earlystop = tf.keras.callbacks.EarlyStopping(monitor='loss',mode='min',restore_best_weights=True ) Checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath='E:/Users/yqx/PycharmProjects/Qmnist/model.h5',save_best_only=True,monitor='val_accuracy',mode='max' )網絡編譯&&訓練
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy']) history = model.fit(train_ds,validation_data=val_ds,epochs=epochs,callbacks=[Earlystop,Checkpoint] )Accuracy以及Loss圖如下所示:
epochs設置的為50,但是在第7個epoch訓練結束后,就停止了,實現了早停策略。
5.模型測試&&混淆矩陣
模型加載:
model = tf.keras.models.load_model('cloud/model.h5')對測試集進行模型測試:
model.evaluate(test_ds)最終結果如下所示:
1500/1500 [==============================] - 9s 6ms/step - loss: 0.0469 - accuracy: 0.9912 [0.046884261071681976, 0.9911637306213379]繪制混淆矩陣:
from sklearn.metrics import confusion_matrix from sklearn.metrics import classification_report import seaborn as snspred = model.predict(test_ds).argmax(axis=1) labels = list(train_ds.class_indices.keys())cm = confusion_matrix(test_data.classes, pred) plt.figure(figsize=(15,10)) sns.heatmap(cm, annot=True, fmt='g', xticklabels=labels, yticklabels=labels, cmap="BuPu") plt.title('Confusion Matrix') plt.show()
努力加油a啊
總結
以上是生活随笔為你收集整理的深度学习之基于卷积神经网络实现超大Mnist数据集识别的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 不一样的春节思维导图的内容_电学思维导图
- 下一篇: R语言软件安装教程「建议收藏」(Proj