gctl_optimization/lib/optimization/loss_func.h
2024-10-25 10:27:08 +08:00

125 lines
4.6 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/********************************************************
* ██████╗ ██████╗████████╗██╗
* ██╔════╝ ██╔════╝╚══██╔══╝██║
* ██║ ███╗██║ ██║ ██║
* ██║ ██║██║ ██║ ██║
* ╚██████╔╝╚██████╗ ██║ ███████╗
* ╚═════╝ ╚═════╝ ╚═╝ ╚══════╝
* Geophysical Computational Tools & Library (GCTL)
*
* Copyright (c) 2023 Yi Zhang (yizhang-geo@zju.edu.cn)
*
* GCTL is distributed under a dual licensing scheme. You can redistribute
* it and/or modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation, either version 2
* of the License, or (at your option) any later version. You should have
* received a copy of the GNU Lesser General Public License along with this
* program. If not, see <http://www.gnu.org/licenses/>.
*
* If the terms and conditions of the LGPL v.2. would prevent you from using
* the GCTL, please consider the option to obtain a commercial license for a
* fee. These licenses are offered by the GCTL's original author. As a rule,
* licenses are provided "as-is", unlimited in time for a one time fee. Please
* send corresponding requests to: yizhang-geo@zju.edu.cn. Please do not forget
* to include some description of your company and the realm of its activities.
* Also add information on how to contact you by electronic and paper mail.
******************************************************/
#ifndef _GCTL_LOSS_FUNC_H
#define _GCTL_LOSS_FUNC_H
// library's head files
#include "gctl/core.h"
namespace gctl
{
/**
* @brief 损失函数对象可计算L1范数, L2范数平方Lp范数定义的数据拟合差及相应的模型偏导数按数据个数归一化
* 损失函数的定义为Phi = Lp(d - d^tar)^2/num(d)
*/
class loss_func
{
public:
loss_func(); ///< 构造函数
/**
* @brief 构造函数
*
* @param tar 数据拟合差目标
* @param n_type 拟合差函数范数类型
* @param p Lp范数的阶次
* @param eps Lp范数分母内的小值防止分母变为奇异值
*/
loss_func(const array<double> &tar, norm_type_e n_type, double p = 2.0, double eps = 1e-16);
virtual ~loss_func(); ///< 析构函数
/**
* @brief 初始化函数
*
* @param tar 数据拟合差目标
* @param n_type 拟合差函数范数类型
* @param p Lp范数的阶次
* @param eps Lp范数分母内的小值防止分母变为奇异值
*/
void init(const array<double> &tar, norm_type_e n_type, double p = 2.0, double eps = 1e-16);
/**
* @brief 设置目标数据的不确定度
*
* @param uncer 不确定度
*/
void set_uncertainty(double uncer);
/**
* @brief 设置目标数据的不确定度
*
* @param uncer 不确定度数组,长度与目标数据一致
*/
void set_uncertainty(const array<double> &uncer);
/**
* @brief 计算单个输入模型数据的拟合差,同时将计算值累计至内部变量
*
* @param inp 输入数据值
* @param id 输入数据的索引
* @return 单个数据拟合差值
*/
double evaluate(double inp, int id);
/**
* @brief 计算输入模型的数据拟合差与模型梯度
*
* @param x 输入模型,长度与目标数据相等
* @param g 数据拟合差相对于模型的梯度
* @return 数据拟合差值
*/
double evaluate(const array<double> &x, array<double> &g);
/**
* @brief 返回内置的数据拟合差函数值然后将值重设为0
*
* @return 累计的数据拟合差
*/
double get_loss();
/**
* @brief 计算数据拟合差函数相对于单个输入模型数据的梯度
*
* @param inp 输入数据值
* @param id 输入数据的索引
* @return 单个数据拟合差函数的梯度
*/
double gradient(double inp, int id);
private:
bool init_;
double loss_, eps_, p_;
unsigned int tnum_;
norm_type_e ntype_;
array<double> tars_, diff_;
array<double> us_;
};
}
#endif // _GCTL_LOSS_FUNC_H