2020-10-19 12:50:18 +08:00
|
|
|
|
/******************************************************//**
|
|
|
|
|
* C++ library of the Stochastic Gradient Descent (SGD) methods.
|
|
|
|
|
*
|
|
|
|
|
* Copyright (c) 2020-2031 Yi Zhang (zhangyiss@icloud.com)
|
|
|
|
|
* All rights reserved.
|
|
|
|
|
*
|
|
|
|
|
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
|
|
|
* of this software and associated documentation files (the "Software"), to deal
|
|
|
|
|
* in the Software without restriction, including without limitation the rights
|
|
|
|
|
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
|
|
|
* copies of the Software, and to permit persons to whom the Software is
|
|
|
|
|
* furnished to do so, subject to the following conditions:
|
|
|
|
|
*
|
|
|
|
|
* The above copyright notice and this permission notice shall be included in
|
|
|
|
|
* all copies or substantial portions of the Software.
|
|
|
|
|
*
|
|
|
|
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
|
|
|
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
|
|
|
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
|
|
|
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
|
|
|
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
|
|
|
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
|
|
|
|
* THE SOFTWARE.
|
|
|
|
|
*********************************************************/
|
|
|
|
|
|
|
|
|
|
#ifndef _SGD_H
|
|
|
|
|
#define _SGD_H
|
|
|
|
|
|
|
|
|
|
#ifndef _cplusplus
|
|
|
|
|
extern "C"
|
|
|
|
|
{
|
|
|
|
|
|
|
|
|
|
#include "stddef.h"
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* @brief A simple definition of the float type we use here.
|
|
|
|
|
* Easy to change in the future. For now it is just an alias of the double type.
|
|
|
|
|
*/
|
|
|
|
|
typedef double sgd_float;
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* @brief Types of method that could be recognized by the sgd_solver() function.
|
|
|
|
|
*/
|
|
|
|
|
typedef enum
|
|
|
|
|
{
|
2020-10-20 16:09:21 +08:00
|
|
|
|
/**
|
|
|
|
|
* Classic momentum.
|
|
|
|
|
*/
|
|
|
|
|
SGD_MOMENTUM,
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Nesterov’s accelerated gradient (NAG)
|
|
|
|
|
*/
|
|
|
|
|
SGD_NAG,
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* AdaGrad method.
|
|
|
|
|
*/
|
|
|
|
|
SGD_ADAGRAD,
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* RMSProp method.
|
|
|
|
|
*/
|
|
|
|
|
SGD_RMSPROP,
|
|
|
|
|
|
2020-10-19 12:50:18 +08:00
|
|
|
|
/**
|
|
|
|
|
* Adam method.
|
|
|
|
|
*/
|
|
|
|
|
SGD_ADAM,
|
|
|
|
|
|
2020-10-20 16:09:21 +08:00
|
|
|
|
/**
|
|
|
|
|
* Nadam method.
|
|
|
|
|
*/
|
|
|
|
|
SGD_NADAM,
|
|
|
|
|
|
2020-10-19 12:50:18 +08:00
|
|
|
|
/**
|
|
|
|
|
* AdaMax method.
|
|
|
|
|
*/
|
|
|
|
|
SGD_ADAMAX,
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* AdaBelief method.
|
|
|
|
|
*/
|
|
|
|
|
SGD_ADABELIEF,
|
|
|
|
|
} sgd_solver_enum;
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* @brief Parameters of the Adam method.
|
|
|
|
|
*/
|
|
|
|
|
typedef struct
|
|
|
|
|
{
|
|
|
|
|
/**
|
|
|
|
|
* Iteration times for the entire observation set. The default is 100.
|
|
|
|
|
*/
|
|
|
|
|
int iteration;
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Epsilon for convergence test. This parameter determines the accuracy
|
|
|
|
|
* with which the solution is to be found. Must be bigger than zero and
|
|
|
|
|
* the default is 1e-6.
|
|
|
|
|
*/
|
|
|
|
|
sgd_float epsilon;
|
|
|
|
|
|
2020-10-20 16:09:21 +08:00
|
|
|
|
/**
|
2020-10-21 10:49:26 +08:00
|
|
|
|
* Damping rate of the classic momentum method and the NAG method, which
|
|
|
|
|
* is typically given between 0 and 1. The default is 0.01.
|
2020-10-20 16:09:21 +08:00
|
|
|
|
*/
|
|
|
|
|
sgd_float mu;
|
|
|
|
|
|
2020-10-19 12:50:18 +08:00
|
|
|
|
/**
|
|
|
|
|
* Step size of the iteration. The default value is 0.001 for Adam and 0.002
|
|
|
|
|
* for AdaMax.
|
|
|
|
|
*/
|
|
|
|
|
sgd_float alpha;
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Exponential decay rates for the first order moment estimates. The range of this
|
|
|
|
|
* parameter is [0, 1) and the default value is 0.9.
|
|
|
|
|
*/
|
|
|
|
|
sgd_float beta_1;
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Exponential decay rates for the second order moment estimates. The range of this
|
|
|
|
|
* parameter is [0, 1) and the default value is 0.999.
|
|
|
|
|
*/
|
|
|
|
|
sgd_float beta_2;
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* A small positive number validates the algorithm. The default value is 1e-8.
|
|
|
|
|
*/
|
|
|
|
|
sgd_float sigma;
|
|
|
|
|
} sgd_para;
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* @brief Callback interface for calculating the value of objective function
|
|
|
|
|
* and the corresponding model gradients.
|
|
|
|
|
*
|
|
|
|
|
* @param instance The user data sent for the sgd_solver() functions by the client.
|
|
|
|
|
* @param x Pointer of the solution.
|
|
|
|
|
* @param g Pointer of the model gradient.
|
|
|
|
|
* @param n_size Length of the solution.
|
2020-10-21 15:37:13 +08:00
|
|
|
|
* @param m Index of the observation.
|
2020-10-19 12:50:18 +08:00
|
|
|
|
*
|
|
|
|
|
* @return Value of objective function.
|
|
|
|
|
*/
|
|
|
|
|
typedef sgd_float (*sgd_evaulate_ptr)(void *instance, const sgd_float *x, sgd_float *g,
|
|
|
|
|
const int n_size, const int m);
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* @brief Callback interface for monitoring the progress and terminate the iteration
|
|
|
|
|
* if necessary.
|
|
|
|
|
*
|
|
|
|
|
* @param instance The user data sent for the sgd_solver() functions by the client.
|
|
|
|
|
* @param fx Current value of the objective function.
|
|
|
|
|
* @param x Current solution.
|
|
|
|
|
* @param g Current model gradients.
|
|
|
|
|
* @param param User defined iteration parameters.
|
|
|
|
|
* @param n_size Length of the solution array.
|
|
|
|
|
* @param k Times of the iteration.
|
|
|
|
|
*
|
|
|
|
|
* @return int Zero to continue the optimization process. Otherwise, the optimization
|
|
|
|
|
* process will be terminated.
|
|
|
|
|
*/
|
|
|
|
|
typedef int (*sgd_progress_ptr)(void *instance, sgd_float fx, const sgd_float *x, const sgd_float *g,
|
|
|
|
|
const sgd_para *param, const int n_size, const int k);
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* @brief Locate memory for a sgd_float pointer type.
|
|
|
|
|
*
|
|
|
|
|
* @param[in] n_size Size of the sgd_float array.
|
|
|
|
|
*
|
|
|
|
|
* @return Pointer of the data
|
|
|
|
|
*/
|
|
|
|
|
sgd_float *sgd_malloc(const int n_size);
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* @brief Destroy memory used by the sgd_float type array.
|
|
|
|
|
*
|
|
|
|
|
* @param x Pointer of the array.
|
|
|
|
|
*/
|
|
|
|
|
void sgd_free(sgd_float *x);
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* @brief Return a sgd_para type instance with default values.
|
|
|
|
|
*
|
|
|
|
|
* @return A sgd_para type instance.
|
|
|
|
|
*/
|
|
|
|
|
sgd_para sgd_default_parameters();
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* @brief Return a string explanation for the sgd_solver() function's return values.
|
|
|
|
|
*
|
|
|
|
|
* @param[in] er_index The error index returned by the sgd_solver() function.
|
|
|
|
|
*
|
|
|
|
|
* @return A string explanation of the error.
|
|
|
|
|
*/
|
|
|
|
|
const char* sgd_error_str(int er_index);
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* @brief An Adam solver function.
|
|
|
|
|
*
|
|
|
|
|
* @note The size of all arrays must be equal to n_size.
|
|
|
|
|
*
|
|
|
|
|
* @param[in] Evafp Callback function for calculating the objective function and its gradient.
|
|
|
|
|
* @param[in] Profp Callback function for monitoring the optimization process.
|
|
|
|
|
* @param fx Returned best value of the objective function by now.
|
|
|
|
|
* @param m Pointer of the solution array.
|
|
|
|
|
* @param[in] n_size Length of the solution array.
|
|
|
|
|
* @param[in] m_size Length of the observation.
|
|
|
|
|
* @param[in] param Parameters of optimization process.
|
|
|
|
|
* @param instance The user data sent for the function by the client.
|
|
|
|
|
* @param solver_id Solver type used to solve the objective. The default value is SGD_ADAM.
|
|
|
|
|
*
|
|
|
|
|
* @return Status of the function.
|
|
|
|
|
*/
|
|
|
|
|
int sgd_solver(sgd_evaulate_ptr Evafp, sgd_progress_ptr Profp, sgd_float *fx, sgd_float *m,
|
|
|
|
|
const int n_size, const int m_size, const sgd_para *param, void *instance,
|
|
|
|
|
sgd_solver_enum solver_id = SGD_ADAM);
|
|
|
|
|
|
|
|
|
|
#ifndef _cplusplus
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#endif // _SGD_H
|