gctl_optimization/lib/optimization/sgd.h
2025-02-22 18:10:15 +08:00

202 lines
6.4 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/********************************************************
* ██████╗ ██████╗████████╗██╗
* ██╔════╝ ██╔════╝╚══██╔══╝██║
* ██║ ███╗██║ ██║ ██║
* ██║ ██║██║ ██║ ██║
* ╚██████╔╝╚██████╗ ██║ ███████╗
* ╚═════╝ ╚═════╝ ╚═╝ ╚══════╝
* Geophysical Computational Tools & Library (GCTL)
*
* Copyright (c) 2022 Yi Zhang (yizhang-geo@zju.edu.cn)
*
* GCTL is distributed under a dual licensing scheme. You can redistribute
* it and/or modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation, either version 2
* of the License, or (at your option) any later version. You should have
* received a copy of the GNU Lesser General Public License along with this
* program. If not, see <http://www.gnu.org/licenses/>.
*
* If the terms and conditions of the LGPL v.2. would prevent you from using
* the GCTL, please consider the option to obtain a commercial license for a
* fee. These licenses are offered by the GCTL's original author. As a rule,
* licenses are provided "as-is", unlimited in time for a one time fee. Please
* send corresponding requests to: yizhang-geo@zju.edu.cn. Please do not forget
* to include some description of your company and the realm of its activities.
* Also add information on how to contact you by electronic and paper mail.
******************************************************/
#ifndef _GCTL_SGD_H
#define _GCTL_SGD_H
#include "gctl/utility.h"
#include "gctl/core.h"
#include "gctl/algorithms.h"
#include "gctl_optimization_config.h"
#ifdef GCTL_OPTIMIZATION_TOML
#include "toml.hpp"
#endif // GCTL_OPTIMIZATION_TOML
#if defined _WINDOWS || __WIN32__
#include "windows.h"
#endif // _WINDOWS || __WIN32__
#ifdef GSTL_OPENMP
#include "omp.h"
#endif // GSTL_OPENMP
namespace gctl
{
/**
* @brief Types of method that could be recognized by the sgd_solver() function.
*/
enum sgd_solver_type
{
/**
* Classic momentum.
*/
MOMENTUM,
/**
* Nesterovs accelerated gradient (NAG)
*/
NAG,
/**
* AdaGrad method.
*/
ADAGRAD,
/**
* RMSProp method.
*/
RMSPROP,
/**
* Adam method.
*/
ADAM,
/**
* Nadam method.
*/
NADAM,
/**
* AdaMax method.
*/
ADAMAX,
/**
* AdaBelief method.
*/
ADABELIEF,
};
/**
* @brief return value of the sgd_solver() function.
*/
enum sgd_return_code
{
SGD_SUCCESS = 0, ///< The optimization terminated successfully.
SGD_CONVERGENCE = 1, ///< The optimization reached convergence.
SGD_STOP, ///< The process stopped by the monitoring function.
SGD_UNKNOWN_ERROR = -1024, ///< Unknown error.
SGD_INVALID_VARIABLE_SIZE, ///< The variable size is negative
SGD_INVALID_EPSILON, ///< The epsilon is negative.
SGD_REACHED_MAX_ITERATIONS, ///< Iteration reached max limit.
SGD_INVALID_MU, ///< Invalid value for mu.
SGD_INVALID_ALPHA, ///< Invalid value for alpha.
SGD_INVALID_BETA, ///< Invalid value for beta.
SGD_INVALID_SIGMA, ///< Invalid value for sigma.
SGD_NAN_VALUE, ///< Nan value.
};
/**
* @brief Parameters of the SGD methods.
*/
struct sgd_para
{
/**
* Maximal iteration times. The iteration won't stop unless the convergence
* is reached if this parameter is equal to or smaller than zero. The default
* is 0.
*/
int iteration;
/**
* Epsilon for convergence test. This parameter determines the accuracy
* with which the solution is to be found. Must be bigger than zero and
* the default is 1e-6.
*/
double epsilon;
/**
* Damping rate of the classic momentum method and the NAG method, which
* is typically given between 0 and 1. The default is 0.01.
*/
double mu;
/**
* Step size of the iteration. The default value is 0.01 for Adam and AdaMax.
*/
double alpha;
/**
* Exponential decay rates for the first order moment estimates. The range of this
* parameter is [0, 1) and the default value is 0.9.
*/
double beta_1;
/**
* Exponential decay rates for the second order moment estimates. The range of this
* parameter is [0, 1) and the default value is 0.999.
*/
double beta_2;
/**
* A small positive number validates the algorithm. The default value is 1e-8.
*/
double sigma;
};
class sgd_solver
{
private:
sgd_para sgd_param_;
int sgd_inter_;
bool sgd_silent_;
std::string solver_name_;
public:
sgd_solver();
virtual ~sgd_solver();
virtual double SGD_Evaluate(const array<double> &x, array<double> &g) = 0;
virtual int SGD_Progress(double fx, const array<double> &x, const sgd_para &param, const int k);
void sgd_silent();
void set_sgd_report_interval(int inter);
void set_sgd_para(const sgd_para &param);
void show_solver();
void sgd_error_str(sgd_return_code err_code, std::ostream &ss = std::clog, bool err_throw = false);
sgd_para default_sgd_para();
#ifdef GCTL_OPTIMIZATION_TOML
void set_sgd_para(const toml::value &toml_data);
#endif // GCTL_OPTIMIZATION_TOML
sgd_return_code momentum(array<double> &m);
sgd_return_code nag(array<double> &m);
sgd_return_code adagrad(array<double> &m);
sgd_return_code rmsprop(array<double> &m);
sgd_return_code adam(array<double> &m);
sgd_return_code nadam(array<double> &m);
sgd_return_code adamax(array<double> &m);
sgd_return_code adabelief(array<double> &m);
void SGD_Minimize(array<double> &m, sgd_solver_type solver_id = ADAM, std::ostream &ss = std::clog, bool verbose = true, bool err_throw = false);
};
};
#endif // _GCTL_SGD_H