取input 输入_tensorRT动态输入(python)
關于tensorRT動態輸入的例子大多數都是c++版本的,python版本的較少,這里簡單總結下python處理tensorRT動態輸入時,遇到的一些問題及解決方案。
這里的動態輸入是指batch,width,height等不固定大小的輸入。對于目標檢測等問題,如果將圖像壓縮到指定長度和寬度,一般會損失一些性能,一般情況下是壓縮較小的邊到一定數值(例如608,416等),另一條邊按照比例進行壓縮,這就導致了輸入可能是變長的情況。而早期的tensorrt是不支持變長輸入,這一點比較令人費解,因為tensorflow,pytorch等框架是支持變長輸入的,為什么tensorrt支持變長輸入要這么麻煩呢?我想到的一個原因是tensorrt內部做了一些優化策略,而這些優化策略對于變長問題是較難處理的。
一、由onnx模型構建trt引擎
tensorrt6 以后的版本是支持動態輸入的,需要給每個動態輸入綁定一個profile,用于指定最大值,最小值和常規值,如果超出這個范圍會報異常。
profile = builder.create_optimization_profile() profile.set_shape(network.get_input(0).name, (1,3, 32, 32), (1,3, 608, 608), (1,3, 2050, 2050)) config.add_optimization_profile(profile)另外建立engine時是通過config設置參數的,
config.max_workspace_size=1<<30 #1GB而不是通過builder,這一點很重要,如果出現顯存溢出的問題需要重新設置
config.max_workspace_size構建引擎
engine = builder.build_engine(network,config)二、推理
1、申請空間
因為是動態輸入,所以每次申請的空間大小不一樣,為了不用每次推理時都要重新申請空間,可以申請一次所需的最大空間,后面取數據的時候對齊就可以了。
inputs = [] outputs = [] bindings = [] stream = cuda.Stream() tmp=[1,32,16,8] print('engine.get_binding_format_desc',engine.get_binding_format_desc(0)) for count,binding in enumerate(engine):print('binding:',binding)size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size*(int)(h_/tmp[count])*(int)(w_/tmp[count]) dtype = trt.nptype(engine.get_binding_dtype(binding)) # Allocate host and device buffershost_mem = cuda.pagelocked_empty(size, dtype)device_mem = cuda.mem_alloc(host_mem.nbytes)# Append the device buffer to device bindings.bindings.append(int(device_mem))# Append to the appropriate list.if engine.binding_is_input(binding):inputs.append(HostDeviceMem(host_mem, device_mem))else:outputs.append(HostDeviceMem(host_mem, device_mem))2、推理
注意推理拿到的數據長度是前面申請空間中數據,輸入確定了,輸出長度也就確定了,取到合適的長度即可。
trt_outputs = common.do_inference_v3(self.context,bindings=self.bindings,inputs=self.inputs,outputs=self.outputs,stream=self.stream,h_=h_,w_=w_)output_shapes = [(1, int(h_/32), int(w_/32), 21),(1, int(h_/16), int(w_/16), 21),(1, int(h_/8), int(w_/8), 21)]print(output_shapes) trt_outputs = [output[:shape[1]*shape[2]*shape[3]].reshape(shape) for output, shape in zip(trt_outputs, output_shapes)]完整用例請參考
https://github.com/zhaogangthu/keras-yolo3-ocr-tensorrt/tree/master/tensorRT_yolo3?github.com總結
以上是生活随笔為你收集整理的取input 输入_tensorRT动态输入(python)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: git 拉新项目_git上拉取项目
- 下一篇: python 怎么爬桌软件数据_如何利用