129 lines
3.4 KiB
C++
129 lines
3.4 KiB
C++
|
/********************************************************
|
||
|
* ██████╗ ██████╗████████╗██╗
|
||
|
* ██╔════╝ ██╔════╝╚══██╔══╝██║
|
||
|
* ██║ ███╗██║ ██║ ██║
|
||
|
* ██║ ██║██║ ██║ ██║
|
||
|
* ╚██████╔╝╚██████╗ ██║ ███████╗
|
||
|
* ╚═════╝ ╚═════╝ ╚═╝ ╚══════╝
|
||
|
* Geophysical Computational Tools & Library (GCTL)
|
||
|
*
|
||
|
* Copyright (c) 2022 Yi Zhang (yizhang-geo@zju.edu.cn)
|
||
|
*
|
||
|
* GCTL is distributed under a dual licensing scheme. You can redistribute
|
||
|
* it and/or modify it under the terms of the GNU Lesser General Public
|
||
|
* License as published by the Free Software Foundation, either version 2
|
||
|
* of the License, or (at your option) any later version. You should have
|
||
|
* received a copy of the GNU Lesser General Public License along with this
|
||
|
* program. If not, see <http://www.gnu.org/licenses/>.
|
||
|
*
|
||
|
* If the terms and conditions of the LGPL v.2. would prevent you from using
|
||
|
* the GCTL, please consider the option to obtain a commercial license for a
|
||
|
* fee. These licenses are offered by the GCTL's original author. As a rule,
|
||
|
* licenses are provided "as-is", unlimited in time for a one time fee. Please
|
||
|
* send corresponding requests to: yizhang-geo@zju.edu.cn. Please do not forget
|
||
|
* to include some description of your company and the realm of its activities.
|
||
|
* Also add information on how to contact you by electronic and paper mail.
|
||
|
******************************************************/
|
||
|
|
||
|
#include "dwa.h"
|
||
|
|
||
|
gctl::dwa::dwa()
|
||
|
{
|
||
|
fx_c_ = 0;
|
||
|
l_ready_ = false;
|
||
|
}
|
||
|
|
||
|
gctl::dwa::~dwa(){}
|
||
|
|
||
|
void gctl::dwa::InitDWA(size_t num, size_t grad_num)
|
||
|
{
|
||
|
fx_n_ = num;
|
||
|
K_ = 1.0*num;
|
||
|
T_ = 1.0;
|
||
|
|
||
|
wgts_.resize(num, 1.0);
|
||
|
L_p1_.resize(num, 1.0);
|
||
|
L_p2_.resize(num, 1.0);
|
||
|
grad_.resize(grad_num, 0.0);
|
||
|
|
||
|
rcd_wgts_.push_back(wgts_);
|
||
|
return;
|
||
|
}
|
||
|
|
||
|
void gctl::dwa::AddSingleLoss(double fx, const array<double> &g)
|
||
|
{
|
||
|
multi_fx_ += wgts_[fx_c_]*fx;
|
||
|
|
||
|
L_p2_[fx_c_] = L_p1_[fx_c_];
|
||
|
L_p1_[fx_c_] = fx;
|
||
|
|
||
|
for (size_t i = 0; i < g.size(); i++)
|
||
|
{
|
||
|
grad_[i] += wgts_[fx_c_]*g[i];
|
||
|
}
|
||
|
|
||
|
fx_c_++;
|
||
|
return;
|
||
|
}
|
||
|
|
||
|
void gctl::dwa::UpdateWeights()
|
||
|
{
|
||
|
double sum = 0.0;
|
||
|
for (size_t i = 0; i < fx_n_; i++)
|
||
|
{
|
||
|
if (l_ready_) wgts_[i] = exp(L_p1_[i]/(L_p2_[i]*T_));
|
||
|
else wgts_[i] = 1.0;
|
||
|
|
||
|
sum += wgts_[i];
|
||
|
}
|
||
|
|
||
|
for (size_t i = 0; i < fx_n_; i++)
|
||
|
{
|
||
|
wgts_[i] *= K_/sum;
|
||
|
}
|
||
|
|
||
|
l_ready_ = true;
|
||
|
rcd_wgts_.push_back(wgts_);
|
||
|
return;
|
||
|
}
|
||
|
|
||
|
double gctl::dwa::DWALoss(array<double> &g)
|
||
|
{
|
||
|
if (fx_c_ != fx_n_)
|
||
|
{
|
||
|
throw std::runtime_error("Not enough loss functions evaluated. From gctl::dwa::UpdateWeights()");
|
||
|
}
|
||
|
|
||
|
double fx = multi_fx_;
|
||
|
g = grad_;
|
||
|
|
||
|
fx_c_ = 0;
|
||
|
multi_fx_ = 0.0;
|
||
|
grad_.assign_all(0.0);
|
||
|
return fx;
|
||
|
}
|
||
|
|
||
|
void gctl::dwa::set_control_temperature(double t)
|
||
|
{
|
||
|
T_ = t;
|
||
|
return;
|
||
|
}
|
||
|
|
||
|
void gctl::dwa::set_normal_sum(double k)
|
||
|
{
|
||
|
K_ = k;
|
||
|
return;
|
||
|
}
|
||
|
|
||
|
void gctl::dwa::get_records(array<double> &logs)
|
||
|
{
|
||
|
logs.resize(fx_n_*rcd_wgts_.size());
|
||
|
for (size_t i = 0; i < rcd_wgts_.size(); i++)
|
||
|
{
|
||
|
for (size_t j = 0; j < fx_n_; j++)
|
||
|
{
|
||
|
logs[i*fx_n_ + j] = rcd_wgts_[i][j];
|
||
|
}
|
||
|
}
|
||
|
return;
|
||
|
}
|