Added a new parameter orthantwise_start so that users can protect some variables from being regularized.
git-svn-id: file:///home/svnrepos/software/liblbfgs/trunk@10 ecf4c44f-38d1-4fa4-9757-a0b4dd0349fc
This commit is contained in:
parent
554fbdaed6
commit
4473a5dde1
@ -105,6 +105,8 @@ enum {
|
||||
LBFGSERR_INVALID_MAXLINESEARCH,
|
||||
/** Invalid parameter lbfgs_parameter_t::orthantwise_c specified. */
|
||||
LBFGSERR_INVALID_ORTHANTWISE,
|
||||
/** Invalid parameter lbfgs_parameter_t::orthantwise_start specified. */
|
||||
LBFGSERR_INVALID_ORTHANTWISE_START,
|
||||
/** The line-search step went out of the interval of uncertainty. */
|
||||
LBFGSERR_OUTOFINTERVAL,
|
||||
/** A logic error occurred; alternatively, the interval of uncertainty
|
||||
@ -250,6 +252,8 @@ typedef struct {
|
||||
* F(x) and gradients G(x) as usual. The default value is zero.
|
||||
*/
|
||||
lbfgsfloatval_t orthantwise_c;
|
||||
|
||||
int orthantwise_start;
|
||||
} lbfgs_parameter_t;
|
||||
|
||||
|
||||
|
27
lib/lbfgs.c
27
lib/lbfgs.c
@ -113,7 +113,7 @@ typedef struct tag_iteration_data iteration_data_t;
|
||||
static const lbfgs_parameter_t _defparam = {
|
||||
6, 1e-5, 0, LBFGS_LINESEARCH_DEFAULT, 20,
|
||||
1e-20, 1e20, 1e-4, 0.9, 1.0e-16,
|
||||
0.0,
|
||||
0.0, 0,
|
||||
};
|
||||
|
||||
/* Forward function declarations. */
|
||||
@ -267,6 +267,9 @@ int lbfgs(
|
||||
if (param->orthantwise_c < 0.) {
|
||||
return LBFGSERR_INVALID_ORTHANTWISE;
|
||||
}
|
||||
if (param->orthantwise_start < 0 || n < param->orthantwise_start) {
|
||||
return LBFGSERR_INVALID_ORTHANTWISE_START;
|
||||
}
|
||||
switch (param->linesearch) {
|
||||
case LBFGS_LINESEARCH_MORETHUENTE:
|
||||
linesearch = line_search_morethuente;
|
||||
@ -314,7 +317,7 @@ int lbfgs(
|
||||
if (0. < param->orthantwise_c) {
|
||||
/* Compute L1-regularization factor and add it to the object value. */
|
||||
norm = 0.;
|
||||
for (i = 0;i < n;++i) {
|
||||
for (i = param->orthantwise_start;i < n;++i) {
|
||||
norm += fabs(x[i]);
|
||||
}
|
||||
fx += norm * param->orthantwise_c;
|
||||
@ -325,7 +328,7 @@ int lbfgs(
|
||||
vecncpy(d, g, n);
|
||||
} else {
|
||||
/* Compute the negative of psuedo-gradients. */
|
||||
for (i = 0;i < n;++i) {
|
||||
for (i = param->orthantwise_start;i < n;++i) {
|
||||
if (x[i] < 0.) {
|
||||
/* Differentiable. */
|
||||
d[i] = -g[i] + param->orthantwise_c;
|
||||
@ -430,7 +433,7 @@ int lbfgs(
|
||||
vecncpy(d, g, n);
|
||||
} else {
|
||||
/* Compute the negative of psuedo-gradients. */
|
||||
for (i = 0;i < n;++i) {
|
||||
for (i = param->orthantwise_start;i < n;++i) {
|
||||
if (x[i] < 0.) {
|
||||
/* Differentiable. */
|
||||
d[i] = -g[i] + param->orthantwise_c;
|
||||
@ -480,7 +483,7 @@ int lbfgs(
|
||||
Constrain the search direction for orthant-wise updates.
|
||||
*/
|
||||
if (param->orthantwise_c != 0.) {
|
||||
for (i = 0;i < n;++i) {
|
||||
for (i = param->orthantwise_start;i < n;++i) {
|
||||
if (d[i] * w[i] <= 0) {
|
||||
d[i] = 0;
|
||||
}
|
||||
@ -542,7 +545,7 @@ static int line_search_backtracking(
|
||||
/* Compute the initial gradient in the search direction. */
|
||||
if (param->orthantwise_c != 0.) {
|
||||
/* Use psuedo-gradients for orthant-wise updates. */
|
||||
for (i = 0;i < n;++i) {
|
||||
for (i = param->orthantwise_start;i < n;++i) {
|
||||
/* Notice that:
|
||||
(-s[i] < 0) <==> (g[i] < -param->orthantwise_c)
|
||||
(-s[i] > 0) <==> (param->orthantwise_c < g[i])
|
||||
@ -586,7 +589,7 @@ static int line_search_backtracking(
|
||||
|
||||
if (param->orthantwise_c != 0.) {
|
||||
/* The current point is projected onto the orthant of the initial one. */
|
||||
for (i = 0;i < n;++i) {
|
||||
for (i = param->orthantwise_start;i < n;++i) {
|
||||
if (x[i] * xp[i] < 0.) {
|
||||
x[i] = 0.;
|
||||
}
|
||||
@ -598,7 +601,7 @@ static int line_search_backtracking(
|
||||
if (0. < param->orthantwise_c) {
|
||||
/* Compute L1-regularization factor and add it to the object value. */
|
||||
norm = 0.;
|
||||
for (i = 0;i < n;++i) {
|
||||
for (i = param->orthantwise_start;i < n;++i) {
|
||||
norm += fabs(x[i]);
|
||||
}
|
||||
*f += norm * param->orthantwise_c;
|
||||
@ -662,7 +665,7 @@ static int line_search_morethuente(
|
||||
if (param->orthantwise_c != 0.) {
|
||||
/* Use psuedo-gradients for orthant-wise updates. */
|
||||
dginit = 0.;
|
||||
for (i = 0;i < n;++i) {
|
||||
for (i = param->orthantwise_start;i < n;++i) {
|
||||
/* Notice that:
|
||||
(-s[i] < 0) <==> (g[i] < -param->orthantwise_c)
|
||||
(-s[i] > 0) <==> (param->orthantwise_c < g[i])
|
||||
@ -751,7 +754,7 @@ static int line_search_morethuente(
|
||||
|
||||
if (param->orthantwise_c != 0.) {
|
||||
/* The current point is projected onto the orthant of the previous one. */
|
||||
for (i = 0;i < n;++i) {
|
||||
for (i = param->orthantwise_start;i < n;++i) {
|
||||
if (x[i] * wa[i] < 0.) {
|
||||
x[i] = 0.;
|
||||
}
|
||||
@ -763,14 +766,14 @@ static int line_search_morethuente(
|
||||
if (0. < param->orthantwise_c) {
|
||||
/* Compute L1-regularization factor and add it to the object value. */
|
||||
norm = 0.;
|
||||
for (i = 0;i < n;++i) {
|
||||
for (i = param->orthantwise_start;i < n;++i) {
|
||||
norm += fabs(x[i]);
|
||||
}
|
||||
*f += norm * param->orthantwise_c;
|
||||
|
||||
/* Use psuedo-gradients for orthant-wise updates. */
|
||||
dg = 0.;
|
||||
for (i = 0;i < n;++i) {
|
||||
for (i = param->orthantwise_start;i < n;++i) {
|
||||
if (x[i] < 0.) {
|
||||
/* Differentiable. */
|
||||
dg += s[i] * (g[i] - param->orthantwise_c);
|
||||
|
Loading…
Reference in New Issue
Block a user