libsgd/src/lib/sgd.h
2020-10-21 15:37:13 +08:00

225 lines
6.4 KiB
C
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/******************************************************//**
* C++ library of the Stochastic Gradient Descent (SGD) methods.
*
* Copyright (c) 2020-2031 Yi Zhang (zhangyiss@icloud.com)
* All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*********************************************************/
#ifndef _SGD_H
#define _SGD_H
#ifndef _cplusplus
extern "C"
{
#include "stddef.h"
#endif
/**
* @brief A simple definition of the float type we use here.
* Easy to change in the future. For now it is just an alias of the double type.
*/
typedef double sgd_float;
/**
* @brief Types of method that could be recognized by the sgd_solver() function.
*/
typedef enum
{
/**
* Classic momentum.
*/
SGD_MOMENTUM,
/**
* Nesterovs accelerated gradient (NAG)
*/
SGD_NAG,
/**
* AdaGrad method.
*/
SGD_ADAGRAD,
/**
* RMSProp method.
*/
SGD_RMSPROP,
/**
* Adam method.
*/
SGD_ADAM,
/**
* Nadam method.
*/
SGD_NADAM,
/**
* AdaMax method.
*/
SGD_ADAMAX,
/**
* AdaBelief method.
*/
SGD_ADABELIEF,
} sgd_solver_enum;
/**
* @brief Parameters of the Adam method.
*/
typedef struct
{
/**
* Iteration times for the entire observation set. The default is 100.
*/
int iteration;
/**
* Epsilon for convergence test. This parameter determines the accuracy
* with which the solution is to be found. Must be bigger than zero and
* the default is 1e-6.
*/
sgd_float epsilon;
/**
* Damping rate of the classic momentum method and the NAG method, which
* is typically given between 0 and 1. The default is 0.01.
*/
sgd_float mu;
/**
* Step size of the iteration. The default value is 0.001 for Adam and 0.002
* for AdaMax.
*/
sgd_float alpha;
/**
* Exponential decay rates for the first order moment estimates. The range of this
* parameter is [0, 1) and the default value is 0.9.
*/
sgd_float beta_1;
/**
* Exponential decay rates for the second order moment estimates. The range of this
* parameter is [0, 1) and the default value is 0.999.
*/
sgd_float beta_2;
/**
* A small positive number validates the algorithm. The default value is 1e-8.
*/
sgd_float sigma;
} sgd_para;
/**
* @brief Callback interface for calculating the value of objective function
* and the corresponding model gradients.
*
* @param instance The user data sent for the sgd_solver() functions by the client.
* @param x Pointer of the solution.
* @param g Pointer of the model gradient.
* @param n_size Length of the solution.
* @param m Index of the observation.
*
* @return Value of objective function.
*/
typedef sgd_float (*sgd_evaulate_ptr)(void *instance, const sgd_float *x, sgd_float *g,
const int n_size, const int m);
/**
* @brief Callback interface for monitoring the progress and terminate the iteration
* if necessary.
*
* @param instance The user data sent for the sgd_solver() functions by the client.
* @param fx Current value of the objective function.
* @param x Current solution.
* @param g Current model gradients.
* @param param User defined iteration parameters.
* @param n_size Length of the solution array.
* @param k Times of the iteration.
*
* @return int Zero to continue the optimization process. Otherwise, the optimization
* process will be terminated.
*/
typedef int (*sgd_progress_ptr)(void *instance, sgd_float fx, const sgd_float *x, const sgd_float *g,
const sgd_para *param, const int n_size, const int k);
/**
* @brief Locate memory for a sgd_float pointer type.
*
* @param[in] n_size Size of the sgd_float array.
*
* @return Pointer of the data
*/
sgd_float *sgd_malloc(const int n_size);
/**
* @brief Destroy memory used by the sgd_float type array.
*
* @param x Pointer of the array.
*/
void sgd_free(sgd_float *x);
/**
* @brief Return a sgd_para type instance with default values.
*
* @return A sgd_para type instance.
*/
sgd_para sgd_default_parameters();
/**
* @brief Return a string explanation for the sgd_solver() function's return values.
*
* @param[in] er_index The error index returned by the sgd_solver() function.
*
* @return A string explanation of the error.
*/
const char* sgd_error_str(int er_index);
/**
* @brief An Adam solver function.
*
* @note The size of all arrays must be equal to n_size.
*
* @param[in] Evafp Callback function for calculating the objective function and its gradient.
* @param[in] Profp Callback function for monitoring the optimization process.
* @param fx Returned best value of the objective function by now.
* @param m Pointer of the solution array.
* @param[in] n_size Length of the solution array.
* @param[in] m_size Length of the observation.
* @param[in] param Parameters of optimization process.
* @param instance The user data sent for the function by the client.
* @param solver_id Solver type used to solve the objective. The default value is SGD_ADAM.
*
* @return Status of the function.
*/
int sgd_solver(sgd_evaulate_ptr Evafp, sgd_progress_ptr Profp, sgd_float *fx, sgd_float *m,
const int n_size, const int m_size, const sgd_para *param, void *instance,
sgd_solver_enum solver_id = SGD_ADAM);
#ifndef _cplusplus
}
#endif
#endif // _SGD_H