/******************************************************** * ██████╗ ██████╗████████╗██╗ * ██╔════╝ ██╔════╝╚══██╔══╝██║ * ██║ ███╗██║ ██║ ██║ * ██║ ██║██║ ██║ ██║ * ╚██████╔╝╚██████╗ ██║ ███████╗ * ╚═════╝ ╚═════╝ ╚═╝ ╚══════╝ * Geophysical Computational Tools & Library (GCTL) * * Copyright (c) 2022 Yi Zhang (yizhang-geo@zju.edu.cn) * * GCTL is distributed under a dual licensing scheme. You can redistribute * it and/or modify it under the terms of the GNU Lesser General Public * License as published by the Free Software Foundation, either version 2 * of the License, or (at your option) any later version. You should have * received a copy of the GNU Lesser General Public License along with this * program. If not, see . * * If the terms and conditions of the LGPL v.2. would prevent you from using * the GCTL, please consider the option to obtain a commercial license for a * fee. These licenses are offered by the GCTL's original author. As a rule, * licenses are provided "as-is", unlimited in time for a one time fee. Please * send corresponding requests to: yizhang-geo@zju.edu.cn. Please do not forget * to include some description of your company and the realm of its activities. * Also add information on how to contact you by electronic and paper mail. ******************************************************/ #include "sgd.h" /** * Default parameter for the SGD methods. */ static const gctl::sgd_para sgd_defparam = {0, 1e-6, 0.01, 0.01, 0.9, 0.999, 1e-8}; gctl::sgd_solver::sgd_solver() { sgd_param_ = sgd_defparam; sgd_inter_ = 1; sgd_silent_ = false; solver_name_ = "Undefined"; } gctl::sgd_solver::~sgd_solver(){} int gctl::sgd_solver::SGD_Progress(double fx, const array &x, const sgd_para ¶m, const int k) { if (sgd_silent_) return 0; if (param.epsilon > 0.0 && fx <= param.epsilon) { std::clog << GCTL_CLEARLINE << "\rF(x) = " << fx << ", Train-Times = " << k; return 0; } if (sgd_inter_ > 0 && k%sgd_inter_ == 0) { std::clog << GCTL_CLEARLINE << "\rF(x) = " << fx << ", Train-Times = " << k; } return 0; } void gctl::sgd_solver::sgd_silent() { sgd_silent_ = true; return; } void gctl::sgd_solver::set_sgd_report_interval(int inter) { sgd_inter_ = inter; return; } void gctl::sgd_solver::set_sgd_para(const sgd_para &in_param) { sgd_param_ = in_param; return; } void gctl::sgd_solver::set_sgd_para(const toml::value &toml_data) { sgd_param_ = sgd_defparam; std::string SGD = "sgd"; if (toml_data.contains(SGD)) { if (toml_data.at(SGD).contains("iteration")) sgd_param_.iteration = toml::find(toml_data, SGD, "iteration"); if (toml_data.at(SGD).contains("epsilon")) sgd_param_.epsilon = toml::find(toml_data, SGD, "epsilon"); if (toml_data.at(SGD).contains("mu")) sgd_param_.mu = toml::find(toml_data, SGD, "mu"); if (toml_data.at(SGD).contains("alpha")) sgd_param_.alpha = toml::find(toml_data, SGD, "alpha"); if (toml_data.at(SGD).contains("beta_1")) sgd_param_.beta_1 = toml::find(toml_data, SGD, "beta_1"); if (toml_data.at(SGD).contains("beta_2")) sgd_param_.beta_2 = toml::find(toml_data, SGD, "beta_2"); if (toml_data.at(SGD).contains("sigma")) sgd_param_.sigma = toml::find(toml_data, SGD, "sigma"); } return; } void gctl::sgd_solver::show_solver() { std::clog << "Solver's Setup Panel\n"; std::clog << "-----------------------------\n"; std::clog << "Solver: " << solver_name_ << "\n"; std::clog << "Iteration = " << sgd_param_.iteration << ", Epsilon = " << sgd_param_.epsilon << ", Mu = " << sgd_param_.mu << "\n"; std::clog << "Alpha = " << sgd_param_.alpha << ", Beta1 = " << sgd_param_.beta_1 << ", Beta2 = " << sgd_param_.beta_2 << ", Sigma = " << sgd_param_.sigma << "\n"; std::clog << "=============================\n"; return; } void gctl::sgd_solver::sgd_error_str(sgd_return_code err_code, std::ostream &ss, bool err_throw) { #if defined _WINDOWS || __WIN32__ if (!err_throw) { if (err_code >= 0) { SetConsoleTextAttribute(GetStdHandle(STD_ERROR_HANDLE), FOREGROUND_INTENSITY | FOREGROUND_GREEN); ss << "Success! "; } else { SetConsoleTextAttribute(GetStdHandle(STD_ERROR_HANDLE), FOREGROUND_INTENSITY | FOREGROUND_RED); ss << "Fail! "; } } #else if (!err_throw) { if (err_code >= 0) ss << "\033[1m\033[32mSGD Success! "; else ss << "\033[1m\033[31mSGD Fail! "; } #endif std::string err_str; switch (err_code) { case SGD_SUCCESS: err_str = "Success."; break; case SGD_CONVERGENCE: err_str = "The iteration reached convergence."; break; case SGD_STOP: err_str = "The iteration stopped by the progress evaluation function."; break; case SGD_UNKNOWN_ERROR: err_str = "Unknown error."; break; case SGD_INVALID_VARIABLE_SIZE: err_str = "Invalid array size."; break; case SGD_REACHED_MAX_ITERATIONS: err_str = "The maximal iteration is reached."; break; case SGD_INVALID_EPSILON: err_str = "Invalid value for epsilon."; break; case SGD_INVALID_BETA: err_str = "Invalid value for beta."; break; case SGD_INVALID_MU: err_str = "Invalid value for mu."; break; case SGD_INVALID_ALPHA: err_str = "Invalid value for alpha."; break; case SGD_INVALID_SIGMA: err_str = "Invalid value for sigma."; break; case SGD_NAN_VALUE: err_str = "NaN values found."; break; default: err_str = "Unknown error."; break; } if (err_throw && err_code < 0) throw err_str; else ss << err_str; #if defined _WINDOWS || __WIN32__ if (!err_throw) { if (err_code >= 0) { SetConsoleTextAttribute(GetStdHandle(STD_ERROR_HANDLE), 7); ss << std::endl; } else { SetConsoleTextAttribute(GetStdHandle(STD_ERROR_HANDLE), 7); ss << std::endl; } } #else if (!err_throw) { if (err_code >= 0) ss << "\033[0m" << std::endl; else ss << "\033[0m" << std::endl; } #endif return; } gctl::sgd_para gctl::sgd_solver::default_sgd_para() { sgd_para dp = sgd_defparam; return dp; } void gctl::sgd_solver::SGD_Minimize(array &m, sgd_solver_type solver_id, std::ostream &ss, bool verbose, bool err_throw) { if (sgd_silent_) { sgd_return_code ret; if (solver_id == MOMENTUM) {solver_name_ = "MOMENTUM"; ret = momentum(m);} else if (solver_id == NAG) {solver_name_ = "NAG"; ret = nag(m);} else if (solver_id == ADAGRAD) {solver_name_ = "ADAGRAD"; ret = adagrad(m);} else if (solver_id == RMSPROP) {solver_name_ = "RMSPROP"; ret = rmsprop(m);} else if (solver_id == ADAM) {solver_name_ = "ADAM"; ret = adam(m);} else if (solver_id == NADAM) {solver_name_ = "NADAM"; ret = nadam(m);} else if (solver_id == ADAMAX) {solver_name_ = "ADAMAX"; ret = adamax(m);} else if (solver_id == ADABELIEF) {solver_name_ = "ADABELIEF"; ret = adabelief(m);} else throw std::invalid_argument("Invalid solver type. gstl::sgd_solver::SGD_Minimize(...)"); if (ret < 0) sgd_error_str(ret, ss, true); return; } // 使用lcg求解 注意当我们使用函数指针来调用求解函数时默认参数不可以省略 #ifdef GCTL_OPENMP double start = omp_get_wtime(); sgd_return_code ret; if (solver_id == MOMENTUM) {solver_name_ = "MOMENTUM"; ret = momentum(m);} else if (solver_id == NAG) {solver_name_ = "NAG"; ret = nag(m);} else if (solver_id == ADAGRAD) {solver_name_ = "ADAGRAD"; ret = adagrad(m);} else if (solver_id == RMSPROP) {solver_name_ = "RMSPROP"; ret = rmsprop(m);} else if (solver_id == ADAM) {solver_name_ = "ADAM"; ret = adam(m);} else if (solver_id == NADAM) {solver_name_ = "NADAM"; ret = nadam(m);} else if (solver_id == ADAMAX) {solver_name_ = "ADAMAX"; ret = adamax(m);} else if (solver_id == ADABELIEF) {solver_name_ = "ADABELIEF"; ret = adabelief(m);} else throw std::invalid_argument("Invalid solver type. gstl::sgd_solver::SGD_Minimize(...)"); double end = omp_get_wtime(); double costime = 1000*(end-start); #else clock_t start = clock(); sgd_return_code ret; if (solver_id == MOMENTUM) {solver_name_ = "MOMENTUM"; ret = momentum(m);} else if (solver_id == NAG) {solver_name_ = "NAG"; ret = nag(m);} else if (solver_id == ADAGRAD) {solver_name_ = "ADAGRAD"; ret = adagrad(m);} else if (solver_id == RMSPROP) {solver_name_ = "RMSPROP"; ret = rmsprop(m);} else if (solver_id == ADAM) {solver_name_ = "ADAM"; ret = adam(m);} else if (solver_id == NADAM) {solver_name_ = "NADAM"; ret = nadam(m);} else if (solver_id == ADAMAX) {solver_name_ = "ADAMAX"; ret = adamax(m);} else if (solver_id == ADABELIEF) {solver_name_ = "ADABELIEF"; ret = adabelief(m);} else throw std::invalid_argument("Invalid solver type. gstl::sgd_solver::SGD_Minimize(...)"); clock_t end = clock(); double costime = 1000*(end-start)/(double)CLOCKS_PER_SEC; #endif if (!err_throw) { std::clog << std::endl; switch (solver_id) { case MOMENTUM: std::clog << "Solver: MOMENTUM. Time cost: " << costime << " ms" << std::endl; break; case NAG: std::clog << "Solver: NAG. Time cost: " << costime << " ms" << std::endl; break; case ADAGRAD: std::clog << "Solver: ADAGRAD. Time cost: " << costime << " ms" << std::endl; break; case RMSPROP: std::clog << "Solver: RMSPROP. Time cost: " << costime << " ms" << std::endl; break; case ADAM: std::clog << "Solver: ADAM. Time cost: " << costime << " ms" << std::endl; break; case NADAM: std::clog << "Solver: NADAM. Time cost: " << costime << " ms" << std::endl; break; case ADAMAX: std::clog << "Solver: ADAMAX. Time cost: " << costime << " ms" << std::endl; break; case ADABELIEF: std::clog << "Solver: ADABELIEF. Time cost: " << costime << " ms" << std::endl; break; default: std::clog << "Solver: Unknown. Time cost: " << costime << " ms" << std::endl; break; } } if (verbose) sgd_error_str(ret, ss, err_throw); else if (ret < 0) sgd_error_str(ret, ss, err_throw); return; } gctl::sgd_return_code gctl::sgd_solver::momentum(array &m) { int n_size = m.size(); //check parameters if (n_size <= 0) return SGD_INVALID_VARIABLE_SIZE; if (sgd_param_.epsilon < 0) return SGD_INVALID_EPSILON; if (sgd_param_.mu < 0 || sgd_param_.mu >= 1.0) return SGD_INVALID_MU; array mk(n_size, 0.0); array g(n_size); int t = 0; double fx; while (1) { fx = SGD_Evaluate(m, g); if (SGD_Progress(fx, m, sgd_param_, t)) return SGD_STOP; if (fx < sgd_param_.epsilon) return SGD_CONVERGENCE; for (int i = 0; i < n_size; i++) { mk[i] = sgd_param_.mu*mk[i] + g[i]; m[i] = m[i] - sgd_param_.alpha * mk[i]; if (m[i] != m[i]) return SGD_NAN_VALUE; } t++; if (sgd_param_.iteration > 0 && t >= sgd_param_.iteration) break; } return SGD_REACHED_MAX_ITERATIONS; } gctl::sgd_return_code gctl::sgd_solver::nag(array &m) { int n_size = m.size(); //check parameters if (n_size <= 0) return SGD_INVALID_VARIABLE_SIZE; if (sgd_param_.epsilon < 0) return SGD_INVALID_EPSILON; if (sgd_param_.mu < 0 || sgd_param_.mu >= 1.0) return SGD_INVALID_MU; array mk(n_size, 0.0); array xk(n_size); array g (n_size); int t = 0; double fx; while (1) { for (int i = 0; i < n_size; i++) { xk[i] = m[i] - sgd_param_.mu*sgd_param_.alpha*mk[i]; } fx = SGD_Evaluate(xk, g); if (SGD_Progress(fx, m, sgd_param_, t)) return SGD_STOP; if (fx < sgd_param_.epsilon) return SGD_CONVERGENCE; for (int i = 0; i < n_size; i++) { mk[i] = sgd_param_.mu*mk[i] + g[i]; m[i] = m[i] - sgd_param_.alpha * mk[i]; if (m[i] != m[i]) return SGD_NAN_VALUE; } t++; if (sgd_param_.iteration > 0 && t >= sgd_param_.iteration) break; } return SGD_REACHED_MAX_ITERATIONS; } gctl::sgd_return_code gctl::sgd_solver::adagrad(array &m) { int n_size = m.size(); //check parameters if (n_size <= 0) return SGD_INVALID_VARIABLE_SIZE; if (sgd_param_.epsilon < 0.0) return SGD_INVALID_EPSILON; if (sgd_param_.sigma < 0.0) return SGD_INVALID_SIGMA; array mk(n_size, 0.0); array g (n_size); int t = 0; double fx; while (1) { fx = SGD_Evaluate(m, g); if (SGD_Progress(fx, m, sgd_param_, t)) return SGD_STOP; if (fx < sgd_param_.epsilon) return SGD_CONVERGENCE; for (int i = 0; i < n_size; i++) { mk[i] = mk[i] + g[i]*g[i]; m[i] = m[i] - sgd_param_.alpha * g[i]/(sqrt(mk[i]) + sgd_param_.sigma); if (m[i] != m[i]) return SGD_NAN_VALUE; } t++; if (sgd_param_.iteration > 0 && t >= sgd_param_.iteration) break; } return SGD_REACHED_MAX_ITERATIONS; } gctl::sgd_return_code gctl::sgd_solver::rmsprop(array &m) { int n_size = m.size(); //check parameters if (n_size <= 0) return SGD_INVALID_VARIABLE_SIZE; if (sgd_param_.epsilon < 0.0) return SGD_INVALID_EPSILON; if (sgd_param_.sigma < 0.0) return SGD_INVALID_SIGMA; array vk(n_size, 0.0); array g (n_size); int t = 0; double fx; while (1) { fx = SGD_Evaluate(m, g); if (SGD_Progress(fx, m, sgd_param_, t)) return SGD_STOP; if (fx < sgd_param_.epsilon) return SGD_CONVERGENCE; for (int i = 0; i < n_size; i++) { vk[i] = sgd_param_.beta_2 * vk[i] + (1.0 - sgd_param_.beta_2)*g[i]*g[i]; m[i] = m[i] - sgd_param_.alpha * g[i]/(sqrt(vk[i]) + sgd_param_.sigma); if (m[i] != m[i]) return SGD_NAN_VALUE; } t++; if (sgd_param_.iteration > 0 && t >= sgd_param_.iteration) break; } return SGD_REACHED_MAX_ITERATIONS; } gctl::sgd_return_code gctl::sgd_solver::adam(array &m) { int n_size = m.size(); //check parameters if (n_size <= 0) return SGD_INVALID_VARIABLE_SIZE; if (sgd_param_.epsilon < 0) return SGD_INVALID_EPSILON; if (sgd_param_.alpha < 0) return SGD_INVALID_ALPHA; if (sgd_param_.beta_1 < 0.0 || sgd_param_.beta_1 >= 1.0) return SGD_INVALID_BETA; if (sgd_param_.beta_2 < 0.0 || sgd_param_.beta_2 >= 1.0) return SGD_INVALID_BETA; if (sgd_param_.sigma < 0.0) return SGD_INVALID_SIGMA; array mk(n_size, 0.0); array vk(n_size, 0.0); array g (n_size); double beta_1t = 1.0, beta_2t = 1.0; double alpha_k; int t = 0; double fx; while (1) { fx = SGD_Evaluate(m, g); if (SGD_Progress(fx, m, sgd_param_, t)) return SGD_STOP; if (fx < sgd_param_.epsilon) return SGD_CONVERGENCE; beta_1t *= sgd_param_.beta_1; beta_2t *= sgd_param_.beta_2; alpha_k = sgd_param_.alpha * sqrt(1.0 - beta_2t)/(1.0 - beta_1t); int i; #pragma omp parallel for private (i) schedule(guided) for (i = 0; i < n_size; i++) { mk[i] = sgd_param_.beta_1*mk[i] + (1.0 - sgd_param_.beta_1)*g[i]; vk[i] = sgd_param_.beta_2*vk[i] + (1.0 - sgd_param_.beta_2)*g[i]*g[i]; m[i] = m[i] - alpha_k * mk[i]/(sqrt(vk[i]) + sgd_param_.sigma); //if (m[i] != m[i]) return SGD_NAN_VALUE; } t++; if (sgd_param_.iteration > 0 && t >= sgd_param_.iteration) break; } return SGD_REACHED_MAX_ITERATIONS; } gctl::sgd_return_code gctl::sgd_solver::nadam(array &m) { int n_size = m.size(); //check parameters if (n_size <= 0) return SGD_INVALID_VARIABLE_SIZE; if (sgd_param_.epsilon < 0) return SGD_INVALID_EPSILON; if (sgd_param_.alpha < 0) return SGD_INVALID_ALPHA; if (sgd_param_.beta_1 < 0.0 || sgd_param_.beta_1 >= 1.0) return SGD_INVALID_BETA; if (sgd_param_.beta_2 < 0.0 || sgd_param_.beta_2 >= 1.0) return SGD_INVALID_BETA; if (sgd_param_.sigma < 0.0) return SGD_INVALID_SIGMA; array mk(n_size, 0.0); array mk_hat(n_size); array nk(n_size, 0.0); array nk_hat(n_size); array g (n_size); array g_hat(n_size); double beta_1t = 1.0, beta_1t1 = sgd_param_.beta_1, beta_2t = 1.0; int t = 0; double fx; while (1) { fx = SGD_Evaluate(m, g); if (SGD_Progress(fx, m, sgd_param_, t)) return SGD_STOP; if (fx < sgd_param_.epsilon) return SGD_CONVERGENCE; beta_1t *= sgd_param_.beta_1; beta_1t1 *= sgd_param_.beta_1; beta_2t *= sgd_param_.beta_2; for (int i = 0; i < n_size; i++) { g_hat[i] = g[i]/(1.0 - beta_1t); mk[i] = sgd_param_.beta_1*mk[i] + (1.0 - sgd_param_.beta_1)*g[i]; nk[i] = sgd_param_.beta_2*nk[i] + (1.0 - sgd_param_.beta_2)*g[i]*g[i]; mk_hat[i] = mk[i]/(1.0 - beta_1t1); nk_hat[i] = nk[i]/(1.0 - beta_2t); m[i] = m[i] - sgd_param_.alpha * ((1.0 - beta_1t)*g_hat[i] + beta_1t1*mk_hat[i])/(sqrt(nk_hat[i]) + sgd_param_.sigma); if (m[i] != m[i]) return SGD_NAN_VALUE; } t++; if (sgd_param_.iteration > 0 && t >= sgd_param_.iteration) break; } return SGD_REACHED_MAX_ITERATIONS; } gctl::sgd_return_code gctl::sgd_solver::adamax(array &m) { int n_size = m.size(); //check parameters if (n_size <= 0) return SGD_INVALID_VARIABLE_SIZE; if (sgd_param_.epsilon < 0) return SGD_INVALID_EPSILON; if (sgd_param_.alpha < 0) return SGD_INVALID_ALPHA; if (sgd_param_.beta_1 < 0.0 || sgd_param_.beta_1 >= 1.0) return SGD_INVALID_BETA; if (sgd_param_.beta_2 < 0.0 || sgd_param_.beta_2 >= 1.0) return SGD_INVALID_BETA; if (sgd_param_.sigma < 0.0) return SGD_INVALID_SIGMA; array mk(n_size, 0.0); array vk(n_size, 0.0); array g (n_size); double beta_1t = 1.0; int t = 0; double fx; while (1) { fx = SGD_Evaluate(m, g); if (SGD_Progress(fx, m, sgd_param_, t)) return SGD_STOP; if (fx < sgd_param_.epsilon) return SGD_CONVERGENCE; beta_1t *= sgd_param_.beta_1; for (int i = 0; i < n_size; i++) { mk[i] = sgd_param_.beta_1*mk[i] + (1.0 - sgd_param_.beta_1)*g[i]; vk[i] = std::max(sgd_param_.beta_2*vk[i], std::fabs(g[i])); m[i] = m[i] - sgd_param_.alpha * mk[i]/((1.0 - beta_1t)*vk[i]); if (m[i] != m[i]) return SGD_NAN_VALUE; } t++; if (sgd_param_.iteration > 0 && t >= sgd_param_.iteration) break; } return SGD_REACHED_MAX_ITERATIONS; } gctl::sgd_return_code gctl::sgd_solver::adabelief(array &m) { int n_size = m.size(); //check parameters if (n_size <= 0) return SGD_INVALID_VARIABLE_SIZE; if (sgd_param_.epsilon < 0) return SGD_INVALID_EPSILON; if (sgd_param_.alpha < 0) return SGD_INVALID_ALPHA; if (sgd_param_.beta_1 < 0.0 || sgd_param_.beta_1 >= 1.0) return SGD_INVALID_BETA; if (sgd_param_.beta_2 < 0.0 || sgd_param_.beta_2 >= 1.0) return SGD_INVALID_BETA; if (sgd_param_.sigma < 0.0) return SGD_INVALID_SIGMA; array mk(n_size, 0.0); array vk(n_size, 0.0); array g (n_size); double beta_1t = 1.0, beta_2t = 1.0; double alpha_k; int t = 0; double fx; while (1) { fx = SGD_Evaluate(m, g); if (SGD_Progress(fx, m, sgd_param_, t)) return SGD_STOP; if (fx < sgd_param_.epsilon) return SGD_CONVERGENCE; beta_1t *= sgd_param_.beta_1; beta_2t *= sgd_param_.beta_2; alpha_k = sgd_param_.alpha * sqrt(1.0 - beta_2t)/(1.0 - beta_1t); for (int i = 0; i < n_size; i++) { mk[i] = sgd_param_.beta_1*mk[i] + (1.0 - sgd_param_.beta_1)*g[i]; vk[i] = sgd_param_.beta_2*vk[i] + (1.0 - sgd_param_.beta_2)*(g[i] - mk[i])*(g[i] - mk[i]); m[i] = m[i] - alpha_k * mk[i]/(sqrt(vk[i]) + sgd_param_.sigma); if (m[i] != m[i]) return SGD_NAN_VALUE; } t++; if (sgd_param_.iteration > 0 && t >= sgd_param_.iteration) break; } return SGD_REACHED_MAX_ITERATIONS; }