/********************************************************
* ██████╗ ██████╗████████╗██╗
* ██╔════╝ ██╔════╝╚══██╔══╝██║
* ██║ ███╗██║ ██║ ██║
* ██║ ██║██║ ██║ ██║
* ╚██████╔╝╚██████╗ ██║ ███████╗
* ╚═════╝ ╚═════╝ ╚═╝ ╚══════╝
* Geophysical Computational Tools & Library (GCTL)
*
* Copyright (c) 2022 Yi Zhang (yizhang-geo@zju.edu.cn)
*
* GCTL is distributed under a dual licensing scheme. You can redistribute
* it and/or modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation, either version 2
* of the License, or (at your option) any later version. You should have
* received a copy of the GNU Lesser General Public License along with this
* program. If not, see .
*
* If the terms and conditions of the LGPL v.2. would prevent you from using
* the GCTL, please consider the option to obtain a commercial license for a
* fee. These licenses are offered by the GCTL's original author. As a rule,
* licenses are provided "as-is", unlimited in time for a one time fee. Please
* send corresponding requests to: yizhang-geo@zju.edu.cn. Please do not forget
* to include some description of your company and the realm of its activities.
* Also add information on how to contact you by electronic and paper mail.
******************************************************/
#include "dwa.h"
gctl::dwa::dwa()
{
fx_c_ = 0;
l_ready_ = false;
}
gctl::dwa::~dwa(){}
void gctl::dwa::InitDWA(size_t num, size_t grad_num)
{
fx_n_ = num;
K_ = 1.0*num;
T_ = 1.0;
wgts_.resize(num, 1.0);
L_p1_.resize(num, 1.0);
L_p2_.resize(num, 1.0);
grad_.resize(grad_num, 0.0);
rcd_wgts_.push_back(wgts_);
return;
}
void gctl::dwa::AddSingleLoss(double fx, const array &g)
{
multi_fx_ += wgts_[fx_c_]*fx;
L_p2_[fx_c_] = L_p1_[fx_c_];
L_p1_[fx_c_] = fx;
for (size_t i = 0; i < g.size(); i++)
{
grad_[i] += wgts_[fx_c_]*g[i];
}
fx_c_++;
return;
}
void gctl::dwa::UpdateWeights()
{
double sum = 0.0;
for (size_t i = 0; i < fx_n_; i++)
{
if (l_ready_) wgts_[i] = exp(L_p1_[i]/(L_p2_[i]*T_));
else wgts_[i] = 1.0;
sum += wgts_[i];
}
for (size_t i = 0; i < fx_n_; i++)
{
wgts_[i] *= K_/sum;
}
l_ready_ = true;
rcd_wgts_.push_back(wgts_);
return;
}
double gctl::dwa::DWALoss(array &g)
{
if (fx_c_ != fx_n_)
{
throw std::runtime_error("Not enough loss functions evaluated. From gctl::dwa::UpdateWeights()");
}
double fx = multi_fx_;
g = grad_;
fx_c_ = 0;
multi_fx_ = 0.0;
grad_.assign_all(0.0);
return fx;
}
void gctl::dwa::set_control_temperature(double t)
{
T_ = t;
return;
}
void gctl::dwa::set_normal_sum(double k)
{
K_ = k;
return;
}
void gctl::dwa::get_records(array &logs)
{
logs.resize(fx_n_*rcd_wgts_.size());
for (size_t i = 0; i < rcd_wgts_.size(); i++)
{
for (size_t j = 0; j < fx_n_; j++)
{
logs[i*fx_n_ + j] = rcd_wgts_[i][j];
}
}
return;
}