Caffe源码中Solver文件分析
Caffe源碼(caffe version commit: 09868ac , date: 2015.08.15)中有一些重要的頭文件,這里介紹下include/caffe/solver.hpp文件的內容:
1.??????include文件:
<caffe/solver.hpp>:此文件的介紹可以參考:?http://blog.csdn.net/fengbingchun/article/details/62423060
2.??????模板類Solver:虛基類
3.??????模板類WorkerSolver:繼承父類Solver,用于多GPU訓練時僅計算梯度
4.??????模板類SGDSolver:繼承父類Solver
5.??????模板類NesterovSolver:繼承SGDSolver
6.??????模板類AdaGradSolver:繼承SGDSolver
7.??????模板類RMSPropSolver:繼承SGDSolver
8.??????模板類AdaDeltaSolver:繼承SGDSolver
9.??????模板類AdamSolver:繼承SGDSolver
10.??函數GetSolver:new solver對象
Solver通過協調Net的前向推斷計算和反向梯度計算(forward inference and backward gradients),來對參數進行更新,從而達到減少loss的目的。Caffe模型的學習被分為兩個部分:由Solver進行優化、更新參數,由Net計算出loss和gradient。
solver.prototxt是一個配置文件用來告知Caffe怎樣對網絡進行訓練。
有了Net就可以進行神經網絡的前后向傳播計算了,但是還缺少神經網絡的訓練和預測功能,Solver類進一步封裝了訓練和預測相關的一些功能。Solver定義了針對Net網絡模型的求解方法,記錄神經網絡的訓練過程,保存神經網絡模型參數,中斷并恢復網絡的訓練過程。自定義Solver能夠實現不同的神經網絡求解方式。
Caffe支持的solvers包括:
(1)、Stochastic Gradient Descent(type: “SGD”)即隨機梯度下降:利用負梯度和上一次權重的更新值的線性組合來更新權重。學習率(learning rate)是負梯度的權重。動量是上一次更新值的權重。一般將學習速率初始化為0.01,然后在訓練(training)中當loss達到穩定時,將學習速率除以一個常數(例如10),將這個過程重復多次。對于動量一般設置為0.9,動量使weight得更新更為平緩,使學習過程更為穩定、快速。
(2)、AdaDelta(type:“AdaDelta”):是一種”魯棒的學習率方法”,同SGD一樣是一種基于梯度的優化方法。
(3)、Adaptive Gradient(type: “AdaGrad”)即自適應梯度下降,與隨機梯度下降一樣是基于梯度的優化方法。
(4)、Adam(type:“Adam”):也是一種基于梯度的優化方法。它包含一對自適應時刻估計變量,可以看做是AdaGrad的一種泛化形式。
(5)、Nesterov’s Accelerated Gradient(type: “Nesterov”):Nesterov提出的加速梯度下降(Nesterov’s accelerated gradient)是凸優化的一種最優算法,其收斂速度可以達到O(1/t^2),而不是O(1/t)。盡管在使用Caffe訓練深度神經網絡時很難滿足O(1/t^2)收斂條件,但實際中NAG對于某些特定結構的深度學習模型仍是一個非常有效的方法。
(6)、RMSprop(type:“RMSProp”):是一種基于梯度的優化方法(同SGD類似)。
Solver:
(1)、用于優化過程的記錄、創建訓練網絡(用于學習)和測試網絡(用于評估);
(2)、通過forward和backward過程來迭代地優化和更新參數;
(3)、周期性地用測試網絡評估模型性能;
(4)、在優化過程中記錄模型和solver狀態的快照(snapshot)。
每一次迭代過程中:
(1)、調用Net的前向過程計算出輸出和loss;
(2)、調用Net的反向過程計算出梯度(loss對每層的權重w和偏置b求導);
(3)、根據下面所講的Solver方法,利用梯度更新參數;
(4)、根據學習率(learning rate),歷史數據和求解方法更新solver的狀態,使權重從初始化狀態逐步更新到最終的學習到的狀態。
Solvers的運行模式有CPU/GPU兩種模式。
Solver方法:用于最小化損失(loss)值。給定一個數據集D,優化的目標是D中所有數據損失的均值,即平均損失,取得最小值。
注:以上關于Solver內容的介紹主要摘自由CaffeCN社區翻譯的《Caffe官方教程中譯本》。
<caffe/solver.hpp>文件的詳細介紹如下:
#ifndef CAFFE_OPTIMIZATION_SOLVER_HPP_
#define CAFFE_OPTIMIZATION_SOLVER_HPP_#include <string>
#include <vector>#include "caffe/net.hpp"namespace caffe {/*** @brief An interface for classes that perform optimization on Net%s.** Requires implementation of ApplyUpdate to compute a parameter update* given the current state of the Net parameters.*/
template <typename Dtype>
class Solver { // Solver模板類,虛基類public:
// 顯示構造函數, 內部會調用Init函數explicit Solver(const SolverParameter& param, const Solver* root_solver = NULL);explicit Solver(const string& param_file, const Solver* root_solver = NULL);
// 成員變量賦值,包括param_、iter_、current_step_,并調用InitTrainNet和InitTestNets函數void Init(const SolverParameter& param);
// 為成員變量net_賦值void InitTrainNet();
// 為成員變量test_nets_賦值void InitTestNets();// The main entry of the solver function. In default, iter will be zero. Pass// in a non-zero iter number to resume training for a pre-trained net.
// 依次調用函數Restore、Step、Snapshot,然后執行net_的前向傳播函數ForwardPrefilled,最后調用TestAll函數virtual void Solve(const char* resume_file = NULL);inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }
// 反復執行net前向傳播反向傳播計算,期間會調用函數TestAll、ApplyUpdate、Snapshot及類Callback兩個成員函數void Step(int iters);// The Restore method simply dispatches to one of the// RestoreSolverStateFrom___ protected methods. You should implement these// methods to restore the state from the appropriate snapshot type.
// 加載已有的模型void Restore(const char* resume_file);
// 虛析構函數virtual ~Solver() {}// 獲得slover parameterinline const SolverParameter& param() const { return param_; }
// 獲得train Netinline shared_ptr<Net<Dtype> > net() { return net_; }
// 獲得test Netinline const vector<shared_ptr<Net<Dtype> > >& test_nets() {return test_nets_;}
// 獲得當前的迭代數int iter() { return iter_; }// Invoked at specific points during an iteration
// 內部Callback類,僅在多卡GPU模式下使用class Callback {protected:virtual void on_start() = 0;virtual void on_gradients_ready() = 0;template <typename T>friend class Solver;};
// 獲得Callbackconst vector<Callback*>& callbacks() const { return callbacks_; }
// 添加一個Callbackvoid add_callback(Callback* value) { callbacks_.push_back(value); }protected:// Make and apply the update value for the current iteration.
// 更新net的權值和偏置virtual void ApplyUpdate() = 0;// The Solver::Snapshot function implements the basic snapshotting utility// that stores the learned net. You should implement the SnapshotSolverState()// function that produces a SolverState protocol buffer that needs to be// written to disk together with the learned net.
// 快照,內部會調用SnapshotToBinaryProto或SnapshotToHDF5、SnapshotSolverState函數void Snapshot();
// 獲取快照文件名string SnapshotFilename(const string extension);
// 寫proto到.caffemodelstring SnapshotToBinaryProto();
// 寫proto到HDF5文件string SnapshotToHDF5();// The test routine
// 內部會循環調用Test函數void TestAll();
// 執行測試網絡,net前向傳播void Test(const int test_net_id = 0);
// 存儲snapshot solver statevirtual void SnapshotSolverState(const string& model_filename) = 0;
// 讀HDF5文件到solver statevirtual void RestoreSolverStateFromHDF5(const string& state_file) = 0;
// 讀二進制文件.solverstate到solver statevirtual void RestoreSolverStateFromBinaryProto(const string& state_file) = 0;
// dummy function,只有聲明沒有實現void DisplayOutputBlobs(const int net_id);// Caffe中類的成員變量名都帶有后綴"_",這樣就容易區分臨時變量和類成員變量SolverParameter param_; // solver parameterint iter_; // 當前的迭代數int current_step_; // shared_ptr<Net<Dtype> > net_; // train netvector<shared_ptr<Net<Dtype> > > test_nets_; // test netvector<Callback*> callbacks_; // Callback// The root solver that holds root nets (actually containing shared layers)// in data parallelismconst Solver* const root_solver_;// 禁止使用Solver類的拷貝和賦值操作DISABLE_COPY_AND_ASSIGN(Solver);
};/*** @brief Solver that only computes gradients, used as worker* for multi-GPU training.*/
template <typename Dtype>
class WorkerSolver : public Solver<Dtype> { // 模板類WorkerSolver,繼承父類Solverpublic:
// 顯示構造函數explicit WorkerSolver(const SolverParameter& param, const Solver<Dtype>* root_solver = NULL): Solver<Dtype>(param, root_solver) {}protected:void ApplyUpdate() {}void SnapshotSolverState(const string& model_filename) {LOG(FATAL) << "Should not be called on worker solver.";}void RestoreSolverStateFromBinaryProto(const string& state_file) {LOG(FATAL) << "Should not be called on worker solver.";}void RestoreSolverStateFromHDF5(const string& state_file) {LOG(FATAL) << "Should not be called on worker solver.";}
};/*** @brief Optimizes the parameters of a Net using* stochastic gradient descent (SGD) with momentum.*/
template <typename Dtype>
class SGDSolver : public Solver<Dtype> { // 模板類SGDSolver,繼承父類Solverpublic:
// 顯示構造函數,調用PreSolve函數explicit SGDSolver(const SolverParameter& param) : Solver<Dtype>(param) { PreSolve(); }explicit SGDSolver(const string& param_file) : Solver<Dtype>(param_file) { PreSolve(); }
// 獲取history數據const vector<shared_ptr<Blob<Dtype> > >& history() { return history_; }protected:
// 成員變量history_, update_, temp_初始化void PreSolve();
// 獲取學習率Dtype GetLearningRate();
// 內部會調用ClipGradients、Normalize、Regularize、ComputeUpdateValue,更新net權值和偏置virtual void ApplyUpdate();
// 調用caffe_scal函數virtual void Normalize(int param_id);
// 調用caffe_axpy函數virtual void Regularize(int param_id);
// 計算并更新相應Blob值,調用caffe_cpu_axpby和caffe_copy函數virtual void ComputeUpdateValue(int param_id, Dtype rate);
// clip parameter gradients to that L2 norm,如果梯度值過大,就會對梯度做一個修剪,
// 對所有的參數乘以一個縮放因子,使得所有參數的平方和不超過參數中設定的梯度總值virtual void ClipGradients();
// 存儲snapshot solver state,內部會掉用SnapshotSolverStateToBinaryProto或SnapshotSolverStateToHDF5函數virtual void SnapshotSolverState(const string& model_filename);
// 寫solver state到二進制文件.solverstatevirtual void SnapshotSolverStateToBinaryProto(const string& model_filename);
// 寫solver state到HDF5virtual void SnapshotSolverStateToHDF5(const string& model_filename);// 讀HDF5文件到solver statevirtual void RestoreSolverStateFromHDF5(const string& state_file);// 讀二進制文件.solverstate到solver statevirtual void RestoreSolverStateFromBinaryProto(const string& state_file);// history maintains the historical momentum data.// update maintains update related data and is not needed in snapshots.// temp maintains other information that might be needed in computation// of gradients/updates and is not needed in snapshots
// Caffe中類的成員變量名都帶有后綴"_",這樣就容易區分臨時變量和類成員變量vector<shared_ptr<Blob<Dtype> > > history_, update_, temp_;// 禁止使用SGDSolver類的拷貝和賦值操作DISABLE_COPY_AND_ASSIGN(SGDSolver);
};template <typename Dtype>
class NesterovSolver : public SGDSolver<Dtype> { // 模板類NesterovSolver,繼承SGDSolverpublic:
// 顯示構造函數explicit NesterovSolver(const SolverParameter& param) : SGDSolver<Dtype>(param) {}explicit NesterovSolver(const string& param_file) : SGDSolver<Dtype>(param_file) {}protected:
// 計算并更新相應Blob值,調用caffe_cpu_axpby和caffe_copy函數virtual void ComputeUpdateValue(int param_id, Dtype rate);// 禁止使用NesterovSolver類的拷貝和賦值操作DISABLE_COPY_AND_ASSIGN(NesterovSolver);
};template <typename Dtype>
class AdaGradSolver : public SGDSolver<Dtype> { // 模板類AdaGradSolver,繼承SGDSolverpublic:
// 顯示構造函數,調用constuctor_sanity_check函數explicit AdaGradSolver(const SolverParameter& param) : SGDSolver<Dtype>(param) { constructor_sanity_check(); }explicit AdaGradSolver(const string& param_file) : SGDSolver<Dtype>(param_file) { constructor_sanity_check(); }protected:
// 計算并更新相應Blob值virtual void ComputeUpdateValue(int param_id, Dtype rate);void constructor_sanity_check() {CHECK_EQ(0, this->param_.momentum())<< "Momentum cannot be used with AdaGrad.";}// 禁止使用AdaGradSolver類的拷貝和賦值操作DISABLE_COPY_AND_ASSIGN(AdaGradSolver);
};template <typename Dtype>
class RMSPropSolver : public SGDSolver<Dtype> { // 模板類RMSPropSolver,繼承SGDSolverpublic:
// 顯示構造函數,調用constructor_sanity_check函數explicit RMSPropSolver(const SolverParameter& param) : SGDSolver<Dtype>(param) { constructor_sanity_check(); }explicit RMSPropSolver(const string& param_file) : SGDSolver<Dtype>(param_file) { constructor_sanity_check(); }protected:
// 計算并更新相應Blob值virtual void ComputeUpdateValue(int param_id, Dtype rate);void constructor_sanity_check() {CHECK_EQ(0, this->param_.momentum())<< "Momentum cannot be used with RMSProp.";CHECK_GE(this->param_.rms_decay(), 0)<< "rms_decay should lie between 0 and 1.";CHECK_LT(this->param_.rms_decay(), 1)<< "rms_decay should lie between 0 and 1.";}// 禁止使用RMSPropSolver類的拷貝和賦值操作DISABLE_COPY_AND_ASSIGN(RMSPropSolver);
};template <typename Dtype>
class AdaDeltaSolver : public SGDSolver<Dtype> { // 模板類AdaDeltaSolver,繼承SGDSolverpublic:
// 顯示構造函數,調用AdaDeltaPreSolve函數explicit AdaDeltaSolver(const SolverParameter& param) : SGDSolver<Dtype>(param) { AdaDeltaPreSolve(); }explicit AdaDeltaSolver(const string& param_file) : SGDSolver<Dtype>(param_file) { AdaDeltaPreSolve(); }protected:void AdaDeltaPreSolve();
// 計算并更新相應Blob值virtual void ComputeUpdateValue(int param_id, Dtype rate);// 禁止使用AdaDeltaSolver類的拷貝和賦值操作DISABLE_COPY_AND_ASSIGN(AdaDeltaSolver);
};/*** @brief AdamSolver, an algorithm for first-order gradient-based optimization* of stochastic objective functions, based on adaptive estimates of* lower-order moments. Described in [1].** [1] D. P. Kingma and J. L. Ba, "ADAM: A Method for Stochastic Optimization."* arXiv preprint arXiv:1412.6980v8 (2014).*/
template <typename Dtype>
class AdamSolver : public SGDSolver<Dtype> { // 模板類AdamSolver,繼承SGDSolverpublic:
// 顯示構造函數,調用AdamPreSolve函數explicit AdamSolver(const SolverParameter& param) : SGDSolver<Dtype>(param) { AdamPreSolve();}explicit AdamSolver(const string& param_file) : SGDSolver<Dtype>(param_file) { AdamPreSolve(); }protected:void AdamPreSolve();
// 計算并更新相應Blob值virtual void ComputeUpdateValue(int param_id, Dtype rate);// 禁止使用AdamSolver類的拷貝和賦值操作DISABLE_COPY_AND_ASSIGN(AdamSolver);
};// new一個指定的solver方法對象
template <typename Dtype>
Solver<Dtype>* GetSolver(const SolverParameter& param) {SolverParameter_SolverType type = param.solver_type();switch (type) {case SolverParameter_SolverType_SGD:return new SGDSolver<Dtype>(param);case SolverParameter_SolverType_NESTEROV:return new NesterovSolver<Dtype>(param);case SolverParameter_SolverType_ADAGRAD:return new AdaGradSolver<Dtype>(param);case SolverParameter_SolverType_RMSPROP:return new RMSPropSolver<Dtype>(param);case SolverParameter_SolverType_ADADELTA:return new AdaDeltaSolver<Dtype>(param);case SolverParameter_SolverType_ADAM:return new AdamSolver<Dtype>(param);default:LOG(FATAL) << "Unknown SolverType: " << type;}return (Solver<Dtype>*) NULL;
}} // namespace caffe#endif // CAFFE_OPTIMIZATION_SOLVER_HPP_
在caffe.proto文件中,主要有一個message是與solver相關的,如下:
// NOTE
// Update the next available ID when you add a new SolverParameter field.
//
// SolverParameter next available ID: 40 (last added: momentum2)
message SolverParameter { // Solver參數//// Specifying the train and test networks//// Exactly one train net must be specified using one of the following fields:// train_net_param, train_net, net_param, net// One or more test nets may be specified using any of the following fields:// test_net_param, test_net, net_param, net// If more than one test net field is specified (e.g., both net and// test_net are specified), they will be evaluated in the field order given// above: (1) test_net_param, (2) test_net, (3) net_param/net.// A test_iter must be specified for each test_net.// A test_level and/or a test_stage may also be specified for each test_net.//// Proto filename for the train net, possibly combined with one or more test nets.optional string net = 24; // .prototxt文件名, train or test net// Inline train net param, possibly combined with one or more test nets.optional NetParameter net_param = 25; // net parameter類optional string train_net = 1; // Proto filename for the train net, .prototxt文件名,train netrepeated string test_net = 2; // Proto filenames for the test nets, .prototxt文件名,test netoptional NetParameter train_net_param = 21; // Inline train net params, train net parameter類repeated NetParameter test_net_param = 22; // Inline test net params, test net parameter類// The states for the train/test nets. Must be unspecified or// specified once per net.//// By default, all states will have solver = true;// train_state will have phase = TRAIN,// and all test_state's will have phase = TEST.// Other defaults are set according to the NetState defaults.optional NetState train_state = 26; // train net staterepeated NetState test_state = 27; // test net state// The number of iterations for each test net.repeated int32 test_iter = 3; // 對于測試網絡(用于評估)執行一次需要迭代的次數, test_iter * batch_size = 測試圖像總數量// The number of iterations between two testing phases.optional int32 test_interval = 4 [default = 0]; // 指定執行多少次訓練網絡執行一次測試網絡optional bool test_compute_loss = 19 [default = false]; // 執行測試網絡時是否計算loss// If true, run an initial test pass before the first iteration,// ensuring memory availability and printing the starting value of the loss.optional bool test_initialization = 32 [default = true]; // 在總的開始前,是否先執行一次測試網絡optional float base_lr = 5; // The base learning rate,基礎學習率// the number of iterations between displaying info. If display = 0, no info// will be displayed.optional int32 display = 6; // 指定迭代多少次顯示一次結果信息// Display the loss averaged over the last average_loss iterationsoptional int32 average_loss = 33 [default = 1]; // optional int32 max_iter = 7; // the maximum number of iterations// accumulate gradients over `iter_size` x `batch_size` instancesoptional int32 iter_size = 36 [default = 1]; // // The learning rate decay policy. The currently implemented learning rate// policies are as follows: // 學習率衰減策略// - fixed: always return base_lr.// - step: return base_lr * gamma ^ (floor(iter / step))// - exp: return base_lr * gamma ^ iter// - inv: return base_lr * (1 + gamma * iter) ^ (- power)// - multistep: similar to step but it allows non uniform steps defined by// stepvalue// - poly: the effective learning rate follows a polynomial decay, to be// zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power)// - sigmoid: the effective learning rate follows a sigmod decay// return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))//// where base_lr, max_iter, gamma, step, stepvalue and power are defined// in the solver parameter protocol buffer, and iter is the current iteration.optional string lr_policy = 8; // 學習策略,可取的值包括:fixed、step、exp、inv、multistep、poly、sigmoidoptional float gamma = 9; // The parameter to compute the learning rate.optional float power = 10; // The parameter to compute the learning rate.optional float momentum = 11; // The momentum value, 動量optional float weight_decay = 12; // The weight decay. // // regularization types supported: L1 and L2// controlled by weight_decayoptional string regularization_type = 29 [default = "L2"]; // L1 or L2// the stepsize for learning rate policy "step"optional int32 stepsize = 13; //// the stepsize for learning rate policy "multistep"repeated int32 stepvalue = 34; //// Set clip_gradients to >= 0 to clip parameter gradients to that L2 norm,// whenever their actual L2 norm is larger.optional float clip_gradients = 35 [default = -1]; //optional int32 snapshot = 14 [default = 0]; // The snapshot interval, 迭代多少次保存下結果(如權值、偏置)optional string snapshot_prefix = 15; // The prefix for the snapshot,指定保存文件名的前綴// whether to snapshot diff in the results or not. Snapshotting diff will help// debugging but the final protocol buffer size will be much larger.optional bool snapshot_diff = 16 [default = false]; //enum SnapshotFormat {HDF5 = 0;BINARYPROTO = 1;}optional SnapshotFormat snapshot_format = 37 [default = BINARYPROTO]; // HDF5 or BINARYPROTO// the mode solver will use: 0 for CPU and 1 for GPU. Use GPU in default.enum SolverMode {CPU = 0;GPU = 1;}optional SolverMode solver_mode = 17 [default = GPU]; // 指定solve mode是CPU還是GPU// the device_id will that be used in GPU mode. Use device_id = 0 in default.optional int32 device_id = 18 [default = 0]; // GPU mode下使用// If non-negative, the seed with which the Solver will initialize the Caffe// random number generator -- useful for reproducible results. Otherwise,// (and by default) initialize using a seed derived from the system clock.optional int64 random_seed = 20 [default = -1]; // // Solver typeenum SolverType { // solver優化方法SGD = 0;NESTEROV = 1;ADAGRAD = 2;RMSPROP = 3;ADADELTA = 4;ADAM = 5;}optional SolverType solver_type = 30 [default = SGD]; // 指定solver優化方法// numerical stability for RMSProp, AdaGrad and AdaDelta and Adamoptional float delta = 31 [default = 1e-8]; // // parameters for the Adam solveroptional float momentum2 = 39 [default = 0.999]; // // RMSProp decay value// MeanSquare(t) = rms_decay*MeanSquare(t-1) + (1-rms_decay)*SquareGradient(t)optional float rms_decay = 38; // // If true, print information about the state of the net that may help with// debugging learning problems.optional bool debug_info = 23 [default = false]; // // If false, don't save a snapshot after training finishes.optional bool snapshot_after_train = 28 [default = true]; //
}
solver的測試代碼如下:
#include "funset.hpp"
#include <string>
#include <vector>
#include <map>
#include "common.hpp"int test_caffe_solver()
{caffe::Caffe::set_mode(caffe::Caffe::CPU); // set run caffe modeconst std::string solver_prototxt{ "E:/GitCode/Caffe_Test/test_data/model/mnist/lenet_solver.prototxt" };caffe::SolverParameter solver_param;if (!caffe::ReadProtoFromTextFile(solver_prototxt.c_str(), &solver_param)) {fprintf(stderr, "parse solver.prototxt fail\n");return -1;}boost::shared_ptr<caffe::Solver<float> > solver(caffe::GetSolver<float>(solver_param));caffe::SolverParameter param = solver->param();if (param.has_net())fprintf(stderr, "net: %s\n", param.net().c_str());if (param.has_net_param()) {fprintf(stderr, "has net param\n");caffe::NetParameter net_param = param.net_param();if (net_param.has_name())fprintf(stderr, "net param name: %s\n", net_param.name().c_str());}if (param.has_train_net())fprintf(stderr, "train_net: %s\n", param.train_net());if (param.test_net().size() > 0) {for (auto test_net : param.test_net())fprintf(stderr, "test_net: %s\n", test_net);}if (param.has_train_net_param()) {fprintf(stderr, "has train net param\n");caffe::NetParameter train_net_param = param.train_net_param();}if (param.test_net_param().size() > 0) {fprintf(stderr, "has test net param\n");std::vector<caffe::NetParameter> test_net_param;for (auto net_param : param.test_net_param())test_net_param.push_back(net_param);}if (param.has_train_state()) {fprintf(stderr, "has train state\n");caffe::NetState state = param.train_state();}if (param.test_state().size()) {fprintf(stderr, "has test state\n");}if (param.test_iter_size() > 0) {fprintf(stderr, "has test iter\n");for (auto iter : param.test_iter())fprintf(stderr, " %d ", iter);fprintf(stderr, "\n");}if (param.has_test_interval())fprintf(stderr, "test interval: %d\n", param.test_interval());bool test_compute_loss = param.test_compute_loss();fprintf(stderr, "test compute loss: %d\n", test_compute_loss);bool test_initialization = param.test_initialization();fprintf(stderr, "test initializtion: %d\n", test_initialization);if (param.has_base_lr()) {float base_lr = param.base_lr();fprintf(stderr, "base lr: %f\n", base_lr);}if (param.has_display()) {int display = param.display();fprintf(stderr, "display: %d\n", display);}int average_loss = param.average_loss();fprintf(stderr, "average loss: %d\n", average_loss);if (param.has_max_iter()) {int max_iter = param.max_iter();fprintf(stderr, "max iter: %d\n", max_iter);}int iter_size = param.iter_size();fprintf(stderr, "iter size: %d\n", iter_size);if (param.has_lr_policy())fprintf(stderr, "lr policy: %s\n", param.lr_policy().c_str());if (param.has_gamma())fprintf(stderr, "gamma: %f\n", param.gamma());if (param.has_power())fprintf(stderr, "power: %f\n", param.power());if (param.has_momentum())fprintf(stderr, "momentum: %f\n", param.momentum());if (param.has_weight_decay())fprintf(stderr, "weight decay: %f\n", param.weight_decay());std::string regularization_type = param.regularization_type();fprintf(stderr, "regularization type: %s\n", param.regularization_type().c_str());if (param.has_stepsize())fprintf(stderr, "stepsize: %d\n", param.stepsize());if (param.stepvalue_size() > 0) {fprintf(stderr, "has stepvalue\n");for (auto value : param.stepvalue())fprintf(stderr, " %d ", value);fprintf(stderr, "\n");}fprintf(stderr, "clip gradients: %f\n", param.clip_gradients());fprintf(stderr, "snapshot: %d\n", param.snapshot());if (param.has_snapshot_prefix())fprintf(stderr, "snapshot prefix: %s\n", param.snapshot_prefix().c_str());fprintf(stderr, "snapshot diff: %d\n", param.snapshot_diff());caffe::SolverParameter_SnapshotFormat snapshot_format = param.snapshot_format();fprintf(stderr, "snapshot format: %s\n", snapshot_format == 0 ? "HDF5" : "BINARYPROTO");caffe::SolverParameter_SolverMode solver_mode = param.solver_mode();fprintf(stderr, "solver mode: %s\n", solver_mode == 0 ? "CPU" : "GPU");if (param.has_device_id())fprintf(stderr, "device id: %d\n", param.device_id());fprintf(stderr, "random seed: %d\n", param.random_seed());caffe::SolverParameter_SolverType solver_type = param.solver_type();std::string solver_method[] {"SGD", "NESTEROV", "ADAGRAD", "RMSPROP", "ADADELTA", "ADAM"};fprintf(stderr, "solver type: %s\n", solver_method[solver_type].c_str());fprintf(stderr, "delta: %f\n", param.delta());fprintf(stderr, "momentum2: %f\n", param.momentum2());if (param.has_rms_decay())fprintf(stderr, "rms decy: %f\n", param.rms_decay());fprintf(stderr, "debug info: %d\n", param.debug_info());fprintf(stderr, "snapshot after train: %d\n", param.snapshot_after_train());boost::shared_ptr<caffe::Net<float>> net = solver->net();std::vector<boost::shared_ptr<caffe::Net<float>>> test_nets = solver->test_nets();fprintf(stderr, "test nets size: %d\n", test_nets.size());fprintf(stderr, "iter: %d\n", solver->iter());return 0;
}
部分輸出結果如下:
GitHub: https://github.com/fengbingchun/Caffe_Test
總結
以上是生活随笔為你收集整理的Caffe源码中Solver文件分析的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: C++中nothrow的介绍及使用
- 下一篇: C++中try/catch/throw的