gctl_optimization/lib/optimization/cmn_grad.h
2025-02-22 18:10:15 +08:00

92 lines
3.3 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/********************************************************
* ██████╗ ██████╗████████╗██╗
* ██╔════╝ ██╔════╝╚══██╔══╝██║
* ██║ ███╗██║ ██║ ██║
* ██║ ██║██║ ██║ ██║
* ╚██████╔╝╚██████╗ ██║ ███████╗
* ╚═════╝ ╚═════╝ ╚═╝ ╚══════╝
* Geophysical Computational Tools & Library (GCTL)
*
* Copyright (c) 2023 Yi Zhang (yizhang-geo@zju.edu.cn)
*
* GCTL is distributed under a dual licensing scheme. You can redistribute
* it and/or modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation, either version 2
* of the License, or (at your option) any later version. You should have
* received a copy of the GNU Lesser General Public License along with this
* program. If not, see <http://www.gnu.org/licenses/>.
*
* If the terms and conditions of the LGPL v.2. would prevent you from using
* the GCTL, please consider the option to obtain a commercial license for a
* fee. These licenses are offered by the GCTL's original author. As a rule,
* licenses are provided "as-is", unlimited in time for a one time fee. Please
* send corresponding requests to: yizhang-geo@zju.edu.cn. Please do not forget
* to include some description of your company and the realm of its activities.
* Also add information on how to contact you by electronic and paper mail.
******************************************************/
#ifndef _GCTL_COMMON_GRADIENT_H
#define _GCTL_COMMON_GRADIENT_H
#include "lcg.h"
namespace gctl
{
class common_gradient : public lcg_solver
{
public:
common_gradient(); ///< 构造函数
/**
* @brief Construct a new common_gradient object
*
* @param Ln Number of loss functions
* @param Mn Number of model parameters
*/
common_gradient(size_t Ln, size_t Mn);
virtual ~common_gradient(); ///< 析构函数
virtual void LCG_Ax(const array<double> &x, array<double> &ax); ///< 计算Ax
/**
* @brief Configure the solver's setups
*
* @param para LCG solver parameters
*/
void set_solver(const lcg_para &para);
/**
* @brief Initialize the common_gradient object
*
* @param Ln Number of loss functions
* @param Mn Number of model parameters
*/
void init(size_t Ln, size_t Mn);
/**
* @brief Fill the model gradient
*
* @param id Loss function index
* @param g Model gradient
*/
void fill_model_gradient(size_t id, const _1d_array &g);
/**
* @brief Get the conflict free gradient
*
* @param normalized Normalize the output gradient
* @return Calculated model gradient
*/
const _1d_array &get_common_gradient(bool normalized = true);
private:
size_t Ln_, Mn_; // Ln_: loss_func numberMn_: model number
_2d_matrix G_;
_1d_array B_, g_, t_, x_;
_1d_array gm_;
array<bool> filled_;
};
};
#endif // _GCTL_COMMON_GRADIENT_H