生活随笔
收集整理的這篇文章主要介紹了
caffe源码c++学习笔记
小編覺(jué)得挺不錯(cuò)的,現(xiàn)在分享給大家,幫大家做個(gè)參考.
轉(zhuǎn)載自:深度學(xué)習(xí)(七)caffe源碼c++學(xué)習(xí)筆記 - hjimce的專欄 - 博客頻道 - CSDN.NET
http://blog.csdn.net/hjimce/article/details/48933845
一、預(yù)測(cè)分類
最近幾天為了希望深入理解caffe,于是便開(kāi)始學(xué)起了caffe函數(shù)的c++調(diào)用,caffe的函數(shù)調(diào)用例子網(wǎng)上很少,需要自己慢慢的摸索,即便是找到了例子,有的時(shí)候caffe版本不一樣,也會(huì)出現(xiàn)錯(cuò)誤。對(duì)于預(yù)測(cè)分類的函數(shù)調(diào)用,caffe為我們提供了一個(gè)例子,一開(kāi)始我懶得解讀這個(gè)例子,網(wǎng)上找了一些分類預(yù)測(cè)的例子,總是會(huì)出現(xiàn)各種各樣的錯(cuò)誤,于是沒(méi)辦法最后只能老老實(shí)實(shí)的學(xué)官方給的例子比較實(shí)在,因此最后自己把代碼解讀了一下,然后自己整理成自己的類,這個(gè)類主要用于訓(xùn)練好模型后,我們要進(jìn)行調(diào)用預(yù)測(cè)一張新輸入圖片的類別。
頭文件:
[cpp]?view plaincopy
? ? ? ? ? ?? ?? #ifndef?CLASSIFIER_H_?? #define?CLASSIFIER_H_?? ?? ?? #include?<caffe/caffe.hpp>?? ?? #include?<opencv2/core/core.hpp>?? #include?<opencv2/highgui/highgui.hpp>?? #include?<opencv2/imgproc/imgproc.hpp>?? ?? #include?<algorithm>?? #include?<iosfwd>?? #include?<memory>?? #include?<string>?? #include?<utility>?? #include?<vector>?? ?? ?? using?namespace?caffe;?? using?std::string;?? ?? ?? typedef?std::pair<string,?float>?Prediction;?? ?? class?Classifier?? {?? ?public:?? ????Classifier(const?string&?model_file,?const?string&?trained_file,const?string&?mean_file);?? ????std::vector<Prediction>?Classify(const?cv::Mat&?img,?int?N?=?1);?? ????void?SetLabelString(std::vector<string>strlabel);?? private:?? ?? ????void?SetMean(const?string&?mean_file);?? ?? ???std::vector<float>?Predict(const?cv::Mat&?img);?? ?? ???void?WrapInputLayer(std::vector<cv::Mat>*?input_channels);?? ?? ???void?Preprocess(const?cv::Mat&?img,?? ??????????????????std::vector<cv::Mat>*?input_channels);?? ?? private:?? ???shared_ptr<Net<float>?>?net_;?? ???cv::Size?input_geometry_;?? ???int?num_channels_;?? ???cv::Mat?mean_;?? ???std::vector<string>?labels_;?? };?? ?? ?? #endif?/*?CLASSIFIER_H_?*/??
源文件:
[cpp]?view plaincopy
? ? ? ? ? ?? ?? #include?"Classifier.h"?? using?namespace?caffe;?? Classifier::Classifier(const?string&?model_file,const?string&?trained_file,const?string&?mean_file)?? {?? ?????? ??Caffe::set_mode(Caffe::CPU);?? ?? ?? ??? ??net_.reset(new?Net<float>(model_file,?TEST));?? ???? ??net_->CopyTrainedLayersFrom(trained_file);?? ?? ??CHECK_EQ(net_->num_inputs(),?1)?<<?"Network?should?have?exactly?one?input.";?? ??CHECK_EQ(net_->num_outputs(),?1)?<<?"Network?should?have?exactly?one?output.";?? ??? ??Blob<float>*?input_layer?=?net_->input_blobs()[0];?? ??num_channels_?=?input_layer->channels();?? ???? ??CHECK(num_channels_?==?3?||?num_channels_?==?1)<<?"Input?layer?should?have?1?or?3?channels.";?? ???? ??input_geometry_?=?cv::Size(input_layer->width(),?input_layer->height());?? ?? ??? ??SetMean(mean_file);?? ?? }?? ?? static?bool?PairCompare(const?std::pair<float,?int>&?lhs,?? ????????????????????????const?std::pair<float,?int>&?rhs)?{?? ??return?lhs.first?>?rhs.first;?? }?? ?? ?? ?? static?std::vector<int>?Argmax(const?std::vector<float>&?v,?int?N)?? {?? ?????? ??std::vector<std::pair<float,?int>?>?pairs;?? ??for?(size_t?i?=?0;?i?<?v.size();?++i)?? ????pairs.push_back(std::make_pair(v[i],?i));?? ??std::partial_sort(pairs.begin(),?pairs.begin()?+?N,?pairs.end(),?PairCompare);?? ?? ??std::vector<int>?result;?? ??for?(int?i?=?0;?i?<?N;?++i)?? ????result.push_back(pairs[i].second);?? ??return?result;?? }?? ?? ?? ?? std::vector<Prediction>?Classifier::Classify(const?cv::Mat&?img,?int?N)?{?? ??std::vector<float>?output?=?Predict(img);?? ?? ??N?=?std::min<int>(labels_.size(),?N);?? ??std::vector<int>?maxN?=?Argmax(output,?N);?? ??std::vector<Prediction>?predictions;?? ??for?(int?i?=?0;?i?<?N;?++i)?{?? ????int?idx?=?maxN[i];?? ????predictions.push_back(std::make_pair(labels_[idx],?output[idx]));?? ??}?? ?? ??return?predictions;?? }?? void?Classifier::SetLabelString(std::vector<string>strlabel)?? {?? ????labels_=strlabel;?? }?? ?? ?? ?? ?? ?? ?? ?? ?? ?? ?? void?Classifier::SetMean(const?string&?mean_file)?? {?? ??BlobProto?blob_proto;?? ??ReadProtoFromBinaryFileOrDie(mean_file.c_str(),?&blob_proto);?? ?? ???? ??Blob<float>?mean_blob;?? ??mean_blob.FromProto(blob_proto);?? ???? ??CHECK_EQ(mean_blob.channels(),?num_channels_)<<?"Number?of?channels?of?mean?file?doesn't?match?input?layer.";?? ?? ??? ??std::vector<cv::Mat>?channels;?? ??float*?data?=?mean_blob.mutable_cpu_data();?? ??for?(int?i?=?0;?i?<?num_channels_;?++i)?{?? ?? ????cv::Mat?channel(mean_blob.height(),?mean_blob.width(),?CV_32FC1,?data);?? ????channels.push_back(channel);?? ????data?+=?mean_blob.height()?*?mean_blob.width();?? ??}?? ?? ?? ??cv::Mat?mean;?? ??cv::merge(channels,?mean);?? ?? ?? ???? ??cv::Scalar?channel_mean?=?cv::mean(mean);?? ??mean_?=?cv::Mat(input_geometry_,?mean.type(),?channel_mean);?? }?? ?? std::vector<float>?Classifier::Predict(const?cv::Mat&?img)?? {?? ?????? ????Blob<float>*?input_layer?=?net_->input_blobs()[0];?? ????input_layer->Reshape(1,?num_channels_,?input_geometry_.height,?input_geometry_.width);?? ????net_->Reshape();?? ????? ????std::vector<cv::Mat>?input_channels;?? ????WrapInputLayer(&input_channels);?? ?? ????Preprocess(img,?&input_channels);?? ????? ????net_->ForwardPrefilled();?? ?? ???? ??Blob<float>*?output_layer?=?net_->output_blobs()[0];?? ??const?float*?begin?=?output_layer->cpu_data();?? ??const?float*?end?=?begin?+?output_layer->channels();?? ??return?std::vector<float>(begin,?end);?? }?? ?? ?? void?Classifier::WrapInputLayer(std::vector<cv::Mat>*?input_channels)?? {?? ??Blob<float>*?input_layer?=?net_->input_blobs()[0];?? ?? ??int?width?=?input_layer->width();?? ??int?height?=?input_layer->height();?? ??float*?input_data?=?input_layer->mutable_cpu_data();?? ??for?(int?i?=?0;?i?<?input_layer->channels();?++i)?{?? ????cv::Mat?channel(height,?width,?CV_32FC1,?input_data);?? ????input_channels->push_back(channel);?? ????input_data?+=?width?*?height;?? ??}?? }?? ?? ?? void?Classifier::Preprocess(const?cv::Mat&?img,std::vector<cv::Mat>*?input_channels)?? {?? ?? ??cv::Mat?sample;?? ???? ??if?(img.channels()?==?3?&&?num_channels_?==?1)?? ????cv::cvtColor(img,?sample,?CV_BGR2GRAY);?? ??else?if?(img.channels()?==?4?&&?num_channels_?==?1)?? ????cv::cvtColor(img,?sample,?CV_BGRA2GRAY);?? ???? ??else?if?(img.channels()?==?4?&&?num_channels_?==?3)?? ????cv::cvtColor(img,?sample,?CV_BGRA2BGR);?? ??else?if?(img.channels()?==?1?&&?num_channels_?==?3)?? ????cv::cvtColor(img,?sample,?CV_GRAY2BGR);?? ??else?? ????sample?=?img;?? ?? ??cv::Mat?sample_resized;?? ??if?(sample.size()?!=?input_geometry_)?? ????cv::resize(sample,?sample_resized,?input_geometry_);?? ??else?? ????sample_resized?=?sample;?? ?? ??cv::Mat?sample_float;?? ??if?(num_channels_?==?3)?? ????sample_resized.convertTo(sample_float,?CV_32FC3);?? ??else?? ????sample_resized.convertTo(sample_float,?CV_32FC1);?? ?? ??cv::Mat?sample_normalized;?? ??cv::subtract(sample_float,?mean_,?sample_normalized);?? ?? ???? ??cv::split(sample_normalized,?*input_channels);?? ?? ??CHECK(reinterpret_cast<float*>(input_channels->at(0).data)?==?net_->input_blobs()[0]->cpu_data())?<<?"Input?channels?are?not?wrapping?the?input?layer?of?the?network.";?? }??
調(diào)用實(shí)例,下面這個(gè)實(shí)例是要用于性別預(yù)測(cè)的例子:
[cpp]?view plaincopy
?? ?? ?? ?? ?? ?? ?? ?? #include?<string>?? #include?<vector>?? #include?<fstream>?? #include?"caffe/caffe.hpp"?? #include?<opencv2/opencv.hpp>?? #include"Classifier.h"?? ?? int?main()?? {?? ?????caffe::Caffe::set_mode(caffe::Caffe::CPU);?? ????cv::Mat?src1;?? ????src1?=?cv::imread("4.jpg");?? ????Classifier?cl("deploy.prototxt",?"gender_net.caffemodel","imagenet_mean.binaryproto");?? ????std::vector<string>label;?? ????label.push_back("male");?? ????label.push_back("female");?? ????cl.SetLabelString(label);?? ????std::vector<Prediction>pre=cl.Classify(src1);?? ????cv::imshow("1.jpg",src1);?? ?? ????std::cout?<<pre[0].first<<?std::endl;?? ????return?0;?? }??
二、文件數(shù)據(jù)
[cpp]?view plaincopy
/函數(shù)的作用是讀取一張圖片,并保存到到datum中?? ?? ?? ?? ?? ? ?? ?? bool?ReadImageToDatum(const?string&?filename,?const?int?label,?? ????const?int?height,?const?int?width,?const?bool?is_color,?? ????const?std::string?&?encoding,?Datum*?datum)?{?? ??cv::Mat?cv_img?=?ReadImageToCVMat(filename,?height,?width,?is_color);?? ??if?(cv_img.data)?{?? ????if?(encoding.size())?{?? ??????if?(?(cv_img.channels()?==?3)?==?is_color?&&?!height?&&?!width?&&?? ??????????matchExt(filename,?encoding)?)?? ????????return?ReadFileToDatum(filename,?label,?datum);?? ??????std::vector<uchar>?buf;?? ??????cv::imencode("."+encoding,?cv_img,?buf);?? ??????datum->set_data(std::string(reinterpret_cast<char*>(&buf[0]),?? ??????????????????????buf.size()));?? ??????datum->set_label(label);?? ??????datum->set_encoded(true);?? ??????return?true;?? ????}?? ????CVMatToDatum(cv_img,?datum);?? ????datum->set_label(label);?? ????return?true;?? ??}?else?{?? ????return?false;?? ??}?? }??
總結(jié)
以上是生活随笔為你收集整理的caffe源码c++学习笔记的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
如果覺(jué)得生活随笔網(wǎng)站內(nèi)容還不錯(cuò),歡迎將生活随笔推薦給好友。