This commit is contained in:
2022-03-20 10:44:52 +08:00
parent 24235f44ca
commit f27bc17948
4 changed files with 123 additions and 3 deletions

View File

@@ -148,6 +148,20 @@ static int line_search_backtracking(
const lbfgs_parameter_t *param
);
static int line_search_backtracking_quad(
int n,
lbfgsfloatval_t *x,
lbfgsfloatval_t *f,
lbfgsfloatval_t *g,
lbfgsfloatval_t *s,
lbfgsfloatval_t *stp,
const lbfgsfloatval_t* xp,
const lbfgsfloatval_t* gp,
lbfgsfloatval_t *wa,
callback_data_t *cd,
const lbfgs_parameter_t *param
);
static int line_search_backtracking_owlqn(
int n,
lbfgsfloatval_t *x,
@@ -370,6 +384,11 @@ int lbfgs(
case LBFGS_LINESEARCH_BACKTRACKING_STRONG_WOLFE:
linesearch = line_search_backtracking;
break;
case LBFGS_LINESEARCH_BACKTRACKING_ARMIJO_QUAD:
//case LBFGS_LINESEARCH_BACKTRACKING_WOLFE_QUAD:
//case LBFGS_LINESEARCH_BACKTRACKING_STRONG_WOLFE_QUAD:
linesearch = line_search_backtracking_quad;
break;
default:
return LBFGSERR_INVALID_LINESEARCH;
}
@@ -945,6 +964,91 @@ static int line_search_backtracking(
}
}
static int line_search_backtracking_quad(
int n,
lbfgsfloatval_t *x,
lbfgsfloatval_t *f,
lbfgsfloatval_t *g,
lbfgsfloatval_t *s,
lbfgsfloatval_t *stp,
const lbfgsfloatval_t* xp,
const lbfgsfloatval_t* gp,
lbfgsfloatval_t *wp,
callback_data_t *cd,
const lbfgs_parameter_t *param
)
{
int count = 0;
lbfgsfloatval_t dg, stp2;
lbfgsfloatval_t finit, dginit = 0., dgtest;
const lbfgsfloatval_t dec = 0.5;
/* Check the input parameters for errors. */
if (*stp <= 0.) {
return LBFGSERR_INVALIDPARAMETERS;
}
/* Compute the initial gradient in the search direction. */
vecdot(&dginit, g, s, n); //计算点积 g为梯度方向 s为下降方向
/* Make sure that s points to a descent direction. */
if (0 < dginit) {
return LBFGSERR_INCREASEGRADIENT;
}
/* The initial value of the objective function. */
finit = *f;
dgtest = param->ftol * dginit; // ftol 大概为 function tolerance
for (;;) {
veccpy(x, xp, n);
vecadd(x, s, *stp, n); // vecadd x += (*stp)*s
/* Evaluate the function and gradient values. */
// 这里我们发现的cd的用法即传递函数指针
*f = cd->proc_evaluate(cd->instance, x, g, cd->n, *stp);
++count;
// 充分下降条件
if (*f > finit + *stp * dgtest) {
stp2 = 0.5*dginit*(*stp)*(*stp)/(finit - (*f) + dginit*(*stp));
if (stp2 < 0) {
(*stp) *= dec;
}
else {
(*stp) = stp2;
}
} else {
// 充分下降条件满足并搜索方法为backtracking搜索条件为Armijo则可以退出了。否则更新步长继续搜索。
/* The sufficient decrease condition (Armijo condition). */
if (param->linesearch == LBFGS_LINESEARCH_BACKTRACKING_ARMIJO_QUAD) {
/* Exit with the Armijo condition. */
return count;
}
}
// 以下情况返回的步长不能保证满足搜索条件
if (*stp < param->min_step) {
/* The step is the minimum value. */
// 退出 此时步长小于最小步长
return LBFGSERR_MINIMUMSTEP;
}
if (*stp > param->max_step) {
/* The step is the maximum value. */
// 退出 此时步长大于最大步长
return LBFGSERR_MAXIMUMSTEP;
}
if (param->max_linesearch <= count) {
/* Maximum number of iteration. */
// 退出 线性搜索次数超过了最大限制
return LBFGSERR_MAXIMUMLINESEARCH;
}
}
}
// 还是反向搜索 只是添加了L1模方向
static int line_search_backtracking_owlqn(
int n,