This commit is contained in:
张壹 2025-03-03 19:50:40 +08:00
parent b17e7b529e
commit 4d29b9c8fa
6 changed files with 111 additions and 13 deletions

View File

@ -20,4 +20,4 @@ add_example(ex7 OFF)
add_example(ex8 OFF)
add_example(ex9 OFF)
add_example(ex10 OFF)
add_example(cfg_ex ON)
add_example(cfg_ex OFF)

View File

@ -26,6 +26,7 @@
******************************************************/
#include "../lib/optimization/lcg.h"
#include "gctl/graphic/gnuplot.h"
#define M 1000
#define N 800
@ -126,6 +127,13 @@ int main(int argc, char const *argv[])
ofile << "maximal difference: " << max_diff(fm, m) << std::endl;
test.save_convergence("convergence");
gctl::gnuplot gt;
gt.to_buffer();
gt.send("set terminal png size 800,600");
gt.send("set output \"convergence.png\"");
gt.send("plot \"convergence.txt\" using 1:2 with lines");
gt.send("set output");
gt.send_buffer();
m.assign(0.0);

View File

@ -64,33 +64,43 @@ void gctl::common_gradient::set_weights(const _1d_array &w)
return;
}
void gctl::common_gradient::set_exp_weight(double T)
{
T_ = T;
return;
}
void gctl::common_gradient::init(size_t Ln, size_t Mn)
{
Ln_ = Ln;
Mn_ = Mn;
T_ = 1.0;
g_.resize(Mn_);
B_.resize(Mn_);
G_.resize(Ln_, Mn_);
t_.resize(Ln_);
gm_.resize(Ln_);
lx_.resize(Ln_);
lt_.resize(Ln_, 1.0);
w_.resize(Ln_, 1.0);
x_.resize(Ln_, 1.0);
filled_.resize(Ln_, false);
return;
}
void gctl::common_gradient::fill_model_gradient(size_t id, const _1d_array &g)
void gctl::common_gradient::fill_model_gradient(size_t id, double fx, const _1d_array &g)
{
if (id >= Ln_) throw std::runtime_error("[gctl::common_gradient] Invalid index.");
if (g.size() != Mn_) throw std::runtime_error("[gctl::common_gradient] Invalid array size.");
G_.fill_row(id, g);
filled_[id] = true;
gm_[id] = g.module();
lx_[id] = fx;
filled_[id] = true;
return;
}
const gctl::_1d_array &gctl::common_gradient::get_common_gradient(bool normalized)
const gctl::_1d_array &gctl::common_gradient::get_common_gradient(bool normalized, bool fixed_w)
{
for (size_t i = 0; i < Ln_; i++)
{
@ -98,6 +108,20 @@ const gctl::_1d_array &gctl::common_gradient::get_common_gradient(bool normalize
}
filled_.assign(false);
if (!fixed_w)
{
// 计算权重
double a;
for (size_t i = 0; i < Ln_; i++)
{
a = abs(lx_[i] - lt_[i]);
w_[i] = 1.0/pow(1.0/(1.0 + exp(-0.05*a) + 0.5), 6);
}
}
rcd_wgts_.push_back(w_);
rcd_fxs_.push_back(lx_);
G_.normalize(RowMajor);
matvec(B_, G_, x_, Trans);
@ -120,4 +144,39 @@ const gctl::_1d_array &gctl::common_gradient::get_common_gradient(bool normalize
g_.scale(gmod);
}
return g_;
}
void gctl::common_gradient::save_records(std::string file)
{
std::ofstream fout;
open_outfile(fout, file, ".csv");
fout << "Num";
for (size_t j = 0; j < Ln_; j++)
{
fout << ",l" << j;
}
for (size_t j = 0; j < Ln_; j++)
{
fout << ",w" << j;
}
fout << std::endl;
for (size_t i = 0; i < rcd_wgts_.size(); i++)
{
fout << i;
for (size_t j = 0; j < Ln_; j++)
{
fout << "," << rcd_fxs_[i][j];
}
for (size_t j = 0; j < Ln_; j++)
{
fout << "," << rcd_wgts_[i][j];
}
fout << std::endl;
}
fout.close();
return;
}

View File

@ -61,6 +61,11 @@ namespace gctl
*/
void set_weights(const _1d_array &w);
/**
* @brief Set the weight for the exponential weighting function
*/
void set_exp_weight(double T);
/**
* @brief Initialize the common_gradient object
*
@ -73,24 +78,37 @@ namespace gctl
* @brief Fill the model gradient
*
* @param id Loss function index
* @param fx Objective value
* @param g Model gradient
*/
void fill_model_gradient(size_t id, const _1d_array &g);
void fill_model_gradient(size_t id, double fx, const _1d_array &g);
/**
* @brief Get the conflict free gradient
*
* @param normalized Normalize the output gradient
* @param fixed_w Fixed weights
* @return Calculated model gradient
*/
const _1d_array &get_common_gradient(bool normalized = true);
const _1d_array &get_common_gradient(bool normalized = true, bool fixed_w = true);
/**
* @brief Save the recorded weights.
*
* @param file Output file name
*/
void save_records(std::string file);
private:
size_t Ln_, Mn_; // Ln_: loss_func numberMn_: model number
double T_;
size_t Ln_, Mn_; // lc: loss count, Ln_: loss_func numberMn_: model number
_2d_matrix G_;
_1d_array B_, g_, t_, x_;
_1d_array gm_, w_;
_1d_array lx_, lt_;
array<bool> filled_;
std::vector<array<double> > rcd_wgts_;
std::vector<array<double> > rcd_fxs_;
};
};

View File

@ -115,15 +115,27 @@ void gctl::dwa::set_normal_sum(double k)
return;
}
void gctl::dwa::get_records(array<double> &logs)
void gctl::dwa::save_records(std::string file)
{
logs.resize(fx_n_*rcd_wgts_.size());
std::ofstream fout;
open_outfile(fout, file, ".csv");
fout << "Num";
for (size_t j = 0; j < fx_n_; j++)
{
fout << ",l" << j;
}
fout << std::endl;
for (size_t i = 0; i < rcd_wgts_.size(); i++)
{
fout << i;
for (size_t j = 0; j < fx_n_; j++)
{
logs[i*fx_n_ + j] = rcd_wgts_[i][j];
fout << "," << rcd_wgts_[i][j];
}
fout << std::endl;
}
fout.close();
return;
}

View File

@ -29,6 +29,7 @@
#define _GCTL_DWA_H
#include "gctl/core.h"
#include "gctl/utility.h"
namespace gctl
{
@ -103,11 +104,11 @@ namespace gctl
void set_normal_sum(double k);
/**
* @brief Get the recorded weights. Size of the log equals the function size times iteration times.
* @brief Save the recorded weights.
*
* @param logs Output log
* @param file Output file name
*/
void get_records(array<double> &logs);
void save_records(std::string file);
};
};