diff --git a/lib/lbfgs.c b/lib/lbfgs.c index c73f40c..db4e852 100644 --- a/lib/lbfgs.c +++ b/lib/lbfgs.c @@ -213,15 +213,6 @@ static void owlqn_project( const int end ); -static lbfgsfloatval_t owlqn_direction_line( - const lbfgsfloatval_t* x, - const lbfgsfloatval_t* g, - const lbfgsfloatval_t* s, - const lbfgsfloatval_t c, - const int start, - const int n - ); - #if defined(USE_SSE) && (defined(__SSE__) || defined(__SSE2__)) static int round_out_variables(int n) @@ -342,17 +333,27 @@ int lbfgs( if (n < param.orthantwise_end) { return LBFGSERR_INVALID_ORTHANTWISE_END; } - - switch (param.linesearch) { - case LBFGS_LINESEARCH_MORETHUENTE: - linesearch = line_search_morethuente; - break; - case LBFGS_LINESEARCH_BACKTRACKING: - case LBFGS_LINESEARCH_BACKTRACKING_STRONG: - linesearch = line_search_backtracking; - break; - default: - return LBFGSERR_INVALID_LINESEARCH; + if (param.orthantwise_c != 0.) { + switch (param.linesearch) { + case LBFGS_LINESEARCH_BACKTRACKING: + linesearch = line_search_backtracking_owlqn; + break; + default: + /* Only the backtracking method is available. */ + return LBFGSERR_INVALID_LINESEARCH; + } + } else { + switch (param.linesearch) { + case LBFGS_LINESEARCH_MORETHUENTE: + linesearch = line_search_morethuente; + break; + case LBFGS_LINESEARCH_BACKTRACKING: + case LBFGS_LINESEARCH_BACKTRACKING_STRONG: + linesearch = line_search_backtracking; + break; + default: + return LBFGSERR_INVALID_LINESEARCH; + } } /* Allocate working space. */ @@ -803,12 +804,12 @@ static int line_search_morethuente( lbfgsfloatval_t *stp, const lbfgsfloatval_t* xp, const lbfgsfloatval_t* gp, - lbfgsfloatval_t *wp, + lbfgsfloatval_t *wa, callback_data_t *cd, const lbfgs_parameter_t *param ) { - int i, count = 0; + int count = 0; int brackt, stage1, uinfo = 0; lbfgsfloatval_t dg; lbfgsfloatval_t stx, fx, dgx; @@ -823,11 +824,6 @@ static int line_search_morethuente( return LBFGSERR_INVALIDPARAMETERS; } - /* Choose the orthant for the new point. */ - for (i = 0;i < n;++i) { - wp[i] = (xp[i] == 0.) ? -gp[i] : xp[i]; - } - /* Compute the initial gradient in the search direction. */ vecdot(&dginit, g, s, n); @@ -889,14 +885,9 @@ static int line_search_morethuente( veccpy(x, xp, n); vecadd(x, s, *stp, n); - /* The current point is projected onto the orthant. */ - owlqn_project(x, wp, param->orthantwise_start, param->orthantwise_end); - /* Evaluate the function and gradient values. */ *f = cd->proc_evaluate(cd->instance, x, g, cd->n, *stp); - *f += param->orthantwise_c * owlqn_x1norm(x, param->orthantwise_start, param->orthantwise_end); -// vecdot(&dg, g, s, n); - dg = owlqn_direction_line(x, g, s, param->orthantwise_c, param->orthantwise_start, param->orthantwise_end); + vecdot(&dg, g, s, n); ftest1 = finit + *stp * dgtest; ++count; @@ -1364,47 +1355,3 @@ static void owlqn_project( } } } - -static lbfgsfloatval_t owlqn_direction_line( - const lbfgsfloatval_t* x, - const lbfgsfloatval_t* g, - const lbfgsfloatval_t* s, - const lbfgsfloatval_t c, - const int start, - const int n - ) -{ - int i; - lbfgsfloatval_t d = 0.; - - /* Compute the negative of gradients. */ - for (i = 0;i < start;++i) { - d += s[i] * g[i]; - } - - /* Use psuedo-gradients for orthant-wise updates. */ - for (i = start;i < n;++i) { - /* Notice that: - (-s[i] < 0) <==> (g[i] < -param->orthantwise_c) - (-s[i] > 0) <==> (param->orthantwise_c < g[i]) - as the result of the lbfgs() function for orthant-wise updates. - */ - if (s[i] != 0.) { - if (x[i] < 0.) { - /* Differentiable. */ - d += s[i] * (g[i] - c); - } else if (0. < x[i]) { - /* Differentiable. */ - d += s[i] * (g[i] + c); - } else if (s[i] < 0.) { - /* Take the left partial derivative. */ - d += s[i] * (g[i] - c); - } else if (0. < s[i]) { - /* Take the right partial derivative. */ - d += s[i] * (g[i] + c); - } - } - } - - return d; -}