Constrain the search direction.
git-svn-id: file:///home/svnrepos/software/liblbfgs/trunk@49 ecf4c44f-38d1-4fa4-9757-a0b4dd0349fc
This commit is contained in:
		
							
								
								
									
										189
									
								
								lib/lbfgs.c
									
									
									
									
									
								
							
							
						
						
									
										189
									
								
								lib/lbfgs.c
									
									
									
									
									
								
							@@ -260,7 +260,9 @@ int lbfgs(
 | 
				
			|||||||
    lbfgs_parameter_t param = (_param != NULL) ? (*_param) : _defparam;
 | 
					    lbfgs_parameter_t param = (_param != NULL) ? (*_param) : _defparam;
 | 
				
			||||||
    const int m = param.m;
 | 
					    const int m = param.m;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    lbfgsfloatval_t *xp = NULL, *g = NULL, *gp = NULL, *pg = NULL, *d = NULL, *w = NULL, *pf = NULL;
 | 
					    lbfgsfloatval_t *xp = NULL;
 | 
				
			||||||
 | 
					    lbfgsfloatval_t *g = NULL, *gp = NULL, *pg = NULL;
 | 
				
			||||||
 | 
					    lbfgsfloatval_t *d = NULL, *w = NULL, *pf = NULL;
 | 
				
			||||||
    iteration_data_t *lm = NULL, *it = NULL;
 | 
					    iteration_data_t *lm = NULL, *it = NULL;
 | 
				
			||||||
    lbfgsfloatval_t ys, yy;
 | 
					    lbfgsfloatval_t ys, yy;
 | 
				
			||||||
    lbfgsfloatval_t xnorm, gnorm, beta;
 | 
					    lbfgsfloatval_t xnorm, gnorm, beta;
 | 
				
			||||||
@@ -360,14 +362,22 @@ int lbfgs(
 | 
				
			|||||||
    xp = (lbfgsfloatval_t*)vecalloc(n * sizeof(lbfgsfloatval_t));
 | 
					    xp = (lbfgsfloatval_t*)vecalloc(n * sizeof(lbfgsfloatval_t));
 | 
				
			||||||
    g = (lbfgsfloatval_t*)vecalloc(n * sizeof(lbfgsfloatval_t));
 | 
					    g = (lbfgsfloatval_t*)vecalloc(n * sizeof(lbfgsfloatval_t));
 | 
				
			||||||
    gp = (lbfgsfloatval_t*)vecalloc(n * sizeof(lbfgsfloatval_t));
 | 
					    gp = (lbfgsfloatval_t*)vecalloc(n * sizeof(lbfgsfloatval_t));
 | 
				
			||||||
    pg = (lbfgsfloatval_t*)vecalloc(n * sizeof(lbfgsfloatval_t));
 | 
					 | 
				
			||||||
    d = (lbfgsfloatval_t*)vecalloc(n * sizeof(lbfgsfloatval_t));
 | 
					    d = (lbfgsfloatval_t*)vecalloc(n * sizeof(lbfgsfloatval_t));
 | 
				
			||||||
    w = (lbfgsfloatval_t*)vecalloc(n * sizeof(lbfgsfloatval_t));
 | 
					    w = (lbfgsfloatval_t*)vecalloc(n * sizeof(lbfgsfloatval_t));
 | 
				
			||||||
    if (xp == NULL || g == NULL || gp == NULL || pg == NULL || d == NULL || w == NULL) {
 | 
					    if (xp == NULL || g == NULL || gp == NULL || d == NULL || w == NULL) {
 | 
				
			||||||
        ret = LBFGSERR_OUTOFMEMORY;
 | 
					        ret = LBFGSERR_OUTOFMEMORY;
 | 
				
			||||||
        goto lbfgs_exit;
 | 
					        goto lbfgs_exit;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (param.orthantwise_c != 0.) {
 | 
				
			||||||
 | 
					        /* Allocate working space for OW-LQN. */
 | 
				
			||||||
 | 
					        pg = (lbfgsfloatval_t*)vecalloc(n * sizeof(lbfgsfloatval_t));
 | 
				
			||||||
 | 
					        if (pg == NULL) {
 | 
				
			||||||
 | 
					            ret = LBFGSERR_OUTOFMEMORY;
 | 
				
			||||||
 | 
					            goto lbfgs_exit;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /* Allocate limited memory storage. */
 | 
					    /* Allocate limited memory storage. */
 | 
				
			||||||
    lm = (iteration_data_t*)vecalloc(m * sizeof(iteration_data_t));
 | 
					    lm = (iteration_data_t*)vecalloc(m * sizeof(iteration_data_t));
 | 
				
			||||||
    if (lm == NULL) {
 | 
					    if (lm == NULL) {
 | 
				
			||||||
@@ -399,7 +409,10 @@ int lbfgs(
 | 
				
			|||||||
        /* Compute the L1 norm of the variable and add it to the object value. */
 | 
					        /* Compute the L1 norm of the variable and add it to the object value. */
 | 
				
			||||||
        xnorm = owlqn_x1norm(x, param.orthantwise_start, param.orthantwise_end);
 | 
					        xnorm = owlqn_x1norm(x, param.orthantwise_start, param.orthantwise_end);
 | 
				
			||||||
        fx += xnorm * param.orthantwise_c;
 | 
					        fx += xnorm * param.orthantwise_c;
 | 
				
			||||||
        owlqn_pseudo_gradient(pg, x, g, n, param.orthantwise_c, param.orthantwise_start, param.orthantwise_end);
 | 
					        owlqn_pseudo_gradient(
 | 
				
			||||||
 | 
					            pg, x, g, n,
 | 
				
			||||||
 | 
					            param.orthantwise_c, param.orthantwise_start, param.orthantwise_end
 | 
				
			||||||
 | 
					            );
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /* Store the initial value of the objective function. */
 | 
					    /* Store the initial value of the objective function. */
 | 
				
			||||||
@@ -449,7 +462,10 @@ int lbfgs(
 | 
				
			|||||||
            ls = linesearch(n, x, &fx, g, d, &step, xp, gp, w, &cd, ¶m);
 | 
					            ls = linesearch(n, x, &fx, g, d, &step, xp, gp, w, &cd, ¶m);
 | 
				
			||||||
        } else {
 | 
					        } else {
 | 
				
			||||||
            ls = linesearch(n, x, &fx, g, d, &step, xp, pg, w, &cd, ¶m);
 | 
					            ls = linesearch(n, x, &fx, g, d, &step, xp, pg, w, &cd, ¶m);
 | 
				
			||||||
            owlqn_pseudo_gradient(pg, x, g, n, param.orthantwise_c, param.orthantwise_start, param.orthantwise_end);
 | 
					            owlqn_pseudo_gradient(
 | 
				
			||||||
 | 
					                pg, x, g, n,
 | 
				
			||||||
 | 
					                param.orthantwise_c, param.orthantwise_start, param.orthantwise_end
 | 
				
			||||||
 | 
					                );
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        if (ls < 0) {
 | 
					        if (ls < 0) {
 | 
				
			||||||
            /* Revert to the previous point. */
 | 
					            /* Revert to the previous point. */
 | 
				
			||||||
@@ -576,12 +592,15 @@ int lbfgs(
 | 
				
			|||||||
            j = (j + 1) % m;        /* if (++j == m) j = 0; */
 | 
					            j = (j + 1) % m;        /* if (++j == m) j = 0; */
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        /*
 | 
				
			||||||
 | 
					            Constrain the search direction for orthant-wise updates.
 | 
				
			||||||
 | 
					         */
 | 
				
			||||||
        if (param.orthantwise_c != 0.) {
 | 
					        if (param.orthantwise_c != 0.) {
 | 
				
			||||||
            vecdot(&gnorm, d, pg, n);
 | 
					            for (i = param.orthantwise_start;i < param.orthantwise_end;++i) {
 | 
				
			||||||
            if (gnorm >= 0) {
 | 
								    if (d[i] * pg[i] >= 0) {
 | 
				
			||||||
                vecncpy(gp, pg, n);
 | 
									    d[i] = 0;
 | 
				
			||||||
                owlqn_project(d, gp, param.orthantwise_start, param.orthantwise_end);
 | 
								    }
 | 
				
			||||||
            }
 | 
							    }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        /*
 | 
					        /*
 | 
				
			||||||
@@ -606,9 +625,9 @@ lbfgs_exit:
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
        vecfree(lm);
 | 
					        vecfree(lm);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					    vecfree(pg);
 | 
				
			||||||
    vecfree(w);
 | 
					    vecfree(w);
 | 
				
			||||||
    vecfree(d);
 | 
					    vecfree(d);
 | 
				
			||||||
    vecfree(pg);
 | 
					 | 
				
			||||||
    vecfree(gp);
 | 
					    vecfree(gp);
 | 
				
			||||||
    vecfree(g);
 | 
					    vecfree(g);
 | 
				
			||||||
    vecfree(xp);
 | 
					    vecfree(xp);
 | 
				
			||||||
@@ -618,80 +637,6 @@ lbfgs_exit:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
static int line_search_backtracking_owlqn(
 | 
					 | 
				
			||||||
    int n,
 | 
					 | 
				
			||||||
    lbfgsfloatval_t *x,
 | 
					 | 
				
			||||||
    lbfgsfloatval_t *f,
 | 
					 | 
				
			||||||
    lbfgsfloatval_t *g,
 | 
					 | 
				
			||||||
    lbfgsfloatval_t *s,
 | 
					 | 
				
			||||||
    lbfgsfloatval_t *stp,
 | 
					 | 
				
			||||||
    const lbfgsfloatval_t* xp,
 | 
					 | 
				
			||||||
    const lbfgsfloatval_t* gp,
 | 
					 | 
				
			||||||
    lbfgsfloatval_t *wp,
 | 
					 | 
				
			||||||
    callback_data_t *cd,
 | 
					 | 
				
			||||||
    const lbfgs_parameter_t *param
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
{
 | 
					 | 
				
			||||||
    int i, ret = 0, count = 0;
 | 
					 | 
				
			||||||
    lbfgsfloatval_t width = 0.5, norm = 0.;
 | 
					 | 
				
			||||||
    lbfgsfloatval_t finit = *f, dgtest;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    /* Check the input parameters for errors. */
 | 
					 | 
				
			||||||
    if (*stp <= 0.) {
 | 
					 | 
				
			||||||
        return LBFGSERR_INVALIDPARAMETERS;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    /* Choose the orthant for the new point. */
 | 
					 | 
				
			||||||
    for (i = 0;i < n;++i) {
 | 
					 | 
				
			||||||
        wp[i] = (xp[i] == 0.) ? -gp[i] : xp[i];
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    for (;;) {
 | 
					 | 
				
			||||||
        /* Update the current point. */
 | 
					 | 
				
			||||||
        veccpy(x, xp, n);
 | 
					 | 
				
			||||||
        vecadd(x, s, *stp, n);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        /* The current point is projected onto the orthant. */
 | 
					 | 
				
			||||||
        owlqn_project(x, wp, param->orthantwise_start, param->orthantwise_end);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        /* Evaluate the function and gradient values. */
 | 
					 | 
				
			||||||
        *f = cd->proc_evaluate(cd->instance, x, g, cd->n, *stp);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        /* Compute the L1 norm of the variables and add it to the object value. */
 | 
					 | 
				
			||||||
        norm = owlqn_x1norm(x, param->orthantwise_start, param->orthantwise_end);
 | 
					 | 
				
			||||||
        *f += norm * param->orthantwise_c;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        ++count;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        dgtest = 0.;
 | 
					 | 
				
			||||||
        for (i = 0;i < n;++i) {
 | 
					 | 
				
			||||||
            dgtest += (x[i] - xp[i]) * gp[i];
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if (*f <= finit + param->ftol * dgtest) {
 | 
					 | 
				
			||||||
            /* The sufficient decrease condition. */
 | 
					 | 
				
			||||||
            return count;
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if (*stp < param->min_step) {
 | 
					 | 
				
			||||||
            /* The step is the minimum value. */
 | 
					 | 
				
			||||||
            return LBFGSERR_MINIMUMSTEP;
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        if (*stp > param->max_step) {
 | 
					 | 
				
			||||||
            /* The step is the maximum value. */
 | 
					 | 
				
			||||||
            return LBFGSERR_MAXIMUMSTEP;
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        if (param->max_linesearch <= count) {
 | 
					 | 
				
			||||||
            /* Maximum number of iteration. */
 | 
					 | 
				
			||||||
            return LBFGSERR_MAXIMUMLINESEARCH;
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        (*stp) *= width;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
static int line_search_backtracking(
 | 
					static int line_search_backtracking(
 | 
				
			||||||
    int n,
 | 
					    int n,
 | 
				
			||||||
    lbfgsfloatval_t *x,
 | 
					    lbfgsfloatval_t *x,
 | 
				
			||||||
@@ -778,6 +723,80 @@ static int line_search_backtracking(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					static int line_search_backtracking_owlqn(
 | 
				
			||||||
 | 
					    int n,
 | 
				
			||||||
 | 
					    lbfgsfloatval_t *x,
 | 
				
			||||||
 | 
					    lbfgsfloatval_t *f,
 | 
				
			||||||
 | 
					    lbfgsfloatval_t *g,
 | 
				
			||||||
 | 
					    lbfgsfloatval_t *s,
 | 
				
			||||||
 | 
					    lbfgsfloatval_t *stp,
 | 
				
			||||||
 | 
					    const lbfgsfloatval_t* xp,
 | 
				
			||||||
 | 
					    const lbfgsfloatval_t* gp,
 | 
				
			||||||
 | 
					    lbfgsfloatval_t *wp,
 | 
				
			||||||
 | 
					    callback_data_t *cd,
 | 
				
			||||||
 | 
					    const lbfgs_parameter_t *param
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					    int i, ret = 0, count = 0;
 | 
				
			||||||
 | 
					    lbfgsfloatval_t width = 0.5, norm = 0.;
 | 
				
			||||||
 | 
					    lbfgsfloatval_t finit = *f, dgtest;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /* Check the input parameters for errors. */
 | 
				
			||||||
 | 
					    if (*stp <= 0.) {
 | 
				
			||||||
 | 
					        return LBFGSERR_INVALIDPARAMETERS;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /* Choose the orthant for the new point. */
 | 
				
			||||||
 | 
					    for (i = 0;i < n;++i) {
 | 
				
			||||||
 | 
					        wp[i] = (xp[i] == 0.) ? -gp[i] : xp[i];
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for (;;) {
 | 
				
			||||||
 | 
					        /* Update the current point. */
 | 
				
			||||||
 | 
					        veccpy(x, xp, n);
 | 
				
			||||||
 | 
					        vecadd(x, s, *stp, n);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        /* The current point is projected onto the orthant. */
 | 
				
			||||||
 | 
					        owlqn_project(x, wp, param->orthantwise_start, param->orthantwise_end);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        /* Evaluate the function and gradient values. */
 | 
				
			||||||
 | 
					        *f = cd->proc_evaluate(cd->instance, x, g, cd->n, *stp);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        /* Compute the L1 norm of the variables and add it to the object value. */
 | 
				
			||||||
 | 
					        norm = owlqn_x1norm(x, param->orthantwise_start, param->orthantwise_end);
 | 
				
			||||||
 | 
					        *f += norm * param->orthantwise_c;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        ++count;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        dgtest = 0.;
 | 
				
			||||||
 | 
					        for (i = 0;i < n;++i) {
 | 
				
			||||||
 | 
					            dgtest += (x[i] - xp[i]) * gp[i];
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if (*f <= finit + param->ftol * dgtest) {
 | 
				
			||||||
 | 
					            /* The sufficient decrease condition. */
 | 
				
			||||||
 | 
					            return count;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if (*stp < param->min_step) {
 | 
				
			||||||
 | 
					            /* The step is the minimum value. */
 | 
				
			||||||
 | 
					            return LBFGSERR_MINIMUMSTEP;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        if (*stp > param->max_step) {
 | 
				
			||||||
 | 
					            /* The step is the maximum value. */
 | 
				
			||||||
 | 
					            return LBFGSERR_MAXIMUMSTEP;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        if (param->max_linesearch <= count) {
 | 
				
			||||||
 | 
					            /* Maximum number of iteration. */
 | 
				
			||||||
 | 
					            return LBFGSERR_MAXIMUMLINESEARCH;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        (*stp) *= width;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
static int line_search_morethuente(
 | 
					static int line_search_morethuente(
 | 
				
			||||||
    int n,
 | 
					    int n,
 | 
				
			||||||
    lbfgsfloatval_t *x,
 | 
					    lbfgsfloatval_t *x,
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user