From 4d29b9c8fa3c2a49d54b59d91db68b03d3c53ca6 Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Mon, 3 Mar 2025 19:50:40 +0800 Subject: [PATCH] tmp --- example/CMakeLists.txt | 2 +- example/ex1.cpp | 8 +++++ lib/optimization/cmn_grad.cpp | 65 +++++++++++++++++++++++++++++++++-- lib/optimization/cmn_grad.h | 24 +++++++++++-- lib/optimization/dwa.cpp | 18 ++++++++-- lib/optimization/dwa.h | 7 ++-- 6 files changed, 111 insertions(+), 13 deletions(-) diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index fb39569..f67ce4c 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -20,4 +20,4 @@ add_example(ex7 OFF) add_example(ex8 OFF) add_example(ex9 OFF) add_example(ex10 OFF) -add_example(cfg_ex ON) \ No newline at end of file +add_example(cfg_ex OFF) \ No newline at end of file diff --git a/example/ex1.cpp b/example/ex1.cpp index e8d8eae..1a5dda3 100644 --- a/example/ex1.cpp +++ b/example/ex1.cpp @@ -26,6 +26,7 @@ ******************************************************/ #include "../lib/optimization/lcg.h" +#include "gctl/graphic/gnuplot.h" #define M 1000 #define N 800 @@ -126,6 +127,13 @@ int main(int argc, char const *argv[]) ofile << "maximal difference: " << max_diff(fm, m) << std::endl; test.save_convergence("convergence"); + gctl::gnuplot gt; + gt.to_buffer(); + gt.send("set terminal png size 800,600"); + gt.send("set output \"convergence.png\""); + gt.send("plot \"convergence.txt\" using 1:2 with lines"); + gt.send("set output"); + gt.send_buffer(); m.assign(0.0); diff --git a/lib/optimization/cmn_grad.cpp b/lib/optimization/cmn_grad.cpp index 2d54a6e..7d5798e 100644 --- a/lib/optimization/cmn_grad.cpp +++ b/lib/optimization/cmn_grad.cpp @@ -64,33 +64,43 @@ void gctl::common_gradient::set_weights(const _1d_array &w) return; } +void gctl::common_gradient::set_exp_weight(double T) +{ + T_ = T; + return; +} + void gctl::common_gradient::init(size_t Ln, size_t Mn) { Ln_ = Ln; Mn_ = Mn; + T_ = 1.0; g_.resize(Mn_); B_.resize(Mn_); G_.resize(Ln_, Mn_); t_.resize(Ln_); gm_.resize(Ln_); + lx_.resize(Ln_); + lt_.resize(Ln_, 1.0); w_.resize(Ln_, 1.0); x_.resize(Ln_, 1.0); filled_.resize(Ln_, false); return; } -void gctl::common_gradient::fill_model_gradient(size_t id, const _1d_array &g) +void gctl::common_gradient::fill_model_gradient(size_t id, double fx, const _1d_array &g) { if (id >= Ln_) throw std::runtime_error("[gctl::common_gradient] Invalid index."); if (g.size() != Mn_) throw std::runtime_error("[gctl::common_gradient] Invalid array size."); G_.fill_row(id, g); - filled_[id] = true; gm_[id] = g.module(); + lx_[id] = fx; + filled_[id] = true; return; } -const gctl::_1d_array &gctl::common_gradient::get_common_gradient(bool normalized) +const gctl::_1d_array &gctl::common_gradient::get_common_gradient(bool normalized, bool fixed_w) { for (size_t i = 0; i < Ln_; i++) { @@ -98,6 +108,20 @@ const gctl::_1d_array &gctl::common_gradient::get_common_gradient(bool normalize } filled_.assign(false); + if (!fixed_w) + { + // 计算权重 + double a; + for (size_t i = 0; i < Ln_; i++) + { + a = abs(lx_[i] - lt_[i]); + w_[i] = 1.0/pow(1.0/(1.0 + exp(-0.05*a) + 0.5), 6); + } + } + + rcd_wgts_.push_back(w_); + rcd_fxs_.push_back(lx_); + G_.normalize(RowMajor); matvec(B_, G_, x_, Trans); @@ -120,4 +144,39 @@ const gctl::_1d_array &gctl::common_gradient::get_common_gradient(bool normalize g_.scale(gmod); } return g_; +} + +void gctl::common_gradient::save_records(std::string file) +{ + std::ofstream fout; + open_outfile(fout, file, ".csv"); + + fout << "Num"; + for (size_t j = 0; j < Ln_; j++) + { + fout << ",l" << j; + } + + for (size_t j = 0; j < Ln_; j++) + { + fout << ",w" << j; + } + fout << std::endl; + + for (size_t i = 0; i < rcd_wgts_.size(); i++) + { + fout << i; + for (size_t j = 0; j < Ln_; j++) + { + fout << "," << rcd_fxs_[i][j]; + } + + for (size_t j = 0; j < Ln_; j++) + { + fout << "," << rcd_wgts_[i][j]; + } + fout << std::endl; + } + fout.close(); + return; } \ No newline at end of file diff --git a/lib/optimization/cmn_grad.h b/lib/optimization/cmn_grad.h index 568b6b6..0395afe 100644 --- a/lib/optimization/cmn_grad.h +++ b/lib/optimization/cmn_grad.h @@ -61,6 +61,11 @@ namespace gctl */ void set_weights(const _1d_array &w); + /** + * @brief Set the weight for the exponential weighting function + */ + void set_exp_weight(double T); + /** * @brief Initialize the common_gradient object * @@ -73,24 +78,37 @@ namespace gctl * @brief Fill the model gradient * * @param id Loss function index + * @param fx Objective value * @param g Model gradient */ - void fill_model_gradient(size_t id, const _1d_array &g); + void fill_model_gradient(size_t id, double fx, const _1d_array &g); /** * @brief Get the conflict free gradient * * @param normalized Normalize the output gradient + * @param fixed_w Fixed weights * @return Calculated model gradient */ - const _1d_array &get_common_gradient(bool normalized = true); + const _1d_array &get_common_gradient(bool normalized = true, bool fixed_w = true); + + /** + * @brief Save the recorded weights. + * + * @param file Output file name + */ + void save_records(std::string file); private: - size_t Ln_, Mn_; // Ln_: loss_func number,Mn_: model number + double T_; + size_t Ln_, Mn_; // lc: loss count, Ln_: loss_func number,Mn_: model number _2d_matrix G_; _1d_array B_, g_, t_, x_; _1d_array gm_, w_; + _1d_array lx_, lt_; array filled_; + std::vector > rcd_wgts_; + std::vector > rcd_fxs_; }; }; diff --git a/lib/optimization/dwa.cpp b/lib/optimization/dwa.cpp index 8a723c1..0cd50b0 100644 --- a/lib/optimization/dwa.cpp +++ b/lib/optimization/dwa.cpp @@ -115,15 +115,27 @@ void gctl::dwa::set_normal_sum(double k) return; } -void gctl::dwa::get_records(array &logs) +void gctl::dwa::save_records(std::string file) { - logs.resize(fx_n_*rcd_wgts_.size()); + std::ofstream fout; + open_outfile(fout, file, ".csv"); + + fout << "Num"; + for (size_t j = 0; j < fx_n_; j++) + { + fout << ",l" << j; + } + fout << std::endl; + for (size_t i = 0; i < rcd_wgts_.size(); i++) { + fout << i; for (size_t j = 0; j < fx_n_; j++) { - logs[i*fx_n_ + j] = rcd_wgts_[i][j]; + fout << "," << rcd_wgts_[i][j]; } + fout << std::endl; } + fout.close(); return; } \ No newline at end of file diff --git a/lib/optimization/dwa.h b/lib/optimization/dwa.h index db73693..b3adb36 100644 --- a/lib/optimization/dwa.h +++ b/lib/optimization/dwa.h @@ -29,6 +29,7 @@ #define _GCTL_DWA_H #include "gctl/core.h" +#include "gctl/utility.h" namespace gctl { @@ -103,11 +104,11 @@ namespace gctl void set_normal_sum(double k); /** - * @brief Get the recorded weights. Size of the log equals the function size times iteration times. + * @brief Save the recorded weights. * - * @param logs Output log + * @param file Output file name */ - void get_records(array &logs); + void save_records(std::string file); }; };