CNN鲜花分类
CNN鮮花分類
- 1、數據集介紹
- 2、代碼實戰
- 2.1 導入依賴
- 2.2 下載數據
- 2.3 統計數據集
- 2.4 創建dataset
- 2.5 可視化一個batch_size
- 2.6 將數據集緩存到內存中,加速讀取
- 2.7 搭建模型
- 2.8 編譯模型
- 2.9 模型訓練
- 2.10 可視化訓練結果
- 3、模型優化
- 3.1 數據增強設置
- 3.2 顯示數據增強后的效果
- 3.3 搭建新的模型
- 3.4 編譯模型
- 3.5 模型訓練
- 3.6 可視化訓練結果
- 3.7 模型預測
1、數據集介紹
總共5種花,按照文件夾區分花朵的類別。
下載下來的是個壓縮包,需要將其解壓。
數據集下載地址:https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
2、代碼實戰
2.1 導入依賴
import PIL import numpy as np import matplotlib.pyplot as plt import pathlib import libimport tensorflow as tf from tensorflow.keras import layers,models2.2 下載數據
# 下載數據集到本地 data_url='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz' data_dir=tf.keras.utils.get_file('flower_photos',origin=data_url,untar=True)#untar=True 下載后解壓 data_dir=pathlib.Path(data_dir)2.3 統計數據集
# 統計數據集大小 dataset_size=len(list(data_dir.glob('*/*.jpg'))) dataset_size總共3670張照片,比上次小狗分類那個少多了。
# 顯示部分圖片 imgs=list(data_dir.glob('*/*.jpg')) imgs查看下第1張圖片
img1=imgs[0] #第一張圖片 img1 str(img1) PIL.Image.open(str(img1)) #讀取并顯示再查看下第2張圖片
img2=imgs[1] #第2張圖片 PIL.Image.open(str(img2))2.4 創建dataset
訓練集:
# 3 創建dataset BATCH_SIZE=32 HEIGHT=180 WIDTH=180#80%是訓練集,20%是驗證集 train_ds=tf.keras.preprocessing.image_dataset_from_directory(directory=data_dir,batch_size=BATCH_SIZE,validation_split=0.2,subset='training',seed=666,image_size=(HEIGHT,WIDTH)) train_ds class_names=train_ds.class_names #數據集類別 class_names驗證集:
val_ds=tf.keras.preprocessing.image_dataset_from_directory(directory=data_dir,batch_size=BATCH_SIZE,validation_split=0.2,subset='validation',seed=666,image_size=(HEIGHT,WIDTH)) val_ds2.5 可視化一個batch_size
# 可視化一個batch_size的數據 for images,labels in train_ds.take(1):for i in range(9): # 一個batch_size有32張,這里只顯示9張plt.subplot(3,3,i+1)plt.imshow(images[i].numpy().astype('uint8'))plt.title(class_names[labels[i]])plt.axis('off')2.6 將數據集緩存到內存中,加速讀取
#將數據集緩存到內存中,加速讀取 AUTOTUNE=tf.data.AUTOTUNE train_ds=train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE) val_ds=val_ds.cache().prefetch(buffer_size=AUTOTUNE)2.7 搭建模型
這里僅作測試,并沒有使用預訓練模型
#搭建模型 model=models.Sequential([layers.experimental.preprocessing.Rescaling(1./255,input_shape=(HEIGHT,WIDTH,3)),# 數據歸一化layers.Conv2D(16,3,padding='same',activation='relu'),layers.MaxPool2D(),layers.Conv2D(32,3,padding='same',activation='relu'),layers.MaxPool2D(),layers.Conv2D(64,3,padding='same',activation='relu'),layers.MaxPool2D(),layers.Flatten(),layers.Dense(128,activation='relu'),layers.Dense(5) ]) model.summary()2.8 編譯模型
#編譯模型 model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])這里使用的SparseCategoricalCrossentropy會自動幫我們
2.9 模型訓練
#模型訓練 EPOCHS=10 history=model.fit(train_ds,validation_data=val_ds,epochs=EPOCHS)這里由于設備太拉跨,略微出手已是顯卡極限,所以就只設置了10個epoch
2.10 可視化訓練結果
# 可視化訓練結果 ranges=range(EPOCHS) train_acc=history.history['accuracy'] val_acc=history.history['val_accuracy']train_loss=history.history['loss'] val_loss=history.history['val_loss']plt.figure(figsize=(16,8)) plt.subplot(1,2,1) plt.plot(ranges,train_acc,label='train_acc') plt.plot(ranges,val_acc,label='val_acc') plt.title('Accuracy') plt.xlabel('Epochs') plt.ylabel('Accuracy')plt.subplot(1,2,2) plt.plot(ranges,train_loss,label='train_loss') plt.plot(ranges,val_loss,label='val_loss') plt.title('Loss') plt.xlabel('Epochs') plt.ylabel('Loss')plt.show()過擬合非常嚴重,下面對模型進行優化
3、模型優化
3.1 數據增強設置
# 數據增強參數設置 data_argumentation=tf.keras.Sequential([# 隨機水平翻轉layers.experimental.preprocessing.RandomFlip('horizontal',input_shape=(HEIGHT,WIDTH,3)),# 隨機旋轉layers.experimental.preprocessing.RandomRotation(0.1), # 旋轉# 隨機縮放layers.experimental.preprocessing.RandomZoom(0.1), # ])這塊的API太多了,多去查查官網。
3.2 顯示數據增強后的效果
# 顯示數據增強后的效果 for images,labels in train_ds.take(1):for i in range(9): # 一個batch_size有32張,這里只顯示9張plt.subplot(3,3,i+1)argumeng_images=data_argumentation(images) #數據增強plt.imshow(argumeng_images[i].numpy().astype('uint8')) # 顯示plt.title(class_names[labels[i]])plt.axis('off')3.3 搭建新的模型
#搭建新的模型 model_2=models.Sequential([data_argumentation, # 數據增強layers.experimental.preprocessing.Rescaling(1./255),# 數據歸一化layers.Conv2D(16,3,padding='same',activation='relu'),layers.MaxPool2D(),layers.Conv2D(32,3,padding='same',activation='relu'),layers.MaxPool2D(),layers.Conv2D(64,3,padding='same',activation='relu'),layers.MaxPool2D(),layers.Dropout(0.2),layers.Flatten(),layers.Dense(128,activation='relu'),layers.Dense(5) ]) model_2.summary()3.4 編譯模型
#編譯模型 model_2.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])3.5 模型訓練
#模型訓練 history=model_2.fit(train_ds,validation_data=val_ds,epochs=EPOCHS)3.6 可視化訓練結果
# 可視化訓練結果 ranges=range(EPOCHS) train_acc=history.history['accuracy'] val_acc=history.history['val_accuracy']train_loss=history.history['loss'] val_loss=history.history['val_loss']plt.figure(figsize=(16,8)) plt.subplot(1,2,1) plt.plot(ranges,train_acc,label='train_acc') plt.plot(ranges,val_acc,label='val_acc') plt.title('Accuracy') plt.xlabel('Epochs') plt.ylabel('Accuracy')plt.subplot(1,2,2) plt.plot(ranges,train_loss,label='train_loss') plt.plot(ranges,val_loss,label='val_loss') plt.title('Loss') plt.xlabel('Epochs') plt.ylabel('Loss')plt.show()現在這個效果比優化之前的好多了。
3.7 模型預測
# 模型預測 test_img=tf.keras.preprocessing.image.load_img('sunfloor.jpg',target_size=(HEIGHT,WIDTH)) test_img這里我們自己在網上下載一張向日葵的圖片進行預測
test_img=tf.keras.preprocessing.image.img_to_array(test_img) # 類型變換 test_img.shape將數據擴充一維,因為第一個維度是batchsize
test_img=tf.expand_dims(test_img,0) #擴充一維 test_img.shape預測:
preds=model_2.predict(test_img) #預測 preds.shape得分:
preds #得分得分轉換成概率:
scores=tf.nn.softmax(preds[0])# 得分轉換成概率 scores print('模型預測可能性最大的類別是:{},概率值為:{}'.format(class_names[np.argmax(scores)],np.max(scores)))這里最后一個全連接層可以直接加上個softmax激活函數,這樣預測后就不用再轉化了。
總結
- 上一篇: linux下php反编译apk,php反
- 下一篇: 信息抽取--关键词提取