gctl_ai/lib/dnn/dnn.h

138 lines
5.2 KiB
C
Raw Normal View History

2024-09-10 20:15:33 +08:00
/********************************************************
*
*
*
*
*
*
* Geophysical Computational Tools & Library (GCTL)
*
* Copyright (c) 2022 Yi Zhang (yizhang-geo@zju.edu.cn)
*
* GCTL is distributed under a dual licensing scheme. You can redistribute
* it and/or modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation, either version 2
* of the License, or (at your option) any later version. You should have
* received a copy of the GNU Lesser General Public License along with this
* program. If not, see <http://www.gnu.org/licenses/>.
*
* If the terms and conditions of the LGPL v.2. would prevent you from using
* the GCTL, please consider the option to obtain a commercial license for a
* fee. These licenses are offered by the GCTL's original author. As a rule,
* licenses are provided "as-is", unlimited in time for a one time fee. Please
* send corresponding requests to: yizhang-geo@zju.edu.cn. Please do not forget
* to include some description of your company and the realm of its activities.
* Also add information on how to contact you by electronic and paper mail.
******************************************************/
#ifndef _GCTL_DNN_DNN_H
#define _GCTL_DNN_DNN_H
#include "hlayer_fully_connected.h"
#include "hlayer_maxpooling.h"
#include "hlayer_avgpooling.h"
#include "hlayer_convolution.h"
#include "olayer_rmse.h"
#include "olayer_binaryentropy.h"
#include "olayer_multientropy.h"
#include "gctl/optimization.h"
namespace gctl
{
enum dnn_status_e
{
NoneSet,
PartSet,
AllSet,
Initialized,
Trained,
};
class dnn : public sgd_solver, public lgd_solver
{
public:
dnn(std::string info);
virtual ~dnn();
/**
* @brief Load DNN network from file.
*
* @param filename Name of the saved binary file.
*/
void load_network(std::string filename);
/**
* @brief Save DNN network to file.
*
* @param filename Name of the output binary file.
*/
void save_network(std::string filename) const;
/**
* @brief Save DNN's layer to a text file.
*
* @param l_idx Layer index.
* @param filename Name of the output file.
*/
void save_layer2text(int l_idx, std::string filename) const;
/**
* @brief Display DNN's setup on screen.
*
*/
void show_network();
/**
* @brief Add a fully connected hind layer
*
* @param in_s Input parameter size.
* @param out_s Output parameter size.
* @param h_type Layer type. For now it must be FullyConnected.
* @param a_type Output type.
*/
void add_hind_layer(int in_s, int out_s, hlayer_type_e h_type, activation_type_e a_type);
void add_hind_layer(int in_rows, int in_cols, int pool_rows, int pool_cols, int stride_rows,
int stride_cols, hlayer_type_e h_type, pad_type_e p_type, activation_type_e acti_type);
void add_hind_layer(int channels, int in_rows, int in_cols, int filter_rows, int filter_cols,
int stride_rows, int stride_cols, hlayer_type_e h_type, pad_type_e p_type, activation_type_e acti_type);
void add_output_layer(olayer_type_e o_type);
void add_train_set(const matrix<double> &train_obs, const matrix<double> &train_tar, int batch_size = 0);
void init_network(double mu, double sigma, unsigned int seed = 0);
void train_network(const sgd_para &pa, sgd_solver_type solver = ADAM);
void train_network(const lgd_para &pa);
void predict(const matrix<double> &pre_obs, matrix<double> &predicts);
double SGD_Evaluate(const array<double> &x, array<double> &g);
int SGD_Progress(double fx, const array<double> &x, const sgd_para &param, const int k);
double LGD_Evaluate(const array<double> &x, array<double> &g);
int LGD_Progress(const int curr_t, const double curr_fx, const double mean_fx, const double best_fx, const lgd_para &param);
private:
void forward_propagation(const matrix<double> &input);
void backward_propagation(const matrix<double> &input, const matrix<double> &target);
private:
int hlayer_size_;
std::vector<dnn_hlayer*> hind_layers_;
dnn_olayer *output_layer_;
bool has_stats_;
int wd_st_, wd_size_;
array<double> weights_, ders_;
array<double> weights_mean_, weights_std_;
int batch_iter_, batch_size_, batch_num_, obs_num_;
array<matrix<double> > obs_, targets_;
int h_ss, o_ss, t_ss;
dnn_status_e status_;
std::string info_;
};
}
#endif // _GCTL_DNN_DNN_H