diff --git a/lib/lbfgs.c b/lib/lbfgs.c index e9e587b..ee638be 100644 --- a/lib/lbfgs.c +++ b/lib/lbfgs.c @@ -169,6 +169,15 @@ static int update_trial_interval( int *brackt ); +static lbfgsfloatval_t orthantwise_gnorm( + const lbfgsfloatval_t* x, + const lbfgsfloatval_t* g, + const lbfgsfloatval_t c, + const int start, + const int n + ); + + #if defined(USE_SSE) && (defined(__SSE__) || defined(__SSE2__)) static int round_out_variables(int n) { @@ -357,8 +366,12 @@ int lbfgs( /* Make sure that the initial variables are not a minimizer. */ - vecnorm(&gnorm, g, n); vecnorm(&xnorm, x, n); + if (param->orthantwise_c != 0.) { + gnorm = orthantwise_gnorm(x, g, param->orthantwise_c, param->orthantwise_start, n); + } else { + vecnorm(&gnorm, g, n); + } if (xnorm < 1.0) xnorm = 1.0; if (gnorm / xnorm <= param->epsilon) { ret = LBFGS_ALREADY_MINIMIZED; @@ -385,8 +398,12 @@ int lbfgs( } /* Compute x and g norms. */ - vecnorm(&gnorm, g, n); vecnorm(&xnorm, x, n); + if (param->orthantwise_c != 0.) { + gnorm = orthantwise_gnorm(x, g, param->orthantwise_c, param->orthantwise_start, n); + } else { + vecnorm(&gnorm, g, n); + } /* Report the progress. */ if (cd.proc_progress) { @@ -1217,3 +1234,32 @@ static int update_trial_interval( *t = newt; return 0; } + +static lbfgsfloatval_t orthantwise_gnorm( + const lbfgsfloatval_t* x, + const lbfgsfloatval_t* g, + const lbfgsfloatval_t c, + const int start, + const int n + ) +{ + int i; + lbfgsfloatval_t d = 0.; + lbfgsfloatval_t norm = 0.; + + for (i = 0;i < start;++i) { + norm += g[i] * g[i]; + } + + for (i = start;i < n;++i) { + d = g[i]; + if (x[i] < 0.) { + d -= c; + } else if (0. < x[i]) { + d += c; + } + norm += d * d; + } + + return sqrt(norm); +}