gctl_optimization/lib/optimization/sgd.h

201 lines
6.3 KiB
C
Raw Permalink Normal View History

2024-09-10 20:04:47 +08:00
/********************************************************
*
*
*
*
*
*
* Geophysical Computational Tools & Library (GCTL)
*
* Copyright (c) 2022 Yi Zhang (yizhang-geo@zju.edu.cn)
*
* GCTL is distributed under a dual licensing scheme. You can redistribute
* it and/or modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation, either version 2
* of the License, or (at your option) any later version. You should have
* received a copy of the GNU Lesser General Public License along with this
* program. If not, see <http://www.gnu.org/licenses/>.
*
* If the terms and conditions of the LGPL v.2. would prevent you from using
* the GCTL, please consider the option to obtain a commercial license for a
* fee. These licenses are offered by the GCTL's original author. As a rule,
* licenses are provided "as-is", unlimited in time for a one time fee. Please
* send corresponding requests to: yizhang-geo@zju.edu.cn. Please do not forget
* to include some description of your company and the realm of its activities.
* Also add information on how to contact you by electronic and paper mail.
******************************************************/
#ifndef _GCTL_SGD_H
#define _GCTL_SGD_H
#include "gctl/core.h"
#include "gctl/algorithm.h"
#include "gctl_optimization_config.h"
#ifdef GCTL_OPTIMIZATION_TOML
#include "toml.hpp"
#endif // GCTL_OPTIMIZATION_TOML
#if defined _WINDOWS || __WIN32__
#include "windows.h"
#endif // _WINDOWS || __WIN32__
#ifdef GSTL_OPENMP
#include "omp.h"
#endif // GSTL_OPENMP
namespace gctl
{
/**
* @brief Types of method that could be recognized by the sgd_solver() function.
*/
enum sgd_solver_type
{
/**
* Classic momentum.
*/
MOMENTUM,
/**
* Nesterovs accelerated gradient (NAG)
*/
NAG,
/**
* AdaGrad method.
*/
ADAGRAD,
/**
* RMSProp method.
*/
RMSPROP,
/**
* Adam method.
*/
ADAM,
/**
* Nadam method.
*/
NADAM,
/**
* AdaMax method.
*/
ADAMAX,
/**
* AdaBelief method.
*/
ADABELIEF,
};
/**
* @brief return value of the sgd_solver() function.
*/
enum sgd_return_code
{
SGD_SUCCESS = 0, ///< The optimization terminated successfully.
SGD_CONVERGENCE = 1, ///< The optimization reached convergence.
SGD_STOP, ///< The process stopped by the monitoring function.
SGD_UNKNOWN_ERROR = -1024, ///< Unknown error.
SGD_INVALID_VARIABLE_SIZE, ///< The variable size is negative
SGD_INVALID_EPSILON, ///< The epsilon is negative.
SGD_REACHED_MAX_ITERATIONS, ///< Iteration reached max limit.
SGD_INVALID_MU, ///< Invalid value for mu.
SGD_INVALID_ALPHA, ///< Invalid value for alpha.
SGD_INVALID_BETA, ///< Invalid value for beta.
SGD_INVALID_SIGMA, ///< Invalid value for sigma.
SGD_NAN_VALUE, ///< Nan value.
};
/**
* @brief Parameters of the SGD methods.
*/
struct sgd_para
{
/**
* Maximal iteration times. The iteration won't stop unless the convergence
* is reached if this parameter is equal to or smaller than zero. The default
* is 0.
*/
int iteration;
/**
* Epsilon for convergence test. This parameter determines the accuracy
* with which the solution is to be found. Must be bigger than zero and
* the default is 1e-6.
*/
double epsilon;
/**
* Damping rate of the classic momentum method and the NAG method, which
* is typically given between 0 and 1. The default is 0.01.
*/
double mu;
/**
* Step size of the iteration. The default value is 0.01 for Adam and AdaMax.
*/
double alpha;
/**
* Exponential decay rates for the first order moment estimates. The range of this
* parameter is [0, 1) and the default value is 0.9.
*/
double beta_1;
/**
* Exponential decay rates for the second order moment estimates. The range of this
* parameter is [0, 1) and the default value is 0.999.
*/
double beta_2;
/**
* A small positive number validates the algorithm. The default value is 1e-8.
*/
double sigma;
};
class sgd_solver
{
private:
sgd_para sgd_param_;
int sgd_inter_;
bool sgd_silent_;
std::string solver_name_;
public:
sgd_solver();
virtual ~sgd_solver();
virtual double SGD_Evaluate(const array<double> &x, array<double> &g) = 0;
virtual int SGD_Progress(double fx, const array<double> &x, const sgd_para &param, const int k);
void sgd_silent();
void set_sgd_report_interval(int inter);
void set_sgd_para(const sgd_para &param);
void show_solver();
void sgd_error_str(sgd_return_code err_code, std::ostream &ss = std::clog, bool err_throw = false);
sgd_para default_sgd_para();
#ifdef GCTL_OPTIMIZATION_TOML
void set_sgd_para(const toml::value &toml_data);
#endif // GCTL_OPTIMIZATION_TOML
sgd_return_code momentum(array<double> &m);
sgd_return_code nag(array<double> &m);
sgd_return_code adagrad(array<double> &m);
sgd_return_code rmsprop(array<double> &m);
sgd_return_code adam(array<double> &m);
sgd_return_code nadam(array<double> &m);
sgd_return_code adamax(array<double> &m);
sgd_return_code adabelief(array<double> &m);
void SGD_Minimize(array<double> &m, sgd_solver_type solver_id = ADAM, std::ostream &ss = std::clog, bool verbose = true, bool err_throw = false);
};
}
#endif // _GCTL_SGD_H