OpenCV3.3中 K-最近邻法(KNN)接口简介及使用
OpenCV 3.3中給出了K-最近鄰(KNN)算法的實(shí)現(xiàn),即cv::ml::Knearest類(lèi),此類(lèi)的聲明在include/opecv2/ml.hpp文件中,實(shí)現(xiàn)在modules/ml/src/knearest.cpp文件中。其中:
(1)、cv::ml::Knearest類(lèi):繼承自cv::ml::StateModel,而cv::ml::StateModel又繼承自cv::Algorithm;
(2)、create函數(shù):為static,new一個(gè)KNearestImpl用來(lái)創(chuàng)建一個(gè)KNearest對(duì)象;
(3)、setDefaultK/getDefaultK函數(shù):在預(yù)測(cè)時(shí),設(shè)置/獲取的K值;
(4)、setIsClassifier/getIsClassifier函數(shù):設(shè)置/獲取應(yīng)用KNN是進(jìn)行分類(lèi)還是回歸;
(5)、setEmax/getEmax函數(shù):在使用KDTree算法時(shí),設(shè)置/獲取Emax參數(shù)值;
(6)、setAlgorithmType/getAlgorithmType函數(shù):設(shè)置/獲取KNN算法類(lèi)型,目前支持兩種:brute_force和KDTree;
(7)、findNearest函數(shù):根據(jù)輸入預(yù)測(cè)分類(lèi)/回歸結(jié)果。
關(guān)于KNN算法介紹可以參考:?http://blog.csdn.net/fengbingchun/article/details/78464169 ?
以下是從數(shù)據(jù)集MNIST中提取的40幅圖像,0,1,2,3四類(lèi)各20張,每類(lèi)的前10幅來(lái)自于訓(xùn)練樣本,用于訓(xùn)練,后10幅來(lái)自測(cè)試樣本,用于測(cè)試,如下圖:
關(guān)于MNIST的介紹可以參考:? http://blog.csdn.net/fengbingchun/article/details/49611549?
測(cè)試代碼如下:
#include "opencv.hpp"
#include <string>
#include <vector>
#include <memory>
#include <algorithm>
#include <opencv2/opencv.hpp>
#include <opencv2/ml.hpp>
#include "common.hpp"/// K-Nearest Neighbor(KNN) //
int test_opencv_knn_predict()
{const int K{ 3 };cv::Ptr<cv::ml::KNearest> knn = cv::ml::KNearest::create();knn->setDefaultK(K);knn->setIsClassifier(true);knn->setAlgorithmType(cv::ml::KNearest::BRUTE_FORCE);const std::string image_path{"E:/GitCode/NN_Test/data/images/digit/handwriting_0_and_1/"};cv::Mat tmp = cv::imread(image_path + "0_1.jpg", 0);const int train_samples_number{ 40 }, predict_samples_number{ 40 };const int every_class_number{ 10 };cv::Mat train_data(train_samples_number, tmp.rows * tmp.cols, CV_32FC1);cv::Mat train_labels(train_samples_number, 1, CV_32FC1);float* p = (float*)train_labels.data;for (int i = 0; i < 4; ++i) {std::for_each(p + i * every_class_number, p + (i + 1)*every_class_number, [i](float& v){v = (float)i; });}// train datafor (int i = 0; i < 4; ++i) {static const std::vector<std::string> digit{ "0_", "1_", "2_", "3_" };static const std::string suffix{ ".jpg" };for (int j = 1; j <= every_class_number; ++j) {std::string image_name = image_path + digit[i] + std::to_string(j) + suffix;cv::Mat image = cv::imread(image_name, 0);CHECK(!image.empty() && image.isContinuous());image.convertTo(image, CV_32FC1);image = image.reshape(0, 1);tmp = train_data.rowRange(i * every_class_number + j - 1, i * every_class_number + j);image.copyTo(tmp);}}knn->train(train_data, cv::ml::ROW_SAMPLE, train_labels);// predict dattacv::Mat predict_data(predict_samples_number, tmp.rows * tmp.cols, CV_32FC1);for (int i = 0; i < 4; ++i) {static const std::vector<std::string> digit{ "0_", "1_", "2_", "3_" };static const std::string suffix{ ".jpg" };for (int j = 11; j <= every_class_number+10; ++j) {std::string image_name = image_path + digit[i] + std::to_string(j) + suffix;cv::Mat image = cv::imread(image_name, 0);CHECK(!image.empty() && image.isContinuous());image.convertTo(image, CV_32FC1);image = image.reshape(0, 1);tmp = predict_data.rowRange(i * every_class_number + j - 10 - 1, i * every_class_number + j - 10);image.copyTo(tmp);}}cv::Mat result;knn->findNearest(predict_data, K, result);CHECK(result.rows == predict_samples_number);cv::Mat predict_labels(predict_samples_number, 1, CV_32FC1);p = (float*)predict_labels.data;for (int i = 0; i < 4; ++i) {std::for_each(p + i * every_class_number, p + (i + 1)*every_class_number, [i](float& v){v = (float)i; });}int count{ 0 };for (int i = 0; i < predict_samples_number; ++i) {float value1 = ((float*)predict_labels.data)[i];float value2 = ((float*)result.data)[i];fprintf(stdout, "expected value: %f, actual value: %f\n", value1, value2);if (int(value1) == int(value2)) ++count;}fprintf(stdout, "when K = %d, accuracy: %f\n", K, count * 1.f / predict_samples_number);return 0;
}測(cè)試結(jié)果如下:
GitHub:?https://github.com/fengbingchun/NN_Test ??
總結(jié)
以上是生活随笔為你收集整理的OpenCV3.3中 K-最近邻法(KNN)接口简介及使用的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
 
                            
                        - 上一篇: K-最近邻法(KNN)简介
- 下一篇: Brute Force算法介绍及C++实
