libsgd/src/lib/sgd.h

225 lines
6.4 KiB
C
Raw Normal View History

2020-10-19 12:50:18 +08:00
/******************************************************//**
* C++ library of the Stochastic Gradient Descent (SGD) methods.
*
* Copyright (c) 2020-2031 Yi Zhang (zhangyiss@icloud.com)
* All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*********************************************************/
#ifndef _SGD_H
#define _SGD_H
#ifndef _cplusplus
extern "C"
{
#include "stddef.h"
#endif
/**
* @brief A simple definition of the float type we use here.
* Easy to change in the future. For now it is just an alias of the double type.
*/
typedef double sgd_float;
/**
* @brief Types of method that could be recognized by the sgd_solver() function.
*/
typedef enum
{
2020-10-20 16:09:21 +08:00
/**
* Classic momentum.
*/
SGD_MOMENTUM,
/**
* Nesterovs accelerated gradient (NAG)
*/
SGD_NAG,
/**
* AdaGrad method.
*/
SGD_ADAGRAD,
/**
* RMSProp method.
*/
SGD_RMSPROP,
2020-10-19 12:50:18 +08:00
/**
* Adam method.
*/
SGD_ADAM,
2020-10-20 16:09:21 +08:00
/**
* Nadam method.
*/
SGD_NADAM,
2020-10-19 12:50:18 +08:00
/**
* AdaMax method.
*/
SGD_ADAMAX,
/**
* AdaBelief method.
*/
SGD_ADABELIEF,
} sgd_solver_enum;
/**
* @brief Parameters of the Adam method.
*/
typedef struct
{
/**
* Iteration times for the entire observation set. The default is 100.
*/
int iteration;
/**
* Epsilon for convergence test. This parameter determines the accuracy
* with which the solution is to be found. Must be bigger than zero and
* the default is 1e-6.
*/
sgd_float epsilon;
2020-10-20 16:09:21 +08:00
/**
2020-10-21 10:49:26 +08:00
* Damping rate of the classic momentum method and the NAG method, which
* is typically given between 0 and 1. The default is 0.01.
2020-10-20 16:09:21 +08:00
*/
sgd_float mu;
2020-10-19 12:50:18 +08:00
/**
* Step size of the iteration. The default value is 0.001 for Adam and 0.002
* for AdaMax.
*/
sgd_float alpha;
/**
* Exponential decay rates for the first order moment estimates. The range of this
* parameter is [0, 1) and the default value is 0.9.
*/
sgd_float beta_1;
/**
* Exponential decay rates for the second order moment estimates. The range of this
* parameter is [0, 1) and the default value is 0.999.
*/
sgd_float beta_2;
/**
* A small positive number validates the algorithm. The default value is 1e-8.
*/
sgd_float sigma;
} sgd_para;
/**
* @brief Callback interface for calculating the value of objective function
* and the corresponding model gradients.
*
* @param instance The user data sent for the sgd_solver() functions by the client.
* @param x Pointer of the solution.
* @param g Pointer of the model gradient.
* @param n_size Length of the solution.
2020-10-21 15:37:13 +08:00
* @param m Index of the observation.
2020-10-19 12:50:18 +08:00
*
* @return Value of objective function.
*/
typedef sgd_float (*sgd_evaulate_ptr)(void *instance, const sgd_float *x, sgd_float *g,
const int n_size, const int m);
/**
* @brief Callback interface for monitoring the progress and terminate the iteration
* if necessary.
*
* @param instance The user data sent for the sgd_solver() functions by the client.
* @param fx Current value of the objective function.
* @param x Current solution.
* @param g Current model gradients.
* @param param User defined iteration parameters.
* @param n_size Length of the solution array.
* @param k Times of the iteration.
*
* @return int Zero to continue the optimization process. Otherwise, the optimization
* process will be terminated.
*/
typedef int (*sgd_progress_ptr)(void *instance, sgd_float fx, const sgd_float *x, const sgd_float *g,
const sgd_para *param, const int n_size, const int k);
/**
* @brief Locate memory for a sgd_float pointer type.
*
* @param[in] n_size Size of the sgd_float array.
*
* @return Pointer of the data
*/
sgd_float *sgd_malloc(const int n_size);
/**
* @brief Destroy memory used by the sgd_float type array.
*
* @param x Pointer of the array.
*/
void sgd_free(sgd_float *x);
/**
* @brief Return a sgd_para type instance with default values.
*
* @return A sgd_para type instance.
*/
sgd_para sgd_default_parameters();
/**
* @brief Return a string explanation for the sgd_solver() function's return values.
*
* @param[in] er_index The error index returned by the sgd_solver() function.
*
* @return A string explanation of the error.
*/
const char* sgd_error_str(int er_index);
/**
* @brief An Adam solver function.
*
* @note The size of all arrays must be equal to n_size.
*
* @param[in] Evafp Callback function for calculating the objective function and its gradient.
* @param[in] Profp Callback function for monitoring the optimization process.
* @param fx Returned best value of the objective function by now.
* @param m Pointer of the solution array.
* @param[in] n_size Length of the solution array.
* @param[in] m_size Length of the observation.
* @param[in] param Parameters of optimization process.
* @param instance The user data sent for the function by the client.
* @param solver_id Solver type used to solve the objective. The default value is SGD_ADAM.
*
* @return Status of the function.
*/
int sgd_solver(sgd_evaulate_ptr Evafp, sgd_progress_ptr Profp, sgd_float *fx, sgd_float *m,
const int n_size, const int m_size, const sgd_para *param, void *instance,
sgd_solver_enum solver_id = SGD_ADAM);
#ifndef _cplusplus
}
#endif
#endif // _SGD_H