diff --git a/include/lbfgs.h b/include/lbfgs.h index 525111f..b594c95 100644 --- a/include/lbfgs.h +++ b/include/lbfgs.h @@ -154,6 +154,8 @@ enum { LBFGS_LINESEARCH_MORETHUENTE = 0, /** Backtracking method. */ LBFGS_LINESEARCH_BACKTRACKING, + /** Backtracking method with strong Wolfe condition. */ + LBFGS_LINESEARCH_BACKTRACKING_STRONGWOLFE, }; /** diff --git a/lib/lbfgs.c b/lib/lbfgs.c index 0cc03c6..29b7956 100644 --- a/lib/lbfgs.c +++ b/lib/lbfgs.c @@ -143,6 +143,18 @@ static int line_search_backtracking( const lbfgs_parameter_t *param ); +static int line_search_backtracking_strong_wolfe( + int n, + lbfgsfloatval_t *x, + lbfgsfloatval_t *f, + lbfgsfloatval_t *g, + lbfgsfloatval_t *s, + lbfgsfloatval_t *stp, + lbfgsfloatval_t *xp, + callback_data_t *cd, + const lbfgs_parameter_t *param + ); + static int line_search_morethuente( int n, lbfgsfloatval_t *x, @@ -334,6 +346,9 @@ int lbfgs( case LBFGS_LINESEARCH_BACKTRACKING: linesearch = line_search_backtracking; break; + case LBFGS_LINESEARCH_BACKTRACKING_STRONGWOLFE: + linesearch = line_search_backtracking_strong_wolfe; + break; default: return LBFGSERR_INVALID_LINESEARCH; } @@ -678,6 +693,105 @@ static int line_search_backtracking( +static int line_search_backtracking_strong_wolfe( + int n, + lbfgsfloatval_t *x, + lbfgsfloatval_t *f, + lbfgsfloatval_t *g, + lbfgsfloatval_t *s, + lbfgsfloatval_t *stp, + lbfgsfloatval_t *xp, + callback_data_t *cd, + const lbfgs_parameter_t *param + ) +{ + int ret = 0, count = 0; + lbfgsfloatval_t dg, norm, mult; + lbfgsfloatval_t finit, dginit = 0., dgtest; + const lbfgsfloatval_t wolfe = 0.9, dec = 0.7, inc = 1.5; + + /* Check the input parameters for errors. */ + if (*stp <= 0.) { + return LBFGSERR_INVALIDPARAMETERS; + } + + /* Compute the initial gradient in the search direction. */ + if (param->orthantwise_c != 0.) { + dginit = owlqn_direction_line(x, g, s, param->orthantwise_c, param->orthantwise_start, param->orthantwise_end); + } else { + vecdot(&dginit, g, s, n); + } + + /* Make sure that s points to a descent direction. */ + if (0 < dginit) { + return LBFGSERR_INCREASEGRADIENT; + } + + /* The initial value of the objective function. */ + finit = *f; + dgtest = param->ftol * dginit; + + /* Copy the value of x to the work area. */ + veccpy(xp, x, n); + + for (;;) { + veccpy(x, xp, n); + vecadd(x, s, *stp, n); + + if (param->orthantwise_c != 0.) { + /* The current point is projected onto the orthant of the initial one. */ + owlqn_project(x, xp, param->orthantwise_start, param->orthantwise_end); + } + + /* Evaluate the function and gradient values. */ + *f = cd->proc_evaluate(cd->instance, x, g, cd->n, *stp); + if (0. < param->orthantwise_c) { + /* Compute the L1 norm of the variables and add it to the object value. */ + norm = owlqn_x1norm(x, param->orthantwise_start, param->orthantwise_end); + *f += norm * param->orthantwise_c; + + dg = owlqn_direction_line(x, g, s, param->orthantwise_c, param->orthantwise_start, param->orthantwise_end); + } else { + vecdot(&dg, g, s, n); + } + + ++count; + + if (*f <= finit + *stp * dgtest) { + /* The sufficient decrease condition. */ + if (dg > -wolfe * dginit) { + mult = dec; + } else if (dg < wolfe * dginit) { + mult = inc; + } else { + /* Strong Wolfe condition. */ + return count; + } + } else { + mult = dec; + } + + if (*stp < param->min_step) { + /* The step is the minimum value. */ + ret = LBFGSERR_MINIMUMSTEP; + break; + } + if (param->max_linesearch <= count) { + /* Maximum number of iteration. */ + ret = LBFGSERR_MAXIMUMLINESEARCH; + break; + } + + *stp *= mult; + } + + /* Revert to the previous position. */ + veccpy(x, xp, n); + return ret; +} + + + static int line_search_morethuente( int n, lbfgsfloatval_t *x,