深度学习模型保存_Web服务部署深度学习模型
本文的目的是介紹如何使用Web服務快速部署深度學習模型,雖然TF有TFserving可以進行模型部署,但是對于Pytorch無能為力(如果要使用的話需要把torch模型進行轉換,有些麻煩);因此,本文在這里介紹一種使用Web服務部署深度學習的方法(簡單有效,不喜勿噴)。
本文以簡單的新聞分類模型來舉例,模型:BERT;數據來源:清華新聞語料(地址:
THUCTC: 一個高效的中文文本分類工具),清華新聞語料共有14個類別,分別是體育,娛樂,家居,彩票,房產,教育,時尚,時政,星座,游戲,社會,科技,股票和財經。為了快速訓練模型,本人在每個類別中分別隨機挑選1000個作為訓練集,200個作為驗證集。數據預處理、模型訓練和pb模型保存代碼見:新聞分類模型訓練github地址。(非重點,不過多介紹了,github上有詳細的使用說明,有問題可留言。)
為了使web服務部署變得簡潔,因此本人構造一個方法類,方便加載pb模型,對傳入文本進行數據預處理以及進行模型預測。
模型初始化代碼如下:
import bert_tokenization import tensorflow as tf from tensorflow.python.platform import gfile import numpy as np import osclass ClassificationModel(object):def __init__(self):self.tokenizer = Noneself.sess = Noneself.is_train = Noneself.input_ids = Noneself.input_mask = Noneself.segment_ids = Noneself.predictions = Noneself.max_seq_length = Noneself.label_dict = ['體育', '娛樂', '家居', '彩票', '房產', '教育', '時尚', '時政', '星座', '游戲', '社會', '科技', '股票', '財經']其中,tokenizer 為分詞器;sess為TF的session模塊;is_train、input_ids、input_mask和segment_ids分別是pb模型的輸入;predictions為pb模型的輸出;max_seq_length為模型的最大輸入長度;label_dict為新聞分類標簽。
加載pb模型代碼如下:
def load_model(self, gpu_id, vocab_file, gpu_memory_fraction, model_path, max_seq_length):os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'os.environ['CUDA_VISIBLE_DEVICES'] = gpu_idself.tokenizer = bert_tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=True)gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_memory_fraction)sess_config = tf.ConfigProto(gpu_options=gpu_options)self.sess = tf.Session(config=sess_config)with gfile.FastGFile(model_path, "rb") as f:graph_def = tf.GraphDef()graph_def.ParseFromString(f.read())self.sess.graph.as_default()tf.import_graph_def(graph_def, name="")self.sess.run(tf.global_variables_initializer())self.is_train = self.sess.graph.get_tensor_by_name("input/is_train:0")self.input_ids = self.sess.graph.get_tensor_by_name("input/input_ids:0")self.input_mask = self.sess.graph.get_tensor_by_name("input/input_mask:0")self.segment_ids = self.sess.graph.get_tensor_by_name("input/segment_ids:0")self.predictions = self.sess.graph.get_tensor_by_name("output_layer/predictions:0")self.max_seq_length = max_seq_length其中,gpu_id為使用GPU的序號;vocab_file為BERT模型所使用的字典路徑;gpu_memory_fraction為使用GPU時所占用的比例;model_path為pb模型的路徑;max_seq_length為BERT模型的最大長度。
將傳入文本轉化成模型所需格式代碼如下:
def convert_fearture(self, text):max_seq_length = self.max_seq_lengthmax_length_context = max_seq_length - 2content_token = self.tokenizer.tokenize(text)if len(content_token) > max_length_context:content_token = content_token[:max_length_context]tokens = []segment_ids = []tokens.append("[CLS]")segment_ids.append(0)for token in content_token:tokens.append(token)segment_ids.append(0)tokens.append("[SEP]")segment_ids.append(0)input_ids = self.tokenizer.convert_tokens_to_ids(tokens)input_mask = [1] * len(input_ids)while len(input_ids) < max_seq_length:input_ids.append(0)input_mask.append(0)segment_ids.append(0)assert len(input_ids) == max_seq_lengthassert len(input_mask) == max_seq_lengthassert len(segment_ids) == max_seq_lengthinput_ids = np.array(input_ids)input_mask = np.array(input_mask)segment_ids = np.array(segment_ids)return input_ids, input_mask, segment_ids預測代碼如下:
def predict(self, text):input_ids_temp, input_mask_temp, segment_ids_temp = self.convert_fearture(text)feed = {self.is_train: False,self.input_ids: input_ids_temp.reshape(1, self.max_seq_length),self.input_mask: input_mask_temp.reshape(1, self.max_seq_length),self.segment_ids: segment_ids_temp.reshape(1, self.max_seq_length)}[label] = self.sess.run([self.predictions], feed)label_name = self.label_dict[label[0]]return label[0], label_name其中,輸入是一個新聞文本,輸出為類別序號以及對應的標簽名稱。詳細完整代碼見github:
ClassificationModel.py文件。
(劃重點)上面介紹的都是如何方便簡潔地加載模型,下面開始使用web服務掛起模型。通俗地講,其實本人就是通過flask框架,搭建了一個web服務,來獲取外部的輸入;并且使用掛載的模型進行預測;最后將預測結果通過web服務傳出。
from gevent import monkey monkey.patch_all() from flask import Flask, request from gevent import wsgi import json from ClassificationModel import ClassificationModeldef start_sever(http_id, port, gpu_id, vocab_file, gpu_memory_fraction, model_path, max_seq_length):model = ClassificationModel()model.load_model(gpu_id, vocab_file, gpu_memory_fraction, model_path, max_seq_length)print("load model ending!")app = Flask(__name__)@app.route('/')def index():return "This is News Classification Model Server"@app.route('/news-classification', methods=['Get', 'POST'])def response_request():if request.method == 'POST':text = request.form.get('text')else:text = request.args.get('text')label, label_name = model.predict(text)d = {"label": str(label), "label_name": label_name}print(d)return json.dumps(d, ensure_ascii=False)server = wsgi.WSGIServer((str(http_id), port), app)server.serve_forever()其中,http_id為web服務的地址;port為端口號;gpu_id、vocab_file、gpu_memory_fraction、model_path和max_seq_length為上面介紹的加載模型所需要的參數,詳細見上文。
index函數用于檢驗web服務是否暢通。如圖1所示。
圖1response_request函數為響應函數。定義了兩種請求數據的方式,get和post。當使用get方法獲取web輸入時,獲取命令為request.args.get('text');當使用post方法獲取web輸入時,獲取命令為request.form.get('text')。
當web服務起起來之后,就可以調用啦!!!
瀏覽器調用如圖2所示。
圖2Code調用如下:
import requestsdef http_test(text):url = 'http://127.0.0.1:5555/news-classification'raw_data = {'text': text}res = requests.post(url, raw_data)result = res.json()return resultif __name__ == "__main__":text = "姚明在NBA打球,很強。"result = http_test(text)print(result["label_name"])以上就是通過web服務部署深度學習模型的全部內容,喜歡的同學還請多多點贊~~~~~
推薦幾篇本人之前寫的一些文章:
劉聰NLP:短文本相似度算法研究
劉聰NLP:閱讀筆記:開放域檢索問答(ORQA)
劉聰NLP:論文閱讀筆記:文本蘊含之BiMPM
喜歡的同學,可以關注一下專欄,關注一下作者,還請多多點贊~~~~~~
總結
以上是生活随笔為你收集整理的深度学习模型保存_Web服务部署深度学习模型的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: python读取路径中字符串_pytho
- 下一篇: connection refused_E