gctl_optimization/lib/optimization/cmn_grad.h
2025-04-08 08:37:45 +08:00

114 lines
4.3 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/********************************************************
* ██████╗ ██████╗████████╗██╗
* ██╔════╝ ██╔════╝╚══██╔══╝██║
* ██║ ███╗██║ ██║ ██║
* ██║ ██║██║ ██║ ██║
* ╚██████╔╝╚██████╗ ██║ ███████╗
* ╚═════╝ ╚═════╝ ╚═╝ ╚══════╝
* Geophysical Computational Tools & Library (GCTL)
*
* Copyright (c) 2023 Yi Zhang (yizhang-geo@zju.edu.cn)
*
* GCTL is distributed under a dual licensing scheme. You can redistribute
* it and/or modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation, either version 2
* of the License, or (at your option) any later version. You should have
* received a copy of the GNU Lesser General Public License along with this
* program. If not, see <http://www.gnu.org/licenses/>.
*
* If the terms and conditions of the LGPL v.2. would prevent you from using
* the GCTL, please consider the option to obtain a commercial license for a
* fee. These licenses are offered by the GCTL's original author. As a rule,
* licenses are provided "as-is", unlimited in time for a one time fee. Please
* send corresponding requests to: yizhang-geo@zju.edu.cn. Please do not forget
* to include some description of your company and the realm of its activities.
* Also add information on how to contact you by electronic and paper mail.
******************************************************/
#ifndef _GCTL_COMMON_GRADIENT_H
#define _GCTL_COMMON_GRADIENT_H
#include "lcg.h"
namespace gctl
{
class common_gradient : public lcg_solver
{
public:
common_gradient(); ///< 构造函数
/**
* @brief Construct a new common_gradient object
*
* @param Ln Number of loss functions
* @param Mn Number of model parameters
*/
common_gradient(size_t Ln, size_t Mn);
virtual ~common_gradient(); ///< 析构函数
virtual void LCG_Ax(const array<double> &x, array<double> &ax); ///< 计算Ax
/**
* @brief Configure the solver's setups
*
* @param para LCG solver parameters
*/
void set_solver(const lcg_para &para);
/**
* @brief Set the weights for the loss functions.
*
* The number of weights equal to the number of the loss functions.
* The bigger weights is the calculated gradient is more dependent
* on the corresponding gradients.
*/
void set_weights(const _1d_array &w);
/**
* @brief Initialize the common_gradient object
*
* @param Ln Number of loss functions
* @param Mn Number of model parameters
*/
void init(size_t Ln, size_t Mn);
/**
* @brief Fill the model gradient
*
* @param id Loss function index
* @param fx Objective value
* @param g Model gradient
*/
void fill_model_gradient(size_t id, double fx, const _1d_array &g);
/**
* @brief Get the conflict free gradient
*
* @param normalized Normalize the output gradient
* @param fixed_w Fixed weights
* @return Calculated model gradient
*/
const _1d_array &get_common_gradient(bool normalized = true, bool fixed_w = true);
/**
* @brief Save the recorded weights.
*
* @param file Output file name
*/
void save_records(std::string file);
private:
bool zero_iter_;
size_t Ln_, Mn_; // Ln_: loss_func numberMn_: model number
_2d_matrix G_; // kernel martix
_1d_array B_, g_, t_, x_; // variables of the linear system
_1d_array gm_, w_; // gradient module and functions' weight
_1d_array fx_, fx0_; // functions' value, initial functions' value
array<bool> filled_; // new gradient filled for the current round of evaluation
std::vector<array<double> > rcd_wgts_; // weights records
std::vector<array<double> > rcd_fxs_; // fx records
};
};
#endif // _GCTL_COMMON_GRADIENT_H