OpenCV3.3中逻辑回归(Logistic Regression)使用举例
OpenCV3.3中給出了邏輯回歸(logistic regression)的實現,即cv::ml::LogisticRegression類,類的聲明在include/opencv2/ml.hpp文件中,實現在modules/ml/src/lr.cpp文件中,它既支持兩分類,也支持多分類,其中:
(1)、cv::ml::LogisticRegression類繼承自cv::ml::StateModel,而cv::ml::StateModel又繼承自cv::Algorithm;
(2)、setLearningRate函數用來設置學習率,getLearningRate函數用來獲取學習率值;
(3)、setIterations函數用來設置迭代次數,getIterations函數用來獲取迭代次數值;
(4)、setRegularization函數用來設置采用哪種正則化方法,目前支持兩種L1 norm和L2 norm,正則化方法主要用來防止過擬合,getRegularization函數用來獲取采用哪種正則化方法;
(5)、setTrainMethod函數用來設置采用哪種訓練方法,目前支持兩種Batch和Mini-Batch, getTrainMethod函數用來獲取采用哪種訓練方法;
(6)、setMiniBatchSize函數用來設置在Mini-Batch梯度下降訓練方法中每一個step采集的訓練樣本數,getMiniBatchSize函數用來獲取每一個step采集的訓練樣本數;
(7)、setTermCriteria函數用來設置終止訓練的條件,包括迭代次數和期望的精度,getTermCriteria用來獲取終止訓練的條件;
(8)、get_learnt_thetas函數用來獲取訓練參數;
(9)、create函數為static, new一個LogisticRegressionImpl用來創建一個LogisticRegression對象;
(10)、train函數(使用基類StatModel中的)進行訓練;
(11)、predict函數用于預測;
(12)、save函數(使用基類Algorithm中的)保存已訓練好的model,支持xml,yaml,json格式;
(13)、load函數用來load已訓練好的model;
? 以下為兩分類測試代碼:訓練數據集為從MNIST中train中隨機選取的0、1各10個圖像;測試數據集為從MNIST中test中隨機選取的0、1各10個圖像,如下圖,其中第一排前10個0用于訓練,后10個0用于測試;第二排前10個1用于訓練,后10個1用于測試:
#include "opencv.hpp"
#include <string>
#include <vector>
#include <memory>
#include <opencv2/opencv.hpp>
#include <opencv2/ml.hpp>
#include "common.hpp"// Logistic Regression ///
static void show_image(const cv::Mat& data, int columns, const std::string& name)
{cv::Mat big_image;for (int i = 0; i < data.rows; ++i) {big_image.push_back(data.row(i).reshape(0, columns));}cv::imshow(name, big_image);cv::waitKey(0);
}static float calculate_accuracy_percent(const cv::Mat& original, const cv::Mat& predicted)
{return 100 * (float)cv::countNonZero(original == predicted) / predicted.rows;
}int test_opencv_logistic_regression_train()
{const std::string image_path{ "E:/GitCode/NN_Test/data/images/digit/handwriting_0_and_1/" };cv::Mat data, labels, result;for (int i = 1; i < 11; ++i) {const std::vector<std::string> label{ "0_", "1_" };for (const auto& value : label) {std::string name = std::to_string(i);name = image_path + value + name + ".jpg";cv::Mat image = cv::imread(name, 0);if (image.empty()) {fprintf(stderr, "read image fail: %s\n", name.c_str());return -1;}data.push_back(image.reshape(0, 1));}}data.convertTo(data, CV_32F);//show_image(data, 28, "train data");std::unique_ptr<float[]> tmp(new float[20]);for (int i = 0; i < 20; ++i) {if (i % 2 == 0) tmp[i] = 0.f;else tmp[i] = 1.f;}labels = cv::Mat(20, 1, CV_32FC1, tmp.get());cv::Ptr<cv::ml::LogisticRegression> lr = cv::ml::LogisticRegression::create();lr->setLearningRate(0.00001);lr->setIterations(100);lr->setRegularization(cv::ml::LogisticRegression::REG_DISABLE);lr->setTrainMethod(cv::ml::LogisticRegression::MINI_BATCH);lr->setMiniBatchSize(1);CHECK(lr->train(data, cv::ml::ROW_SAMPLE, labels));const std::string save_file{ "E:/GitCode/NN_Test/data/logistic_regression_model.xml" }; // .xml, .yaml, .jsonslr->save(save_file);return 0;
}int test_opencv_logistic_regression_predict()
{const std::string image_path{ "E:/GitCode/NN_Test/data/images/digit/handwriting_0_and_1/" };cv::Mat data, labels, result;for (int i = 11; i < 21; ++i) {const std::vector<std::string> label{ "0_", "1_" };for (const auto& value : label) {std::string name = std::to_string(i);name = image_path + value + name + ".jpg";cv::Mat image = cv::imread(name, 0);if (image.empty()) {fprintf(stderr, "read image fail: %s\n", name.c_str());return -1;}data.push_back(image.reshape(0, 1));}}data.convertTo(data, CV_32F);//show_image(data, 28, "test data");std::unique_ptr<int[]> tmp(new int[20]);for (int i = 0; i < 20; ++i) {if (i % 2 == 0) tmp[i] = 0;else tmp[i] = 1;}labels = cv::Mat(20, 1, CV_32SC1, tmp.get());const std::string model_file{ "E:/GitCode/NN_Test/data/logistic_regression_model.xml" };cv::Ptr<cv::ml::LogisticRegression> lr = cv::ml::LogisticRegression::load(model_file);lr->predict(data, result);fprintf(stdout, "predict result: \n");std::cout << "actual: " << labels.t() << std::endl;std::cout << "target: " << result.t() << std::endl;fprintf(stdout, "accuracy: %.2f%%\n", calculate_accuracy_percent(labels, result));return 0;
}
測試代碼中,test_opencv_logistic_regression_train函數用于訓練,訓練結果會產生一個叫logistic_regression_model.xml的model文件;test_opencv_logistic_regression_predict函數用于預測,預測結果如下,由結果可知,預測全部正確:
GitHub:?https://github.com/fengbingchun/NN_Test
總結
以上是生活随笔為你收集整理的OpenCV3.3中逻辑回归(Logistic Regression)使用举例的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Ubuntu 14.04 64位上安装
- 下一篇: 逻辑回归(Logistic Regres