/********************************************************
* ██████╗ ██████╗████████╗██╗
* ██╔════╝ ██╔════╝╚══██╔══╝██║
* ██║ ███╗██║ ██║ ██║
* ██║ ██║██║ ██║ ██║
* ╚██████╔╝╚██████╗ ██║ ███████╗
* ╚═════╝ ╚═════╝ ╚═╝ ╚══════╝
* Geophysical Computational Tools & Library (GCTL)
*
* Copyright (c) 2022 Yi Zhang (yizhang-geo@zju.edu.cn)
*
* GCTL is distributed under a dual licensing scheme. You can redistribute
* it and/or modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation, either version 2
* of the License, or (at your option) any later version. You should have
* received a copy of the GNU Lesser General Public License along with this
* program. If not, see .
*
* If the terms and conditions of the LGPL v.2. would prevent you from using
* the GCTL, please consider the option to obtain a commercial license for a
* fee. These licenses are offered by the GCTL's original author. As a rule,
* licenses are provided "as-is", unlimited in time for a one time fee. Please
* send corresponding requests to: yizhang-geo@zju.edu.cn. Please do not forget
* to include some description of your company and the realm of its activities.
* Also add information on how to contact you by electronic and paper mail.
******************************************************/
#include "sgd.h"
/**
* Default parameter for the SGD methods.
*/
static const gctl::sgd_para sgd_defparam = {0, 1e-6, 0.01, 0.01, 0.9, 0.999, 1e-8};
gctl::sgd_solver::sgd_solver()
{
sgd_param_ = sgd_defparam;
sgd_inter_ = 1;
sgd_silent_ = false;
solver_name_ = "Undefined";
}
gctl::sgd_solver::~sgd_solver(){}
int gctl::sgd_solver::SGD_Progress(double fx, const array &x, const sgd_para ¶m, const int k)
{
if (sgd_silent_) return 0;
if (param.epsilon > 0.0 && fx <= param.epsilon)
{
std::clog << GCTL_CLEARLINE << "\rF(x) = " << fx << ", Train-Times = " << k;
return 0;
}
if (sgd_inter_ > 0 && k%sgd_inter_ == 0)
{
std::clog << GCTL_CLEARLINE << "\rF(x) = " << fx << ", Train-Times = " << k;
}
return 0;
}
void gctl::sgd_solver::sgd_silent()
{
sgd_silent_ = true;
return;
}
void gctl::sgd_solver::set_sgd_report_interval(int inter)
{
sgd_inter_ = inter;
return;
}
void gctl::sgd_solver::set_sgd_para(const sgd_para &in_param)
{
sgd_param_ = in_param;
return;
}
void gctl::sgd_solver::set_sgd_para(const toml::value &toml_data)
{
sgd_param_ = sgd_defparam;
std::string SGD = "sgd";
if (toml_data.contains(SGD))
{
if (toml_data.at(SGD).contains("iteration")) sgd_param_.iteration = toml::find(toml_data, SGD, "iteration");
if (toml_data.at(SGD).contains("epsilon")) sgd_param_.epsilon = toml::find(toml_data, SGD, "epsilon");
if (toml_data.at(SGD).contains("mu")) sgd_param_.mu = toml::find(toml_data, SGD, "mu");
if (toml_data.at(SGD).contains("alpha")) sgd_param_.alpha = toml::find(toml_data, SGD, "alpha");
if (toml_data.at(SGD).contains("beta_1")) sgd_param_.beta_1 = toml::find(toml_data, SGD, "beta_1");
if (toml_data.at(SGD).contains("beta_2")) sgd_param_.beta_2 = toml::find(toml_data, SGD, "beta_2");
if (toml_data.at(SGD).contains("sigma")) sgd_param_.sigma = toml::find(toml_data, SGD, "sigma");
}
return;
}
void gctl::sgd_solver::show_solver()
{
std::clog << "Solver's Setup Panel\n";
std::clog << "-----------------------------\n";
std::clog << "Solver: " << solver_name_ << "\n";
std::clog << "Iteration = " << sgd_param_.iteration << ", Epsilon = " << sgd_param_.epsilon << ", Mu = " << sgd_param_.mu << "\n";
std::clog << "Alpha = " << sgd_param_.alpha << ", Beta1 = " << sgd_param_.beta_1 << ", Beta2 = " << sgd_param_.beta_2 << ", Sigma = " << sgd_param_.sigma << "\n";
std::clog << "=============================\n";
return;
}
void gctl::sgd_solver::sgd_error_str(sgd_return_code err_code, std::ostream &ss, bool err_throw)
{
#if defined _WINDOWS || __WIN32__
if (!err_throw)
{
if (err_code >= 0)
{
SetConsoleTextAttribute(GetStdHandle(STD_ERROR_HANDLE), FOREGROUND_INTENSITY | FOREGROUND_GREEN);
ss << "Success! ";
}
else
{
SetConsoleTextAttribute(GetStdHandle(STD_ERROR_HANDLE), FOREGROUND_INTENSITY | FOREGROUND_RED);
ss << "Fail! ";
}
}
#else
if (!err_throw)
{
if (err_code >= 0)
ss << "\033[1m\033[32mSGD Success! ";
else
ss << "\033[1m\033[31mSGD Fail! ";
}
#endif
std::string err_str;
switch (err_code)
{
case SGD_SUCCESS:
err_str = "Success."; break;
case SGD_CONVERGENCE:
err_str = "The iteration reached convergence."; break;
case SGD_STOP:
err_str = "The iteration stopped by the progress evaluation function."; break;
case SGD_UNKNOWN_ERROR:
err_str = "Unknown error."; break;
case SGD_INVALID_VARIABLE_SIZE:
err_str = "Invalid array size."; break;
case SGD_REACHED_MAX_ITERATIONS:
err_str = "The maximal iteration is reached."; break;
case SGD_INVALID_EPSILON:
err_str = "Invalid value for epsilon."; break;
case SGD_INVALID_BETA:
err_str = "Invalid value for beta."; break;
case SGD_INVALID_MU:
err_str = "Invalid value for mu."; break;
case SGD_INVALID_ALPHA:
err_str = "Invalid value for alpha."; break;
case SGD_INVALID_SIGMA:
err_str = "Invalid value for sigma."; break;
case SGD_NAN_VALUE:
err_str = "NaN values found."; break;
default:
err_str = "Unknown error."; break;
}
if (err_throw && err_code < 0) throw err_str;
else ss << err_str;
#if defined _WINDOWS || __WIN32__
if (!err_throw)
{
if (err_code >= 0)
{
SetConsoleTextAttribute(GetStdHandle(STD_ERROR_HANDLE), 7);
ss << std::endl;
}
else
{
SetConsoleTextAttribute(GetStdHandle(STD_ERROR_HANDLE), 7);
ss << std::endl;
}
}
#else
if (!err_throw)
{
if (err_code >= 0)
ss << "\033[0m" << std::endl;
else
ss << "\033[0m" << std::endl;
}
#endif
return;
}
gctl::sgd_para gctl::sgd_solver::default_sgd_para()
{
sgd_para dp = sgd_defparam;
return dp;
}
void gctl::sgd_solver::SGD_Minimize(array &m, sgd_solver_type solver_id, std::ostream &ss, bool verbose, bool err_throw)
{
if (sgd_silent_)
{
sgd_return_code ret;
if (solver_id == MOMENTUM) {solver_name_ = "MOMENTUM"; ret = momentum(m);}
else if (solver_id == NAG) {solver_name_ = "NAG"; ret = nag(m);}
else if (solver_id == ADAGRAD) {solver_name_ = "ADAGRAD"; ret = adagrad(m);}
else if (solver_id == RMSPROP) {solver_name_ = "RMSPROP"; ret = rmsprop(m);}
else if (solver_id == ADAM) {solver_name_ = "ADAM"; ret = adam(m);}
else if (solver_id == NADAM) {solver_name_ = "NADAM"; ret = nadam(m);}
else if (solver_id == ADAMAX) {solver_name_ = "ADAMAX"; ret = adamax(m);}
else if (solver_id == ADABELIEF) {solver_name_ = "ADABELIEF"; ret = adabelief(m);}
else throw std::invalid_argument("Invalid solver type. gstl::sgd_solver::SGD_Minimize(...)");
if (ret < 0) sgd_error_str(ret, ss, true);
return;
}
// 使用lcg求解 注意当我们使用函数指针来调用求解函数时默认参数不可以省略
#ifdef GCTL_OPENMP
double start = omp_get_wtime();
sgd_return_code ret;
if (solver_id == MOMENTUM) {solver_name_ = "MOMENTUM"; ret = momentum(m);}
else if (solver_id == NAG) {solver_name_ = "NAG"; ret = nag(m);}
else if (solver_id == ADAGRAD) {solver_name_ = "ADAGRAD"; ret = adagrad(m);}
else if (solver_id == RMSPROP) {solver_name_ = "RMSPROP"; ret = rmsprop(m);}
else if (solver_id == ADAM) {solver_name_ = "ADAM"; ret = adam(m);}
else if (solver_id == NADAM) {solver_name_ = "NADAM"; ret = nadam(m);}
else if (solver_id == ADAMAX) {solver_name_ = "ADAMAX"; ret = adamax(m);}
else if (solver_id == ADABELIEF) {solver_name_ = "ADABELIEF"; ret = adabelief(m);}
else throw std::invalid_argument("Invalid solver type. gstl::sgd_solver::SGD_Minimize(...)");
double end = omp_get_wtime();
double costime = 1000*(end-start);
#else
clock_t start = clock();
sgd_return_code ret;
if (solver_id == MOMENTUM) {solver_name_ = "MOMENTUM"; ret = momentum(m);}
else if (solver_id == NAG) {solver_name_ = "NAG"; ret = nag(m);}
else if (solver_id == ADAGRAD) {solver_name_ = "ADAGRAD"; ret = adagrad(m);}
else if (solver_id == RMSPROP) {solver_name_ = "RMSPROP"; ret = rmsprop(m);}
else if (solver_id == ADAM) {solver_name_ = "ADAM"; ret = adam(m);}
else if (solver_id == NADAM) {solver_name_ = "NADAM"; ret = nadam(m);}
else if (solver_id == ADAMAX) {solver_name_ = "ADAMAX"; ret = adamax(m);}
else if (solver_id == ADABELIEF) {solver_name_ = "ADABELIEF"; ret = adabelief(m);}
else throw std::invalid_argument("Invalid solver type. gstl::sgd_solver::SGD_Minimize(...)");
clock_t end = clock();
double costime = 1000*(end-start)/(double)CLOCKS_PER_SEC;
#endif
if (!err_throw)
{
std::clog << std::endl;
switch (solver_id)
{
case MOMENTUM:
std::clog << "Solver: MOMENTUM. Time cost: " << costime << " ms" << std::endl;
break;
case NAG:
std::clog << "Solver: NAG. Time cost: " << costime << " ms" << std::endl;
break;
case ADAGRAD:
std::clog << "Solver: ADAGRAD. Time cost: " << costime << " ms" << std::endl;
break;
case RMSPROP:
std::clog << "Solver: RMSPROP. Time cost: " << costime << " ms" << std::endl;
break;
case ADAM:
std::clog << "Solver: ADAM. Time cost: " << costime << " ms" << std::endl;
break;
case NADAM:
std::clog << "Solver: NADAM. Time cost: " << costime << " ms" << std::endl;
break;
case ADAMAX:
std::clog << "Solver: ADAMAX. Time cost: " << costime << " ms" << std::endl;
break;
case ADABELIEF:
std::clog << "Solver: ADABELIEF. Time cost: " << costime << " ms" << std::endl;
break;
default:
std::clog << "Solver: Unknown. Time cost: " << costime << " ms" << std::endl;
break;
}
}
if (verbose) sgd_error_str(ret, ss, err_throw);
else if (ret < 0) sgd_error_str(ret, ss, err_throw);
return;
}
gctl::sgd_return_code gctl::sgd_solver::momentum(array &m)
{
int n_size = m.size();
//check parameters
if (n_size <= 0) return SGD_INVALID_VARIABLE_SIZE;
if (sgd_param_.epsilon < 0) return SGD_INVALID_EPSILON;
if (sgd_param_.mu < 0 || sgd_param_.mu >= 1.0) return SGD_INVALID_MU;
array mk(n_size, 0.0);
array g(n_size);
int t = 0;
double fx;
while (1)
{
fx = SGD_Evaluate(m, g);
if (SGD_Progress(fx, m, sgd_param_, t)) return SGD_STOP;
if (fx < sgd_param_.epsilon) return SGD_CONVERGENCE;
for (int i = 0; i < n_size; i++)
{
mk[i] = sgd_param_.mu*mk[i] + g[i];
m[i] = m[i] - sgd_param_.alpha * mk[i];
if (m[i] != m[i]) return SGD_NAN_VALUE;
}
t++;
if (sgd_param_.iteration > 0 && t >= sgd_param_.iteration) break;
}
return SGD_REACHED_MAX_ITERATIONS;
}
gctl::sgd_return_code gctl::sgd_solver::nag(array &m)
{
int n_size = m.size();
//check parameters
if (n_size <= 0) return SGD_INVALID_VARIABLE_SIZE;
if (sgd_param_.epsilon < 0) return SGD_INVALID_EPSILON;
if (sgd_param_.mu < 0 || sgd_param_.mu >= 1.0) return SGD_INVALID_MU;
array mk(n_size, 0.0);
array xk(n_size);
array g (n_size);
int t = 0;
double fx;
while (1)
{
for (int i = 0; i < n_size; i++)
{
xk[i] = m[i] - sgd_param_.mu*sgd_param_.alpha*mk[i];
}
fx = SGD_Evaluate(xk, g);
if (SGD_Progress(fx, m, sgd_param_, t)) return SGD_STOP;
if (fx < sgd_param_.epsilon) return SGD_CONVERGENCE;
for (int i = 0; i < n_size; i++)
{
mk[i] = sgd_param_.mu*mk[i] + g[i];
m[i] = m[i] - sgd_param_.alpha * mk[i];
if (m[i] != m[i]) return SGD_NAN_VALUE;
}
t++;
if (sgd_param_.iteration > 0 && t >= sgd_param_.iteration) break;
}
return SGD_REACHED_MAX_ITERATIONS;
}
gctl::sgd_return_code gctl::sgd_solver::adagrad(array &m)
{
int n_size = m.size();
//check parameters
if (n_size <= 0) return SGD_INVALID_VARIABLE_SIZE;
if (sgd_param_.epsilon < 0.0) return SGD_INVALID_EPSILON;
if (sgd_param_.sigma < 0.0) return SGD_INVALID_SIGMA;
array mk(n_size, 0.0);
array g (n_size);
int t = 0;
double fx;
while (1)
{
fx = SGD_Evaluate(m, g);
if (SGD_Progress(fx, m, sgd_param_, t)) return SGD_STOP;
if (fx < sgd_param_.epsilon) return SGD_CONVERGENCE;
for (int i = 0; i < n_size; i++)
{
mk[i] = mk[i] + g[i]*g[i];
m[i] = m[i] - sgd_param_.alpha * g[i]/(sqrt(mk[i]) + sgd_param_.sigma);
if (m[i] != m[i]) return SGD_NAN_VALUE;
}
t++;
if (sgd_param_.iteration > 0 && t >= sgd_param_.iteration) break;
}
return SGD_REACHED_MAX_ITERATIONS;
}
gctl::sgd_return_code gctl::sgd_solver::rmsprop(array &m)
{
int n_size = m.size();
//check parameters
if (n_size <= 0) return SGD_INVALID_VARIABLE_SIZE;
if (sgd_param_.epsilon < 0.0) return SGD_INVALID_EPSILON;
if (sgd_param_.sigma < 0.0) return SGD_INVALID_SIGMA;
array vk(n_size, 0.0);
array g (n_size);
int t = 0;
double fx;
while (1)
{
fx = SGD_Evaluate(m, g);
if (SGD_Progress(fx, m, sgd_param_, t)) return SGD_STOP;
if (fx < sgd_param_.epsilon) return SGD_CONVERGENCE;
for (int i = 0; i < n_size; i++)
{
vk[i] = sgd_param_.beta_2 * vk[i] + (1.0 - sgd_param_.beta_2)*g[i]*g[i];
m[i] = m[i] - sgd_param_.alpha * g[i]/(sqrt(vk[i]) + sgd_param_.sigma);
if (m[i] != m[i]) return SGD_NAN_VALUE;
}
t++;
if (sgd_param_.iteration > 0 && t >= sgd_param_.iteration) break;
}
return SGD_REACHED_MAX_ITERATIONS;
}
gctl::sgd_return_code gctl::sgd_solver::adam(array &m)
{
int n_size = m.size();
//check parameters
if (n_size <= 0) return SGD_INVALID_VARIABLE_SIZE;
if (sgd_param_.epsilon < 0) return SGD_INVALID_EPSILON;
if (sgd_param_.alpha < 0) return SGD_INVALID_ALPHA;
if (sgd_param_.beta_1 < 0.0 || sgd_param_.beta_1 >= 1.0) return SGD_INVALID_BETA;
if (sgd_param_.beta_2 < 0.0 || sgd_param_.beta_2 >= 1.0) return SGD_INVALID_BETA;
if (sgd_param_.sigma < 0.0) return SGD_INVALID_SIGMA;
array mk(n_size, 0.0);
array vk(n_size, 0.0);
array g (n_size);
double beta_1t = 1.0, beta_2t = 1.0;
double alpha_k;
int t = 0;
double fx;
while (1)
{
fx = SGD_Evaluate(m, g);
if (SGD_Progress(fx, m, sgd_param_, t)) return SGD_STOP;
if (fx < sgd_param_.epsilon) return SGD_CONVERGENCE;
beta_1t *= sgd_param_.beta_1;
beta_2t *= sgd_param_.beta_2;
alpha_k = sgd_param_.alpha * sqrt(1.0 - beta_2t)/(1.0 - beta_1t);
int i;
#pragma omp parallel for private (i) schedule(guided)
for (i = 0; i < n_size; i++)
{
mk[i] = sgd_param_.beta_1*mk[i] + (1.0 - sgd_param_.beta_1)*g[i];
vk[i] = sgd_param_.beta_2*vk[i] + (1.0 - sgd_param_.beta_2)*g[i]*g[i];
m[i] = m[i] - alpha_k * mk[i]/(sqrt(vk[i]) + sgd_param_.sigma);
//if (m[i] != m[i]) return SGD_NAN_VALUE;
}
t++;
if (sgd_param_.iteration > 0 && t >= sgd_param_.iteration) break;
}
return SGD_REACHED_MAX_ITERATIONS;
}
gctl::sgd_return_code gctl::sgd_solver::nadam(array &m)
{
int n_size = m.size();
//check parameters
if (n_size <= 0) return SGD_INVALID_VARIABLE_SIZE;
if (sgd_param_.epsilon < 0) return SGD_INVALID_EPSILON;
if (sgd_param_.alpha < 0) return SGD_INVALID_ALPHA;
if (sgd_param_.beta_1 < 0.0 || sgd_param_.beta_1 >= 1.0) return SGD_INVALID_BETA;
if (sgd_param_.beta_2 < 0.0 || sgd_param_.beta_2 >= 1.0) return SGD_INVALID_BETA;
if (sgd_param_.sigma < 0.0) return SGD_INVALID_SIGMA;
array mk(n_size, 0.0);
array mk_hat(n_size);
array nk(n_size, 0.0);
array nk_hat(n_size);
array g (n_size);
array g_hat(n_size);
double beta_1t = 1.0, beta_1t1 = sgd_param_.beta_1, beta_2t = 1.0;
int t = 0;
double fx;
while (1)
{
fx = SGD_Evaluate(m, g);
if (SGD_Progress(fx, m, sgd_param_, t)) return SGD_STOP;
if (fx < sgd_param_.epsilon) return SGD_CONVERGENCE;
beta_1t *= sgd_param_.beta_1;
beta_1t1 *= sgd_param_.beta_1;
beta_2t *= sgd_param_.beta_2;
for (int i = 0; i < n_size; i++)
{
g_hat[i] = g[i]/(1.0 - beta_1t);
mk[i] = sgd_param_.beta_1*mk[i] + (1.0 - sgd_param_.beta_1)*g[i];
nk[i] = sgd_param_.beta_2*nk[i] + (1.0 - sgd_param_.beta_2)*g[i]*g[i];
mk_hat[i] = mk[i]/(1.0 - beta_1t1);
nk_hat[i] = nk[i]/(1.0 - beta_2t);
m[i] = m[i] - sgd_param_.alpha * ((1.0 - beta_1t)*g_hat[i]
+ beta_1t1*mk_hat[i])/(sqrt(nk_hat[i]) + sgd_param_.sigma);
if (m[i] != m[i]) return SGD_NAN_VALUE;
}
t++;
if (sgd_param_.iteration > 0 && t >= sgd_param_.iteration) break;
}
return SGD_REACHED_MAX_ITERATIONS;
}
gctl::sgd_return_code gctl::sgd_solver::adamax(array &m)
{
int n_size = m.size();
//check parameters
if (n_size <= 0) return SGD_INVALID_VARIABLE_SIZE;
if (sgd_param_.epsilon < 0) return SGD_INVALID_EPSILON;
if (sgd_param_.alpha < 0) return SGD_INVALID_ALPHA;
if (sgd_param_.beta_1 < 0.0 || sgd_param_.beta_1 >= 1.0) return SGD_INVALID_BETA;
if (sgd_param_.beta_2 < 0.0 || sgd_param_.beta_2 >= 1.0) return SGD_INVALID_BETA;
if (sgd_param_.sigma < 0.0) return SGD_INVALID_SIGMA;
array mk(n_size, 0.0);
array vk(n_size, 0.0);
array g (n_size);
double beta_1t = 1.0;
int t = 0;
double fx;
while (1)
{
fx = SGD_Evaluate(m, g);
if (SGD_Progress(fx, m, sgd_param_, t)) return SGD_STOP;
if (fx < sgd_param_.epsilon) return SGD_CONVERGENCE;
beta_1t *= sgd_param_.beta_1;
for (int i = 0; i < n_size; i++)
{
mk[i] = sgd_param_.beta_1*mk[i] + (1.0 - sgd_param_.beta_1)*g[i];
vk[i] = std::max(sgd_param_.beta_2*vk[i], std::fabs(g[i]));
m[i] = m[i] - sgd_param_.alpha * mk[i]/((1.0 - beta_1t)*vk[i]);
if (m[i] != m[i]) return SGD_NAN_VALUE;
}
t++;
if (sgd_param_.iteration > 0 && t >= sgd_param_.iteration) break;
}
return SGD_REACHED_MAX_ITERATIONS;
}
gctl::sgd_return_code gctl::sgd_solver::adabelief(array &m)
{
int n_size = m.size();
//check parameters
if (n_size <= 0) return SGD_INVALID_VARIABLE_SIZE;
if (sgd_param_.epsilon < 0) return SGD_INVALID_EPSILON;
if (sgd_param_.alpha < 0) return SGD_INVALID_ALPHA;
if (sgd_param_.beta_1 < 0.0 || sgd_param_.beta_1 >= 1.0) return SGD_INVALID_BETA;
if (sgd_param_.beta_2 < 0.0 || sgd_param_.beta_2 >= 1.0) return SGD_INVALID_BETA;
if (sgd_param_.sigma < 0.0) return SGD_INVALID_SIGMA;
array mk(n_size, 0.0);
array vk(n_size, 0.0);
array g (n_size);
double beta_1t = 1.0, beta_2t = 1.0;
double alpha_k;
int t = 0;
double fx;
while (1)
{
fx = SGD_Evaluate(m, g);
if (SGD_Progress(fx, m, sgd_param_, t)) return SGD_STOP;
if (fx < sgd_param_.epsilon) return SGD_CONVERGENCE;
beta_1t *= sgd_param_.beta_1;
beta_2t *= sgd_param_.beta_2;
alpha_k = sgd_param_.alpha * sqrt(1.0 - beta_2t)/(1.0 - beta_1t);
for (int i = 0; i < n_size; i++)
{
mk[i] = sgd_param_.beta_1*mk[i] + (1.0 - sgd_param_.beta_1)*g[i];
vk[i] = sgd_param_.beta_2*vk[i] + (1.0 - sgd_param_.beta_2)*(g[i] - mk[i])*(g[i] - mk[i]);
m[i] = m[i] - alpha_k * mk[i]/(sqrt(vk[i]) + sgd_param_.sigma);
if (m[i] != m[i]) return SGD_NAN_VALUE;
}
t++;
if (sgd_param_.iteration > 0 && t >= sgd_param_.iteration) break;
}
return SGD_REACHED_MAX_ITERATIONS;
}