This commit is contained in:
张壹 2025-03-03 19:50:40 +08:00
parent b17e7b529e
commit 4d29b9c8fa
6 changed files with 111 additions and 13 deletions

View File

@ -20,4 +20,4 @@ add_example(ex7 OFF)
add_example(ex8 OFF) add_example(ex8 OFF)
add_example(ex9 OFF) add_example(ex9 OFF)
add_example(ex10 OFF) add_example(ex10 OFF)
add_example(cfg_ex ON) add_example(cfg_ex OFF)

View File

@ -26,6 +26,7 @@
******************************************************/ ******************************************************/
#include "../lib/optimization/lcg.h" #include "../lib/optimization/lcg.h"
#include "gctl/graphic/gnuplot.h"
#define M 1000 #define M 1000
#define N 800 #define N 800
@ -126,6 +127,13 @@ int main(int argc, char const *argv[])
ofile << "maximal difference: " << max_diff(fm, m) << std::endl; ofile << "maximal difference: " << max_diff(fm, m) << std::endl;
test.save_convergence("convergence"); test.save_convergence("convergence");
gctl::gnuplot gt;
gt.to_buffer();
gt.send("set terminal png size 800,600");
gt.send("set output \"convergence.png\"");
gt.send("plot \"convergence.txt\" using 1:2 with lines");
gt.send("set output");
gt.send_buffer();
m.assign(0.0); m.assign(0.0);

View File

@ -64,33 +64,43 @@ void gctl::common_gradient::set_weights(const _1d_array &w)
return; return;
} }
void gctl::common_gradient::set_exp_weight(double T)
{
T_ = T;
return;
}
void gctl::common_gradient::init(size_t Ln, size_t Mn) void gctl::common_gradient::init(size_t Ln, size_t Mn)
{ {
Ln_ = Ln; Ln_ = Ln;
Mn_ = Mn; Mn_ = Mn;
T_ = 1.0;
g_.resize(Mn_); g_.resize(Mn_);
B_.resize(Mn_); B_.resize(Mn_);
G_.resize(Ln_, Mn_); G_.resize(Ln_, Mn_);
t_.resize(Ln_); t_.resize(Ln_);
gm_.resize(Ln_); gm_.resize(Ln_);
lx_.resize(Ln_);
lt_.resize(Ln_, 1.0);
w_.resize(Ln_, 1.0); w_.resize(Ln_, 1.0);
x_.resize(Ln_, 1.0); x_.resize(Ln_, 1.0);
filled_.resize(Ln_, false); filled_.resize(Ln_, false);
return; return;
} }
void gctl::common_gradient::fill_model_gradient(size_t id, const _1d_array &g) void gctl::common_gradient::fill_model_gradient(size_t id, double fx, const _1d_array &g)
{ {
if (id >= Ln_) throw std::runtime_error("[gctl::common_gradient] Invalid index."); if (id >= Ln_) throw std::runtime_error("[gctl::common_gradient] Invalid index.");
if (g.size() != Mn_) throw std::runtime_error("[gctl::common_gradient] Invalid array size."); if (g.size() != Mn_) throw std::runtime_error("[gctl::common_gradient] Invalid array size.");
G_.fill_row(id, g); G_.fill_row(id, g);
filled_[id] = true;
gm_[id] = g.module(); gm_[id] = g.module();
lx_[id] = fx;
filled_[id] = true;
return; return;
} }
const gctl::_1d_array &gctl::common_gradient::get_common_gradient(bool normalized) const gctl::_1d_array &gctl::common_gradient::get_common_gradient(bool normalized, bool fixed_w)
{ {
for (size_t i = 0; i < Ln_; i++) for (size_t i = 0; i < Ln_; i++)
{ {
@ -98,6 +108,20 @@ const gctl::_1d_array &gctl::common_gradient::get_common_gradient(bool normalize
} }
filled_.assign(false); filled_.assign(false);
if (!fixed_w)
{
// 计算权重
double a;
for (size_t i = 0; i < Ln_; i++)
{
a = abs(lx_[i] - lt_[i]);
w_[i] = 1.0/pow(1.0/(1.0 + exp(-0.05*a) + 0.5), 6);
}
}
rcd_wgts_.push_back(w_);
rcd_fxs_.push_back(lx_);
G_.normalize(RowMajor); G_.normalize(RowMajor);
matvec(B_, G_, x_, Trans); matvec(B_, G_, x_, Trans);
@ -121,3 +145,38 @@ const gctl::_1d_array &gctl::common_gradient::get_common_gradient(bool normalize
} }
return g_; return g_;
} }
void gctl::common_gradient::save_records(std::string file)
{
std::ofstream fout;
open_outfile(fout, file, ".csv");
fout << "Num";
for (size_t j = 0; j < Ln_; j++)
{
fout << ",l" << j;
}
for (size_t j = 0; j < Ln_; j++)
{
fout << ",w" << j;
}
fout << std::endl;
for (size_t i = 0; i < rcd_wgts_.size(); i++)
{
fout << i;
for (size_t j = 0; j < Ln_; j++)
{
fout << "," << rcd_fxs_[i][j];
}
for (size_t j = 0; j < Ln_; j++)
{
fout << "," << rcd_wgts_[i][j];
}
fout << std::endl;
}
fout.close();
return;
}

