Constrain the search direction.
git-svn-id: file:///home/svnrepos/software/liblbfgs/trunk@49 ecf4c44f-38d1-4fa4-9757-a0b4dd0349fc
This commit is contained in:
parent
6a1b860ea3
commit
f4362676e6
189
lib/lbfgs.c
189
lib/lbfgs.c
@ -260,7 +260,9 @@ int lbfgs(
|
||||
lbfgs_parameter_t param = (_param != NULL) ? (*_param) : _defparam;
|
||||
const int m = param.m;
|
||||
|
||||
lbfgsfloatval_t *xp = NULL, *g = NULL, *gp = NULL, *pg = NULL, *d = NULL, *w = NULL, *pf = NULL;
|
||||
lbfgsfloatval_t *xp = NULL;
|
||||
lbfgsfloatval_t *g = NULL, *gp = NULL, *pg = NULL;
|
||||
lbfgsfloatval_t *d = NULL, *w = NULL, *pf = NULL;
|
||||
iteration_data_t *lm = NULL, *it = NULL;
|
||||
lbfgsfloatval_t ys, yy;
|
||||
lbfgsfloatval_t xnorm, gnorm, beta;
|
||||
@ -360,14 +362,22 @@ int lbfgs(
|
||||
xp = (lbfgsfloatval_t*)vecalloc(n * sizeof(lbfgsfloatval_t));
|
||||
g = (lbfgsfloatval_t*)vecalloc(n * sizeof(lbfgsfloatval_t));
|
||||
gp = (lbfgsfloatval_t*)vecalloc(n * sizeof(lbfgsfloatval_t));
|
||||
pg = (lbfgsfloatval_t*)vecalloc(n * sizeof(lbfgsfloatval_t));
|
||||
d = (lbfgsfloatval_t*)vecalloc(n * sizeof(lbfgsfloatval_t));
|
||||
w = (lbfgsfloatval_t*)vecalloc(n * sizeof(lbfgsfloatval_t));
|
||||
if (xp == NULL || g == NULL || gp == NULL || pg == NULL || d == NULL || w == NULL) {
|
||||
if (xp == NULL || g == NULL || gp == NULL || d == NULL || w == NULL) {
|
||||
ret = LBFGSERR_OUTOFMEMORY;
|
||||
goto lbfgs_exit;
|
||||
}
|
||||
|
||||
if (param.orthantwise_c != 0.) {
|
||||
/* Allocate working space for OW-LQN. */
|
||||
pg = (lbfgsfloatval_t*)vecalloc(n * sizeof(lbfgsfloatval_t));
|
||||
if (pg == NULL) {
|
||||
ret = LBFGSERR_OUTOFMEMORY;
|
||||
goto lbfgs_exit;
|
||||
}
|
||||
}
|
||||
|
||||
/* Allocate limited memory storage. */
|
||||
lm = (iteration_data_t*)vecalloc(m * sizeof(iteration_data_t));
|
||||
if (lm == NULL) {
|
||||
@ -399,7 +409,10 @@ int lbfgs(
|
||||
/* Compute the L1 norm of the variable and add it to the object value. */
|
||||
xnorm = owlqn_x1norm(x, param.orthantwise_start, param.orthantwise_end);
|
||||
fx += xnorm * param.orthantwise_c;
|
||||
owlqn_pseudo_gradient(pg, x, g, n, param.orthantwise_c, param.orthantwise_start, param.orthantwise_end);
|
||||
owlqn_pseudo_gradient(
|
||||
pg, x, g, n,
|
||||
param.orthantwise_c, param.orthantwise_start, param.orthantwise_end
|
||||
);
|
||||
}
|
||||
|
||||
/* Store the initial value of the objective function. */
|
||||
@ -449,7 +462,10 @@ int lbfgs(
|
||||
ls = linesearch(n, x, &fx, g, d, &step, xp, gp, w, &cd, ¶m);
|
||||
} else {
|
||||
ls = linesearch(n, x, &fx, g, d, &step, xp, pg, w, &cd, ¶m);
|
||||
owlqn_pseudo_gradient(pg, x, g, n, param.orthantwise_c, param.orthantwise_start, param.orthantwise_end);
|
||||
owlqn_pseudo_gradient(
|
||||
pg, x, g, n,
|
||||
param.orthantwise_c, param.orthantwise_start, param.orthantwise_end
|
||||
);
|
||||
}
|
||||
if (ls < 0) {
|
||||
/* Revert to the previous point. */
|
||||
@ -576,12 +592,15 @@ int lbfgs(
|
||||
j = (j + 1) % m; /* if (++j == m) j = 0; */
|
||||
}
|
||||
|
||||
/*
|
||||
Constrain the search direction for orthant-wise updates.
|
||||
*/
|
||||
if (param.orthantwise_c != 0.) {
|
||||
vecdot(&gnorm, d, pg, n);
|
||||
if (gnorm >= 0) {
|
||||
vecncpy(gp, pg, n);
|
||||
owlqn_project(d, gp, param.orthantwise_start, param.orthantwise_end);
|
||||
}
|
||||
for (i = param.orthantwise_start;i < param.orthantwise_end;++i) {
|
||||
if (d[i] * pg[i] >= 0) {
|
||||
d[i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
@ -606,9 +625,9 @@ lbfgs_exit:
|
||||
}
|
||||
vecfree(lm);
|
||||
}
|
||||
vecfree(pg);
|
||||
vecfree(w);
|
||||
vecfree(d);
|
||||
vecfree(pg);
|
||||
vecfree(gp);
|
||||
vecfree(g);
|
||||
vecfree(xp);
|
||||
@ -618,80 +637,6 @@ lbfgs_exit:
|
||||
|
||||
|
||||
|
||||
static int line_search_backtracking_owlqn(
|
||||
int n,
|
||||
lbfgsfloatval_t *x,
|
||||
lbfgsfloatval_t *f,
|
||||
lbfgsfloatval_t *g,
|
||||
lbfgsfloatval_t *s,
|
||||
lbfgsfloatval_t *stp,
|
||||
const lbfgsfloatval_t* xp,
|
||||
const lbfgsfloatval_t* gp,
|
||||
lbfgsfloatval_t *wp,
|
||||
callback_data_t *cd,
|
||||
const lbfgs_parameter_t *param
|
||||
)
|
||||
{
|
||||
int i, ret = 0, count = 0;
|
||||
lbfgsfloatval_t width = 0.5, norm = 0.;
|
||||
lbfgsfloatval_t finit = *f, dgtest;
|
||||
|
||||
/* Check the input parameters for errors. */
|
||||
if (*stp <= 0.) {
|
||||
return LBFGSERR_INVALIDPARAMETERS;
|
||||
}
|
||||
|
||||
/* Choose the orthant for the new point. */
|
||||
for (i = 0;i < n;++i) {
|
||||
wp[i] = (xp[i] == 0.) ? -gp[i] : xp[i];
|
||||
}
|
||||
|
||||
for (;;) {
|
||||
/* Update the current point. */
|
||||
veccpy(x, xp, n);
|
||||
vecadd(x, s, *stp, n);
|
||||
|
||||
/* The current point is projected onto the orthant. */
|
||||
owlqn_project(x, wp, param->orthantwise_start, param->orthantwise_end);
|
||||
|
||||
/* Evaluate the function and gradient values. */
|
||||
*f = cd->proc_evaluate(cd->instance, x, g, cd->n, *stp);
|
||||
|
||||
/* Compute the L1 norm of the variables and add it to the object value. */
|
||||
norm = owlqn_x1norm(x, param->orthantwise_start, param->orthantwise_end);
|
||||
*f += norm * param->orthantwise_c;
|
||||
|
||||
++count;
|
||||
|
||||
dgtest = 0.;
|
||||
for (i = 0;i < n;++i) {
|
||||
dgtest += (x[i] - xp[i]) * gp[i];
|
||||
}
|
||||
|
||||
if (*f <= finit + param->ftol * dgtest) {
|
||||
/* The sufficient decrease condition. */
|
||||
return count;
|
||||
}
|
||||
|
||||
if (*stp < param->min_step) {
|
||||
/* The step is the minimum value. */
|
||||
return LBFGSERR_MINIMUMSTEP;
|
||||
}
|
||||
if (*stp > param->max_step) {
|
||||
/* The step is the maximum value. */
|
||||
return LBFGSERR_MAXIMUMSTEP;
|
||||
}
|
||||
if (param->max_linesearch <= count) {
|
||||
/* Maximum number of iteration. */
|
||||
return LBFGSERR_MAXIMUMLINESEARCH;
|
||||
}
|
||||
|
||||
(*stp) *= width;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
static int line_search_backtracking(
|
||||
int n,
|
||||
lbfgsfloatval_t *x,
|
||||
@ -778,6 +723,80 @@ static int line_search_backtracking(
|
||||
|
||||
|
||||
|
||||
static int line_search_backtracking_owlqn(
|
||||
int n,
|
||||
lbfgsfloatval_t *x,
|
||||
lbfgsfloatval_t *f,
|
||||
lbfgsfloatval_t *g,
|
||||
lbfgsfloatval_t *s,
|
||||
lbfgsfloatval_t *stp,
|
||||
const lbfgsfloatval_t* xp,
|
||||
const lbfgsfloatval_t* gp,
|
||||
lbfgsfloatval_t *wp,
|
||||
callback_data_t *cd,
|
||||
const lbfgs_parameter_t *param
|
||||
)
|
||||
{
|
||||
int i, ret = 0, count = 0;
|
||||
lbfgsfloatval_t width = 0.5, norm = 0.;
|
||||
lbfgsfloatval_t finit = *f, dgtest;
|
||||
|
||||
/* Check the input parameters for errors. */
|
||||
if (*stp <= 0.) {
|
||||
return LBFGSERR_INVALIDPARAMETERS;
|
||||
}
|
||||
|
||||
/* Choose the orthant for the new point. */
|
||||
for (i = 0;i < n;++i) {
|
||||
wp[i] = (xp[i] == 0.) ? -gp[i] : xp[i];
|
||||
}
|
||||
|
||||
for (;;) {
|
||||
/* Update the current point. */
|
||||
veccpy(x, xp, n);
|
||||
vecadd(x, s, *stp, n);
|
||||
|
||||
/* The current point is projected onto the orthant. */
|
||||
owlqn_project(x, wp, param->orthantwise_start, param->orthantwise_end);
|
||||
|
||||
/* Evaluate the function and gradient values. */
|
||||
*f = cd->proc_evaluate(cd->instance, x, g, cd->n, *stp);
|
||||
|
||||
/* Compute the L1 norm of the variables and add it to the object value. */
|
||||
norm = owlqn_x1norm(x, param->orthantwise_start, param->orthantwise_end);
|
||||
*f += norm * param->orthantwise_c;
|
||||
|
||||
++count;
|
||||
|
||||
dgtest = 0.;
|
||||
for (i = 0;i < n;++i) {
|
||||
dgtest += (x[i] - xp[i]) * gp[i];
|
||||
}
|
||||
|
||||
if (*f <= finit + param->ftol * dgtest) {
|
||||
/* The sufficient decrease condition. */
|
||||
return count;
|
||||
}
|
||||
|
||||
if (*stp < param->min_step) {
|
||||
/* The step is the minimum value. */
|
||||
return LBFGSERR_MINIMUMSTEP;
|
||||
}
|
||||
if (*stp > param->max_step) {
|
||||
/* The step is the maximum value. */
|
||||
return LBFGSERR_MAXIMUMSTEP;
|
||||
}
|
||||
if (param->max_linesearch <= count) {
|
||||
/* Maximum number of iteration. */
|
||||
return LBFGSERR_MAXIMUMLINESEARCH;
|
||||
}
|
||||
|
||||
(*stp) *= width;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
static int line_search_morethuente(
|
||||
int n,
|
||||
lbfgsfloatval_t *x,
|
||||
|
Loading…
Reference in New Issue
Block a user