/******************************************************** * ██████╗ ██████╗████████╗██╗ * ██╔════╝ ██╔════╝╚══██╔══╝██║ * ██║ ███╗██║ ██║ ██║ * ██║ ██║██║ ██║ ██║ * ╚██████╔╝╚██████╗ ██║ ███████╗ * ╚═════╝ ╚═════╝ ╚═╝ ╚══════╝ * Geophysical Computational Tools & Library (GCTL) * * Copyright (c) 2022 Yi Zhang (yizhang-geo@zju.edu.cn) * * GCTL is distributed under a dual licensing scheme. You can redistribute * it and/or modify it under the terms of the GNU Lesser General Public * License as published by the Free Software Foundation, either version 2 * of the License, or (at your option) any later version. You should have * received a copy of the GNU Lesser General Public License along with this * program. If not, see . * * If the terms and conditions of the LGPL v.2. would prevent you from using * the GCTL, please consider the option to obtain a commercial license for a * fee. These licenses are offered by the GCTL's original author. As a rule, * licenses are provided "as-is", unlimited in time for a one time fee. Please * send corresponding requests to: yizhang-geo@zju.edu.cn. Please do not forget * to include some description of your company and the realm of its activities. * Also add information on how to contact you by electronic and paper mail. ******************************************************/ #include "dwa.h" gctl::dwa::dwa() { fx_c_ = 0; l_ready_ = false; } gctl::dwa::~dwa(){} void gctl::dwa::InitDWA(size_t num, size_t grad_num) { fx_n_ = num; K_ = 1.0*num; T_ = 1.0; wgts_.resize(num, 1.0); L_p1_.resize(num, 1.0); L_p2_.resize(num, 1.0); grad_.resize(grad_num, 0.0); rcd_wgts_.push_back(wgts_); return; } void gctl::dwa::AddSingleLoss(double fx, const array &g) { multi_fx_ += wgts_[fx_c_]*fx; L_p2_[fx_c_] = L_p1_[fx_c_]; L_p1_[fx_c_] = fx; for (size_t i = 0; i < g.size(); i++) { grad_[i] += wgts_[fx_c_]*g[i]; } fx_c_++; return; } void gctl::dwa::UpdateWeights() { double sum = 0.0; for (size_t i = 0; i < fx_n_; i++) { if (l_ready_) wgts_[i] = exp(L_p1_[i]/(L_p2_[i]*T_)); else wgts_[i] = 1.0; sum += wgts_[i]; } for (size_t i = 0; i < fx_n_; i++) { wgts_[i] *= K_/sum; } l_ready_ = true; rcd_wgts_.push_back(wgts_); return; } double gctl::dwa::DWALoss(array &g) { if (fx_c_ != fx_n_) { throw std::runtime_error("Not enough loss functions evaluated. From gctl::dwa::UpdateWeights()"); } double fx = multi_fx_; g = grad_; fx_c_ = 0; multi_fx_ = 0.0; grad_.assign_all(0.0); return fx; } void gctl::dwa::set_control_temperature(double t) { T_ = t; return; } void gctl::dwa::set_normal_sum(double k) { K_ = k; return; } void gctl::dwa::get_records(array &logs) { logs.resize(fx_n_*rcd_wgts_.size()); for (size_t i = 0; i < rcd_wgts_.size(); i++) { for (size_t j = 0; j < fx_n_; j++) { logs[i*fx_n_ + j] = rcd_wgts_[i][j]; } } return; }