634 lines
20 KiB
C++
634 lines
20 KiB
C++
|
/********************************************************
|
||
|
* ██████╗ ██████╗████████╗██╗
|
||
|
* ██╔════╝ ██╔════╝╚══██╔══╝██║
|
||
|
* ██║ ███╗██║ ██║ ██║
|
||
|
* ██║ ██║██║ ██║ ██║
|
||
|
* ╚██████╔╝╚██████╗ ██║ ███████╗
|
||
|
* ╚═════╝ ╚═════╝ ╚═╝ ╚══════╝
|
||
|
* Geophysical Computational Tools & Library (GCTL)
|
||
|
*
|
||
|
* Copyright (c) 2022 Yi Zhang (yizhang-geo@zju.edu.cn)
|
||
|
*
|
||
|
* GCTL is distributed under a dual licensing scheme. You can redistribute
|
||
|
* it and/or modify it under the terms of the GNU Lesser General Public
|
||
|
* License as published by the Free Software Foundation, either version 2
|
||
|
* of the License, or (at your option) any later version. You should have
|
||
|
* received a copy of the GNU Lesser General Public License along with this
|
||
|
* program. If not, see <http://www.gnu.org/licenses/>.
|
||
|
*
|
||
|
* If the terms and conditions of the LGPL v.2. would prevent you from using
|
||
|
* the GCTL, please consider the option to obtain a commercial license for a
|
||
|
* fee. These licenses are offered by the GCTL's original author. As a rule,
|
||
|
* licenses are provided "as-is", unlimited in time for a one time fee. Please
|
||
|
* send corresponding requests to: yizhang-geo@zju.edu.cn. Please do not forget
|
||
|
* to include some description of your company and the realm of its activities.
|
||
|
* Also add information on how to contact you by electronic and paper mail.
|
||
|
******************************************************/
|
||
|
|
||
|
#include "sgd.h"
|
||
|
|
||
|
/**
|
||
|
* Default parameter for the SGD methods.
|
||
|
*/
|
||
|
static const gctl::sgd_para sgd_defparam = {0, 1e-6, 0.01, 0.01, 0.9, 0.999, 1e-8};
|
||
|
|
||
|
gctl::sgd_solver::sgd_solver()
|
||
|
{
|
||
|
sgd_param_ = sgd_defparam;
|
||
|
sgd_inter_ = 1;
|
||
|
sgd_silent_ = false;
|
||
|
solver_name_ = "Undefined";
|
||
|
}
|
||
|
|
||
|
gctl::sgd_solver::~sgd_solver(){}
|
||
|
|
||
|
int gctl::sgd_solver::SGD_Progress(double fx, const array<double> &x, const sgd_para ¶m, const int k)
|
||
|
{
|
||
|
if (sgd_silent_) return 0;
|
||
|
|
||
|
if (param.epsilon > 0.0 && fx <= param.epsilon)
|
||
|
{
|
||
|
std::clog << GCTL_CLEARLINE << "\rF(x) = " << fx << ", Train-Times = " << k;
|
||
|
return 0;
|
||
|
}
|
||
|
|
||
|
if (sgd_inter_ > 0 && k%sgd_inter_ == 0)
|
||
|
{
|
||
|
std::clog << GCTL_CLEARLINE << "\rF(x) = " << fx << ", Train-Times = " << k;
|
||
|
}
|
||
|
return 0;
|
||
|
}
|
||
|
|
||
|
void gctl::sgd_solver::sgd_silent()
|
||
|
{
|
||
|
sgd_silent_ = true;
|
||
|
return;
|
||
|
}
|
||
|
|
||
|
void gctl::sgd_solver::set_sgd_report_interval(int inter)
|
||
|
{
|
||
|
sgd_inter_ = inter;
|
||
|
return;
|
||
|
}
|
||
|
|
||
|
void gctl::sgd_solver::set_sgd_para(const sgd_para &in_param)
|
||
|
{
|
||
|
sgd_param_ = in_param;
|
||
|
return;
|
||
|
}
|
||
|
|
||
|
void gctl::sgd_solver::set_sgd_para(const toml::value &toml_data)
|
||
|
{
|
||
|
sgd_param_ = sgd_defparam;
|
||
|
|
||
|
std::string SGD = "sgd";
|
||
|
if (toml_data.contains(SGD))
|
||
|
{
|
||
|
if (toml_data.at(SGD).contains("iteration")) sgd_param_.iteration = toml::find<int>(toml_data, SGD, "iteration");
|
||
|
if (toml_data.at(SGD).contains("epsilon")) sgd_param_.epsilon = toml::find<double>(toml_data, SGD, "epsilon");
|
||
|
if (toml_data.at(SGD).contains("mu")) sgd_param_.mu = toml::find<double>(toml_data, SGD, "mu");
|
||
|
if (toml_data.at(SGD).contains("alpha")) sgd_param_.alpha = toml::find<double>(toml_data, SGD, "alpha");
|
||
|
if (toml_data.at(SGD).contains("beta_1")) sgd_param_.beta_1 = toml::find<double>(toml_data, SGD, "beta_1");
|
||
|
if (toml_data.at(SGD).contains("beta_2")) sgd_param_.beta_2 = toml::find<double>(toml_data, SGD, "beta_2");
|
||
|
if (toml_data.at(SGD).contains("sigma")) sgd_param_.sigma = toml::find<double>(toml_data, SGD, "sigma");
|
||
|
}
|
||
|
return;
|
||
|
}
|
||
|
|
||
|
void gctl::sgd_solver::show_solver()
|
||
|
{
|
||
|
std::clog << "Solver's Setup Panel\n";
|
||
|
std::clog << "-----------------------------\n";
|
||
|
std::clog << "Solver: " << solver_name_ << "\n";
|
||
|
std::clog << "Iteration = " << sgd_param_.iteration << ", Epsilon = " << sgd_param_.epsilon << ", Mu = " << sgd_param_.mu << "\n";
|
||
|
std::clog << "Alpha = " << sgd_param_.alpha << ", Beta1 = " << sgd_param_.beta_1 << ", Beta2 = " << sgd_param_.beta_2 << ", Sigma = " << sgd_param_.sigma << "\n";
|
||
|
std::clog << "=============================\n";
|
||
|
return;
|
||
|
}
|
||
|
|
||
|
void gctl::sgd_solver::sgd_error_str(sgd_return_code err_code, std::ostream &ss, bool err_throw)
|
||
|
{
|
||
|
#if defined _WINDOWS || __WIN32__
|
||
|
if (!err_throw)
|
||
|
{
|
||
|
if (err_code >= 0)
|
||
|
{
|
||
|
SetConsoleTextAttribute(GetStdHandle(STD_ERROR_HANDLE), FOREGROUND_INTENSITY | FOREGROUND_GREEN);
|
||
|
ss << "Success! ";
|
||
|
}
|
||
|
else
|
||
|
{
|
||
|
SetConsoleTextAttribute(GetStdHandle(STD_ERROR_HANDLE), FOREGROUND_INTENSITY | FOREGROUND_RED);
|
||
|
ss << "Fail! ";
|
||
|
}
|
||
|
}
|
||
|
#else
|
||
|
if (!err_throw)
|
||
|
{
|
||
|
if (err_code >= 0)
|
||
|
ss << "\033[1m\033[32mSGD Success! ";
|
||
|
else
|
||
|
ss << "\033[1m\033[31mSGD Fail! ";
|
||
|
}
|
||
|
#endif
|
||
|
|
||
|
std::string err_str;
|
||
|
switch (err_code)
|
||
|
{
|
||
|
case SGD_SUCCESS:
|
||
|
err_str = "Success."; break;
|
||
|
case SGD_CONVERGENCE:
|
||
|
err_str = "The iteration reached convergence."; break;
|
||
|
case SGD_STOP:
|
||
|
err_str = "The iteration stopped by the progress evaluation function."; break;
|
||
|
case SGD_UNKNOWN_ERROR:
|
||
|
err_str = "Unknown error."; break;
|
||
|
case SGD_INVALID_VARIABLE_SIZE:
|
||
|
err_str = "Invalid array size."; break;
|
||
|
case SGD_REACHED_MAX_ITERATIONS:
|
||
|
err_str = "The maximal iteration is reached."; break;
|
||
|
case SGD_INVALID_EPSILON:
|
||
|
err_str = "Invalid value for epsilon."; break;
|
||
|
case SGD_INVALID_BETA:
|
||
|
err_str = "Invalid value for beta."; break;
|
||
|
case SGD_INVALID_MU:
|
||
|
err_str = "Invalid value for mu."; break;
|
||
|
case SGD_INVALID_ALPHA:
|
||
|
err_str = "Invalid value for alpha."; break;
|
||
|
case SGD_INVALID_SIGMA:
|
||
|
err_str = "Invalid value for sigma."; break;
|
||
|
case SGD_NAN_VALUE:
|
||
|
err_str = "NaN values found."; break;
|
||
|
default:
|
||
|
err_str = "Unknown error."; break;
|
||
|
}
|
||
|
|
||
|
if (err_throw && err_code < 0) throw err_str;
|
||
|
else ss << err_str;
|
||
|
|
||
|
#if defined _WINDOWS || __WIN32__
|
||
|
if (!err_throw)
|
||
|
{
|
||
|
if (err_code >= 0)
|
||
|
{
|
||
|
SetConsoleTextAttribute(GetStdHandle(STD_ERROR_HANDLE), 7);
|
||
|
ss << std::endl;
|
||
|
}
|
||
|
else
|
||
|
{
|
||
|
SetConsoleTextAttribute(GetStdHandle(STD_ERROR_HANDLE), 7);
|
||
|
ss << std::endl;
|
||
|
}
|
||
|
}
|
||
|
#else
|
||
|
if (!err_throw)
|
||
|
{
|
||
|
if (err_code >= 0)
|
||
|
ss << "\033[0m" << std::endl;
|
||
|
else
|
||
|
ss << "\033[0m" << std::endl;
|
||
|
}
|
||
|
#endif
|
||
|
return;
|
||
|
}
|
||
|
|
||
|
gctl::sgd_para gctl::sgd_solver::default_sgd_para()
|
||
|
{
|
||
|
sgd_para dp = sgd_defparam;
|
||
|
return dp;
|
||
|
}
|
||
|
|
||
|
void gctl::sgd_solver::SGD_Minimize(array<double> &m, sgd_solver_type solver_id, std::ostream &ss, bool verbose, bool err_throw)
|
||
|
{
|
||
|
if (sgd_silent_)
|
||
|
{
|
||
|
sgd_return_code ret;
|
||
|
if (solver_id == MOMENTUM) {solver_name_ = "MOMENTUM"; ret = momentum(m);}
|
||
|
else if (solver_id == NAG) {solver_name_ = "NAG"; ret = nag(m);}
|
||
|
else if (solver_id == ADAGRAD) {solver_name_ = "ADAGRAD"; ret = adagrad(m);}
|
||
|
else if (solver_id == RMSPROP) {solver_name_ = "RMSPROP"; ret = rmsprop(m);}
|
||
|
else if (solver_id == ADAM) {solver_name_ = "ADAM"; ret = adam(m);}
|
||
|
else if (solver_id == NADAM) {solver_name_ = "NADAM"; ret = nadam(m);}
|
||
|
else if (solver_id == ADAMAX) {solver_name_ = "ADAMAX"; ret = adamax(m);}
|
||
|
else if (solver_id == ADABELIEF) {solver_name_ = "ADABELIEF"; ret = adabelief(m);}
|
||
|
else throw std::invalid_argument("Invalid solver type. gstl::sgd_solver::SGD_Minimize(...)");
|
||
|
|
||
|
if (ret < 0) sgd_error_str(ret, ss, true);
|
||
|
return;
|
||
|
}
|
||
|
|
||
|
// 使用lcg求解 注意当我们使用函数指针来调用求解函数时默认参数不可以省略
|
||
|
#ifdef GCTL_OPENMP
|
||
|
double start = omp_get_wtime();
|
||
|
sgd_return_code ret;
|
||
|
if (solver_id == MOMENTUM) {solver_name_ = "MOMENTUM"; ret = momentum(m);}
|
||
|
else if (solver_id == NAG) {solver_name_ = "NAG"; ret = nag(m);}
|
||
|
else if (solver_id == ADAGRAD) {solver_name_ = "ADAGRAD"; ret = adagrad(m);}
|
||
|
else if (solver_id == RMSPROP) {solver_name_ = "RMSPROP"; ret = rmsprop(m);}
|
||
|
else if (solver_id == ADAM) {solver_name_ = "ADAM"; ret = adam(m);}
|
||
|
else if (solver_id == NADAM) {solver_name_ = "NADAM"; ret = nadam(m);}
|
||
|
else if (solver_id == ADAMAX) {solver_name_ = "ADAMAX"; ret = adamax(m);}
|
||
|
else if (solver_id == ADABELIEF) {solver_name_ = "ADABELIEF"; ret = adabelief(m);}
|
||
|
else throw std::invalid_argument("Invalid solver type. gstl::sgd_solver::SGD_Minimize(...)");
|
||
|
double end = omp_get_wtime();
|
||
|
|
||
|
double costime = 1000*(end-start);
|
||
|
#else
|
||
|
clock_t start = clock();
|
||
|
sgd_return_code ret;
|
||
|
if (solver_id == MOMENTUM) {solver_name_ = "MOMENTUM"; ret = momentum(m);}
|
||
|
else if (solver_id == NAG) {solver_name_ = "NAG"; ret = nag(m);}
|
||
|
else if (solver_id == ADAGRAD) {solver_name_ = "ADAGRAD"; ret = adagrad(m);}
|
||
|
else if (solver_id == RMSPROP) {solver_name_ = "RMSPROP"; ret = rmsprop(m);}
|
||
|
else if (solver_id == ADAM) {solver_name_ = "ADAM"; ret = adam(m);}
|
||
|
else if (solver_id == NADAM) {solver_name_ = "NADAM"; ret = nadam(m);}
|
||
|
else if (solver_id == ADAMAX) {solver_name_ = "ADAMAX"; ret = adamax(m);}
|
||
|
else if (solver_id == ADABELIEF) {solver_name_ = "ADABELIEF"; ret = adabelief(m);}
|
||
|
else throw std::invalid_argument("Invalid solver type. gstl::sgd_solver::SGD_Minimize(...)");
|
||
|
clock_t end = clock();
|
||
|
|
||
|
double costime = 1000*(end-start)/(double)CLOCKS_PER_SEC;
|
||
|
#endif
|
||
|
|
||
|
if (!err_throw)
|
||
|
{
|
||
|
std::clog << std::endl;
|
||
|
switch (solver_id)
|
||
|
{
|
||
|
case MOMENTUM:
|
||
|
std::clog << "Solver: MOMENTUM. Time cost: " << costime << " ms" << std::endl;
|
||
|
break;
|
||
|
case NAG:
|
||
|
std::clog << "Solver: NAG. Time cost: " << costime << " ms" << std::endl;
|
||
|
break;
|
||
|
case ADAGRAD:
|
||
|
std::clog << "Solver: ADAGRAD. Time cost: " << costime << " ms" << std::endl;
|
||
|
break;
|
||
|
case RMSPROP:
|
||
|
std::clog << "Solver: RMSPROP. Time cost: " << costime << " ms" << std::endl;
|
||
|
break;
|
||
|
case ADAM:
|
||
|
std::clog << "Solver: ADAM. Time cost: " << costime << " ms" << std::endl;
|
||
|
break;
|
||
|
case NADAM:
|
||
|
std::clog << "Solver: NADAM. Time cost: " << costime << " ms" << std::endl;
|
||
|
break;
|
||
|
case ADAMAX:
|
||
|
std::clog << "Solver: ADAMAX. Time cost: " << costime << " ms" << std::endl;
|
||
|
break;
|
||
|
case ADABELIEF:
|
||
|
std::clog << "Solver: ADABELIEF. Time cost: " << costime << " ms" << std::endl;
|
||
|
break;
|
||
|
default:
|
||
|
std::clog << "Solver: Unknown. Time cost: " << costime << " ms" << std::endl;
|
||
|
break;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if (verbose) sgd_error_str(ret, ss, err_throw);
|
||
|
else if (ret < 0) sgd_error_str(ret, ss, err_throw);
|
||
|
return;
|
||
|
}
|
||
|
|
||
|
gctl::sgd_return_code gctl::sgd_solver::momentum(array<double> &m)
|
||
|
{
|
||
|
int n_size = m.size();
|
||
|
//check parameters
|
||
|
if (n_size <= 0) return SGD_INVALID_VARIABLE_SIZE;
|
||
|
if (sgd_param_.epsilon < 0) return SGD_INVALID_EPSILON;
|
||
|
if (sgd_param_.mu < 0 || sgd_param_.mu >= 1.0) return SGD_INVALID_MU;
|
||
|
|
||
|
array<double> mk(n_size, 0.0);
|
||
|
array<double> g(n_size);
|
||
|
|
||
|
int t = 0;
|
||
|
double fx;
|
||
|
while (1)
|
||
|
{
|
||
|
fx = SGD_Evaluate(m, g);
|
||
|
|
||
|
if (SGD_Progress(fx, m, sgd_param_, t)) return SGD_STOP;
|
||
|
if (fx < sgd_param_.epsilon) return SGD_CONVERGENCE;
|
||
|
|
||
|
for (int i = 0; i < n_size; i++)
|
||
|
{
|
||
|
mk[i] = sgd_param_.mu*mk[i] + g[i];
|
||
|
|
||
|
m[i] = m[i] - sgd_param_.alpha * mk[i];
|
||
|
if (m[i] != m[i]) return SGD_NAN_VALUE;
|
||
|
}
|
||
|
|
||
|
t++;
|
||
|
if (sgd_param_.iteration > 0 && t >= sgd_param_.iteration) break;
|
||
|
}
|
||
|
|
||
|
return SGD_REACHED_MAX_ITERATIONS;
|
||
|
}
|
||
|
|
||
|
gctl::sgd_return_code gctl::sgd_solver::nag(array<double> &m)
|
||
|
{
|
||
|
int n_size = m.size();
|
||
|
//check parameters
|
||
|
if (n_size <= 0) return SGD_INVALID_VARIABLE_SIZE;
|
||
|
if (sgd_param_.epsilon < 0) return SGD_INVALID_EPSILON;
|
||
|
if (sgd_param_.mu < 0 || sgd_param_.mu >= 1.0) return SGD_INVALID_MU;
|
||
|
|
||
|
array<double> mk(n_size, 0.0);
|
||
|
array<double> xk(n_size);
|
||
|
array<double> g (n_size);
|
||
|
|
||
|
int t = 0;
|
||
|
double fx;
|
||
|
while (1)
|
||
|
{
|
||
|
for (int i = 0; i < n_size; i++)
|
||
|
{
|
||
|
xk[i] = m[i] - sgd_param_.mu*sgd_param_.alpha*mk[i];
|
||
|
}
|
||
|
|
||
|
fx = SGD_Evaluate(xk, g);
|
||
|
|
||
|
if (SGD_Progress(fx, m, sgd_param_, t)) return SGD_STOP;
|
||
|
if (fx < sgd_param_.epsilon) return SGD_CONVERGENCE;
|
||
|
|
||
|
for (int i = 0; i < n_size; i++)
|
||
|
{
|
||
|
mk[i] = sgd_param_.mu*mk[i] + g[i];
|
||
|
|
||
|
m[i] = m[i] - sgd_param_.alpha * mk[i];
|
||
|
if (m[i] != m[i]) return SGD_NAN_VALUE;
|
||
|
}
|
||
|
|
||
|
t++;
|
||
|
if (sgd_param_.iteration > 0 && t >= sgd_param_.iteration) break;
|
||
|
}
|
||
|
|
||
|
return SGD_REACHED_MAX_ITERATIONS;
|
||
|
}
|
||
|
|
||
|
gctl::sgd_return_code gctl::sgd_solver::adagrad(array<double> &m)
|
||
|
{
|
||
|
int n_size = m.size();
|
||
|
//check parameters
|
||
|
if (n_size <= 0) return SGD_INVALID_VARIABLE_SIZE;
|
||
|
if (sgd_param_.epsilon < 0.0) return SGD_INVALID_EPSILON;
|
||
|
if (sgd_param_.sigma < 0.0) return SGD_INVALID_SIGMA;
|
||
|
|
||
|
array<double> mk(n_size, 0.0);
|
||
|
array<double> g (n_size);
|
||
|
|
||
|
int t = 0;
|
||
|
double fx;
|
||
|
while (1)
|
||
|
{
|
||
|
fx = SGD_Evaluate(m, g);
|
||
|
|
||
|
if (SGD_Progress(fx, m, sgd_param_, t)) return SGD_STOP;
|
||
|
if (fx < sgd_param_.epsilon) return SGD_CONVERGENCE;
|
||
|
|
||
|
for (int i = 0; i < n_size; i++)
|
||
|
{
|
||
|
mk[i] = mk[i] + g[i]*g[i];
|
||
|
|
||
|
m[i] = m[i] - sgd_param_.alpha * g[i]/(sqrt(mk[i]) + sgd_param_.sigma);
|
||
|
if (m[i] != m[i]) return SGD_NAN_VALUE;
|
||
|
}
|
||
|
|
||
|
t++;
|
||
|
if (sgd_param_.iteration > 0 && t >= sgd_param_.iteration) break;
|
||
|
}
|
||
|
|
||
|
return SGD_REACHED_MAX_ITERATIONS;
|
||
|
}
|
||
|
|
||
|
gctl::sgd_return_code gctl::sgd_solver::rmsprop(array<double> &m)
|
||
|
{
|
||
|
int n_size = m.size();
|
||
|
//check parameters
|
||
|
if (n_size <= 0) return SGD_INVALID_VARIABLE_SIZE;
|
||
|
if (sgd_param_.epsilon < 0.0) return SGD_INVALID_EPSILON;
|
||
|
if (sgd_param_.sigma < 0.0) return SGD_INVALID_SIGMA;
|
||
|
|
||
|
array<double> vk(n_size, 0.0);
|
||
|
array<double> g (n_size);
|
||
|
|
||
|
|
||
|
int t = 0;
|
||
|
double fx;
|
||
|
while (1)
|
||
|
{
|
||
|
fx = SGD_Evaluate(m, g);
|
||
|
|
||
|
if (SGD_Progress(fx, m, sgd_param_, t)) return SGD_STOP;
|
||
|
if (fx < sgd_param_.epsilon) return SGD_CONVERGENCE;
|
||
|
|
||
|
|
||
|
for (int i = 0; i < n_size; i++)
|
||
|
{
|
||
|
vk[i] = sgd_param_.beta_2 * vk[i] + (1.0 - sgd_param_.beta_2)*g[i]*g[i];
|
||
|
|
||
|
m[i] = m[i] - sgd_param_.alpha * g[i]/(sqrt(vk[i]) + sgd_param_.sigma);
|
||
|
if (m[i] != m[i]) return SGD_NAN_VALUE;
|
||
|
}
|
||
|
|
||
|
t++;
|
||
|
if (sgd_param_.iteration > 0 && t >= sgd_param_.iteration) break;
|
||
|
}
|
||
|
|
||
|
return SGD_REACHED_MAX_ITERATIONS;
|
||
|
}
|
||
|
|
||
|
gctl::sgd_return_code gctl::sgd_solver::adam(array<double> &m)
|
||
|
{
|
||
|
int n_size = m.size();
|
||
|
//check parameters
|
||
|
if (n_size <= 0) return SGD_INVALID_VARIABLE_SIZE;
|
||
|
if (sgd_param_.epsilon < 0) return SGD_INVALID_EPSILON;
|
||
|
if (sgd_param_.alpha < 0) return SGD_INVALID_ALPHA;
|
||
|
if (sgd_param_.beta_1 < 0.0 || sgd_param_.beta_1 >= 1.0) return SGD_INVALID_BETA;
|
||
|
if (sgd_param_.beta_2 < 0.0 || sgd_param_.beta_2 >= 1.0) return SGD_INVALID_BETA;
|
||
|
if (sgd_param_.sigma < 0.0) return SGD_INVALID_SIGMA;
|
||
|
|
||
|
array<double> mk(n_size, 0.0);
|
||
|
array<double> vk(n_size, 0.0);
|
||
|
array<double> g (n_size);
|
||
|
|
||
|
double beta_1t = 1.0, beta_2t = 1.0;
|
||
|
double alpha_k;
|
||
|
|
||
|
int t = 0;
|
||
|
double fx;
|
||
|
while (1)
|
||
|
{
|
||
|
fx = SGD_Evaluate(m, g);
|
||
|
|
||
|
if (SGD_Progress(fx, m, sgd_param_, t)) return SGD_STOP;
|
||
|
if (fx < sgd_param_.epsilon) return SGD_CONVERGENCE;
|
||
|
|
||
|
beta_1t *= sgd_param_.beta_1;
|
||
|
beta_2t *= sgd_param_.beta_2;
|
||
|
alpha_k = sgd_param_.alpha * sqrt(1.0 - beta_2t)/(1.0 - beta_1t);
|
||
|
|
||
|
int i;
|
||
|
#pragma omp parallel for private (i) schedule(guided)
|
||
|
for (i = 0; i < n_size; i++)
|
||
|
{
|
||
|
mk[i] = sgd_param_.beta_1*mk[i] + (1.0 - sgd_param_.beta_1)*g[i];
|
||
|
vk[i] = sgd_param_.beta_2*vk[i] + (1.0 - sgd_param_.beta_2)*g[i]*g[i];
|
||
|
|
||
|
m[i] = m[i] - alpha_k * mk[i]/(sqrt(vk[i]) + sgd_param_.sigma);
|
||
|
//if (m[i] != m[i]) return SGD_NAN_VALUE;
|
||
|
}
|
||
|
|
||
|
t++;
|
||
|
if (sgd_param_.iteration > 0 && t >= sgd_param_.iteration) break;
|
||
|
}
|
||
|
|
||
|
return SGD_REACHED_MAX_ITERATIONS;
|
||
|
}
|
||
|
|
||
|
gctl::sgd_return_code gctl::sgd_solver::nadam(array<double> &m)
|
||
|
{
|
||
|
int n_size = m.size();
|
||
|
//check parameters
|
||
|
if (n_size <= 0) return SGD_INVALID_VARIABLE_SIZE;
|
||
|
if (sgd_param_.epsilon < 0) return SGD_INVALID_EPSILON;
|
||
|
if (sgd_param_.alpha < 0) return SGD_INVALID_ALPHA;
|
||
|
if (sgd_param_.beta_1 < 0.0 || sgd_param_.beta_1 >= 1.0) return SGD_INVALID_BETA;
|
||
|
if (sgd_param_.beta_2 < 0.0 || sgd_param_.beta_2 >= 1.0) return SGD_INVALID_BETA;
|
||
|
if (sgd_param_.sigma < 0.0) return SGD_INVALID_SIGMA;
|
||
|
|
||
|
array<double> mk(n_size, 0.0);
|
||
|
array<double> mk_hat(n_size);
|
||
|
array<double> nk(n_size, 0.0);
|
||
|
array<double> nk_hat(n_size);
|
||
|
array<double> g (n_size);
|
||
|
array<double> g_hat(n_size);
|
||
|
|
||
|
double beta_1t = 1.0, beta_1t1 = sgd_param_.beta_1, beta_2t = 1.0;
|
||
|
|
||
|
int t = 0;
|
||
|
double fx;
|
||
|
while (1)
|
||
|
{
|
||
|
fx = SGD_Evaluate(m, g);
|
||
|
|
||
|
if (SGD_Progress(fx, m, sgd_param_, t)) return SGD_STOP;
|
||
|
if (fx < sgd_param_.epsilon) return SGD_CONVERGENCE;
|
||
|
|
||
|
beta_1t *= sgd_param_.beta_1;
|
||
|
beta_1t1 *= sgd_param_.beta_1;
|
||
|
beta_2t *= sgd_param_.beta_2;
|
||
|
|
||
|
for (int i = 0; i < n_size; i++)
|
||
|
{
|
||
|
g_hat[i] = g[i]/(1.0 - beta_1t);
|
||
|
mk[i] = sgd_param_.beta_1*mk[i] + (1.0 - sgd_param_.beta_1)*g[i];
|
||
|
nk[i] = sgd_param_.beta_2*nk[i] + (1.0 - sgd_param_.beta_2)*g[i]*g[i];
|
||
|
|
||
|
mk_hat[i] = mk[i]/(1.0 - beta_1t1);
|
||
|
nk_hat[i] = nk[i]/(1.0 - beta_2t);
|
||
|
|
||
|
m[i] = m[i] - sgd_param_.alpha * ((1.0 - beta_1t)*g_hat[i]
|
||
|
+ beta_1t1*mk_hat[i])/(sqrt(nk_hat[i]) + sgd_param_.sigma);
|
||
|
if (m[i] != m[i]) return SGD_NAN_VALUE;
|
||
|
}
|
||
|
|
||
|
t++;
|
||
|
if (sgd_param_.iteration > 0 && t >= sgd_param_.iteration) break;
|
||
|
}
|
||
|
|
||
|
return SGD_REACHED_MAX_ITERATIONS;
|
||
|
}
|
||
|
|
||
|
gctl::sgd_return_code gctl::sgd_solver::adamax(array<double> &m)
|
||
|
{
|
||
|
int n_size = m.size();
|
||
|
//check parameters
|
||
|
if (n_size <= 0) return SGD_INVALID_VARIABLE_SIZE;
|
||
|
if (sgd_param_.epsilon < 0) return SGD_INVALID_EPSILON;
|
||
|
if (sgd_param_.alpha < 0) return SGD_INVALID_ALPHA;
|
||
|
if (sgd_param_.beta_1 < 0.0 || sgd_param_.beta_1 >= 1.0) return SGD_INVALID_BETA;
|
||
|
if (sgd_param_.beta_2 < 0.0 || sgd_param_.beta_2 >= 1.0) return SGD_INVALID_BETA;
|
||
|
if (sgd_param_.sigma < 0.0) return SGD_INVALID_SIGMA;
|
||
|
|
||
|
array<double> mk(n_size, 0.0);
|
||
|
array<double> vk(n_size, 0.0);
|
||
|
array<double> g (n_size);
|
||
|
|
||
|
double beta_1t = 1.0;
|
||
|
|
||
|
int t = 0;
|
||
|
double fx;
|
||
|
while (1)
|
||
|
{
|
||
|
fx = SGD_Evaluate(m, g);
|
||
|
|
||
|
if (SGD_Progress(fx, m, sgd_param_, t)) return SGD_STOP;
|
||
|
if (fx < sgd_param_.epsilon) return SGD_CONVERGENCE;
|
||
|
|
||
|
beta_1t *= sgd_param_.beta_1;
|
||
|
|
||
|
for (int i = 0; i < n_size; i++)
|
||
|
{
|
||
|
mk[i] = sgd_param_.beta_1*mk[i] + (1.0 - sgd_param_.beta_1)*g[i];
|
||
|
vk[i] = std::max(sgd_param_.beta_2*vk[i], std::fabs(g[i]));
|
||
|
|
||
|
m[i] = m[i] - sgd_param_.alpha * mk[i]/((1.0 - beta_1t)*vk[i]);
|
||
|
if (m[i] != m[i]) return SGD_NAN_VALUE;
|
||
|
}
|
||
|
|
||
|
t++;
|
||
|
if (sgd_param_.iteration > 0 && t >= sgd_param_.iteration) break;
|
||
|
}
|
||
|
|
||
|
return SGD_REACHED_MAX_ITERATIONS;
|
||
|
}
|
||
|
|
||
|
gctl::sgd_return_code gctl::sgd_solver::adabelief(array<double> &m)
|
||
|
{
|
||
|
int n_size = m.size();
|
||
|
//check parameters
|
||
|
if (n_size <= 0) return SGD_INVALID_VARIABLE_SIZE;
|
||
|
if (sgd_param_.epsilon < 0) return SGD_INVALID_EPSILON;
|
||
|
if (sgd_param_.alpha < 0) return SGD_INVALID_ALPHA;
|
||
|
if (sgd_param_.beta_1 < 0.0 || sgd_param_.beta_1 >= 1.0) return SGD_INVALID_BETA;
|
||
|
if (sgd_param_.beta_2 < 0.0 || sgd_param_.beta_2 >= 1.0) return SGD_INVALID_BETA;
|
||
|
if (sgd_param_.sigma < 0.0) return SGD_INVALID_SIGMA;
|
||
|
|
||
|
array<double> mk(n_size, 0.0);
|
||
|
array<double> vk(n_size, 0.0);
|
||
|
array<double> g (n_size);
|
||
|
|
||
|
double beta_1t = 1.0, beta_2t = 1.0;
|
||
|
double alpha_k;
|
||
|
|
||
|
int t = 0;
|
||
|
double fx;
|
||
|
while (1)
|
||
|
{
|
||
|
fx = SGD_Evaluate(m, g);
|
||
|
|
||
|
if (SGD_Progress(fx, m, sgd_param_, t)) return SGD_STOP;
|
||
|
if (fx < sgd_param_.epsilon) return SGD_CONVERGENCE;
|
||
|
|
||
|
beta_1t *= sgd_param_.beta_1;
|
||
|
beta_2t *= sgd_param_.beta_2;
|
||
|
|
||
|
alpha_k = sgd_param_.alpha * sqrt(1.0 - beta_2t)/(1.0 - beta_1t);
|
||
|
|
||
|
for (int i = 0; i < n_size; i++)
|
||
|
{
|
||
|
mk[i] = sgd_param_.beta_1*mk[i] + (1.0 - sgd_param_.beta_1)*g[i];
|
||
|
vk[i] = sgd_param_.beta_2*vk[i] + (1.0 - sgd_param_.beta_2)*(g[i] - mk[i])*(g[i] - mk[i]);
|
||
|
|
||
|
m[i] = m[i] - alpha_k * mk[i]/(sqrt(vk[i]) + sgd_param_.sigma);
|
||
|
if (m[i] != m[i]) return SGD_NAN_VALUE;
|
||
|
}
|
||
|
|
||
|
t++;
|
||
|
if (sgd_param_.iteration > 0 && t >= sgd_param_.iteration) break;
|
||
|
}
|
||
|
|
||
|
return SGD_REACHED_MAX_ITERATIONS;
|
||
|
}
|