From f27bc179484129d1a368c179fc8067a3d150bedc Mon Sep 17 00:00:00 2001 From: zhangyi Date: Sun, 20 Mar 2022 10:44:52 +0800 Subject: [PATCH] update --- src/lib/lbfgs.c | 104 +++++++++++++++++++++++++++++++++++++++++ src/lib/lbfgs.h | 8 ++++ src/sample/sample2.cpp | 12 +++-- src/sample/sample4.cpp | 2 + 4 files changed, 123 insertions(+), 3 deletions(-) diff --git a/src/lib/lbfgs.c b/src/lib/lbfgs.c index 12973c2..0db9023 100644 --- a/src/lib/lbfgs.c +++ b/src/lib/lbfgs.c @@ -148,6 +148,20 @@ static int line_search_backtracking( const lbfgs_parameter_t *param ); +static int line_search_backtracking_quad( + int n, + lbfgsfloatval_t *x, + lbfgsfloatval_t *f, + lbfgsfloatval_t *g, + lbfgsfloatval_t *s, + lbfgsfloatval_t *stp, + const lbfgsfloatval_t* xp, + const lbfgsfloatval_t* gp, + lbfgsfloatval_t *wa, + callback_data_t *cd, + const lbfgs_parameter_t *param + ); + static int line_search_backtracking_owlqn( int n, lbfgsfloatval_t *x, @@ -370,6 +384,11 @@ int lbfgs( case LBFGS_LINESEARCH_BACKTRACKING_STRONG_WOLFE: linesearch = line_search_backtracking; break; + case LBFGS_LINESEARCH_BACKTRACKING_ARMIJO_QUAD: + //case LBFGS_LINESEARCH_BACKTRACKING_WOLFE_QUAD: + //case LBFGS_LINESEARCH_BACKTRACKING_STRONG_WOLFE_QUAD: + linesearch = line_search_backtracking_quad; + break; default: return LBFGSERR_INVALID_LINESEARCH; } @@ -945,6 +964,91 @@ static int line_search_backtracking( } } +static int line_search_backtracking_quad( + int n, + lbfgsfloatval_t *x, + lbfgsfloatval_t *f, + lbfgsfloatval_t *g, + lbfgsfloatval_t *s, + lbfgsfloatval_t *stp, + const lbfgsfloatval_t* xp, + const lbfgsfloatval_t* gp, + lbfgsfloatval_t *wp, + callback_data_t *cd, + const lbfgs_parameter_t *param + ) +{ + int count = 0; + lbfgsfloatval_t dg, stp2; + lbfgsfloatval_t finit, dginit = 0., dgtest; + const lbfgsfloatval_t dec = 0.5; + + /* Check the input parameters for errors. */ + if (*stp <= 0.) { + return LBFGSERR_INVALIDPARAMETERS; + } + + /* Compute the initial gradient in the search direction. */ + vecdot(&dginit, g, s, n); //计算点积 g为梯度方向 s为下降方向 + + /* Make sure that s points to a descent direction. */ + if (0 < dginit) { + return LBFGSERR_INCREASEGRADIENT; + } + + /* The initial value of the objective function. */ + finit = *f; + dgtest = param->ftol * dginit; // ftol 大概为 function tolerance + + for (;;) { + veccpy(x, xp, n); + vecadd(x, s, *stp, n); // vecadd x += (*stp)*s + + /* Evaluate the function and gradient values. */ + // 这里我们发现的cd的用法,即传递函数指针 + *f = cd->proc_evaluate(cd->instance, x, g, cd->n, *stp); + + ++count; + + // 充分下降条件 + if (*f > finit + *stp * dgtest) { + stp2 = 0.5*dginit*(*stp)*(*stp)/(finit - (*f) + dginit*(*stp)); + if (stp2 < 0) { + (*stp) *= dec; + } + else { + (*stp) = stp2; + } + + } else { + // 充分下降条件满足并搜索方法为backtracking,搜索条件为Armijo,则可以退出了。否则更新步长,继续搜索。 + /* The sufficient decrease condition (Armijo condition). */ + if (param->linesearch == LBFGS_LINESEARCH_BACKTRACKING_ARMIJO_QUAD) { + /* Exit with the Armijo condition. */ + return count; + } + + } + + // 以下情况返回的步长不能保证满足搜索条件 + if (*stp < param->min_step) { + /* The step is the minimum value. */ + // 退出 此时步长小于最小步长 + return LBFGSERR_MINIMUMSTEP; + } + if (*stp > param->max_step) { + /* The step is the maximum value. */ + // 退出 此时步长大于最大步长 + return LBFGSERR_MAXIMUMSTEP; + } + if (param->max_linesearch <= count) { + /* Maximum number of iteration. */ + // 退出 线性搜索次数超过了最大限制 + return LBFGSERR_MAXIMUMLINESEARCH; + } + } +} + // 还是反向搜索 只是添加了L1模方向 static int line_search_backtracking_owlqn( int n, diff --git a/src/lib/lbfgs.h b/src/lib/lbfgs.h index 4df1ee0..74e474a 100644 --- a/src/lib/lbfgs.h +++ b/src/lib/lbfgs.h @@ -197,6 +197,14 @@ enum { * a is the step length. */ LBFGS_LINESEARCH_BACKTRACKING_STRONG_WOLFE = 3, + + LBFGS_LINESEARCH_BACKTRACKING_ARMIJO_QUAD = 4, + + //LBFGS_LINESEARCH_BACKTRACKING_QUAD = 5, + + //LBFGS_LINESEARCH_BACKTRACKING_WOLFE_QUAD = 5, + + //LBFGS_LINESEARCH_BACKTRACKING_STRONG_WOLFE_QUAD = 6, }; // L-BFGS参数类型。参数很多,简要说明如下: diff --git a/src/sample/sample2.cpp b/src/sample/sample2.cpp index e0ccc6b..e9facd1 100644 --- a/src/sample/sample2.cpp +++ b/src/sample/sample2.cpp @@ -39,10 +39,16 @@ public: Start the L-BFGS optimization; this will invoke the callback functions evaluate() and progress() when necessary. */ - int ret = lbfgs(N, m_x, &fx, _evaluate, _progress, this, NULL, NULL); + lbfgs_parameter_t self_para; + lbfgs_parameter_init(&self_para); + self_para.epsilon = 1e-7; + //self_para.linesearch = LBFGS_LINESEARCH_BACKTRACKING_ARMIJO_QUAD; + //self_para.ftol = 1e-3; + + int ret = lbfgs(N, m_x, &fx, _evaluate, _progress, this, &self_para, NULL); /* Report the result. */ - printf("L-BFGS optimization terminated with status code = %d\n", ret); + printf("L-BFGS optimization terminated with status: %s\n", lbfgs_strerror(ret)); printf(" fx = %f, x[0] = %f, x[1] = %f\n", fx, m_x[0], m_x[1]); return ret; @@ -111,7 +117,7 @@ protected: { printf("Iteration %d:\n", k); printf(" fx = %f, x[0] = %f, x[1] = %f\n", fx, x[0], x[1]); - printf(" xnorm = %f, gnorm = %f, step = %f\n", xnorm, gnorm, step); + printf(" xnorm = %f, gnorm = %f, step = %f, convergence = %e\n", xnorm, gnorm, step, gnorm/xnorm); printf("\n"); return 0; } diff --git a/src/sample/sample4.cpp b/src/sample/sample4.cpp index 9af17ee..6347ea1 100644 --- a/src/sample/sample4.cpp +++ b/src/sample/sample4.cpp @@ -186,6 +186,8 @@ int TEST_FUNC::Routine() lbfgs_parameter_t self_para; lbfgs_parameter_init(&self_para); self_para.epsilon = 1e-7; + self_para.linesearch = LBFGS_LINESEARCH_BACKTRACKING_ARMIJO_QUAD; + self_para.ftol = 1e-3; int ret = lbfgs(N, x, &fx, _Func, _Progress, this, &self_para, _Precondition);