gctl_ai/examples/ex1.cpp

120 lines
4.1 KiB
C++
Raw Normal View History

2024-09-10 20:15:33 +08:00
/********************************************************
*
*
*
*
*
*
* Geophysical Computational Tools & Library (GCTL)
*
* Copyright (c) 2022 Yi Zhang (yizhang-geo@zju.edu.cn)
*
* GCTL is distributed under a dual licensing scheme. You can redistribute
* it and/or modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation, either version 2
* of the License, or (at your option) any later version. You should have
* received a copy of the GNU Lesser General Public License along with this
* program. If not, see <http://www.gnu.org/licenses/>.
*
* If the terms and conditions of the LGPL v.2. would prevent you from using
* the GCTL, please consider the option to obtain a commercial license for a
* fee. These licenses are offered by the GCTL's original author. As a rule,
* licenses are provided "as-is", unlimited in time for a one time fee. Please
* send corresponding requests to: yizhang-geo@zju.edu.cn. Please do not forget
* to include some description of your company and the realm of its activities.
* Also add information on how to contact you by electronic and paper mail.
******************************************************/
#include "../lib/dnn.h"
using namespace gctl;
void data_generator(const matrix<double> &train_obs, matrix<double> &train_tar)
{
for (int j = 0; j < train_obs.col_size(); j++)
{
train_tar[0][j] = sqrt(train_obs[0][j]*train_obs[0][j] + train_obs[1][j]*train_obs[1][j]);
}
return;
}
int main(int argc, char const *argv[]) try
{
// Prepare the data. In this example, we try to learn the sin() function.
matrix<double> train_obs(2, 1000), train_tar(1, 1000), pre_obs(2, 10), pre_tar(1, 10), predicts(1, 10);
unsigned int seed = 101;
srand(seed);
for (int j = 0; j < 1000; j++)
{
for (int i = 0; i < 2; i++)
{
train_obs[i][j] = random(0.0, 1.0);
}
}
for (int j = 0; j < 10; j++)
{
for (int i = 0; i < 2; i++)
{
pre_obs[i][j] = random(0.0, 1.0);
}
}
data_generator(train_obs, train_tar);
data_generator(pre_obs, pre_tar);
dnn my_nn("Ex-1");
my_nn.add_hind_layer(2, 100, FullyConnected, Identity);
my_nn.add_hind_layer(100, 100, FullyConnected, PReLU);
my_nn.add_hind_layer(100, 100, FullyConnected, PReLU);
my_nn.add_hind_layer(100, 1, FullyConnected, Identity);
my_nn.add_output_layer(RegressionMSE);
my_nn.add_train_set(train_obs, train_tar, 200);
my_nn.show_network();
my_nn.init_network(0.0, 0.1, seed);
sgd_para my_para = my_nn.default_sgd_para();
my_nn.train_network(my_para, gctl::ADAM);
//lgd_para my_para = my_nn.default_lgd_para();
//my_para.flight_times = 5000;
//my_para.lambda = 5e-5;
//my_para.epsilon = 1e-5;
//my_para.batch = 10;
//my_nn.train_network(my_para, gctl::LGD);
my_nn.predict(pre_obs, predicts);
double diff = 0;
for (int i = 0; i < 1; i++)
{
for (int j = 0; j < 10; j++)
{
diff = std::max(fabs(predicts[i][j] - pre_tar[i][j]), diff);
}
}
std::clog << "Max difference = " << diff << "\n";
/*
my_nn.save_network("ex1");
dnn file_nn("File NN");
file_nn.load_network("ex1");
file_nn.show_network();
file_nn.predict(pre_obs, predicts);
diff = 0;
for (int i = 0; i < 1; i++)
{
for (int j = 0; j < 10; j++)
{
diff = std::max(fabs(predicts[i][j] - pre_tar[i][j]), diff);
}
}
std::clog << "Max difference = " << diff << "\n";
*/
return 0;
}
catch (std::exception &e)
{
GCTL_ShowWhatError(e.what(), GCTL_ERROR_ERROR, 0, 0, 0);
}