diff --git a/include/lbfgs.h b/include/lbfgs.h index d23fc21..c41a939 100644 --- a/include/lbfgs.h +++ b/include/lbfgs.h @@ -105,6 +105,8 @@ enum { LBFGSERR_INVALID_MAXLINESEARCH, /** Invalid parameter lbfgs_parameter_t::orthantwise_c specified. */ LBFGSERR_INVALID_ORTHANTWISE, + /** Invalid parameter lbfgs_parameter_t::orthantwise_start specified. */ + LBFGSERR_INVALID_ORTHANTWISE_START, /** The line-search step went out of the interval of uncertainty. */ LBFGSERR_OUTOFINTERVAL, /** A logic error occurred; alternatively, the interval of uncertainty @@ -250,6 +252,8 @@ typedef struct { * F(x) and gradients G(x) as usual. The default value is zero. */ lbfgsfloatval_t orthantwise_c; + + int orthantwise_start; } lbfgs_parameter_t; diff --git a/lib/lbfgs.c b/lib/lbfgs.c index 7bd0a1c..51d8ebb 100644 --- a/lib/lbfgs.c +++ b/lib/lbfgs.c @@ -113,7 +113,7 @@ typedef struct tag_iteration_data iteration_data_t; static const lbfgs_parameter_t _defparam = { 6, 1e-5, 0, LBFGS_LINESEARCH_DEFAULT, 20, 1e-20, 1e20, 1e-4, 0.9, 1.0e-16, - 0.0, + 0.0, 0, }; /* Forward function declarations. */ @@ -267,6 +267,9 @@ int lbfgs( if (param->orthantwise_c < 0.) { return LBFGSERR_INVALID_ORTHANTWISE; } + if (param->orthantwise_start < 0 || n < param->orthantwise_start) { + return LBFGSERR_INVALID_ORTHANTWISE_START; + } switch (param->linesearch) { case LBFGS_LINESEARCH_MORETHUENTE: linesearch = line_search_morethuente; @@ -314,7 +317,7 @@ int lbfgs( if (0. < param->orthantwise_c) { /* Compute L1-regularization factor and add it to the object value. */ norm = 0.; - for (i = 0;i < n;++i) { + for (i = param->orthantwise_start;i < n;++i) { norm += fabs(x[i]); } fx += norm * param->orthantwise_c; @@ -325,7 +328,7 @@ int lbfgs( vecncpy(d, g, n); } else { /* Compute the negative of psuedo-gradients. */ - for (i = 0;i < n;++i) { + for (i = param->orthantwise_start;i < n;++i) { if (x[i] < 0.) { /* Differentiable. */ d[i] = -g[i] + param->orthantwise_c; @@ -430,7 +433,7 @@ int lbfgs( vecncpy(d, g, n); } else { /* Compute the negative of psuedo-gradients. */ - for (i = 0;i < n;++i) { + for (i = param->orthantwise_start;i < n;++i) { if (x[i] < 0.) { /* Differentiable. */ d[i] = -g[i] + param->orthantwise_c; @@ -480,7 +483,7 @@ int lbfgs( Constrain the search direction for orthant-wise updates. */ if (param->orthantwise_c != 0.) { - for (i = 0;i < n;++i) { + for (i = param->orthantwise_start;i < n;++i) { if (d[i] * w[i] <= 0) { d[i] = 0; } @@ -542,7 +545,7 @@ static int line_search_backtracking( /* Compute the initial gradient in the search direction. */ if (param->orthantwise_c != 0.) { /* Use psuedo-gradients for orthant-wise updates. */ - for (i = 0;i < n;++i) { + for (i = param->orthantwise_start;i < n;++i) { /* Notice that: (-s[i] < 0) <==> (g[i] < -param->orthantwise_c) (-s[i] > 0) <==> (param->orthantwise_c < g[i]) @@ -586,7 +589,7 @@ static int line_search_backtracking( if (param->orthantwise_c != 0.) { /* The current point is projected onto the orthant of the initial one. */ - for (i = 0;i < n;++i) { + for (i = param->orthantwise_start;i < n;++i) { if (x[i] * xp[i] < 0.) { x[i] = 0.; } @@ -598,7 +601,7 @@ static int line_search_backtracking( if (0. < param->orthantwise_c) { /* Compute L1-regularization factor and add it to the object value. */ norm = 0.; - for (i = 0;i < n;++i) { + for (i = param->orthantwise_start;i < n;++i) { norm += fabs(x[i]); } *f += norm * param->orthantwise_c; @@ -662,7 +665,7 @@ static int line_search_morethuente( if (param->orthantwise_c != 0.) { /* Use psuedo-gradients for orthant-wise updates. */ dginit = 0.; - for (i = 0;i < n;++i) { + for (i = param->orthantwise_start;i < n;++i) { /* Notice that: (-s[i] < 0) <==> (g[i] < -param->orthantwise_c) (-s[i] > 0) <==> (param->orthantwise_c < g[i]) @@ -751,7 +754,7 @@ static int line_search_morethuente( if (param->orthantwise_c != 0.) { /* The current point is projected onto the orthant of the previous one. */ - for (i = 0;i < n;++i) { + for (i = param->orthantwise_start;i < n;++i) { if (x[i] * wa[i] < 0.) { x[i] = 0.; } @@ -763,14 +766,14 @@ static int line_search_morethuente( if (0. < param->orthantwise_c) { /* Compute L1-regularization factor and add it to the object value. */ norm = 0.; - for (i = 0;i < n;++i) { + for (i = param->orthantwise_start;i < n;++i) { norm += fabs(x[i]); } *f += norm * param->orthantwise_c; /* Use psuedo-gradients for orthant-wise updates. */ dg = 0.; - for (i = 0;i < n;++i) { + for (i = param->orthantwise_start;i < n;++i) { if (x[i] < 0.) { /* Differentiable. */ dg += s[i] * (g[i] - param->orthantwise_c);