Added a new parameter orthantwise_start so that users can protect some variables from being regularized.

git-svn-id: file:///home/svnrepos/software/liblbfgs/trunk@10 ecf4c44f-38d1-4fa4-9757-a0b4dd0349fc
This commit is contained in:
naoaki 2008-07-02 03:04:48 +00:00
parent 554fbdaed6
commit 4473a5dde1
2 changed files with 19 additions and 12 deletions

View File

@ -105,6 +105,8 @@ enum {
LBFGSERR_INVALID_MAXLINESEARCH,
/** Invalid parameter lbfgs_parameter_t::orthantwise_c specified. */
LBFGSERR_INVALID_ORTHANTWISE,
/** Invalid parameter lbfgs_parameter_t::orthantwise_start specified. */
LBFGSERR_INVALID_ORTHANTWISE_START,
/** The line-search step went out of the interval of uncertainty. */
LBFGSERR_OUTOFINTERVAL,
/** A logic error occurred; alternatively, the interval of uncertainty
@ -250,6 +252,8 @@ typedef struct {
* F(x) and gradients G(x) as usual. The default value is zero.
*/
lbfgsfloatval_t orthantwise_c;
int orthantwise_start;
} lbfgs_parameter_t;

View File

@ -113,7 +113,7 @@ typedef struct tag_iteration_data iteration_data_t;
static const lbfgs_parameter_t _defparam = {
6, 1e-5, 0, LBFGS_LINESEARCH_DEFAULT, 20,
1e-20, 1e20, 1e-4, 0.9, 1.0e-16,
0.0,
0.0, 0,
};
/* Forward function declarations. */
@ -267,6 +267,9 @@ int lbfgs(
if (param->orthantwise_c < 0.) {
return LBFGSERR_INVALID_ORTHANTWISE;
}
if (param->orthantwise_start < 0 || n < param->orthantwise_start) {
return LBFGSERR_INVALID_ORTHANTWISE_START;
}
switch (param->linesearch) {
case LBFGS_LINESEARCH_MORETHUENTE:
linesearch = line_search_morethuente;
@ -314,7 +317,7 @@ int lbfgs(
if (0. < param->orthantwise_c) {
/* Compute L1-regularization factor and add it to the object value. */
norm = 0.;
for (i = 0;i < n;++i) {
for (i = param->orthantwise_start;i < n;++i) {
norm += fabs(x[i]);
}
fx += norm * param->orthantwise_c;
@ -325,7 +328,7 @@ int lbfgs(
vecncpy(d, g, n);
} else {
/* Compute the negative of psuedo-gradients. */
for (i = 0;i < n;++i) {
for (i = param->orthantwise_start;i < n;++i) {
if (x[i] < 0.) {
/* Differentiable. */
d[i] = -g[i] + param->orthantwise_c;
@ -430,7 +433,7 @@ int lbfgs(
vecncpy(d, g, n);
} else {
/* Compute the negative of psuedo-gradients. */
for (i = 0;i < n;++i) {
for (i = param->orthantwise_start;i < n;++i) {
if (x[i] < 0.) {
/* Differentiable. */
d[i] = -g[i] + param->orthantwise_c;
@ -480,7 +483,7 @@ int lbfgs(
Constrain the search direction for orthant-wise updates.
*/
if (param->orthantwise_c != 0.) {
for (i = 0;i < n;++i) {
for (i = param->orthantwise_start;i < n;++i) {
if (d[i] * w[i] <= 0) {
d[i] = 0;
}
@ -542,7 +545,7 @@ static int line_search_backtracking(
/* Compute the initial gradient in the search direction. */
if (param->orthantwise_c != 0.) {
/* Use psuedo-gradients for orthant-wise updates. */
for (i = 0;i < n;++i) {
for (i = param->orthantwise_start;i < n;++i) {
/* Notice that:
(-s[i] < 0) <==> (g[i] < -param->orthantwise_c)
(-s[i] > 0) <==> (param->orthantwise_c < g[i])
@ -586,7 +589,7 @@ static int line_search_backtracking(
if (param->orthantwise_c != 0.) {
/* The current point is projected onto the orthant of the initial one. */
for (i = 0;i < n;++i) {
for (i = param->orthantwise_start;i < n;++i) {
if (x[i] * xp[i] < 0.) {
x[i] = 0.;
}
@ -598,7 +601,7 @@ static int line_search_backtracking(
if (0. < param->orthantwise_c) {
/* Compute L1-regularization factor and add it to the object value. */
norm = 0.;
for (i = 0;i < n;++i) {
for (i = param->orthantwise_start;i < n;++i) {
norm += fabs(x[i]);
}
*f += norm * param->orthantwise_c;
@ -662,7 +665,7 @@ static int line_search_morethuente(
if (param->orthantwise_c != 0.) {
/* Use psuedo-gradients for orthant-wise updates. */
dginit = 0.;
for (i = 0;i < n;++i) {
for (i = param->orthantwise_start;i < n;++i) {
/* Notice that:
(-s[i] < 0) <==> (g[i] < -param->orthantwise_c)
(-s[i] > 0) <==> (param->orthantwise_c < g[i])
@ -751,7 +754,7 @@ static int line_search_morethuente(
if (param->orthantwise_c != 0.) {
/* The current point is projected onto the orthant of the previous one. */
for (i = 0;i < n;++i) {
for (i = param->orthantwise_start;i < n;++i) {
if (x[i] * wa[i] < 0.) {
x[i] = 0.;
}
@ -763,14 +766,14 @@ static int line_search_morethuente(
if (0. < param->orthantwise_c) {
/* Compute L1-regularization factor and add it to the object value. */
norm = 0.;
for (i = 0;i < n;++i) {
for (i = param->orthantwise_start;i < n;++i) {
norm += fabs(x[i]);
}
*f += norm * param->orthantwise_c;
/* Use psuedo-gradients for orthant-wise updates. */
dg = 0.;
for (i = 0;i < n;++i) {
for (i = param->orthantwise_start;i < n;++i) {
if (x[i] < 0.) {
/* Differentiable. */
dg += s[i] * (g[i] - param->orthantwise_c);