update on docs
This commit is contained in:
parent
86e539756b
commit
ddec148579
@ -25,34 +25,25 @@
|
|||||||
*/
|
*/
|
||||||
enum sgd_return_e
|
enum sgd_return_e
|
||||||
{
|
{
|
||||||
SGD_SUCCESS = 0,
|
SGD_SUCCESS = 0, ///< The optimization terminated successfully.
|
||||||
SGD_CONVERGENCE = 1,
|
SGD_CONVERGENCE = 1, ///< The optimization reached convergence.
|
||||||
SGD_STOP, //2
|
SGD_STOP, ///< The process stopped by the monitoring function.
|
||||||
SGD_UNKNOWN_ERROR = -1024,
|
SGD_UNKNOWN_ERROR = -1024, ///< Unknown error.
|
||||||
// The variable size is negative
|
SGD_INVALID_VARIABLE_SIZE, ///< The variable size is negative
|
||||||
SGD_INVILAD_VARIABLE_SIZE, //-1023
|
SGD_INVALID_MAX_ITERATIONS, ///< The maximal iteration times is negative.
|
||||||
// The maximal iteration times is negative.
|
SGD_INVALID_EPSILON, ///< The epsilon is negative.
|
||||||
SGD_INVILAD_MAX_ITERATIONS, //-1022
|
SGD_REACHED_MAX_ITERATIONS, ///< Iteration reached max limit.
|
||||||
// The epsilon is negative.
|
SGD_INVALID_MU, ///< Invalid value for mu.
|
||||||
SGD_INVILAD_EPSILON, //-1021
|
SGD_INVALID_ALPHA, ///< Invalid value for alpha.
|
||||||
// Iteration reached max limit
|
SGD_INVALID_BETA, ///< Invalid value for beta.
|
||||||
SGD_REACHED_MAX_ITERATIONS,
|
SGD_INVALID_SIGMA, ///< Invalid value for sigma.
|
||||||
// Invalid value for mu
|
SGD_NAN_VALUE, ///< Nan value.
|
||||||
SGD_INVALID_MU,
|
|
||||||
// Invalid value for alpha
|
|
||||||
SGD_INVALID_ALPHA,
|
|
||||||
// Invalid value for beta
|
|
||||||
SGD_INVALID_BETA,
|
|
||||||
// Invalid value for sigma
|
|
||||||
SGD_INVALID_SIGMA,
|
|
||||||
// Nan value
|
|
||||||
SGD_NAN_VALUE,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Default parameter for the SGD methods.
|
* Default parameter for the SGD methods.
|
||||||
*/
|
*/
|
||||||
static const sgd_para defparam = {100, 1e-6, 0.01, 0.001, 0.9, 0.999, 1e-8};
|
static const sgd_para defparam = {300, 1e-6, 0.01, 0.001, 0.9, 0.999, 1e-8};
|
||||||
|
|
||||||
sgd_float *sgd_malloc(const int n_size)
|
sgd_float *sgd_malloc(const int n_size)
|
||||||
{
|
{
|
||||||
@ -85,13 +76,13 @@ const char* sgd_error_str(int er_index)
|
|||||||
return "The iteration stopped by the progress evaluation function.";
|
return "The iteration stopped by the progress evaluation function.";
|
||||||
case SGD_UNKNOWN_ERROR:
|
case SGD_UNKNOWN_ERROR:
|
||||||
return "Unknown error.";
|
return "Unknown error.";
|
||||||
case SGD_INVILAD_VARIABLE_SIZE:
|
case SGD_INVALID_VARIABLE_SIZE:
|
||||||
return "Invalid array size.";
|
return "Invalid array size.";
|
||||||
case SGD_INVILAD_MAX_ITERATIONS:
|
case SGD_INVALID_MAX_ITERATIONS:
|
||||||
return "Invalid maximal iteration times.";
|
return "Invalid maximal iteration times.";
|
||||||
case SGD_REACHED_MAX_ITERATIONS:
|
case SGD_REACHED_MAX_ITERATIONS:
|
||||||
return "The maximal iteration is reached.";
|
return "The maximal iteration is reached.";
|
||||||
case SGD_INVILAD_EPSILON:
|
case SGD_INVALID_EPSILON:
|
||||||
return "Invalid value for epsilon.";
|
return "Invalid value for epsilon.";
|
||||||
case SGD_INVALID_BETA:
|
case SGD_INVALID_BETA:
|
||||||
return "Invalid value for beta.";
|
return "Invalid value for beta.";
|
||||||
@ -190,9 +181,9 @@ int momentum(sgd_evaulate_ptr Evafp, sgd_progress_ptr Profp, sgd_float *fx, sgd_
|
|||||||
sgd_para para = (param != nullptr) ? (*param) : defparam;
|
sgd_para para = (param != nullptr) ? (*param) : defparam;
|
||||||
|
|
||||||
//check parameters
|
//check parameters
|
||||||
if (n_size <= 0) return SGD_INVILAD_VARIABLE_SIZE;
|
if (n_size <= 0) return SGD_INVALID_VARIABLE_SIZE;
|
||||||
if (para.iteration <= 0) return SGD_INVILAD_MAX_ITERATIONS;
|
if (para.iteration <= 0) return SGD_INVALID_MAX_ITERATIONS;
|
||||||
if (para.epsilon < 0) return SGD_INVILAD_EPSILON;
|
if (para.epsilon < 0) return SGD_INVALID_EPSILON;
|
||||||
if (para.mu < 0 || para.mu >= 1.0) return SGD_INVALID_MU;
|
if (para.mu < 0 || para.mu >= 1.0) return SGD_INVALID_MU;
|
||||||
|
|
||||||
sgd_float *mk = sgd_malloc(n_size);
|
sgd_float *mk = sgd_malloc(n_size);
|
||||||
@ -238,9 +229,9 @@ int nag(sgd_evaulate_ptr Evafp, sgd_progress_ptr Profp, sgd_float *fx, sgd_float
|
|||||||
sgd_para para = (param != nullptr) ? (*param) : defparam;
|
sgd_para para = (param != nullptr) ? (*param) : defparam;
|
||||||
|
|
||||||
//check parameters
|
//check parameters
|
||||||
if (n_size <= 0) return SGD_INVILAD_VARIABLE_SIZE;
|
if (n_size <= 0) return SGD_INVALID_VARIABLE_SIZE;
|
||||||
if (para.iteration <= 0) return SGD_INVILAD_MAX_ITERATIONS;
|
if (para.iteration <= 0) return SGD_INVALID_MAX_ITERATIONS;
|
||||||
if (para.epsilon < 0) return SGD_INVILAD_EPSILON;
|
if (para.epsilon < 0) return SGD_INVALID_EPSILON;
|
||||||
if (para.mu < 0 || para.mu >= 1.0) return SGD_INVALID_MU;
|
if (para.mu < 0 || para.mu >= 1.0) return SGD_INVALID_MU;
|
||||||
|
|
||||||
sgd_float *mk = sgd_malloc(n_size);
|
sgd_float *mk = sgd_malloc(n_size);
|
||||||
@ -293,9 +284,9 @@ int adagrad(sgd_evaulate_ptr Evafp, sgd_progress_ptr Profp, sgd_float *fx, sgd_f
|
|||||||
sgd_para para = (param != nullptr) ? (*param) : defparam;
|
sgd_para para = (param != nullptr) ? (*param) : defparam;
|
||||||
|
|
||||||
//check parameters
|
//check parameters
|
||||||
if (n_size <= 0) return SGD_INVILAD_VARIABLE_SIZE;
|
if (n_size <= 0) return SGD_INVALID_VARIABLE_SIZE;
|
||||||
if (para.iteration <= 0) return SGD_INVILAD_MAX_ITERATIONS;
|
if (para.iteration <= 0) return SGD_INVALID_MAX_ITERATIONS;
|
||||||
if (para.epsilon < 0.0) return SGD_INVILAD_EPSILON;
|
if (para.epsilon < 0.0) return SGD_INVALID_EPSILON;
|
||||||
if (para.sigma < 0.0) return SGD_INVALID_SIGMA;
|
if (para.sigma < 0.0) return SGD_INVALID_SIGMA;
|
||||||
|
|
||||||
sgd_float *mk = sgd_malloc(n_size);
|
sgd_float *mk = sgd_malloc(n_size);
|
||||||
@ -341,9 +332,9 @@ int rmsprop(sgd_evaulate_ptr Evafp, sgd_progress_ptr Profp, sgd_float *fx, sgd_f
|
|||||||
sgd_para para = (param != nullptr) ? (*param) : defparam;
|
sgd_para para = (param != nullptr) ? (*param) : defparam;
|
||||||
|
|
||||||
//check parameters
|
//check parameters
|
||||||
if (n_size <= 0) return SGD_INVILAD_VARIABLE_SIZE;
|
if (n_size <= 0) return SGD_INVALID_VARIABLE_SIZE;
|
||||||
if (para.iteration <= 0) return SGD_INVILAD_MAX_ITERATIONS;
|
if (para.iteration <= 0) return SGD_INVALID_MAX_ITERATIONS;
|
||||||
if (para.epsilon < 0.0) return SGD_INVILAD_EPSILON;
|
if (para.epsilon < 0.0) return SGD_INVALID_EPSILON;
|
||||||
if (para.sigma < 0.0) return SGD_INVALID_SIGMA;
|
if (para.sigma < 0.0) return SGD_INVALID_SIGMA;
|
||||||
|
|
||||||
sgd_float *vk = sgd_malloc(n_size);
|
sgd_float *vk = sgd_malloc(n_size);
|
||||||
@ -389,9 +380,9 @@ int adam(sgd_evaulate_ptr Evafp, sgd_progress_ptr Profp, sgd_float *fx, sgd_floa
|
|||||||
sgd_para para = (param != nullptr) ? (*param) : defparam;
|
sgd_para para = (param != nullptr) ? (*param) : defparam;
|
||||||
|
|
||||||
//check parameters
|
//check parameters
|
||||||
if (n_size <= 0) return SGD_INVILAD_VARIABLE_SIZE;
|
if (n_size <= 0) return SGD_INVALID_VARIABLE_SIZE;
|
||||||
if (para.iteration <= 0) return SGD_INVILAD_MAX_ITERATIONS;
|
if (para.iteration <= 0) return SGD_INVALID_MAX_ITERATIONS;
|
||||||
if (para.epsilon < 0) return SGD_INVILAD_EPSILON;
|
if (para.epsilon < 0) return SGD_INVALID_EPSILON;
|
||||||
if (para.alpha < 0) return SGD_INVALID_ALPHA;
|
if (para.alpha < 0) return SGD_INVALID_ALPHA;
|
||||||
if (para.beta_1 < 0.0 || para.beta_1 >= 1.0) return SGD_INVALID_BETA;
|
if (para.beta_1 < 0.0 || para.beta_1 >= 1.0) return SGD_INVALID_BETA;
|
||||||
if (para.beta_2 < 0.0 || para.beta_2 >= 1.0) return SGD_INVALID_BETA;
|
if (para.beta_2 < 0.0 || para.beta_2 >= 1.0) return SGD_INVALID_BETA;
|
||||||
@ -451,9 +442,9 @@ int nadam(sgd_evaulate_ptr Evafp, sgd_progress_ptr Profp, sgd_float *fx, sgd_flo
|
|||||||
sgd_para para = (param != nullptr) ? (*param) : defparam;
|
sgd_para para = (param != nullptr) ? (*param) : defparam;
|
||||||
|
|
||||||
//check parameters
|
//check parameters
|
||||||
if (n_size <= 0) return SGD_INVILAD_VARIABLE_SIZE;
|
if (n_size <= 0) return SGD_INVALID_VARIABLE_SIZE;
|
||||||
if (para.iteration <= 0) return SGD_INVILAD_MAX_ITERATIONS;
|
if (para.iteration <= 0) return SGD_INVALID_MAX_ITERATIONS;
|
||||||
if (para.epsilon < 0) return SGD_INVILAD_EPSILON;
|
if (para.epsilon < 0) return SGD_INVALID_EPSILON;
|
||||||
if (para.alpha < 0) return SGD_INVALID_ALPHA;
|
if (para.alpha < 0) return SGD_INVALID_ALPHA;
|
||||||
if (para.beta_1 < 0.0 || para.beta_1 >= 1.0) return SGD_INVALID_BETA;
|
if (para.beta_1 < 0.0 || para.beta_1 >= 1.0) return SGD_INVALID_BETA;
|
||||||
if (para.beta_2 < 0.0 || para.beta_2 >= 1.0) return SGD_INVALID_BETA;
|
if (para.beta_2 < 0.0 || para.beta_2 >= 1.0) return SGD_INVALID_BETA;
|
||||||
@ -523,9 +514,9 @@ int adamax(sgd_evaulate_ptr Evafp, sgd_progress_ptr Profp, sgd_float *fx, sgd_fl
|
|||||||
if (param == nullptr) para.alpha = 0.002;
|
if (param == nullptr) para.alpha = 0.002;
|
||||||
|
|
||||||
//check parameters
|
//check parameters
|
||||||
if (n_size <= 0) return SGD_INVILAD_VARIABLE_SIZE;
|
if (n_size <= 0) return SGD_INVALID_VARIABLE_SIZE;
|
||||||
if (para.iteration <= 0) return SGD_INVILAD_MAX_ITERATIONS;
|
if (para.iteration <= 0) return SGD_INVALID_MAX_ITERATIONS;
|
||||||
if (para.epsilon < 0) return SGD_INVILAD_EPSILON;
|
if (para.epsilon < 0) return SGD_INVALID_EPSILON;
|
||||||
if (para.alpha < 0) return SGD_INVALID_ALPHA;
|
if (para.alpha < 0) return SGD_INVALID_ALPHA;
|
||||||
if (para.beta_1 < 0.0 || para.beta_1 >= 1.0) return SGD_INVALID_BETA;
|
if (para.beta_1 < 0.0 || para.beta_1 >= 1.0) return SGD_INVALID_BETA;
|
||||||
if (para.beta_2 < 0.0 || para.beta_2 >= 1.0) return SGD_INVALID_BETA;
|
if (para.beta_2 < 0.0 || para.beta_2 >= 1.0) return SGD_INVALID_BETA;
|
||||||
@ -581,9 +572,9 @@ int adabelief(sgd_evaulate_ptr Evafp, sgd_progress_ptr Profp, sgd_float *fx, sgd
|
|||||||
sgd_para para = (param != nullptr) ? (*param) : defparam;
|
sgd_para para = (param != nullptr) ? (*param) : defparam;
|
||||||
|
|
||||||
//check parameters
|
//check parameters
|
||||||
if (n_size <= 0) return SGD_INVILAD_VARIABLE_SIZE;
|
if (n_size <= 0) return SGD_INVALID_VARIABLE_SIZE;
|
||||||
if (para.iteration <= 0) return SGD_INVILAD_MAX_ITERATIONS;
|
if (para.iteration <= 0) return SGD_INVALID_MAX_ITERATIONS;
|
||||||
if (para.epsilon < 0) return SGD_INVILAD_EPSILON;
|
if (para.epsilon < 0) return SGD_INVALID_EPSILON;
|
||||||
if (para.alpha < 0) return SGD_INVALID_ALPHA;
|
if (para.alpha < 0) return SGD_INVALID_ALPHA;
|
||||||
if (para.beta_1 < 0.0 || para.beta_1 >= 1.0) return SGD_INVALID_BETA;
|
if (para.beta_1 < 0.0 || para.beta_1 >= 1.0) return SGD_INVALID_BETA;
|
||||||
if (para.beta_2 < 0.0 || para.beta_2 >= 1.0) return SGD_INVALID_BETA;
|
if (para.beta_2 < 0.0 || para.beta_2 >= 1.0) return SGD_INVALID_BETA;
|
||||||
|
@ -86,12 +86,12 @@ typedef enum
|
|||||||
} sgd_solver_enum;
|
} sgd_solver_enum;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Parameters of the Adam method.
|
* @brief Parameters of the SGD methods.
|
||||||
*/
|
*/
|
||||||
typedef struct
|
typedef struct
|
||||||
{
|
{
|
||||||
/**
|
/**
|
||||||
* Iteration times for the entire observation set. The default is 100.
|
* Iteration times for the entire observation set. The default is 300.
|
||||||
*/
|
*/
|
||||||
int iteration;
|
int iteration;
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user