diff --git a/lib/lbfgs.c b/lib/lbfgs.c index f5ee917..a1dbc04 100644 --- a/lib/lbfgs.c +++ b/lib/lbfgs.c @@ -133,7 +133,7 @@ typedef int (*line_search_proc)( const lbfgs_parameter_t *param ); -static int line_search_backtracking_loose( +static int line_search_backtracking( int n, lbfgsfloatval_t *x, lbfgsfloatval_t *f, @@ -161,20 +161,6 @@ static int line_search_backtracking_owlqn( const lbfgs_parameter_t *param ); -static int line_search_backtracking_strong_wolfe( - int n, - lbfgsfloatval_t *x, - lbfgsfloatval_t *f, - lbfgsfloatval_t *g, - lbfgsfloatval_t *s, - lbfgsfloatval_t *stp, - const lbfgsfloatval_t* xp, - const lbfgsfloatval_t* gp, - lbfgsfloatval_t *wa, - callback_data_t *cd, - const lbfgs_parameter_t *param - ); - static int line_search_morethuente( int n, lbfgsfloatval_t *x, @@ -376,10 +362,10 @@ int lbfgs( linesearch = line_search_morethuente; break; case LBFGS_LINESEARCH_BACKTRACKING: - linesearch = line_search_backtracking_strong_wolfe; + linesearch = line_search_backtracking; break; case LBFGS_LINESEARCH_BACKTRACKING_LOOSE: - linesearch = line_search_backtracking_loose; + linesearch = line_search_backtracking; break; default: return LBFGSERR_INVALID_LINESEARCH; @@ -722,7 +708,7 @@ static int line_search_backtracking_owlqn( -static int line_search_backtracking_loose( +static int line_search_backtracking( int n, lbfgsfloatval_t *x, lbfgsfloatval_t *f, @@ -737,8 +723,9 @@ static int line_search_backtracking_loose( ) { int ret = 0, count = 0; - lbfgsfloatval_t width = 0.5, norm = 0.; + lbfgsfloatval_t width, dg, norm = 0.; lbfgsfloatval_t finit, dginit = 0., dgtest; + const lbfgsfloatval_t wolfe = 0.9, dec = 0.5, inc = 2.1; /* Check the input parameters for errors. */ if (*stp <= 0.) { @@ -768,8 +755,26 @@ static int line_search_backtracking_loose( if (*f <= finit + *stp * dgtest) { /* The sufficient decrease condition. */ - return count; + + if (param->linesearch == LBFGS_LINESEARCH_BACKTRACKING) { + /* Check the strong Wolfe condition. */ + vecdot(&dg, g, s, n); + if (dg > -wolfe * dginit) { + width = dec; + } else if (dg < wolfe * dginit) { + width = inc; + } else { + /* Strong Wolfe condition. */ + return count; + } + } else { + /* Exit with the loose Wolfe condition. */ + return count; + } + } else { + width = dec; } + if (*stp < param->min_step) { /* The step is the minimum value. */ return LBFGSERR_MINIMUMSTEP; @@ -789,102 +794,6 @@ static int line_search_backtracking_loose( -static int line_search_backtracking_strong_wolfe( - int n, - lbfgsfloatval_t *x, - lbfgsfloatval_t *f, - lbfgsfloatval_t *g, - lbfgsfloatval_t *s, - lbfgsfloatval_t *stp, - const lbfgsfloatval_t* xp, - const lbfgsfloatval_t* gp, - lbfgsfloatval_t *wp, - callback_data_t *cd, - const lbfgs_parameter_t *param - ) -{ - int ret = 0, count = 0; - lbfgsfloatval_t dg, norm, mult; - lbfgsfloatval_t finit, dginit = 0., dgtest; - const lbfgsfloatval_t wolfe = 0.9, dec = 0.7, inc = 1.5; - - /* Check the input parameters for errors. */ - if (*stp <= 0.) { - return LBFGSERR_INVALIDPARAMETERS; - } - - /* Compute the initial gradient in the search direction. */ - if (param->orthantwise_c != 0.) { - dginit = owlqn_direction_line(x, g, s, param->orthantwise_c, param->orthantwise_start, param->orthantwise_end); - } else { - vecdot(&dginit, g, s, n); - } - - /* Make sure that s points to a descent direction. */ - if (0 < dginit) { - return LBFGSERR_INCREASEGRADIENT; - } - - /* The initial value of the objective function. */ - finit = *f; - dgtest = param->ftol * dginit; - - for (;;) { - veccpy(x, xp, n); - vecadd(x, s, *stp, n); - - if (param->orthantwise_c != 0.) { - /* The current point is projected onto the orthant of the initial one. */ - owlqn_project(x, xp, param->orthantwise_start, param->orthantwise_end); - } - - /* Evaluate the function and gradient values. */ - *f = cd->proc_evaluate(cd->instance, x, g, cd->n, *stp); - if (0. < param->orthantwise_c) { - /* Compute the L1 norm of the variables and add it to the object value. */ - norm = owlqn_x1norm(x, param->orthantwise_start, param->orthantwise_end); - *f += norm * param->orthantwise_c; - - dg = owlqn_direction_line(x, g, s, param->orthantwise_c, param->orthantwise_start, param->orthantwise_end); - } else { - vecdot(&dg, g, s, n); - } - - ++count; - - if (*f <= finit + *stp * dgtest) { - /* The sufficient decrease condition. */ - if (dg > -wolfe * dginit) { - mult = dec; - } else if (dg < wolfe * dginit) { - mult = inc; - } else { - /* Strong Wolfe condition. */ - return count; - } - } else { - mult = dec; - } - - if (*stp < param->min_step) { - /* The step is the minimum value. */ - return LBFGSERR_MINIMUMSTEP; - } - if (*stp > param->max_step) { - /* The step is the maximum value. */ - return LBFGSERR_MAXIMUMSTEP; - } - if (param->max_linesearch <= count) { - /* Maximum number of iteration. */ - return LBFGSERR_MAXIMUMLINESEARCH; - } - - *stp *= mult; - } -} - - - static int line_search_morethuente( int n, lbfgsfloatval_t *x,