View File

@ -61,6 +61,11 @@ namespace gctl
*/ */
void set_weights(const _1d_array &w); void set_weights(const _1d_array &w);
/**
* @brief Set the weight for the exponential weighting function
*/
void set_exp_weight(double T);
/** /**
* @brief Initialize the common_gradient object * @brief Initialize the common_gradient object
* *
@ -73,24 +78,37 @@ namespace gctl
* @brief Fill the model gradient * @brief Fill the model gradient
* *
* @param id Loss function index * @param id Loss function index
* @param fx Objective value
* @param g Model gradient * @param g Model gradient
*/ */
void fill_model_gradient(size_t id, const _1d_array &g); void fill_model_gradient(size_t id, double fx, const _1d_array &g);
/** /**
* @brief Get the conflict free gradient * @brief Get the conflict free gradient
* *
* @param normalized Normalize the output gradient * @param normalized Normalize the output gradient
* @param fixed_w Fixed weights
* @return Calculated model gradient * @return Calculated model gradient
*/ */
const _1d_array &get_common_gradient(bool normalized = true); const _1d_array &get_common_gradient(bool normalized = true, bool fixed_w = true);
/**
* @brief Save the recorded weights.
*
* @param file Output file name
*/
void save_records(std::string file);
private: private:
size_t Ln_, Mn_; // Ln_: loss_func numberMn_: model number double T_;
size_t Ln_, Mn_; // lc: loss count, Ln_: loss_func numberMn_: model number
_2d_matrix G_; _2d_matrix G_;
_1d_array B_, g_, t_, x_; _1d_array B_, g_, t_, x_;
_1d_array gm_, w_; _1d_array gm_, w_;
_1d_array lx_, lt_;
array<bool> filled_; array<bool> filled_;
std::vector<array<double> > rcd_wgts_;
std::vector<array<double> > rcd_fxs_;
}; };
}; };

View File

@ -115,15 +115,27 @@ void gctl::dwa::set_normal_sum(double k)
return; return;
} }
void gctl::dwa::get_records(array<double> &logs) void gctl::dwa::save_records(std::string file)
{ {
logs.resize(fx_n_*rcd_wgts_.size()); std::ofstream fout;
open_outfile(fout, file, ".csv");
fout << "Num";
for (size_t j = 0; j < fx_n_; j++)
{
fout << ",l" << j;
}
fout << std::endl;
for (size_t i = 0; i < rcd_wgts_.size(); i++) for (size_t i = 0; i < rcd_wgts_.size(); i++)
{ {
fout << i;
for (size_t j = 0; j < fx_n_; j++) for (size_t j = 0; j < fx_n_; j++)
{ {
logs[i*fx_n_ + j] = rcd_wgts_[i][j]; fout << "," << rcd_wgts_[i][j];
} }
fout << std::endl;
} }
fout.close();
return; return;
} }

View File

@ -29,6 +29,7 @@
#define _GCTL_DWA_H #define _GCTL_DWA_H
#include "gctl/core.h" #include "gctl/core.h"
#include "gctl/utility.h"
namespace gctl namespace gctl
{ {
@ -103,11 +104,11 @@ namespace gctl
void set_normal_sum(double k); void set_normal_sum(double k);
/** /**
* @brief Get the recorded weights. Size of the log equals the function size times iteration times. * @brief Save the recorded weights.
* *
* @param logs Output log * @param file Output file name
*/ */
void get_records(array<double> &logs); void save_records(std::string file);
}; };
}; };