/******************************************************** * ██████╗ ██████╗████████╗██╗ * ██╔════╝ ██╔════╝╚══██╔══╝██║ * ██║ ███╗██║ ██║ ██║ * ██║ ██║██║ ██║ ██║ * ╚██████╔╝╚██████╗ ██║ ███████╗ * ╚═════╝ ╚═════╝ ╚═╝ ╚══════╝ * Geophysical Computational Tools & Library (GCTL) * * Copyright (c) 2022 Yi Zhang (yizhang-geo@zju.edu.cn) * * GCTL is distributed under a dual licensing scheme. You can redistribute * it and/or modify it under the terms of the GNU Lesser General Public * License as published by the Free Software Foundation, either version 2 * of the License, or (at your option) any later version. You should have * received a copy of the GNU Lesser General Public License along with this * program. If not, see . * * If the terms and conditions of the LGPL v.2. would prevent you from using * the GCTL, please consider the option to obtain a commercial license for a * fee. These licenses are offered by the GCTL's original author. As a rule, * licenses are provided "as-is", unlimited in time for a one time fee. Please * send corresponding requests to: yizhang-geo@zju.edu.cn. Please do not forget * to include some description of your company and the realm of its activities. * Also add information on how to contact you by electronic and paper mail. ******************************************************/ #ifndef _GCTL_DNN_DNN_H #define _GCTL_DNN_DNN_H #include "hlayer_fully_connected.h" #include "hlayer_maxpooling.h" #include "hlayer_avgpooling.h" #include "hlayer_convolution.h" #include "olayer_rmse.h" #include "olayer_binaryentropy.h" #include "olayer_multientropy.h" #include "gctl/optimization.h" namespace gctl { enum dnn_status_e { NoneSet, PartSet, AllSet, Initialized, Trained, }; class dnn : public sgd_solver, public lgd_solver { public: dnn(std::string info); virtual ~dnn(); /** * @brief Load DNN network from file. * * @param filename Name of the saved binary file. */ void load_network(std::string filename); /** * @brief Save DNN network to file. * * @param filename Name of the output binary file. */ void save_network(std::string filename) const; /** * @brief Save DNN's layer to a text file. * * @param l_idx Layer index. * @param filename Name of the output file. */ void save_layer2text(int l_idx, std::string filename) const; /** * @brief Display DNN's setup on screen. * */ void show_network(); /** * @brief Add a fully connected hind layer * * @param in_s Input parameter size. * @param out_s Output parameter size. * @param h_type Layer type. For now it must be FullyConnected. * @param a_type Output type. */ void add_hind_layer(int in_s, int out_s, hlayer_type_e h_type, activation_type_e a_type); void add_hind_layer(int in_rows, int in_cols, int pool_rows, int pool_cols, int stride_rows, int stride_cols, hlayer_type_e h_type, pad_type_e p_type, activation_type_e acti_type); void add_hind_layer(int channels, int in_rows, int in_cols, int filter_rows, int filter_cols, int stride_rows, int stride_cols, hlayer_type_e h_type, pad_type_e p_type, activation_type_e acti_type); void add_output_layer(olayer_type_e o_type); void add_train_set(const matrix &train_obs, const matrix &train_tar, int batch_size = 0); void init_network(double mu, double sigma, unsigned int seed = 0); void train_network(const sgd_para &pa, sgd_solver_type solver = ADAM); void train_network(const lgd_para &pa); void predict(const matrix &pre_obs, matrix &predicts); double SGD_Evaluate(const array &x, array &g); int SGD_Progress(double fx, const array &x, const sgd_para ¶m, const int k); double LGD_Evaluate(const array &x, array &g); int LGD_Progress(const int curr_t, const double curr_fx, const double mean_fx, const double best_fx, const lgd_para ¶m); private: void forward_propagation(const matrix &input); void backward_propagation(const matrix &input, const matrix &target); private: int hlayer_size_; std::vector hind_layers_; dnn_olayer *output_layer_; bool has_stats_; int wd_st_, wd_size_; array weights_, ders_; array weights_mean_, weights_std_; int batch_iter_, batch_size_, batch_num_, obs_num_; array > obs_, targets_; int h_ss, o_ss, t_ss; dnn_status_e status_; std::string info_; }; } #endif // _GCTL_DNN_DNN_H