201 lines
6.3 KiB
C
201 lines
6.3 KiB
C
|
/********************************************************
|
|||
|
* ██████╗ ██████╗████████╗██╗
|
|||
|
* ██╔════╝ ██╔════╝╚══██╔══╝██║
|
|||
|
* ██║ ███╗██║ ██║ ██║
|
|||
|
* ██║ ██║██║ ██║ ██║
|
|||
|
* ╚██████╔╝╚██████╗ ██║ ███████╗
|
|||
|
* ╚═════╝ ╚═════╝ ╚═╝ ╚══════╝
|
|||
|
* Geophysical Computational Tools & Library (GCTL)
|
|||
|
*
|
|||
|
* Copyright (c) 2022 Yi Zhang (yizhang-geo@zju.edu.cn)
|
|||
|
*
|
|||
|
* GCTL is distributed under a dual licensing scheme. You can redistribute
|
|||
|
* it and/or modify it under the terms of the GNU Lesser General Public
|
|||
|
* License as published by the Free Software Foundation, either version 2
|
|||
|
* of the License, or (at your option) any later version. You should have
|
|||
|
* received a copy of the GNU Lesser General Public License along with this
|
|||
|
* program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
*
|
|||
|
* If the terms and conditions of the LGPL v.2. would prevent you from using
|
|||
|
* the GCTL, please consider the option to obtain a commercial license for a
|
|||
|
* fee. These licenses are offered by the GCTL's original author. As a rule,
|
|||
|
* licenses are provided "as-is", unlimited in time for a one time fee. Please
|
|||
|
* send corresponding requests to: yizhang-geo@zju.edu.cn. Please do not forget
|
|||
|
* to include some description of your company and the realm of its activities.
|
|||
|
* Also add information on how to contact you by electronic and paper mail.
|
|||
|
******************************************************/
|
|||
|
|
|||
|
#ifndef _GCTL_SGD_H
|
|||
|
#define _GCTL_SGD_H
|
|||
|
|
|||
|
#include "gctl/core.h"
|
|||
|
#include "gctl/algorithm.h"
|
|||
|
|
|||
|
#include "gctl_optimization_config.h"
|
|||
|
#ifdef GCTL_OPTIMIZATION_TOML
|
|||
|
#include "toml.hpp"
|
|||
|
#endif // GCTL_OPTIMIZATION_TOML
|
|||
|
|
|||
|
#if defined _WINDOWS || __WIN32__
|
|||
|
#include "windows.h"
|
|||
|
#endif // _WINDOWS || __WIN32__
|
|||
|
|
|||
|
#ifdef GSTL_OPENMP
|
|||
|
#include "omp.h"
|
|||
|
#endif // GSTL_OPENMP
|
|||
|
|
|||
|
namespace gctl
|
|||
|
{
|
|||
|
/**
|
|||
|
* @brief Types of method that could be recognized by the sgd_solver() function.
|
|||
|
*/
|
|||
|
enum sgd_solver_type
|
|||
|
{
|
|||
|
/**
|
|||
|
* Classic momentum.
|
|||
|
*/
|
|||
|
MOMENTUM,
|
|||
|
|
|||
|
/**
|
|||
|
* Nesterov’s accelerated gradient (NAG)
|
|||
|
*/
|
|||
|
NAG,
|
|||
|
|
|||
|
/**
|
|||
|
* AdaGrad method.
|
|||
|
*/
|
|||
|
ADAGRAD,
|
|||
|
|
|||
|
/**
|
|||
|
* RMSProp method.
|
|||
|
*/
|
|||
|
RMSPROP,
|
|||
|
|
|||
|
/**
|
|||
|
* Adam method.
|
|||
|
*/
|
|||
|
ADAM,
|
|||
|
|
|||
|
/**
|
|||
|
* Nadam method.
|
|||
|
*/
|
|||
|
NADAM,
|
|||
|
|
|||
|
/**
|
|||
|
* AdaMax method.
|
|||
|
*/
|
|||
|
ADAMAX,
|
|||
|
|
|||
|
/**
|
|||
|
* AdaBelief method.
|
|||
|
*/
|
|||
|
ADABELIEF,
|
|||
|
};
|
|||
|
|
|||
|
/**
|
|||
|
* @brief return value of the sgd_solver() function.
|
|||
|
*/
|
|||
|
enum sgd_return_code
|
|||
|
{
|
|||
|
SGD_SUCCESS = 0, ///< The optimization terminated successfully.
|
|||
|
SGD_CONVERGENCE = 1, ///< The optimization reached convergence.
|
|||
|
SGD_STOP, ///< The process stopped by the monitoring function.
|
|||
|
SGD_UNKNOWN_ERROR = -1024, ///< Unknown error.
|
|||
|
SGD_INVALID_VARIABLE_SIZE, ///< The variable size is negative
|
|||
|
SGD_INVALID_EPSILON, ///< The epsilon is negative.
|
|||
|
SGD_REACHED_MAX_ITERATIONS, ///< Iteration reached max limit.
|
|||
|
SGD_INVALID_MU, ///< Invalid value for mu.
|
|||
|
SGD_INVALID_ALPHA, ///< Invalid value for alpha.
|
|||
|
SGD_INVALID_BETA, ///< Invalid value for beta.
|
|||
|
SGD_INVALID_SIGMA, ///< Invalid value for sigma.
|
|||
|
SGD_NAN_VALUE, ///< Nan value.
|
|||
|
};
|
|||
|
|
|||
|
/**
|
|||
|
* @brief Parameters of the SGD methods.
|
|||
|
*/
|
|||
|
struct sgd_para
|
|||
|
{
|
|||
|
/**
|
|||
|
* Maximal iteration times. The iteration won't stop unless the convergence
|
|||
|
* is reached if this parameter is equal to or smaller than zero. The default
|
|||
|
* is 0.
|
|||
|
*/
|
|||
|
int iteration;
|
|||
|
|
|||
|
/**
|
|||
|
* Epsilon for convergence test. This parameter determines the accuracy
|
|||
|
* with which the solution is to be found. Must be bigger than zero and
|
|||
|
* the default is 1e-6.
|
|||
|
*/
|
|||
|
double epsilon;
|
|||
|
|
|||
|
/**
|
|||
|
* Damping rate of the classic momentum method and the NAG method, which
|
|||
|
* is typically given between 0 and 1. The default is 0.01.
|
|||
|
*/
|
|||
|
double mu;
|
|||
|
|
|||
|
/**
|
|||
|
* Step size of the iteration. The default value is 0.01 for Adam and AdaMax.
|
|||
|
*/
|
|||
|
double alpha;
|
|||
|
|
|||
|
/**
|
|||
|
* Exponential decay rates for the first order moment estimates. The range of this
|
|||
|
* parameter is [0, 1) and the default value is 0.9.
|
|||
|
*/
|
|||
|
double beta_1;
|
|||
|
|
|||
|
/**
|
|||
|
* Exponential decay rates for the second order moment estimates. The range of this
|
|||
|
* parameter is [0, 1) and the default value is 0.999.
|
|||
|
*/
|
|||
|
double beta_2;
|
|||
|
|
|||
|
/**
|
|||
|
* A small positive number validates the algorithm. The default value is 1e-8.
|
|||
|
*/
|
|||
|
double sigma;
|
|||
|
};
|
|||
|
|
|||
|
class sgd_solver
|
|||
|
{
|
|||
|
private:
|
|||
|
sgd_para sgd_param_;
|
|||
|
int sgd_inter_;
|
|||
|
bool sgd_silent_;
|
|||
|
std::string solver_name_;
|
|||
|
|
|||
|
public:
|
|||
|
sgd_solver();
|
|||
|
virtual ~sgd_solver();
|
|||
|
|
|||
|
virtual double SGD_Evaluate(const array<double> &x, array<double> &g) = 0;
|
|||
|
virtual int SGD_Progress(double fx, const array<double> &x, const sgd_para ¶m, const int k);
|
|||
|
|
|||
|
void sgd_silent();
|
|||
|
void set_sgd_report_interval(int inter);
|
|||
|
void set_sgd_para(const sgd_para ¶m);
|
|||
|
void show_solver();
|
|||
|
void sgd_error_str(sgd_return_code err_code, std::ostream &ss = std::clog, bool err_throw = false);
|
|||
|
sgd_para default_sgd_para();
|
|||
|
|
|||
|
#ifdef GCTL_OPTIMIZATION_TOML
|
|||
|
void set_sgd_para(const toml::value &toml_data);
|
|||
|
#endif // GCTL_OPTIMIZATION_TOML
|
|||
|
|
|||
|
sgd_return_code momentum(array<double> &m);
|
|||
|
sgd_return_code nag(array<double> &m);
|
|||
|
sgd_return_code adagrad(array<double> &m);
|
|||
|
sgd_return_code rmsprop(array<double> &m);
|
|||
|
sgd_return_code adam(array<double> &m);
|
|||
|
sgd_return_code nadam(array<double> &m);
|
|||
|
sgd_return_code adamax(array<double> &m);
|
|||
|
sgd_return_code adabelief(array<double> &m);
|
|||
|
|
|||
|
void SGD_Minimize(array<double> &m, sgd_solver_type solver_id = ADAM, std::ostream &ss = std::clog, bool verbose = true, bool err_throw = false);
|
|||
|
};
|
|||
|
}
|
|||
|
|
|||
|
#endif // _GCTL_SGD_H
|