/******************************************************** * ██████╗ ██████╗████████╗██╗ * ██╔════╝ ██╔════╝╚══██╔══╝██║ * ██║ ███╗██║ ██║ ██║ * ██║ ██║██║ ██║ ██║ * ╚██████╔╝╚██████╗ ██║ ███████╗ * ╚═════╝ ╚═════╝ ╚═╝ ╚══════╝ * Geophysical Computational Tools & Library (GCTL) * * Copyright (c) 2022 Yi Zhang (yizhang-geo@zju.edu.cn) * * GCTL is distributed under a dual licensing scheme. You can redistribute * it and/or modify it under the terms of the GNU Lesser General Public * License as published by the Free Software Foundation, either version 2 * of the License, or (at your option) any later version. You should have * received a copy of the GNU Lesser General Public License along with this * program. If not, see . * * If the terms and conditions of the LGPL v.2. would prevent you from using * the GCTL, please consider the option to obtain a commercial license for a * fee. These licenses are offered by the GCTL's original author. As a rule, * licenses are provided "as-is", unlimited in time for a one time fee. Please * send corresponding requests to: yizhang-geo@zju.edu.cn. Please do not forget * to include some description of your company and the realm of its activities. * Also add information on how to contact you by electronic and paper mail. ******************************************************/ #ifndef _GCTL_GRADNORM_H #define _GCTL_GRADNORM_H #include "gctl/core.h" #include "gctl/io.h" namespace gctl { /** * @brief Gradient normalized (balanced) multitask evaluation. * * @note Reference: Zhao Chen et. al., 2018. GradNorm: Gradient normalization * for adaptive loss balancing in deep multitask networks. * */ class grad_norm { private: bool initialized_; size_t fx_n_, fx_c_; double resi_T_, T_; double lamda_, alpha_, multi_fx_; array fst_iter_; array wgts_; array fx0_; array Gw_, Gdw_, Lx_; array grad_; array rcd_fxs_; array fixed_wgts_; std::vector rcd_wgts_; public: grad_norm(); virtual ~grad_norm(); /** * @brief Initiate the number of loss functions and size of the model gradients. * * @note This function must be called at first. * * @param num Number of the total loss functions * @param grad_num Size of the model gradients */ void InitGradNorm(size_t num, size_t grad_num); /** * @brief Add the value of a single loss function and the current model gradients. * * @param fx objective value * @param g model gradients * * @return weighted value of the current loss function */ double AddSingleLoss(double fx, const array &g); /** * @brief Get the merged objective value and the model gradients. * * @note All single loss functions must be added before calling this function. The * merged objective value and the model gradients will be reset after the calling. * * @param g model gradients * * @return objective value */ double GradNormLoss(array &g); /** * @brief Update weights for single loss functions using the GradNorm algorithm. * */ void UpdateWeights(); /** * @brief Show statistics of the tasks' weight and loss function's value. * */ void ShowStatistics(std::ostream &ss = std::clog, bool one_line = false); /** * @brief Set the control factor alpha. The default is 1.0 * * @param a Input alpha */ void set_control_weight(double a); /** * @brief Set the normal sum of the weights. Ths default equals to function size. * * @param t Input sum */ void set_normal_sum(double t); /** * @brief Set a learning rate of the weights. The default is 0.001 * * @param l Input learning rate */ void set_weight_step(double l); /** * @brief Set the fixed weight. * * @param id Index of the loss function * @param wgt weight of the loss function */ void set_fixed_weight(int id, double wgt); /** * @brief Set the initial weights * * @param w Input weights */ void set_initial_weights(const array &w); /** * @brief Get the recorded weights. Size of the log equals the function size times iteration times. * * @param logs Output log */ void get_records(array &logs); /** * @brief Save recored weights to file. * * @param file File name */ void save_records(std::string file); }; } #endif // _GCTL_GRADNORM_H