tensoflow_yolov3 计算平均识别个数(平均识别数)
生活随笔
收集整理的這篇文章主要介紹了
tensoflow_yolov3 计算平均识别个数(平均识别数)
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
# -*- coding: utf-8 -*-
"""
@File : 20200221_Target_Recognition_光照度對模型識別率影響(計算平均識別個數).py
@Time : 2020/2/21 11:07
@Author : Dontla
@Email : sxana@qq.com
@Software: PyCharm
"""import tracebackimport cv2
import numpy as np
import tensorflow as tf
import core.utils as utils
from core.config import cfg
from core.yolov3 import YOLOV3
import pyrealsense2 as rs
import time
import sysclass YoloTest(object):def __init__(self):# D·C 191111:__C.TEST.INPUT_SIZE = 544self.input_size = cfg.TEST.INPUT_SIZEself.anchor_per_scale = cfg.YOLO.ANCHOR_PER_SCALE# Dontla 191106注釋:初始化class.names文件的字典信息屬性self.classes = utils.read_class_names(cfg.YOLO.CLASSES)# D·C 191115:類數量屬性self.num_classes = len(self.classes)self.anchors = np.array(utils.get_anchors(cfg.YOLO.ANCHORS))# D·C 191111:__C.TEST.SCORE_THRESHOLD = 0.3self.score_threshold = cfg.TEST.SCORE_THRESHOLD# D·C 191120:__C.TEST.IOU_THRESHOLD = 0.45self.iou_threshold = cfg.TEST.IOU_THRESHOLDself.moving_ave_decay = cfg.YOLO.MOVING_AVE_DECAY# D·C 191120:__C.TEST.ANNOT_PATH = "./data/dataset/Dontla/20191023_Artificial_Flower/test.txt"self.annotation_path = cfg.TEST.ANNOT_PATH# D·C 191120:__C.TEST.WEIGHT_FILE = "./checkpoint/f_g_c_weights_files/yolov3_test_loss=15.8845.ckpt-47"self.weight_file = cfg.TEST.WEIGHT_FILE# D·C 191115:可寫標記(bool類型值)self.write_image = cfg.TEST.WRITE_IMAGE# D·C 191115:__C.TEST.WRITE_IMAGE_PATH = "./data/detection/"(識別圖片畫框并標注文本后寫入的圖片路徑)self.write_image_path = cfg.TEST.WRITE_IMAGE_PATH# D·C 191116:TEST.SHOW_LABEL設置為Trueself.show_label = cfg.TEST.SHOW_LABEL# D·C 191120:創建命名空間“input”with tf.name_scope('input'):# D·C 191120:建立變量(創建占位符開辟內存空間)self.input_data = tf.placeholder(dtype=tf.float32, name='input_data')self.trainable = tf.placeholder(dtype=tf.bool, name='trainable')model = YOLOV3(self.input_data, self.trainable)self.pred_sbbox, self.pred_mbbox, self.pred_lbbox = model.pred_sbbox, model.pred_mbbox, model.pred_lbbox# D·C 191120:創建命名空間“指數滑動平均”with tf.name_scope('ema'):ema_obj = tf.train.ExponentialMovingAverage(self.moving_ave_decay)# D·C 191120:在允許軟設備放置的會話中啟動圖形并記錄放置決策。(不懂啥意思。。。)allow_soft_placement=True表示允許tf自動選擇可用的GPU和CPUself.sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))# D·C 191120:variables_to_restore()用于加載模型計算滑動平均值時將影子變量直接映射到變量本身self.saver = tf.train.Saver(ema_obj.variables_to_restore())# D·C 191120:用于下次訓練時恢復模型self.saver.restore(self.sess, self.weight_file)# 攝像頭序列號# self.cam_serials = ['838212073161']# self.cam_serials = ['827312070790']self.cam_serials = ['836612072369']self.cam_num = len(self.cam_serials)def predict(self, image):# D·C 191107:復制一份圖片的鏡像,避免對圖片直接操作改變圖片的內在屬性org_image = np.copy(image)# D·C 191107:獲取圖片尺寸org_h, org_w, _ = org_image.shape# D·C 191108:該函數將源圖結合input_size,將其轉換成預投喂的方形圖像(作者默認544×544,中間為縮小尺寸的源圖,上下空區域為灰圖):image_data = utils.image_preprocess(image, [self.input_size, self.input_size])# D·C 191108:打印維度看看:# print(image_data.shape)# (544, 544, 3)# D·C 191108:創建新軸,不懂要創建新軸干嘛?image_data = image_data[np.newaxis, ...]# D·C 191108:打印維度看看:# print(image_data.shape)# (1, 544, 544, 3)# D·C 191110:三個box可能存放了預測框圖(可能是N多的框,有用的沒用的重疊的都在里面)的信息(但是打印出來的值完全看不懂啊喂?)pred_sbbox, pred_mbbox, pred_lbbox = self.sess.run([self.pred_sbbox, self.pred_mbbox, self.pred_lbbox],feed_dict={self.input_data: image_data,self.trainable: False})# D·C 191110:打印三個box的類型、形狀和值看看:# print(type(pred_sbbox))# print(type(pred_mbbox))# print(type(pred_lbbox))# 都是<class 'numpy.ndarray'># print(pred_sbbox.shape)# print(pred_mbbox.shape)# print(pred_lbbox.shape)# (1, 68, 68, 3, 6)# (1, 34, 34, 3, 6)# (1, 17, 17, 3, 6)# print(pred_sbbox)# print(pred_mbbox)# print(pred_lbbox)# D·C 191110:(-1,6)表示不知道有多少行,反正你給我整成6列,然后concatenate又把它們仨給疊起來,最終得到無數個6列數組(后面self.num_classes)個數存放的貌似是這個框屬于類的概率)pred_bbox = np.concatenate([np.reshape(pred_sbbox, (-1, 5 + self.num_classes)),np.reshape(pred_mbbox, (-1, 5 + self.num_classes)),np.reshape(pred_lbbox, (-1, 5 + self.num_classes))], axis=0)# D·C 191111:打印pred_bbox和它的維度看看:# print(pred_bbox)# print(pred_bbox.shape)# (18207, 6)# D·C 191111:猜測是第一道過濾,過濾掉score_threshold以下的圖片,過濾完之后少了好多:# D·C 191115:bboxes維度為[n,6],前四列是坐標,第五列是得分,第六列是對應類下標bboxes = utils.postprocess_boxes(pred_bbox, (org_h, org_w), self.input_size, self.score_threshold)# D·C 191111:猜測是第二道過濾,過濾掉iou_threshold以下的圖片:bboxes = utils.nms(bboxes, self.iou_threshold)return bboxesdef cam_conti_veri(self, cam_num, ctx):"""攝像頭連續驗證、連續驗證機制"""# D·C 1911202:創建最大驗證次數max_veri_times;創建連續穩定值continuous_stable_value,用于判斷設備重置后是否處于穩定狀態max_veri_times = 100continuous_stable_value = 5print('\n', end='')print('開始連續驗證,連續驗證穩定值:{},最大驗證次數:{}:'.format(continuous_stable_value, max_veri_times))continuous_value = 0veri_times = 0while True:devices = ctx.query_devices()connected_cam_num = len(devices)print('攝像頭個數:{}'.format(connected_cam_num))if connected_cam_num == cam_num:continuous_value += 1if continuous_value == continuous_stable_value:breakelse:continuous_value = 0veri_times += 1if veri_times == max_veri_times:print("檢測超時,請檢查攝像頭連接!")sys.exit()def cam_hardware_reset(self, ctx, cam_serials):"""循環reset攝像頭"""# hardware_reset()后是不是應該延遲一段時間?不延遲就會報錯print('\n', end='')print('開始初始化攝像頭:')for dev in ctx.query_devices():# 先將設備的序列號放進一個變量里,免得在下面for循環里訪問設備的信息過多(雖然不知道它會不會每次都重新訪問)dev_serial = dev.get_info(rs.camera_info.serial_number)# 匹配序列號,重置我們需重置的特定攝像頭(注意兩個for循環順序,哪個在外哪個在內很重要,不然會導致剛重置的攝像頭又被訪問導致報錯)for serial in cam_serials:if serial == dev_serial:dev.hardware_reset()# 像下面這條語句居然不會報錯,不是剛剛才重置了dev嗎?莫非區別在于沒有通過for循環ctx.query_devices()去訪問?# 是不是剛重置后可以通過ctx.query_devices()去查看有這個設備,但是卻沒有存儲設備地址?如果是這樣,# 也就能夠解釋為啥能夠通過len(ctx.query_devices())函數獲取設備數量,但訪問序列號等信息就會報錯的原因了print('攝像頭{}初始化成功'.format(dev.get_info(rs.camera_info.serial_number)))# 如果只有一個攝像頭,要讓它睡夠5秒(避免出錯,保險起見)time.sleep(5 / len(cam_serials))def get_cam_serials(self):passdef calculate_detection_num(self, calcu_list, detect_num):"""計算一段次數內平均識別個數"""# 將列表calcu_list作為隊列,右為頭,左為尾,頭為先進的幀,尾為后進的幀# 定義需做平均的隊列幀數量frame_num = 50# 判斷傳進來的隊列大小,如果小于frame_num就把元素添加到左邊,如果大于或等于50,就把右邊超過50的咔掉,并拋出最右邊那個,將元素加到最左邊if len(calcu_list) < frame_num:calcu_list.insert(0, detect_num)else:calcu_list = calcu_list[:frame_num]calcu_list.pop()calcu_list.insert(0, detect_num)# if len(calcu_list) > frame_num:# calcu_list = calcu_list[:frame_num]# elif len(calcu_list)==frame_num:# 求列表均值average_num = np.mean(calcu_list)return calcu_list, average_numdef dontla_evaluate_detect(self):# 攝像頭個數(在這里設置所需使用攝像頭的總個數)ctx = rs.context()# 連續驗證機制self.cam_conti_veri(self.cam_num, ctx)# 循環reset攝像頭self.cam_hardware_reset(ctx, self.cam_serials)# 連續驗證機制self.cam_conti_veri(self.cam_num, ctx)# 打印攝像頭序列號和接口號并創建需要顯示在窗口上的備注信息字符串列表(窗口名)print('\n', end='')cam_id = 0serial_list = []for i in ctx.query_devices():cam_id += 1serial_list.append('camera{}; serials number {}; usb port {}'.format(cam_id, i.get_info(rs.camera_info.serial_number),i.get_info(rs.camera_info.usb_type_descriptor)))print('serial number {}:{};usb port:{}'.format(cam_id, i.get_info(rs.camera_info.serial_number),i.get_info(rs.camera_info.usb_type_descriptor)))# print(serial_list)# 配置各個攝像頭的基本對象for i in range(self.cam_num):# D·C 191203:括號里是否有必要加ctx,加了沒加好像沒多大區別,但不加它又會提示黃色locals()['pipeline' + str(i)] = rs.pipeline(ctx)locals()['config' + str(i)] = rs.config()# Dontla 20200221 存疑,為何不以前面指定的攝像頭序列號啟動,而要重新獲取序列號?locals()['serial' + str(i)] = ctx.devices[i].get_info(rs.camera_info.serial_number)locals()['config' + str(i)].enable_device(locals()['serial' + str(i)])locals()['config' + str(i)].enable_stream(rs.stream.depth, 640, 360, rs.format.z16, 30)locals()['config' + str(i)].enable_stream(rs.stream.color, 640, 360, rs.format.bgr8, 30)locals()['pipeline' + str(i)].start(locals()['config' + str(i)])# 創建對齊對象(深度對齊顏色)locals()['align' + str(i)] = rs.align(rs.stream.color)# 運行流并進行識別print('\n', end='')print('開始識別:')try:# 設置break標志,方便按下按鈕跳出循環退出窗口break2 = False# 初始化計數列表calcu_list = []while True:for i in range(self.cam_num):locals()['frames' + str(i)] = locals()['pipeline' + str(i)].wait_for_frames()# 獲取對齊幀集locals()['aligned_frames' + str(i)] = locals()['align' + str(i)].process(locals()['frames' + str(i)])# 獲取對齊后的深度幀和彩色幀locals()['aligned_depth_frame' + str(i)] = locals()['aligned_frames' + str(i)].get_depth_frame()locals()['color_frame' + str(i)] = locals()['aligned_frames' + str(i)].get_color_frame()if not locals()['aligned_depth_frame' + str(i)] or not locals()['color_frame' + str(i)]:continue# 獲取顏色幀內參locals()['color_profile' + str(i)] = locals()['color_frame' + str(i)].get_profile()locals()['cvsprofile' + str(i)] = rs.video_stream_profile(locals()['color_profile' + str(i)])locals()['color_intrin' + str(i)] = locals()['cvsprofile' + str(i)].get_intrinsics()locals()['color_intrin_part' + str(i)] = [locals()['color_intrin' + str(i)].ppx,locals()['color_intrin' + str(i)].ppy,locals()['color_intrin' + str(i)].fx,locals()['color_intrin' + str(i)].fy]locals()['color_image' + str(i)] = np.asanyarray(locals()['color_frame' + str(i)].get_data())locals()['bboxes_pr' + str(i)] = self.predict(locals()['color_image' + str(i)])# Dontla 20200221 打印識別個數# print(np.array(locals()['bboxes_pr' + str(i)]).shape)detect_num = len(locals()['bboxes_pr' + str(i)])# print('識別個數:{}'.format(detect_num))# Dontla 20200221 計算平均識別個數(這里只針對一個攝像頭情況,多個攝像頭到時再重構)calcu_list, mean_detect_num = self.calculate_detection_num(calcu_list, detect_num)# Dontla 20200221 打印平均識別個數print(calcu_list)print('平均識別個數:{}'.format(mean_detect_num))locals()['image' + str(i)] = utils.draw_bbox(locals()['color_image' + str(i)],locals()['bboxes_pr' + str(i)],locals()['aligned_depth_frame' + str(i)],locals()['color_intrin_part' + str(i)],show_label=self.show_label)# D·C 191202:本想創建固定比例的大小可調的窗口,發現無法使用,opencv bug?# cv2.namedWindow('{}'.format(serial_list[i]),# flags=cv2.WINDOW_NORMAL | cv2.WINDOW_FREERATIO | cv2.WINDOW_GUI_EXPANDED)cv2.imshow('{}'.format(serial_list[i]), locals()['image' + str(i)])key = cv2.waitKey(1)# 如果按下ESC,則跳出循環if key == 27:# 貌似直接用return也行# returnbreak2 = Truebreakif break2:breakexcept Exception as e:print("掉幀了!")traceback.print_exc()traceback.print_exc(file=open('traceback.txt', 'w+'))finally:# 大概覺得先關閉窗口再停止流比較靠譜# 銷毀所有窗口cv2.destroyAllWindows()print('\n', end='')print('已關閉所有窗口!')# 停止所有流for i in range(self.cam_num):locals()['pipeline' + str(i)].stop()print('正在停止所有流,請等待數秒至程序穩定結束!')if __name__ == '__main__':YoloTest().dontla_evaluate_detect()print('程序已結束!')
總結
以上是生活随笔為你收集整理的tensoflow_yolov3 计算平均识别个数(平均识别数)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: numpy报错:ModuleNotFou
- 下一篇: python wheel库(安装包查找)