201 lines
6.3 KiB
C++
201 lines
6.3 KiB
C++
/********************************************************
|
||
* ██████╗ ██████╗████████╗██╗
|
||
* ██╔════╝ ██╔════╝╚══██╔══╝██║
|
||
* ██║ ███╗██║ ██║ ██║
|
||
* ██║ ██║██║ ██║ ██║
|
||
* ╚██████╔╝╚██████╗ ██║ ███████╗
|
||
* ╚═════╝ ╚═════╝ ╚═╝ ╚══════╝
|
||
* Geophysical Computational Tools & Library (GCTL)
|
||
*
|
||
* Copyright (c) 2022 Yi Zhang (yizhang-geo@zju.edu.cn)
|
||
*
|
||
* GCTL is distributed under a dual licensing scheme. You can redistribute
|
||
* it and/or modify it under the terms of the GNU Lesser General Public
|
||
* License as published by the Free Software Foundation, either version 2
|
||
* of the License, or (at your option) any later version. You should have
|
||
* received a copy of the GNU Lesser General Public License along with this
|
||
* program. If not, see <http://www.gnu.org/licenses/>.
|
||
*
|
||
* If the terms and conditions of the LGPL v.2. would prevent you from using
|
||
* the GCTL, please consider the option to obtain a commercial license for a
|
||
* fee. These licenses are offered by the GCTL's original author. As a rule,
|
||
* licenses are provided "as-is", unlimited in time for a one time fee. Please
|
||
* send corresponding requests to: yizhang-geo@zju.edu.cn. Please do not forget
|
||
* to include some description of your company and the realm of its activities.
|
||
* Also add information on how to contact you by electronic and paper mail.
|
||
******************************************************/
|
||
|
||
#ifndef _GCTL_SGD_H
|
||
#define _GCTL_SGD_H
|
||
|
||
#include "gctl/core.h"
|
||
#include "gctl/algorithm.h"
|
||
|
||
#include "gctl_optimization_config.h"
|
||
#ifdef GCTL_OPTIMIZATION_TOML
|
||
#include "toml.hpp"
|
||
#endif // GCTL_OPTIMIZATION_TOML
|
||
|
||
#if defined _WINDOWS || __WIN32__
|
||
#include "windows.h"
|
||
#endif // _WINDOWS || __WIN32__
|
||
|
||
#ifdef GSTL_OPENMP
|
||
#include "omp.h"
|
||
#endif // GSTL_OPENMP
|
||
|
||
namespace gctl
|
||
{
|
||
/**
|
||
* @brief Types of method that could be recognized by the sgd_solver() function.
|
||
*/
|
||
enum sgd_solver_type
|
||
{
|
||
/**
|
||
* Classic momentum.
|
||
*/
|
||
MOMENTUM,
|
||
|
||
/**
|
||
* Nesterov’s accelerated gradient (NAG)
|
||
*/
|
||
NAG,
|
||
|
||
/**
|
||
* AdaGrad method.
|
||
*/
|
||
ADAGRAD,
|
||
|
||
/**
|
||
* RMSProp method.
|
||
*/
|
||
RMSPROP,
|
||
|
||
/**
|
||
* Adam method.
|
||
*/
|
||
ADAM,
|
||
|
||
/**
|
||
* Nadam method.
|
||
*/
|
||
NADAM,
|
||
|
||
/**
|
||
* AdaMax method.
|
||
*/
|
||
ADAMAX,
|
||
|
||
/**
|
||
* AdaBelief method.
|
||
*/
|
||
ADABELIEF,
|
||
};
|
||
|
||
/**
|
||
* @brief return value of the sgd_solver() function.
|
||
*/
|
||
enum sgd_return_code
|
||
{
|
||
SGD_SUCCESS = 0, ///< The optimization terminated successfully.
|
||
SGD_CONVERGENCE = 1, ///< The optimization reached convergence.
|
||
SGD_STOP, ///< The process stopped by the monitoring function.
|
||
SGD_UNKNOWN_ERROR = -1024, ///< Unknown error.
|
||
SGD_INVALID_VARIABLE_SIZE, ///< The variable size is negative
|
||
SGD_INVALID_EPSILON, ///< The epsilon is negative.
|
||
SGD_REACHED_MAX_ITERATIONS, ///< Iteration reached max limit.
|
||
SGD_INVALID_MU, ///< Invalid value for mu.
|
||
SGD_INVALID_ALPHA, ///< Invalid value for alpha.
|
||
SGD_INVALID_BETA, ///< Invalid value for beta.
|
||
SGD_INVALID_SIGMA, ///< Invalid value for sigma.
|
||
SGD_NAN_VALUE, ///< Nan value.
|
||
};
|
||
|
||
/**
|
||
* @brief Parameters of the SGD methods.
|
||
*/
|
||
struct sgd_para
|
||
{
|
||
/**
|
||
* Maximal iteration times. The iteration won't stop unless the convergence
|
||
* is reached if this parameter is equal to or smaller than zero. The default
|
||
* is 0.
|
||
*/
|
||
int iteration;
|
||
|
||
/**
|
||
* Epsilon for convergence test. This parameter determines the accuracy
|
||
* with which the solution is to be found. Must be bigger than zero and
|
||
* the default is 1e-6.
|
||
*/
|
||
double epsilon;
|
||
|
||
/**
|
||
* Damping rate of the classic momentum method and the NAG method, which
|
||
* is typically given between 0 and 1. The default is 0.01.
|
||
*/
|
||
double mu;
|
||
|
||
/**
|
||
* Step size of the iteration. The default value is 0.01 for Adam and AdaMax.
|
||
*/
|
||
double alpha;
|
||
|
||
/**
|
||
* Exponential decay rates for the first order moment estimates. The range of this
|
||
* parameter is [0, 1) and the default value is 0.9.
|
||
*/
|
||
double beta_1;
|
||
|
||
/**
|
||
* Exponential decay rates for the second order moment estimates. The range of this
|
||
* parameter is [0, 1) and the default value is 0.999.
|
||
*/
|
||
double beta_2;
|
||
|
||
/**
|
||
* A small positive number validates the algorithm. The default value is 1e-8.
|
||
*/
|
||
double sigma;
|
||
};
|
||
|
||
class sgd_solver
|
||
{
|
||
private:
|
||
sgd_para sgd_param_;
|
||
int sgd_inter_;
|
||
bool sgd_silent_;
|
||
std::string solver_name_;
|
||
|
||
public:
|
||
sgd_solver();
|
||
virtual ~sgd_solver();
|
||
|
||
virtual double SGD_Evaluate(const array<double> &x, array<double> &g) = 0;
|
||
virtual int SGD_Progress(double fx, const array<double> &x, const sgd_para ¶m, const int k);
|
||
|
||
void sgd_silent();
|
||
void set_sgd_report_interval(int inter);
|
||
void set_sgd_para(const sgd_para ¶m);
|
||
void show_solver();
|
||
void sgd_error_str(sgd_return_code err_code, std::ostream &ss = std::clog, bool err_throw = false);
|
||
sgd_para default_sgd_para();
|
||
|
||
#ifdef GCTL_OPTIMIZATION_TOML
|
||
void set_sgd_para(const toml::value &toml_data);
|
||
#endif // GCTL_OPTIMIZATION_TOML
|
||
|
||
sgd_return_code momentum(array<double> &m);
|
||
sgd_return_code nag(array<double> &m);
|
||
sgd_return_code adagrad(array<double> &m);
|
||
sgd_return_code rmsprop(array<double> &m);
|
||
sgd_return_code adam(array<double> &m);
|
||
sgd_return_code nadam(array<double> &m);
|
||
sgd_return_code adamax(array<double> &m);
|
||
sgd_return_code adabelief(array<double> &m);
|
||
|
||
void SGD_Minimize(array<double> &m, sgd_solver_type solver_id = ADAM, std::ostream &ss = std::clog, bool verbose = true, bool err_throw = false);
|
||
};
|
||
}
|
||
|
||
#endif // _GCTL_SGD_H
|