- Added some note on SSE/SEE2 optimization.
- Supported SSE/SSE2 optimization with GCC. git-svn-id: file:///home/svnrepos/software/liblbfgs/trunk@3 ecf4c44f-38d1-4fa4-9757-a0b4dd0349fc
This commit is contained in:
85
lib/lbfgs.c
85
lib/lbfgs.c
@@ -76,14 +76,14 @@ licence.
|
||||
typedef unsigned int uint32_t;
|
||||
#endif/*_MSC_VER*/
|
||||
|
||||
#if defined(USE_SSE) && defined(__SSE__) && LBFGS_FLOAT == 32
|
||||
/* Use SSE optimization for 32bit float precision. */
|
||||
#include "arithmetic_sse_float.h"
|
||||
|
||||
#elif defined(USE_SSE) && defined(__SSE__) && LBFGS_FLOAT == 64
|
||||
#if defined(USE_SSE) && defined(__SSE2__) && LBFGS_FLOAT == 64
|
||||
/* Use SSE2 optimization for 64bit double precision. */
|
||||
#include "arithmetic_sse_double.h"
|
||||
|
||||
#elif defined(USE_SSE) && defined(__SSE__) && LBFGS_FLOAT == 32
|
||||
/* Use SSE optimization for 32bit float precision. */
|
||||
#include "arithmetic_sse_float.h"
|
||||
|
||||
#else
|
||||
/* No CPU specific optimization. */
|
||||
#include "arithmetic_ansi.h"
|
||||
@@ -94,6 +94,14 @@ typedef unsigned int uint32_t;
|
||||
#define max2(a, b) ((a) >= (b) ? (a) : (b))
|
||||
#define max3(a, b, c) max2(max2((a), (b)), (c));
|
||||
|
||||
struct tag_callback_data {
|
||||
int n;
|
||||
void *instance;
|
||||
lbfgs_evaluate_t proc_evaluate;
|
||||
lbfgs_progress_t proc_progress;
|
||||
};
|
||||
typedef struct tag_callback_data callback_data_t;
|
||||
|
||||
struct tag_iteration_data {
|
||||
lbfgsfloatval_t alpha;
|
||||
lbfgsfloatval_t *s; /* [n] */
|
||||
@@ -118,8 +126,7 @@ typedef int (*line_search_proc)(
|
||||
lbfgsfloatval_t *s,
|
||||
lbfgsfloatval_t *stp,
|
||||
lbfgsfloatval_t *wa,
|
||||
lbfgs_evaluate_t proc_evaluate,
|
||||
void *instance,
|
||||
callback_data_t *cd,
|
||||
const lbfgs_parameter_t *param
|
||||
);
|
||||
|
||||
@@ -131,8 +138,7 @@ static int line_search_backtracking(
|
||||
lbfgsfloatval_t *s,
|
||||
lbfgsfloatval_t *stp,
|
||||
lbfgsfloatval_t *wa,
|
||||
lbfgs_evaluate_t proc_evaluate,
|
||||
void *instance,
|
||||
callback_data_t *cd,
|
||||
const lbfgs_parameter_t *param
|
||||
);
|
||||
|
||||
@@ -144,8 +150,7 @@ static int line_search_morethuente(
|
||||
lbfgsfloatval_t *s,
|
||||
lbfgsfloatval_t *stp,
|
||||
lbfgsfloatval_t *wa,
|
||||
lbfgs_evaluate_t proc_evaluate,
|
||||
void *instance,
|
||||
callback_data_t *cd,
|
||||
const lbfgs_parameter_t *param
|
||||
);
|
||||
|
||||
@@ -164,8 +169,26 @@ static int update_trial_interval(
|
||||
int *brackt
|
||||
);
|
||||
|
||||
static int round_out_variables(int n)
|
||||
{
|
||||
n += 7;
|
||||
n /= 8;
|
||||
n *= 8;
|
||||
return n;
|
||||
}
|
||||
|
||||
lbfgsfloatval_t* lbfgs_malloc(int n)
|
||||
{
|
||||
#if defined(USE_SSE) && (defined(__SSE__) || defined(__SSE2__))
|
||||
n = round_out_variables(n);
|
||||
#endif/*defined(USE_SSE)*/
|
||||
return vecalloc(sizeof(lbfgsfloatval_t) * n);
|
||||
}
|
||||
|
||||
void lbfgs_free(lbfgsfloatval_t *x)
|
||||
{
|
||||
vecfree(x);
|
||||
}
|
||||
|
||||
void lbfgs_parameter_init(lbfgs_parameter_t *param)
|
||||
{
|
||||
@@ -173,7 +196,7 @@ void lbfgs_parameter_init(lbfgs_parameter_t *param)
|
||||
}
|
||||
|
||||
int lbfgs(
|
||||
const int n,
|
||||
int n,
|
||||
lbfgsfloatval_t *x,
|
||||
lbfgsfloatval_t *ptr_fx,
|
||||
lbfgs_evaluate_t proc_evaluate,
|
||||
@@ -197,15 +220,30 @@ int lbfgs(
|
||||
lbfgsfloatval_t fx = 0.;
|
||||
line_search_proc linesearch = line_search_morethuente;
|
||||
|
||||
/* Construct a callback data. */
|
||||
callback_data_t cd;
|
||||
cd.n = n;
|
||||
cd.instance = instance;
|
||||
cd.proc_evaluate = proc_evaluate;
|
||||
cd.proc_progress = proc_progress;
|
||||
|
||||
#if defined(USE_SSE) && (defined(__SSE__) || defined(__SSE2__))
|
||||
/* Round out the number of variables. */
|
||||
n = round_out_variables(n);
|
||||
#endif/*defined(USE_SSE)*/
|
||||
|
||||
/* Check the input parameters for errors. */
|
||||
if (n <= 0) {
|
||||
return LBFGSERR_INVALID_N;
|
||||
}
|
||||
#if defined(USE_SSE) && defined(__SSE__)
|
||||
#if defined(USE_SSE) && (defined(__SSE__) || defined(__SSE2__))
|
||||
if (n % 8 != 0) {
|
||||
return LBFGSERR_INVALID_N_SSE;
|
||||
}
|
||||
#endif/*defined(__SSE__)*/
|
||||
if (((unsigned short)x & 0x000F) != 0) {
|
||||
return LBFGSERR_INVALID_X_SSE;
|
||||
}
|
||||
#endif/*defined(USE_SSE)*/
|
||||
if (param->min_step < 0.) {
|
||||
return LBFGSERR_INVALID_MINSTEP;
|
||||
}
|
||||
@@ -270,7 +308,7 @@ int lbfgs(
|
||||
}
|
||||
|
||||
/* Evaluate the function value and its gradient. */
|
||||
fx = proc_evaluate(instance, x, g, n, 0);
|
||||
fx = cd.proc_evaluate(cd.instance, x, g, cd.n, 0);
|
||||
if (0. < param->orthantwise_c) {
|
||||
/* Compute L1-regularization factor and add it to the object value. */
|
||||
norm = 0.;
|
||||
@@ -319,8 +357,7 @@ int lbfgs(
|
||||
veccpy(gp, g, n);
|
||||
|
||||
/* Search for an optimal step. */
|
||||
ls = linesearch(
|
||||
n, x, &fx, g, d, &step, w, proc_evaluate, instance, param);
|
||||
ls = linesearch(n, x, &fx, g, d, &step, w, &cd, param);
|
||||
if (ls < 0) {
|
||||
ret = ls;
|
||||
goto lbfgs_exit;
|
||||
@@ -331,8 +368,8 @@ int lbfgs(
|
||||
vecnorm(&xnorm, x, n);
|
||||
|
||||
/* Report the progress. */
|
||||
if (proc_progress) {
|
||||
if (ret = proc_progress(instance, x, g, fx, xnorm, gnorm, step, n, k, ls)) {
|
||||
if (cd.proc_progress) {
|
||||
if (ret = cd.proc_progress(cd.instance, x, g, fx, xnorm, gnorm, step, cd.n, k, ls)) {
|
||||
goto lbfgs_exit;
|
||||
}
|
||||
}
|
||||
@@ -487,8 +524,7 @@ static int line_search_backtracking(
|
||||
lbfgsfloatval_t *s,
|
||||
lbfgsfloatval_t *stp,
|
||||
lbfgsfloatval_t *xp,
|
||||
lbfgs_evaluate_t proc_evaluate,
|
||||
void *instance,
|
||||
callback_data_t *cd,
|
||||
const lbfgs_parameter_t *param
|
||||
)
|
||||
{
|
||||
@@ -556,7 +592,7 @@ static int line_search_backtracking(
|
||||
}
|
||||
|
||||
/* Evaluate the function and gradient values. */
|
||||
*f = proc_evaluate(instance, x, g, n, *stp);
|
||||
*f = cd->proc_evaluate(cd->instance, x, g, cd->n, *stp);
|
||||
if (0. < param->orthantwise_c) {
|
||||
/* Compute L1-regularization factor and add it to the object value. */
|
||||
norm = 0.;
|
||||
@@ -601,8 +637,7 @@ static int line_search_morethuente(
|
||||
lbfgsfloatval_t *s,
|
||||
lbfgsfloatval_t *stp,
|
||||
lbfgsfloatval_t *wa,
|
||||
lbfgs_evaluate_t proc_evaluate,
|
||||
void *instance,
|
||||
callback_data_t *cd,
|
||||
const lbfgs_parameter_t *param
|
||||
)
|
||||
{
|
||||
@@ -722,7 +757,7 @@ static int line_search_morethuente(
|
||||
}
|
||||
|
||||
/* Evaluate the function and gradient values. */
|
||||
*f = proc_evaluate(instance, x, g, n, *stp);
|
||||
*f = cd->proc_evaluate(cd->instance, x, g, cd->n, *stp);
|
||||
if (0. < param->orthantwise_c) {
|
||||
/* Compute L1-regularization factor and add it to the object value. */
|
||||
norm = 0.;
|
||||
|
||||
Reference in New Issue
Block a user