/******************************************************** * ██████╗ ██████╗████████╗██╗ * ██╔════╝ ██╔════╝╚══██╔══╝██║ * ██║ ███╗██║ ██║ ██║ * ██║ ██║██║ ██║ ██║ * ╚██████╔╝╚██████╗ ██║ ███████╗ * ╚═════╝ ╚═════╝ ╚═╝ ╚══════╝ * Geophysical Computational Tools & Library (GCTL) * * Copyright (c) 2022 Yi Zhang (yizhang-geo@zju.edu.cn) * * GCTL is distributed under a dual licensing scheme. You can redistribute * it and/or modify it under the terms of the GNU Lesser General Public * License as published by the Free Software Foundation, either version 2 * of the License, or (at your option) any later version. You should have * received a copy of the GNU Lesser General Public License along with this * program. If not, see . * * If the terms and conditions of the LGPL v.2. would prevent you from using * the GCTL, please consider the option to obtain a commercial license for a * fee. These licenses are offered by the GCTL's original author. As a rule, * licenses are provided "as-is", unlimited in time for a one time fee. Please * send corresponding requests to: yizhang-geo@zju.edu.cn. Please do not forget * to include some description of your company and the realm of its activities. * Also add information on how to contact you by electronic and paper mail. ******************************************************/ #include "gradnorm.h" gctl::grad_norm::grad_norm() { fx_c_ = 0; alpha_ = 1.0; lamda_ = 0.001; initialized_ = false; } gctl::grad_norm::~grad_norm(){} void gctl::grad_norm::InitGradNorm(size_t num, size_t grad_num) { fx_n_ = num; T_ = 1.0; resi_T_ = 0.0; fst_iter_.resize(num, true); wgts_.resize(num, 1.0/num); fx0_.resize(num, 0.0); Gw_.resize(num, 0.0); Gdw_.resize(num, 0.0); Lx_.resize(num, 0.0); grad_.resize(grad_num, 0.0); rcd_fxs_.resize(num, 0.0); fixed_wgts_.resize(num, -1.0); rcd_wgts_.reserve(100000); for (size_t i = 0; i < fx_n_; i++) { rcd_wgts_.push_back(wgts_[i]); } initialized_ = true; return; } double gctl::grad_norm::AddSingleLoss(double fx, const array &g) { if (fst_iter_[fx_c_]) { fx0_[fx_c_] = fx; fst_iter_[fx_c_] = false; } Lx_[fx_c_] = fx/fx0_[fx_c_]; double curr_fx = wgts_[fx_c_]*fx; multi_fx_ += curr_fx; rcd_fxs_[fx_c_] = fx; double sum = 0.0; for (size_t i = 0; i < g.size(); i++) { sum += g[i]*g[i]; grad_[i] += wgts_[fx_c_]*g[i]; } Gw_[fx_c_] = sqrt(wgts_[fx_c_]*wgts_[fx_c_]*sum); Gdw_[fx_c_] = sqrt(sum); // wgts_[fx_c_]*sum/Gw_[fx_c_] fx_c_++; return curr_fx; } void gctl::grad_norm::UpdateWeights() { double ac = 0; double avg_Lx = 0.0, avg_Gw = 0.0; resi_T_ = T_; for (size_t i = 0; i < fx_n_; i++) { if (fixed_wgts_[i] < 0.0) { avg_Lx += Lx_[i]; avg_Gw += Gw_[i]; ac += 1.0; } else resi_T_ -= fixed_wgts_[i]; } avg_Lx /= ac; avg_Gw /= ac; double r_i, sum = 0.0; // L1 norm approach for (size_t i = 0; i < fx_n_; i++) { if (fixed_wgts_[i] < 0.0) { r_i = Lx_[i]/avg_Lx; if (Gw_[i] >= avg_Gw*pow(r_i, alpha_)) { wgts_[i] -= lamda_*Gdw_[i]; } else wgts_[i] += lamda_*Gdw_[i]; // make sure the weights are positive wgts_[i] = std::max(wgts_[i], 1e-16); sum += wgts_[i]; } } for (size_t i = 0; i < fx_n_; i++) { if (fixed_wgts_[i] < 0.0) wgts_[i] *= resi_T_/sum; rcd_wgts_.push_back(wgts_[i]); } return; } void gctl::grad_norm::ShowStatistics(std::ostream &ss, bool one_line) { double s, t = 0.0; if (one_line) { ss << "Wgts:"; for (size_t i = 0; i < fx_n_; i++) { ss << " " << wgts_[i]; } ss << ", Loss:"; for (size_t i = 0; i < fx_n_; i++) { ss << " " << rcd_fxs_[i]; } ss << ", WgtLoss:"; for (size_t i = 0; i < fx_n_; i++) { s = wgts_[i]*rcd_fxs_[i]; ss << " " << s; t += s; } ss << ", Total: " << t << "\n"; return; } ss << "----------------------------\n"; ss << "GradNorm's Progress\n"; ss << "Tasks' weight: "; for (size_t i = 0; i < fx_n_; i++) { ss << wgts_[i] << " | "; } ss << "\n"; ss << "Tasks' loss: "; for (size_t i = 0; i < fx_n_; i++) { ss << rcd_fxs_[i] << " | "; } ss << "\n"; ss << "Weighted losses: "; for (size_t i = 0; i < fx_n_; i++) { s = wgts_[i]*rcd_fxs_[i]; ss << s << " | "; t += s; } ss << t << " (total) |\n"; ss << "----------------------------\n"; return; } double gctl::grad_norm::GradNormLoss(array &g) { if (fx_c_ != fx_n_) { throw std::runtime_error("Not all loss functions evaluated. From gctl::grad_norm::GradNormLoss()"); } if (!initialized_) { throw std::runtime_error("GradNorm is not initialized. From gctl::grad_norm::GradNormLoss()"); } double fx = multi_fx_; g = grad_; fx_c_ = 0; multi_fx_ = 0.0; grad_.assign(0.0); return fx; } void gctl::grad_norm::set_control_weight(double a) { alpha_ = a; return; } void gctl::grad_norm::set_normal_sum(double t) { T_ = t; return; } void gctl::grad_norm::set_weight_step(double l) { lamda_ = l; return; } void gctl::grad_norm::set_fixed_weight(int id, double wgt) { if (id < 0 || id >= fx_n_) { throw std::runtime_error("Invalid loss function's index. From gctl::grad_norm::set_fixed_weight(...)"); } if (wgt <= 0.0 || wgt >= T_) { throw std::runtime_error("Invalid fixed weight value. From gctl::grad_norm::set_fixed_weight(...)"); } fixed_wgts_[id] = wgt; wgts_[id] = wgt; resi_T_ = T_; double ac = 0.0; for (size_t i = 0; i < fx_n_; i++) { if (fixed_wgts_[i] > 0.0) resi_T_ -= fixed_wgts_[i]; else ac += 1.0; } if (resi_T_ <= 0.0) { throw std::runtime_error("Invalid tasks' weight detected. From gctl::grad_norm::UpdateWeights()"); } for (size_t i = 0; i < fx_n_; i++) { if (fixed_wgts_[i] < 0.0) wgts_[i] = resi_T_/ac; } for (size_t i = 0; i < fx_n_; i++) { rcd_wgts_[i] = wgts_[i]; } return; } void gctl::grad_norm::set_initial_weights(const array &w) { if (w.size() != fx_n_) { throw std::runtime_error("Invalid input array size. From gctl::grad_norm::set_initial_weights(...)"); } double sum = 0.0; for (size_t i = 0; i < fx_n_; i++) { wgts_[i] = w[i]; sum += wgts_[i]; } for (size_t i = 0; i < fx_n_; i++) { wgts_[i] *= T_/sum; rcd_wgts_[i] = wgts_[i]; } return; } void gctl::grad_norm::get_records(array &logs) { logs.resize(rcd_wgts_.size()); for (size_t i = 0; i < rcd_wgts_.size(); i++) { logs[i] = rcd_wgts_[i]; } return; } void gctl::grad_norm::save_records(std::string file) { std::ofstream ofile; open_outfile(ofile, file, ".txt"); ofile << "# 'tw' for 'task weight'\n# "; for (size_t j = 0; j < fx_n_; j++) { ofile << "tw" << std::to_string(j) << " "; } ofile << "\n"; for (int i = 0; i < rcd_wgts_.size(); i++) { ofile << rcd_wgts_[i] << " "; if ((i+1)%fx_n_ == 0) ofile << "\n"; } ofile.close(); }