add data folder
BIN
data/gauss_dist.nc
Normal file
4092
data/trace.txt
Normal file
1267
data/trace_adabelief.eps
Normal file
BIN
data/trace_adabelief.png
Normal file
After Width: | Height: | Size: 45 KiB |
110
data/trace_adabelief.txt
Normal file
@ -0,0 +1,110 @@
|
||||
25 20
|
||||
80.5556 75.5556
|
||||
51.6773 105.89
|
||||
45.4879 100.534
|
||||
55.9072 83.2383
|
||||
78.3211 52.1536
|
||||
73.1432 13.608
|
||||
66.97 -5.04074
|
||||
57.998 4.01014
|
||||
69.4042 20.8659
|
||||
87.2395 31.5036
|
||||
84.4461 32.9223
|
||||
73.2929 22.1821
|
||||
63.9285 11.7412
|
||||
68.635 6.36007
|
||||
73.9149 17.7964
|
||||
78.6996 28.3878
|
||||
82.167 33.0463
|
||||
80.9749 26.8905
|
||||
76.2556 18.829
|
||||
71.5822 11.8813
|
||||
67.2713 11.7077
|
||||
69.386 16.2861
|
||||
74.492 21.0133
|
||||
79.1576 25.3007
|
||||
81.8828 27.9281
|
||||
79.1782 26.7454
|
||||
75.1944 23.1341
|
||||
71.7028 19.4667
|
||||
69.5162 15.86
|
||||
70.6333 13.5124
|
||||
73.1503 14.7693
|
||||
75.153 18.339
|
||||
76.8358 21.8347
|
||||
78.1961 24.9817
|
||||
78.6096 26.6716
|
||||
77.8659 25.466
|
||||
76.5785 22.784
|
||||
75.2886 20.1543
|
||||
74.1119 17.7624
|
||||
73.0065 15.8739
|
||||
72.0711 15.4734
|
||||
71.777 16.6778
|
||||
72.3059 18.4678
|
||||
73.2825 20.1744
|
||||
74.36 21.653
|
||||
75.4398 22.8751
|
||||
76.4937 23.6887
|
||||
77.3604 23.8793
|
||||
77.733 23.4227
|
||||
77.4493 22.5751
|
||||
76.7263 21.6403
|
||||
75.8919 20.7665
|
||||
75.1045 19.9774
|
||||
74.388 19.2622
|
||||
73.7405 18.6319
|
||||
73.2026 18.1416
|
||||
72.8708 17.8755
|
||||
72.8359 17.8887
|
||||
73.0863 18.154
|
||||
73.508 18.5766
|
||||
73.98 19.0557
|
||||
74.4339 19.5261
|
||||
74.8485 19.9604
|
||||
75.2248 20.354
|
||||
75.5666 20.7108
|
||||
75.8698 21.0348
|
||||
76.1188 21.3249
|
||||
76.2893 21.5738
|
||||
76.36 21.7684
|
||||
76.3253 21.8946
|
||||
76.2012 21.9437
|
||||
76.0174 21.9171
|
||||
75.8056 21.8264
|
||||
75.5902 21.6891
|
||||
75.3863 21.5236
|
||||
75.2007 21.3457
|
||||
75.0351 21.167
|
||||
74.8885 20.9946
|
||||
74.7589 20.8324
|
||||
74.6443 20.6817
|
||||
74.5433 20.5425
|
||||
74.4551 20.4141
|
||||
74.3796 20.2955
|
||||
74.317 20.1858
|
||||
74.2677 20.0846
|
||||
74.2321 19.9915
|
||||
74.2101 19.9067
|
||||
74.2014 19.8307
|
||||
74.2049 19.7638
|
||||
74.2192 19.7065
|
||||
74.2426 19.6593
|
||||
74.2733 19.6222
|
||||
74.3093 19.5953
|
||||
74.3489 19.5781
|
||||
74.3905 19.5703
|
||||
74.4326 19.5709
|
||||
74.4742 19.5791
|
||||
74.5144 19.5938
|
||||
74.5524 19.6139
|
||||
74.5881 19.6381
|
||||
74.621 19.6654
|
||||
74.6511 19.6946
|
||||
74.6785 19.7248
|
||||
74.7032 19.7552
|
||||
74.7255 19.785
|
||||
74.7455 19.8136
|
||||
74.7635 19.8408
|
||||
74.7797 19.8661
|
||||
74.7943 19.8894
|
1176
data/trace_adagrad.eps
Normal file
BIN
data/trace_adagrad.png
Normal file
After Width: | Height: | Size: 39 KiB |
189
data/trace_adagrad.txt
Normal file
@ -0,0 +1,189 @@
|
||||
25 20
|
||||
55 50
|
||||
66.2949 29.3676
|
||||
95.1386 19.1501
|
||||
72.1298 35.9407
|
||||
73.1951 17.2262
|
||||
73.4971 18.3337
|
||||
73.6789 18.6274
|
||||
73.7984 18.8091
|
||||
73.8855 18.9379
|
||||
73.9531 19.0359
|
||||
74.0075 19.114
|
||||
74.0527 19.1782
|
||||
74.0909 19.2323
|
||||
74.1239 19.2786
|
||||
74.1527 19.3189
|
||||
74.1782 19.3544
|
||||
74.2009 19.386
|
||||
74.2213 19.4143
|
||||
74.2398 19.4398
|
||||
74.2566 19.463
|
||||
74.272 19.4842
|
||||
74.2862 19.5037
|
||||
74.2992 19.5217
|
||||
74.3114 19.5383
|
||||
74.3226 19.5537
|
||||
74.3331 19.5681
|
||||
74.343 19.5816
|
||||
74.3522 19.5942
|
||||
74.3609 19.606
|
||||
74.369 19.6172
|
||||
74.3767 19.6277
|
||||
74.384 19.6376
|
||||
74.3909 19.6471
|
||||
74.3975 19.656
|
||||
74.4037 19.6645
|
||||
74.4097 19.6725
|
||||
74.4153 19.6802
|
||||
74.4207 19.6876
|
||||
74.4259 19.6946
|
||||
74.4308 19.7013
|
||||
74.4355 19.7077
|
||||
74.44 19.7138
|
||||
74.4444 19.7197
|
||||
74.4485 19.7253
|
||||
74.4526 19.7308
|
||||
74.4564 19.736
|
||||
74.4601 19.741
|
||||
74.4637 19.7458
|
||||
74.4671 19.7505
|
||||
74.4704 19.755
|
||||
74.4736 19.7593
|
||||
74.4767 19.7635
|
||||
74.4797 19.7676
|
||||
74.4826 19.7715
|
||||
74.4854 19.7753
|
||||
74.4881 19.7789
|
||||
74.4907 19.7825
|
||||
74.4933 19.7859
|
||||
74.4957 19.7892
|
||||
74.4981 19.7924
|
||||
74.5004 19.7956
|
||||
74.5027 19.7986
|
||||
74.5048 19.8016
|
||||
74.507 19.8044
|
||||
74.509 19.8072
|
||||
74.511 19.8099
|
||||
74.513 19.8125
|
||||
74.5149 19.8151
|
||||
74.5167 19.8176
|
||||
74.5185 19.82
|
||||
74.5202 19.8223
|
||||
74.5219 19.8246
|
||||
74.5236 19.8269
|
||||
74.5252 19.829
|
||||
74.5268 19.8312
|
||||
74.5283 19.8332
|
||||
74.5298 19.8352
|
||||
74.5313 19.8372
|
||||
74.5327 19.8391
|
||||
74.5341 19.841
|
||||
74.5354 19.8428
|
||||
74.5367 19.8446
|
||||
74.538 19.8463
|
||||
74.5393 19.848
|
||||
74.5405 19.8497
|
||||
74.5417 19.8513
|
||||
74.5429 19.8529
|
||||
74.544 19.8545
|
||||
74.5452 19.856
|
||||
74.5463 19.8575
|
||||
74.5473 19.8589
|
||||
74.5484 19.8603
|
||||
74.5494 19.8617
|
||||
74.5504 19.8631
|
||||
74.5514 19.8644
|
||||
74.5524 19.8657
|
||||
74.5533 19.867
|
||||
74.5542 19.8682
|
||||
74.5552 19.8694
|
||||
74.556 19.8706
|
||||
74.5569 19.8718
|
||||
74.5578 19.8729
|
||||
74.5586 19.8741
|
||||
74.5594 19.8752
|
||||
74.5602 19.8762
|
||||
74.561 19.8773
|
||||
74.5618 19.8783
|
||||
74.5625 19.8794
|
||||
74.5632 19.8803
|
||||
74.564 19.8813
|
||||
74.5647 19.8823
|
||||
74.5654 19.8832
|
||||
74.5661 19.8841
|
||||
74.5667 19.885
|
||||
74.5674 19.8859
|
||||
74.568 19.8868
|
||||
74.5687 19.8876
|
||||
74.5693 19.8885
|
||||
74.5699 19.8893
|
||||
74.5705 19.8901
|
||||
74.5711 19.8909
|
||||
74.5717 19.8917
|
||||
74.5722 19.8924
|
||||
74.5728 19.8932
|
||||
74.5733 19.8939
|
||||
74.5739 19.8946
|
||||
74.5744 19.8954
|
||||
74.5749 19.896
|
||||
74.5754 19.8967
|
||||
74.5759 19.8974
|
||||
74.5764 19.8981
|
||||
74.5769 19.8987
|
||||
74.5774 19.8993
|
||||
74.5778 19.9
|
||||
74.5783 19.9006
|
||||
74.5787 19.9012
|
||||
74.5792 19.9018
|
||||
74.5796 19.9024
|
||||
74.58 19.9029
|
||||
74.5804 19.9035
|
||||
74.5809 19.9041
|
||||
74.5813 19.9046
|
||||
74.5817 19.9051
|
||||
74.582 19.9057
|
||||
74.5824 19.9062
|
||||
74.5828 19.9067
|
||||
74.5832 19.9072
|
||||
74.5835 19.9077
|
||||
74.5839 19.9082
|
||||
74.5843 19.9086
|
||||
74.5846 19.9091
|
||||
74.585 19.9096
|
||||
74.5853 19.91
|
||||
74.5856 19.9105
|
||||
74.5859 19.9109
|
||||
74.5863 19.9113
|
||||
74.5866 19.9118
|
||||
74.5869 19.9122
|
||||
74.5872 19.9126
|
||||
74.5875 19.913
|
||||
74.5878 19.9134
|
||||
74.5881 19.9138
|
||||
74.5884 19.9142
|
||||
74.5886 19.9145
|
||||
74.5889 19.9149
|
||||
74.5892 19.9153
|
||||
74.5895 19.9156
|
||||
74.5897 19.916
|
||||
74.59 19.9163
|
||||
74.5902 19.9167
|
||||
74.5905 19.917
|
||||
74.5907 19.9174
|
||||
74.591 19.9177
|
||||
74.5912 19.918
|
||||
74.5915 19.9183
|
||||
74.5917 19.9186
|
||||
74.5919 19.919
|
||||
74.5922 19.9193
|
||||
74.5924 19.9196
|
||||
74.5926 19.9199
|
||||
74.5928 19.9201
|
||||
74.593 19.9204
|
||||
74.5932 19.9207
|
||||
74.5934 19.921
|
||||
74.5936 19.9213
|
||||
74.5938 19.9215
|
||||
74.594 19.9218
|
||||
74.5942 19.922
|
1392
data/trace_adam.eps
Normal file
BIN
data/trace_adam.png
Normal file
After Width: | Height: | Size: 46 KiB |
235
data/trace_adam.txt
Normal file
@ -0,0 +1,235 @@
|
||||
25 20
|
||||
75 70
|
||||
44.8917 98.7743
|
||||
32.6504 101.649
|
||||
22.9108 100.006
|
||||
13.9398 91.9168
|
||||
6.26957 70.5065
|
||||
12.6287 46.3368
|
||||
24.2283 42.7528
|
||||
42.4555 44.5843
|
||||
64.0646 39.6989
|
||||
76.5126 41.289
|
||||
83.8675 32.4828
|
||||
83.3832 16.8001
|
||||
71.9312 14.558
|
||||
61.3246 14.1031
|
||||
62.2715 12.6225
|
||||
70.2794 11.7761
|
||||
78.6079 15.2732
|
||||
82.9537 23.8788
|
||||
82.9833 33.0889
|
||||
80.5439 34.6091
|
||||
77.3775 28.9631
|
||||
74.5517 20.5407
|
||||
71.8284 12.4204
|
||||
69.1634 7.6459
|
||||
67.0811 9.8592
|
||||
67.7948 15.5237
|
||||
71.2757 21.3197
|
||||
75.6961 26.2578
|
||||
80.3762 29.642
|
||||
83.615 29.8663
|
||||
82.903 27.4007
|
||||
78.8916 23.7271
|
||||
74.3965 19.9595
|
||||
70.0392 16.3068
|
||||
66.9447 13.1213
|
||||
67.4768 11.6077
|
||||
70.6815 12.8852
|
||||
74.4029 16.2754
|
||||
77.8055 20.2134
|
||||
80.5939 24.3253
|
||||
81.4676 28.0699
|
||||
80.0297 29.4851
|
||||
77.4852 27.5109
|
||||
75.0053 23.5604
|
||||
72.7439 19.3651
|
||||
70.6955 15.2677
|
||||
69.5575 12.15
|
||||
69.902 12.1445
|
||||
71.4347 15.0991
|
||||
73.4413 18.9205
|
||||
75.437 22.6598
|
||||
77.4032 26.1848
|
||||
79.1786 28.239
|
||||
79.976 27.4448
|
||||
79.1681 24.6099
|
||||
77.3326 21.3023
|
||||
75.2772 18.132
|
||||
73.1298 15.2758
|
||||
70.9565 13.7054
|
||||
69.6987 14.3439
|
||||
70.4542 16.5259
|
||||
72.5719 19.0922
|
||||
74.8867 21.581
|
||||
77.156 23.9753
|
||||
79.1573 25.8808
|
||||
79.8167 26.3385
|
||||
78.5637 25.0643
|
||||
76.4686 22.8929
|
||||
74.3527 20.6166
|
||||
72.3013 18.3931
|
||||
70.6869 16.2977
|
||||
70.593 14.9457
|
||||
72.0387 15.2142
|
||||
73.9103 16.9383
|
||||
75.6486 19.1758
|
||||
77.2344 21.4688
|
||||
78.4362 23.7376
|
||||
78.5991 25.4837
|
||||
77.6031 25.6657
|
||||
76.195 24.1184
|
||||
74.8641 21.8743
|
||||
73.6138 19.6024
|
||||
72.4603 17.3887
|
||||
71.7523 15.6054
|
||||
71.9244 15.3427
|
||||
72.8145 16.8337
|
||||
73.957 18.9384
|
||||
75.0973 21.0313
|
||||
76.209 23.0676
|
||||
77.2338 24.729
|
||||
77.8396 25.0527
|
||||
77.5992 23.7982
|
||||
76.7004 21.9088
|
||||
75.6233 20.0037
|
||||
74.5342 18.1739
|
||||
73.3864 16.6615
|
||||
72.3085 16.2221
|
||||
71.9422 17.1199
|
||||
72.6233 18.6582
|
||||
73.8035 20.2415
|
||||
75.0444 21.7519
|
||||
76.2887 23.141
|
||||
77.399 24.0265
|
||||
77.7772 23.8462
|
||||
77.0762 22.7435
|
||||
75.916 21.3529
|
||||
74.741 19.9759
|
||||
73.5958 18.6331
|
||||
72.6004 17.4532
|
||||
72.293 16.9387
|
||||
72.9647 17.4791
|
||||
74.0395 18.6491
|
||||
75.1214 19.9283
|
||||
76.1709 21.1803
|
||||
77.0755 22.3923
|
||||
77.3447 23.3212
|
||||
76.7152 23.4459
|
||||
75.7252 22.6754
|
||||
74.7834 21.4923
|
||||
73.9288 20.2414
|
||||
73.1956 18.9872
|
||||
72.8522 17.8483
|
||||
73.1665 17.3066
|
||||
73.8667 17.8229
|
||||
74.579 18.984
|
||||
75.2469 20.2335
|
||||
75.8959 21.4525
|
||||
76.4576 22.5921
|
||||
76.6889 23.2558
|
||||
76.421 22.9101
|
||||
75.853 21.877
|
||||
75.2267 20.725
|
||||
74.6168 19.6027
|
||||
74.0228 18.5231
|
||||
73.5022 17.7726
|
||||
73.2715 17.9035
|
||||
73.5224 18.7708
|
||||
74.0557 19.8216
|
||||
74.6409 20.8483
|
||||
75.2341 21.8263
|
||||
75.8682 22.5693
|
||||
76.4285 22.666
|
||||
76.5476 22.0675
|
||||
76.1345 21.1878
|
||||
75.5051 20.3088
|
||||
74.8531 19.4702
|
||||
74.1909 18.7122
|
||||
73.5765 18.2768
|
||||
73.3257 18.4608
|
||||
73.6698 19.1171
|
||||
74.2861 19.8913
|
||||
74.9146 20.6503
|
||||
75.525 21.3937
|
||||
76.0897 22.0471
|
||||
76.3769 22.3214
|
||||
76.1306 21.9949
|
||||
75.586 21.3293
|
||||
75.0104 20.6167
|
||||
74.449 19.9263
|
||||
73.9103 19.2565
|
||||
73.5702 18.7207
|
||||
73.7193 18.6096
|
||||
74.2157 19.0085
|
||||
74.762 19.6297
|
||||
75.2855 20.2684
|
||||
75.7874 20.8955
|
||||
76.144 21.5027
|
||||
76.0918 21.9374
|
||||
75.6681 21.9391
|
||||
75.1773 21.5005
|
||||
74.7481 20.8789
|
||||
74.359 20.2384
|
||||
74.0211 19.6023
|
||||
73.8647 19.0301
|
||||
74.0238 18.7674
|
||||
74.3709 19.0398
|
||||
74.726 19.6195
|
||||
75.063 20.2302
|
||||
75.3991 20.8192
|
||||
75.7117 21.3869
|
||||
75.8757 21.7965
|
||||
75.7575 21.7636
|
||||
75.4546 21.3216
|
||||
75.1176 20.773
|
||||
74.7898 20.2421
|
||||
74.4643 19.7225
|
||||
74.1683 19.243
|
||||
74.0295 19.0024
|
||||
74.1621 19.2077
|
||||
74.4539 19.6715
|
||||
74.7659 20.1644
|
||||
75.0642 20.6393
|
||||
75.3571 21.1108
|
||||
75.6181 21.5115
|
||||
75.7323 21.6107
|
||||
75.6079 21.3136
|
||||
75.3531 20.8597
|
||||
75.0889 20.4055
|
||||
74.8447 19.9595
|
||||
74.6086 19.5182
|
||||
74.3753 19.2108
|
||||
74.2094 19.2741
|
||||
74.2308 19.6369
|
||||
74.4005 20.0735
|
||||
74.6114 20.4977
|
||||
74.8285 20.9042
|
||||
75.0857 21.2358
|
||||
75.4002 21.352
|
||||
75.6575 21.213
|
||||
75.6857 20.913
|
||||
75.4609 20.586
|
||||
75.1703 20.2776
|
||||
74.8996 19.9696
|
||||
74.6394 19.6619
|
||||
74.3883 19.4396
|
||||
74.2385 19.4568
|
||||
74.3091 19.7078
|
||||
74.5132 20.0318
|
||||
74.7219 20.3499
|
||||
74.9003 20.6694
|
||||
75.0631 20.9991
|
||||
75.2406 21.2482
|
||||
75.4245 21.2329
|
||||
75.51 20.9884
|
||||
75.4574 20.6705
|
||||
75.3186 20.3613
|
||||
75.1492 20.0742
|
||||
74.9546 19.8206
|
||||
74.7167 19.6535
|
||||
74.463 19.6256
|
||||
74.2914 19.743
|
||||
74.3352 19.9538
|
||||
74.5347 20.1763
|
1204
data/trace_adamax.eps
Normal file
BIN
data/trace_adamax.png
Normal file
After Width: | Height: | Size: 42 KiB |
120
data/trace_adamax.txt
Normal file
@ -0,0 +1,120 @@
|
||||
25 20
|
||||
75 70
|
||||
52.4066 89.5808
|
||||
49.4441 80.1188
|
||||
58.2611 61.5919
|
||||
65.4849 44.7609
|
||||
64.1839 31.2917
|
||||
64.034 27.9776
|
||||
69.3821 26.3762
|
||||
80.4096 17.0992
|
||||
83.1103 17.2225
|
||||
77.388 24.5788
|
||||
72.3171 30.4105
|
||||
70.3201 30.1327
|
||||
71.3279 25.4333
|
||||
74.3233 18.3546
|
||||
76.9477 12.1999
|
||||
76.2211 13.0431
|
||||
73.5507 18.1171
|
||||
71.2072 22.674
|
||||
70.3186 25.5195
|
||||
72.0322 25.0801
|
||||
74.9455 22.6984
|
||||
77.5819 20.4578
|
||||
79.6884 18.6104
|
||||
79.6809 18.7226
|
||||
77.8487 20.498
|
||||
75.8952 22.2778
|
||||
74.14 23.8346
|
||||
72.8235 24.7088
|
||||
72.4955 24.1424
|
||||
73.0477 22.4309
|
||||
73.8359 20.5245
|
||||
74.5625 18.7988
|
||||
75.2109 17.2737
|
||||
75.6527 16.2401
|
||||
75.6318 16.2041
|
||||
75.1936 17.0748
|
||||
74.6375 18.2519
|
||||
74.1192 19.3822
|
||||
73.6564 20.4032
|
||||
73.2664 21.3043
|
||||
73.0291 21.9998
|
||||
73.0481 22.3554
|
||||
73.3427 22.3287
|
||||
73.8138 22.0311
|
||||
74.3342 21.6255
|
||||
74.8287 21.2179
|
||||
75.2765 20.8451
|
||||
75.6795 20.5093
|
||||
76.0386 20.2087
|
||||
76.3413 19.9536
|
||||
76.5607 19.7699
|
||||
76.668 19.6871
|
||||
76.6518 19.7178
|
||||
76.5293 19.8452
|
||||
76.3379 20.032
|
||||
76.1162 20.241
|
||||
75.8919 20.447
|
||||
75.68 20.6382
|
||||
75.4859 20.8113
|
||||
75.3104 20.967
|
||||
75.1521 21.1066
|
||||
75.0096 21.2308
|
||||
74.8823 21.3386
|
||||
74.7708 21.4276
|
||||
74.6763 21.4945
|
||||
74.6004 21.5358
|
||||
74.5443 21.5493
|
||||
74.5085 21.5346
|
||||
74.4917 21.4937
|
||||
74.4917 21.4307
|
||||
74.505 21.351
|
||||
74.5281 21.2606
|
||||
74.5575 21.165
|
||||
74.5902 21.0687
|
||||
74.6238 20.9752
|
||||
74.6567 20.8865
|
||||
74.6881 20.8042
|
||||
74.7171 20.7287
|
||||
74.7436 20.6601
|
||||
74.7676 20.5982
|
||||
74.7891 20.5426
|
||||
74.8083 20.4927
|
||||
74.8254 20.448
|
||||
74.8407 20.4079
|
||||
74.8542 20.372
|
||||
74.8664 20.3397
|
||||
74.8772 20.3106
|
||||
74.887 20.2844
|
||||
74.8957 20.2607
|
||||
74.9036 20.2391
|
||||
74.9108 20.2195
|
||||
74.9173 20.2017
|
||||
74.9233 20.1854
|
||||
74.9288 20.1704
|
||||
74.9338 20.1566
|
||||
74.9384 20.1439
|
||||
74.9427 20.1322
|
||||
74.9467 20.1214
|
||||
74.9504 20.1114
|
||||
74.9539 20.102
|
||||
74.9571 20.0933
|
||||
74.9601 20.0852
|
||||
74.963 20.0777
|
||||
74.9657 20.0706
|
||||
74.9682 20.0641
|
||||
74.9705 20.0579
|
||||
74.9728 20.0522
|
||||
74.9749 20.0468
|
||||
74.9768 20.0417
|
||||
74.9787 20.037
|
||||
74.9804 20.0326
|
||||
74.9821 20.0285
|
||||
74.9836 20.0247
|
||||
74.9851 20.0212
|
||||
74.9864 20.0178
|
||||
74.9877 20.0148
|
||||
74.9888 20.0119
|
||||
74.9899 20.0093
|
1583
data/trace_momentum.eps
Normal file
BIN
data/trace_momentum.png
Normal file
After Width: | Height: | Size: 145 KiB |
1975
data/trace_momentum.txt
Normal file
1186
data/trace_nadam.eps
Normal file
BIN
data/trace_nadam.png
Normal file
After Width: | Height: | Size: 39 KiB |
124
data/trace_nadam.txt
Normal file
@ -0,0 +1,124 @@
|
||||
25 20
|
||||
32.5595 27.5595
|
||||
42.0721 36.9959
|
||||
50.0029 42.3389
|
||||
53.7483 43.8042
|
||||
55.6612 44.4976
|
||||
56.1215 44.6281
|
||||
55.8143 44.3795
|
||||
55.3315 43.9639
|
||||
54.8672 43.5509
|
||||
54.4728 43.2138
|
||||
54.1629 42.9631
|
||||
53.9334 42.7859
|
||||
53.7703 42.6639
|
||||
53.6569 42.5808
|
||||
53.5785 42.5241
|
||||
53.5237 42.4847
|
||||
53.4848 42.4568
|
||||
53.4566 42.4366
|
||||
53.4357 42.4217
|
||||
53.4202 42.4106
|
||||
53.4084 42.4022
|
||||
53.3993 42.3957
|
||||
53.3924 42.3908
|
||||
53.387 42.3869
|
||||
53.3828 42.384
|
||||
53.3795 42.3816
|
||||
53.377 42.3798
|
||||
53.3749 42.3783
|
||||
53.3733 42.3772
|
||||
53.3721 42.3763
|
||||
53.3711 42.3756
|
||||
53.3703 42.375
|
||||
53.3696 42.3745
|
||||
53.3691 42.3742
|
||||
53.3687 42.3739
|
||||
53.3684 42.3737
|
||||
53.3681 42.3735
|
||||
53.3679 42.3733
|
||||
53.3678 42.3732
|
||||
53.3676 42.3731
|
||||
53.3675 42.373
|
||||
53.3674 42.373
|
||||
53.3674 42.3729
|
||||
53.3673 42.3729
|
||||
53.3673 42.3728
|
||||
53.3672 42.3728
|
||||
53.3672 42.3728
|
||||
53.3672 42.3728
|
||||
53.3672 42.3727
|
||||
53.367 42.3729
|
||||
53.3675 42.3724
|
||||
53.3658 42.3741
|
||||
53.3723 42.3675
|
||||
53.3473 42.3928
|
||||
53.4465 42.2925
|
||||
53.0386 42.7054
|
||||
54.7703 40.959
|
||||
47.1163 48.78
|
||||
67.8731 28.2337
|
||||
88.9984 12.8653
|
||||
70.4088 30.016
|
||||
79.8802 17.9192
|
||||
70.8517 25.6199
|
||||
78.7836 17.1841
|
||||
72.2691 23.6545
|
||||
75.4284 20.2173
|
||||
75.4219 20.2216
|
||||
75.4156 20.2258
|
||||
75.4096 20.2298
|
||||
75.4037 20.2337
|
||||
75.398 20.2376
|
||||
75.3925 20.2413
|
||||
75.3871 20.2449
|
||||
75.382 20.2484
|
||||
75.377 20.2517
|
||||
75.3721 20.255
|
||||
75.3675 20.2582
|
||||
75.3629 20.2612
|
||||
75.3586 20.2641
|
||||
75.3543 20.267
|
||||
75.3502 20.2697
|
||||
75.3462 20.2724
|
||||
75.3424 20.275
|
||||
75.3387 20.2774
|
||||
75.335 20.2798
|
||||
75.3315 20.2821
|
||||
75.3282 20.2844
|
||||
75.3249 20.2865
|
||||
75.3217 20.2886
|
||||
75.3186 20.2906
|
||||
75.3156 20.2926
|
||||
75.3128 20.2944
|
||||
75.31 20.2962
|
||||
75.3073 20.298
|
||||
75.3047 20.2996
|
||||
75.3021 20.3013
|
||||
75.2997 20.3028
|
||||
75.2973 20.3043
|
||||
75.2951 20.3058
|
||||
75.2929 20.3072
|
||||
75.2907 20.3085
|
||||
75.2887 20.3098
|
||||
75.2867 20.3111
|
||||
75.2848 20.3122
|
||||
75.2829 20.3134
|
||||
75.2812 20.3145
|
||||
75.2794 20.3156
|
||||
75.2778 20.3166
|
||||
75.2762 20.3176
|
||||
75.2747 20.3185
|
||||
75.2732 20.3194
|
||||
75.2718 20.3203
|
||||
75.2705 20.3211
|
||||
75.2692 20.3219
|
||||
75.268 20.3226
|
||||
75.2668 20.3234
|
||||
75.2657 20.324
|
||||
75.2646 20.3247
|
||||
75.2636 20.3253
|
||||
75.2626 20.3259
|
||||
75.2617 20.3265
|
||||
75.2608 20.327
|
||||
75.2599 20.3275
|
1485
data/trace_nag.eps
Normal file
BIN
data/trace_nag.png
Normal file
After Width: | Height: | Size: 96 KiB |
1038
data/trace_nag.txt
Normal file
1249
data/trace_rmsprop.eps
Normal file
BIN
data/trace_rmsprop.png
Normal file
After Width: | Height: | Size: 42 KiB |
293
data/trace_rmsprop.txt
Normal file
@ -0,0 +1,293 @@
|
||||
25 20
|
||||
47.3607 42.3607
|
||||
64.3811 37.1944
|
||||
44.1966 54.9809
|
||||
55.2149 40.7588
|
||||
52.472 44.1797
|
||||
54.2387 41.0592
|
||||
52.7227 43.5739
|
||||
53.9651 41.4312
|
||||
52.8768 43.253
|
||||
53.84 41.6117
|
||||
52.9345 43.1331
|
||||
53.8158 41.6453
|
||||
52.9113 43.1649
|
||||
53.8835 41.5386
|
||||
52.7876 43.3737
|
||||
54.0925 41.2193
|
||||
52.4661 43.9197
|
||||
54.6433 40.4269
|
||||
51.562 45.4018
|
||||
56.4953 38.2008
|
||||
47.9994 50.0915
|
||||
61.5949 35.5684
|
||||
48.2771 49.6706
|
||||
58.0885 39.524
|
||||
51.5696 44.6053
|
||||
54.0015 41.8871
|
||||
53.3 42.52
|
||||
53.4078 42.3568
|
||||
53.3665 42.3832
|
||||
53.3708 42.372
|
||||
53.3671 42.3737
|
||||
53.3675 42.3726
|
||||
53.3671 42.3728
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.367 42.3728
|
||||
53.3672 42.3726
|
||||
53.3669 42.3729
|
||||
53.3674 42.3724
|
||||
53.3664 42.3734
|
||||
53.3686 42.3711
|
||||
53.3636 42.3765
|
||||
53.3756 42.3635
|
||||
53.3458 42.3957
|
||||
53.4223 42.3131
|
||||
53.2186 42.5333
|
||||
53.7806 41.9271
|
||||
52.1677 43.676
|
||||
57.0134 38.5292
|
||||
42.2502 54.225
|
||||
57.0924 38.9732
|
||||
47.0601 48.5418
|
||||
61.4578 34.9508
|
||||
47.9543 49.7698
|
||||
58.5967 39.585
|
||||
51.0987 44.7044
|
||||
54.1173 41.7503
|
||||
53.2433 42.5363
|
||||
53.4096 42.347
|
||||
53.362 42.3832
|
||||
53.3706 42.3711
|
||||
53.3667 42.3737
|
||||
53.3675 42.3725
|
||||
53.367 42.3728
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.367 42.3728
|
||||
53.3672 42.3726
|
||||
53.3669 42.3729
|
||||
53.3675 42.3724
|
||||
53.3663 42.3735
|
||||
53.3689 42.371
|
||||
53.3629 42.3768
|
||||
53.3771 42.3629
|
||||
53.3421 42.3972
|
||||
53.4318 42.3095
|
||||
53.1934 42.5427
|
||||
53.8495 41.9018
|
||||
51.9681 43.7484
|
||||
57.588 38.2952
|
||||
41.3952 54.6574
|
||||
53.4034 41.369
|
||||
52.6462 43.2953
|
||||
54.6114 41.2513
|
||||
51.4107 44.1017
|
||||
56.5254 39.6417
|
||||
47.5145 47.772
|
||||
62.2085 34.1123
|
||||
48.7595 49.8925
|
||||
59.9043 39.419
|
||||
49.5499 45.659
|
||||
55.0408 40.7082
|
||||
52.7934 42.9632
|
||||
53.5264 42.2319
|
||||
53.3271 42.4169
|
||||
53.3816 42.3601
|
||||
53.3626 42.3776
|
||||
53.369 42.371
|
||||
53.3663 42.3735
|
||||
53.3674 42.3724
|
||||
53.3669 42.3729
|
||||
53.3672 42.3726
|
||||
53.367 42.3728
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.3671 42.3727
|
||||
53.367 42.3728
|
||||
53.3672 42.3726
|
||||
53.367 42.3728
|
||||
53.3673 42.3725
|
||||
53.3667 42.3731
|
||||
53.3677 42.3721
|
||||
53.3658 42.374
|
||||
53.3696 42.3703
|
||||
53.3619 42.3778
|
||||
53.3783 42.3618
|
||||
53.3422 42.3971
|
||||
53.4248 42.3164
|
||||
53.2286 42.5082
|
||||
53.7115 42.0366
|
||||
52.4741 43.2489
|
||||
55.7701 40.0449
|
||||
46.2665 49.5197
|
||||
65.0259 31.1729
|
||||
74.1984 38.6656
|
||||
65.4135 23.9187
|
||||
86.1499 13.4603
|
||||
68.5796 30.5295
|
||||
74.7087 24.3691
|
||||
75.3856 22.7927
|
||||
75.4449 22.5206
|
||||
75.4751 22.337
|
||||
75.492 22.1991
|
||||
75.5017 22.0893
|
||||
75.5069 21.9983
|
||||
75.5093 21.921
|
||||
75.5097 21.8538
|
||||
75.5086 21.7947
|
||||
75.5066 21.7419
|
||||
75.5039 21.6943
|
||||
75.5006 21.6511
|
||||
75.4969 21.6115
|
||||
75.4929 21.575
|
||||
75.4886 21.5412
|
||||
75.4842 21.5098
|
||||
75.4796 21.4805
|
||||
75.475 21.453
|
||||
75.4703 21.4271
|
||||
75.4655 21.4027
|
||||
75.4607 21.3797
|
||||
75.456 21.3578
|
||||
75.4512 21.3371
|
||||
75.4464 21.3173
|
||||
75.4417 21.2985
|
||||
75.437 21.2805
|
||||
75.4323 21.2634
|
||||
75.4277 21.2469
|
||||
75.4231 21.2311
|
||||
75.4185 21.216
|
||||
75.414 21.2015
|
||||
75.4096 21.1875
|
||||
75.4052 21.174
|
||||
75.4008 21.161
|
||||
75.3966 21.1485
|
||||
75.3923 21.1365
|
||||
75.3882 21.1248
|
||||
75.3841 21.1135
|
||||
75.38 21.1027
|
||||
75.376 21.0921
|
||||
75.3721 21.0819
|
||||
75.3682 21.0721
|
||||
75.3644 21.0625
|
||||
75.3607 21.0533
|
||||
75.357 21.0443
|
||||
75.3534 21.0356
|
||||
75.3498 21.0271
|
||||
75.3463 21.0189
|
||||
75.3429 21.011
|
||||
75.3395 21.0033
|
||||
75.3362 20.9958
|
||||
75.3329 20.9885
|
||||
75.3297 20.9815
|
||||
75.3265 20.9746
|
||||
75.3235 20.9679
|
||||
75.3204 20.9615
|
||||
75.3175 20.9552
|
||||
75.3146 20.9491
|
||||
75.3117 20.9431
|
||||
75.3089 20.9374
|
||||
75.3062 20.9318
|
||||
75.3035 20.9263
|
||||
75.3009 20.921
|
||||
75.2983 20.9159
|
||||
75.2958 20.9109
|
||||
75.2933 20.9061
|
||||
75.2909 20.9014
|
||||
75.2886 20.8968
|
||||
75.2863 20.8924
|
||||
75.284 20.8881
|
||||
75.2819 20.8839
|
||||
75.2797 20.8798
|
||||
75.2777 20.8759
|
||||
75.2756 20.8721
|
||||
75.2737 20.8684
|
||||
75.2718 20.8649
|
||||
75.2699 20.8614
|
||||
75.2681 20.8581
|
||||
75.2663 20.8549
|
||||
75.2646 20.8517
|
||||
75.263 20.8487
|
||||
75.2614 20.8458
|
||||
75.2598 20.843
|
||||
75.2583 20.8403
|
||||
75.2569 20.8377
|
||||
75.2554 20.8351
|
||||
75.2541 20.8327
|
||||
75.2528 20.8304
|
||||
75.2515 20.8281
|
||||
75.2503 20.826
|
||||
75.2491 20.8239
|
||||
75.248 20.8219
|
||||
75.2469 20.82
|
||||
75.2459 20.8182
|
||||
75.2449 20.8164
|
||||
75.2439 20.8148
|
||||
75.243 20.8132
|
||||
75.2421 20.8117
|
||||
75.2413 20.8102
|
287
src/lib/sgd.cpp
@ -1,6 +1,8 @@
|
||||
#include "sgd.h"
|
||||
#include "cmath"
|
||||
|
||||
#include "iostream"
|
||||
|
||||
/**
|
||||
* @brief return absolute value
|
||||
*
|
||||
@ -25,7 +27,7 @@ enum sgd_return_e
|
||||
{
|
||||
SGD_SUCCESS = 0,
|
||||
SGD_CONVERGENCE = 1,
|
||||
SGD_STOP, //1
|
||||
SGD_STOP, //2
|
||||
SGD_UNKNOWN_ERROR = -1024,
|
||||
// The variable size is negative
|
||||
SGD_INVILAD_VARIABLE_SIZE, //-1023
|
||||
@ -35,6 +37,8 @@ enum sgd_return_e
|
||||
SGD_INVILAD_EPSILON, //-1021
|
||||
// Iteration reached max limit
|
||||
SGD_REACHED_MAX_ITERATIONS,
|
||||
// Invalid value for mu
|
||||
SGD_INVALID_MU,
|
||||
// Invalid value for alpha
|
||||
SGD_INVALID_ALPHA,
|
||||
// Invalid value for beta
|
||||
@ -48,7 +52,7 @@ enum sgd_return_e
|
||||
/**
|
||||
* Default parameter for the SGD methods.
|
||||
*/
|
||||
static const sgd_para defparam = {100, 1e-6, 0.001, 0.9, 0.999, 1e-8};
|
||||
static const sgd_para defparam = {100, 1e-6, 0.01, 0.001, 0.9, 0.999, 1e-8};
|
||||
|
||||
sgd_float *sgd_malloc(const int n_size)
|
||||
{
|
||||
@ -91,6 +95,8 @@ const char* sgd_error_str(int er_index)
|
||||
return "Invalid value for epsilon.";
|
||||
case SGD_INVALID_BETA:
|
||||
return "Invalid value for beta.";
|
||||
case SGD_INVALID_MU:
|
||||
return "Invalid value for mu.";
|
||||
case SGD_INVALID_ALPHA:
|
||||
return "Invalid value for alpha.";
|
||||
case SGD_INVALID_SIGMA:
|
||||
@ -121,8 +127,18 @@ const char* sgd_error_str(int er_index)
|
||||
typedef int (*sgd_solver_ptr)(sgd_evaulate_ptr Evafp, sgd_progress_ptr Profp, sgd_float *fx, sgd_float *m,
|
||||
const int n_size, const int m_size, const sgd_para *param, void *instance);
|
||||
|
||||
int momentum(sgd_evaulate_ptr Evafp, sgd_progress_ptr Profp, sgd_float *fx, sgd_float *x,
|
||||
const int n_size, const int m_size, const sgd_para *param, void *instance);
|
||||
int nag(sgd_evaulate_ptr Evafp, sgd_progress_ptr Profp, sgd_float *fx, sgd_float *x,
|
||||
const int n_size, const int m_size, const sgd_para *param, void *instance);
|
||||
int adagrad(sgd_evaulate_ptr Evafp, sgd_progress_ptr Profp, sgd_float *fx, sgd_float *x,
|
||||
const int n_size, const int m_size, const sgd_para *param, void *instance);
|
||||
int rmsprop(sgd_evaulate_ptr Evafp, sgd_progress_ptr Profp, sgd_float *fx, sgd_float *x,
|
||||
const int n_size, const int m_size, const sgd_para *param, void *instance);
|
||||
int adam(sgd_evaulate_ptr Evafp, sgd_progress_ptr Profp, sgd_float *fx, sgd_float *x,
|
||||
const int n_size, const int m_size, const sgd_para *param, void *instance);
|
||||
int nadam(sgd_evaulate_ptr Evafp, sgd_progress_ptr Profp, sgd_float *fx, sgd_float *x,
|
||||
const int n_size, const int m_size, const sgd_para *param, void *instance);
|
||||
int adamax(sgd_evaulate_ptr Evafp, sgd_progress_ptr Profp, sgd_float *fx, sgd_float *x,
|
||||
const int n_size, const int m_size, const sgd_para *param, void *instance);
|
||||
int adabelief(sgd_evaulate_ptr Evafp, sgd_progress_ptr Profp, sgd_float *fx, sgd_float *x,
|
||||
@ -135,9 +151,24 @@ int sgd_solver(sgd_evaulate_ptr Evafp, sgd_progress_ptr Profp, sgd_float *fx, sg
|
||||
sgd_solver_ptr solver;
|
||||
switch (solver_id)
|
||||
{
|
||||
case SGD_MOMENTUM:
|
||||
solver = momentum;
|
||||
break;
|
||||
case SGD_NAG:
|
||||
solver = nag;
|
||||
break;
|
||||
case SGD_ADAGRAD:
|
||||
solver = adagrad;
|
||||
break;
|
||||
case SGD_RMSPROP:
|
||||
solver = rmsprop;
|
||||
break;
|
||||
case SGD_ADAM:
|
||||
solver = adam;
|
||||
break;
|
||||
case SGD_NADAM:
|
||||
solver = nadam;
|
||||
break;
|
||||
case SGD_ADAMAX:
|
||||
solver = adamax;
|
||||
break;
|
||||
@ -152,6 +183,188 @@ int sgd_solver(sgd_evaulate_ptr Evafp, sgd_progress_ptr Profp, sgd_float *fx, sg
|
||||
return solver(Evafp, Profp, fx, m, n_size, m_size, param, instance);
|
||||
}
|
||||
|
||||
int momentum(sgd_evaulate_ptr Evafp, sgd_progress_ptr Profp, sgd_float *fx, sgd_float *x,
|
||||
const int n_size, const int m_size, const sgd_para *param, void *instance)
|
||||
{
|
||||
// set the Adam's parameters
|
||||
sgd_para para = (param != nullptr) ? (*param) : defparam;
|
||||
|
||||
//check parameters
|
||||
if (n_size <= 0) return SGD_INVILAD_VARIABLE_SIZE;
|
||||
if (para.iteration <= 0) return SGD_INVILAD_MAX_ITERATIONS;
|
||||
if (para.epsilon < 0) return SGD_INVILAD_EPSILON;
|
||||
if (para.mu < 0 || para.mu >= 1.0) return SGD_INVALID_MU;
|
||||
|
||||
sgd_float *mk = sgd_malloc(n_size);
|
||||
sgd_float *g = sgd_malloc(n_size);
|
||||
|
||||
for (int i = 0; i < n_size; i++)
|
||||
{
|
||||
mk[i] = 0.0;
|
||||
}
|
||||
|
||||
for (int t = 0; t < para.iteration; t++)
|
||||
{
|
||||
for (int m = 0; m < m_size; m++)
|
||||
{
|
||||
*fx = Evafp(instance, x, g, n_size, m);
|
||||
|
||||
for (int i = 0; i < n_size; i++)
|
||||
{
|
||||
mk[i] = para.mu*mk[i] + g[i];
|
||||
|
||||
x[i] = x[i] - para.alpha * mk[i];
|
||||
if (x[i] != x[i]) return SGD_NAN_VALUE;
|
||||
}
|
||||
}
|
||||
|
||||
if (Profp(instance, *fx, x, g, param, n_size, t)) return SGD_STOP;
|
||||
if (*fx < para.epsilon) return SGD_CONVERGENCE;
|
||||
}
|
||||
|
||||
sgd_free(mk);
|
||||
sgd_free(g);
|
||||
return SGD_REACHED_MAX_ITERATIONS;
|
||||
}
|
||||
|
||||
int nag(sgd_evaulate_ptr Evafp, sgd_progress_ptr Profp, sgd_float *fx, sgd_float *x,
|
||||
const int n_size, const int m_size, const sgd_para *param, void *instance)
|
||||
{
|
||||
// set the Adam's parameters
|
||||
sgd_para para = (param != nullptr) ? (*param) : defparam;
|
||||
|
||||
//check parameters
|
||||
if (n_size <= 0) return SGD_INVILAD_VARIABLE_SIZE;
|
||||
if (para.iteration <= 0) return SGD_INVILAD_MAX_ITERATIONS;
|
||||
if (para.epsilon < 0) return SGD_INVILAD_EPSILON;
|
||||
if (para.mu < 0 || para.mu >= 1.0) return SGD_INVALID_MU;
|
||||
|
||||
sgd_float *mk = sgd_malloc(n_size);
|
||||
sgd_float *xk = sgd_malloc(n_size);
|
||||
sgd_float *g = sgd_malloc(n_size);
|
||||
|
||||
for (int i = 0; i < n_size; i++)
|
||||
{
|
||||
mk[i] = 0.0;
|
||||
}
|
||||
|
||||
for (int t = 0; t < para.iteration; t++)
|
||||
{
|
||||
for (int m = 0; m < m_size; m++)
|
||||
{
|
||||
for (int i = 0; i < n_size; i++)
|
||||
{
|
||||
xk[i] = x[i] - para.mu*para.alpha*mk[i];
|
||||
}
|
||||
|
||||
*fx = Evafp(instance, xk, g, n_size, m);
|
||||
|
||||
for (int i = 0; i < n_size; i++)
|
||||
{
|
||||
mk[i] = para.mu*mk[i] + g[i];
|
||||
|
||||
x[i] = x[i] - para.alpha * mk[i];
|
||||
if (x[i] != x[i]) return SGD_NAN_VALUE;
|
||||
}
|
||||
}
|
||||
|
||||
if (Profp(instance, *fx, x, g, param, n_size, t)) return SGD_STOP;
|
||||
if (*fx < para.epsilon) return SGD_CONVERGENCE;
|
||||
}
|
||||
|
||||
sgd_free(mk);
|
||||
sgd_free(xk);
|
||||
sgd_free(g);
|
||||
return SGD_REACHED_MAX_ITERATIONS;
|
||||
}
|
||||
|
||||
int adagrad(sgd_evaulate_ptr Evafp, sgd_progress_ptr Profp, sgd_float *fx, sgd_float *x,
|
||||
const int n_size, const int m_size, const sgd_para *param, void *instance)
|
||||
{
|
||||
// set the Adam's parameters
|
||||
sgd_para para = (param != nullptr) ? (*param) : defparam;
|
||||
|
||||
//check parameters
|
||||
if (n_size <= 0) return SGD_INVILAD_VARIABLE_SIZE;
|
||||
if (para.iteration <= 0) return SGD_INVILAD_MAX_ITERATIONS;
|
||||
if (para.epsilon < 0.0) return SGD_INVILAD_EPSILON;
|
||||
if (para.sigma < 0.0) return SGD_INVALID_SIGMA;
|
||||
|
||||
sgd_float *mk = sgd_malloc(n_size);
|
||||
sgd_float *g = sgd_malloc(n_size);
|
||||
|
||||
for (int i = 0; i < n_size; i++)
|
||||
{
|
||||
mk[i] = 0.0;
|
||||
}
|
||||
|
||||
for (int t = 0; t < para.iteration; t++)
|
||||
{
|
||||
for (int m = 0; m < m_size; m++)
|
||||
{
|
||||
*fx = Evafp(instance, x, g, n_size, m);
|
||||
|
||||
for (int i = 0; i < n_size; i++)
|
||||
{
|
||||
mk[i] = mk[i] + g[i]*g[i];
|
||||
|
||||
x[i] = x[i] - para.alpha * g[i]/(sqrt(mk[i]) + para.sigma);
|
||||
if (x[i] != x[i]) return SGD_NAN_VALUE;
|
||||
}
|
||||
}
|
||||
|
||||
if (Profp(instance, *fx, x, g, param, n_size, t)) return SGD_STOP;
|
||||
if (*fx < para.epsilon) return SGD_CONVERGENCE;
|
||||
}
|
||||
|
||||
sgd_free(mk);
|
||||
sgd_free(g);
|
||||
return SGD_REACHED_MAX_ITERATIONS;
|
||||
}
|
||||
|
||||
int rmsprop(sgd_evaulate_ptr Evafp, sgd_progress_ptr Profp, sgd_float *fx, sgd_float *x,
|
||||
const int n_size, const int m_size, const sgd_para *param, void *instance)
|
||||
{
|
||||
// set the Adam's parameters
|
||||
sgd_para para = (param != nullptr) ? (*param) : defparam;
|
||||
|
||||
//check parameters
|
||||
if (n_size <= 0) return SGD_INVILAD_VARIABLE_SIZE;
|
||||
if (para.iteration <= 0) return SGD_INVILAD_MAX_ITERATIONS;
|
||||
if (para.epsilon < 0.0) return SGD_INVILAD_EPSILON;
|
||||
if (para.sigma < 0.0) return SGD_INVALID_SIGMA;
|
||||
|
||||
sgd_float *vk = sgd_malloc(n_size);
|
||||
sgd_float *g = sgd_malloc(n_size);
|
||||
|
||||
for (int i = 0; i < n_size; i++)
|
||||
{
|
||||
vk[i] = 0.0;
|
||||
}
|
||||
|
||||
for (int t = 0; t < para.iteration; t++)
|
||||
{
|
||||
for (int m = 0; m < m_size; m++)
|
||||
{
|
||||
*fx = Evafp(instance, x, g, n_size, m);
|
||||
|
||||
for (int i = 0; i < n_size; i++)
|
||||
{
|
||||
vk[i] = para.beta_2 * vk[i] + (1.0 - para.beta_2)*g[i]*g[i];
|
||||
|
||||
x[i] = x[i] - para.alpha * g[i]/(sqrt(vk[i]) + para.sigma);
|
||||
if (x[i] != x[i]) return SGD_NAN_VALUE;
|
||||
}
|
||||
}
|
||||
|
||||
if (Profp(instance, *fx, x, g, param, n_size, t)) return SGD_STOP;
|
||||
if (*fx < para.epsilon) return SGD_CONVERGENCE;
|
||||
}
|
||||
|
||||
sgd_free(vk);
|
||||
sgd_free(g);
|
||||
return SGD_REACHED_MAX_ITERATIONS;
|
||||
}
|
||||
|
||||
int adam(sgd_evaulate_ptr Evafp, sgd_progress_ptr Profp, sgd_float *fx, sgd_float *x,
|
||||
const int n_size, const int m_size, const sgd_para *param, void *instance)
|
||||
@ -180,7 +393,6 @@ int adam(sgd_evaulate_ptr Evafp, sgd_progress_ptr Profp, sgd_float *fx, sgd_floa
|
||||
sgd_float beta_1t = 1.0, beta_2t = 1.0;
|
||||
sgd_float alpha_k;
|
||||
|
||||
int overall_iteration = para.iteration * m_size;
|
||||
for (int t = 0; t < para.iteration; t++)
|
||||
{
|
||||
beta_1t *= para.beta_1;
|
||||
@ -212,6 +424,73 @@ int adam(sgd_evaulate_ptr Evafp, sgd_progress_ptr Profp, sgd_float *fx, sgd_floa
|
||||
return SGD_REACHED_MAX_ITERATIONS;
|
||||
}
|
||||
|
||||
int nadam(sgd_evaulate_ptr Evafp, sgd_progress_ptr Profp, sgd_float *fx, sgd_float *x,
|
||||
const int n_size, const int m_size, const sgd_para *param, void *instance)
|
||||
{
|
||||
// set the Adam's parameters
|
||||
sgd_para para = (param != nullptr) ? (*param) : defparam;
|
||||
|
||||
//check parameters
|
||||
if (n_size <= 0) return SGD_INVILAD_VARIABLE_SIZE;
|
||||
if (para.iteration <= 0) return SGD_INVILAD_MAX_ITERATIONS;
|
||||
if (para.epsilon < 0) return SGD_INVILAD_EPSILON;
|
||||
if (para.alpha < 0) return SGD_INVALID_ALPHA;
|
||||
if (para.beta_1 < 0.0 || para.beta_1 >= 1.0) return SGD_INVALID_BETA;
|
||||
if (para.beta_2 < 0.0 || para.beta_2 >= 1.0) return SGD_INVALID_BETA;
|
||||
if (para.sigma < 0.0) return SGD_INVALID_SIGMA;
|
||||
|
||||
sgd_float *mk = sgd_malloc(n_size);
|
||||
sgd_float *mk_hat = sgd_malloc(n_size);
|
||||
sgd_float *nk = sgd_malloc(n_size);
|
||||
sgd_float *nk_hat = sgd_malloc(n_size);
|
||||
sgd_float *g = sgd_malloc(n_size);
|
||||
sgd_float *g_hat = sgd_malloc(n_size);
|
||||
|
||||
for (int i = 0; i < n_size; i++)
|
||||
{
|
||||
mk[i] = nk[i] = 0.0;
|
||||
}
|
||||
|
||||
sgd_float beta_1t = 1.0, beta_1t1 = para.beta_1, beta_2t = 1.0;
|
||||
|
||||
for (int t = 0; t < para.iteration; t++)
|
||||
{
|
||||
beta_1t *= para.beta_1;
|
||||
beta_1t1 *= para.beta_1;
|
||||
beta_2t *= para.beta_2;
|
||||
|
||||
for (int m = 0; m < m_size; m++)
|
||||
{
|
||||
*fx = Evafp(instance, x, g, n_size, m);
|
||||
|
||||
for (int i = 0; i < n_size; i++)
|
||||
{
|
||||
g_hat[i] = g[i]/(1.0 - beta_1t);
|
||||
mk[i] = para.beta_1*mk[i] + (1.0 - para.beta_1)*g[i];
|
||||
nk[i] = para.beta_2*nk[i] + (1.0 - para.beta_2)*g[i]*g[i];
|
||||
|
||||
mk_hat[i] = mk[i]/(1.0 - beta_1t1);
|
||||
nk_hat[i] = nk[i]/(1.0 - beta_2t);
|
||||
|
||||
x[i] = x[i] - para.alpha * ((1.0 - beta_1t)*g_hat[i]
|
||||
+ beta_1t1*mk_hat[i])/(sqrt(nk_hat[i]) + para.sigma);
|
||||
if (x[i] != x[i]) return SGD_NAN_VALUE;
|
||||
}
|
||||
}
|
||||
|
||||
if (Profp(instance, *fx, x, g, param, n_size, t)) return SGD_STOP;
|
||||
if (*fx < para.epsilon) return SGD_CONVERGENCE;
|
||||
}
|
||||
|
||||
sgd_free(mk);
|
||||
sgd_free(mk_hat);
|
||||
sgd_free(nk);
|
||||
sgd_free(nk_hat);
|
||||
sgd_free(g);
|
||||
sgd_free(g_hat);
|
||||
return SGD_REACHED_MAX_ITERATIONS;
|
||||
}
|
||||
|
||||
int adamax(sgd_evaulate_ptr Evafp, sgd_progress_ptr Profp, sgd_float *fx, sgd_float *x,
|
||||
const int n_size, const int m_size, const sgd_para *param, void *instance)
|
||||
{
|
||||
@ -239,7 +518,6 @@ int adamax(sgd_evaulate_ptr Evafp, sgd_progress_ptr Profp, sgd_float *fx, sgd_fl
|
||||
|
||||
sgd_float beta_1t = 1.0;
|
||||
|
||||
int overall_iteration = para.iteration * m_size;
|
||||
for (int t = 0; t < para.iteration; t++)
|
||||
{
|
||||
beta_1t *= para.beta_1;
|
||||
@ -295,7 +573,6 @@ int adabelief(sgd_evaulate_ptr Evafp, sgd_progress_ptr Profp, sgd_float *fx, sgd
|
||||
sgd_float beta_1t = 1.0, beta_2t = 1.0;
|
||||
sgd_float alpha_k;
|
||||
|
||||
int overall_iteration = para.iteration * m_size;
|
||||
for (int t = 0; t < para.iteration; t++)
|
||||
{
|
||||
beta_1t *= para.beta_1;
|
||||
|
@ -44,11 +44,36 @@ typedef double sgd_float;
|
||||
*/
|
||||
typedef enum
|
||||
{
|
||||
/**
|
||||
* Classic momentum.
|
||||
*/
|
||||
SGD_MOMENTUM,
|
||||
|
||||
/**
|
||||
* Nesterov’s accelerated gradient (NAG)
|
||||
*/
|
||||
SGD_NAG,
|
||||
|
||||
/**
|
||||
* AdaGrad method.
|
||||
*/
|
||||
SGD_ADAGRAD,
|
||||
|
||||
/**
|
||||
* RMSProp method.
|
||||
*/
|
||||
SGD_RMSPROP,
|
||||
|
||||
/**
|
||||
* Adam method.
|
||||
*/
|
||||
SGD_ADAM,
|
||||
|
||||
/**
|
||||
* Nadam method.
|
||||
*/
|
||||
SGD_NADAM,
|
||||
|
||||
/**
|
||||
* AdaMax method.
|
||||
*/
|
||||
@ -77,6 +102,12 @@ typedef struct
|
||||
*/
|
||||
sgd_float epsilon;
|
||||
|
||||
/**
|
||||
* Damping rate of the classic momentum method, which is typically given
|
||||
* between 0 and 1. The default is 0.01.
|
||||
*/
|
||||
sgd_float mu;
|
||||
|
||||
/**
|
||||
* Step size of the iteration. The default value is 0.001 for Adam and 0.002
|
||||
* for AdaMax.
|
||||
|
@ -66,7 +66,8 @@ double evaluate(void *instance, const double *x, double *g, const int n, const i
|
||||
g[1] += -1.0*gaussian_distribution(x[0], x[1], para[i], Dy);
|
||||
}
|
||||
|
||||
fx += 12.82906044;
|
||||
fx += 15.78257991; // 取5组参数 非凸的情况
|
||||
//fx += 12.82906044; // 取2组参数 凸的情况
|
||||
|
||||
g[0] *= 2.0*fx;
|
||||
g[1] *= 2.0*fx;
|
||||
@ -77,7 +78,11 @@ int progress(void *instance, sgd_float fx, const sgd_float *x, const sgd_float *
|
||||
const sgd_para *param, const int n_size, const int k)
|
||||
{
|
||||
std::clog << "iteration time: " << k << ", fx: " << fx << "\r";
|
||||
if (fx < param->epsilon) std::clog << std::endl;
|
||||
if (fx < param->epsilon)
|
||||
{
|
||||
std::clog << std::endl;
|
||||
std::cout << ">" << std::endl;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -91,28 +96,100 @@ int main(int argc, char *argv[])
|
||||
tmp_p.mu_x = 0.60; tmp_p.mu_y = 0.70; tmp_p.sigma_x = 0.10; tmp_p.sigma_y = 0.20; tmp_p.rho = 0.0;
|
||||
para.push_back(tmp_p);
|
||||
|
||||
tmp_p.mu_x = 0.75; tmp_p.mu_y = 0.20; tmp_p.sigma_x = 0.10; tmp_p.sigma_y = 0.12; tmp_p.rho = 0.5;
|
||||
para.push_back(tmp_p);
|
||||
|
||||
tmp_p.mu_x = 0.10; tmp_p.mu_y = 0.40; tmp_p.sigma_x = 0.60; tmp_p.sigma_y = 0.70; tmp_p.rho = 0.1;
|
||||
para.push_back(tmp_p);
|
||||
|
||||
tmp_p.mu_x = 0.22; tmp_p.mu_y = 0.66; tmp_p.sigma_x = 0.15; tmp_p.sigma_y = 0.12; tmp_p.rho = -0.2;
|
||||
para.push_back(tmp_p);
|
||||
|
||||
sgd_float fx;
|
||||
sgd_float x[2] = {0.25, 0.20};
|
||||
|
||||
sgd_para my_para = sgd_default_parameters();
|
||||
my_para.iteration = 10000;
|
||||
my_para.alpha = 0.01;
|
||||
my_para.iteration = 20000;
|
||||
|
||||
int ret = sgd_solver(evaluate, progress, &fx, &x[0], 2, 1, &my_para, nullptr);
|
||||
std::clog << "Adam return: " << sgd_error_str(ret) << std::endl;
|
||||
std::clog << "fx = " << fx << std::endl;
|
||||
std::clog << "model: " << x[0] << " " << x[1] << std::endl;
|
||||
my_para.mu = 0.02;
|
||||
|
||||
my_para.alpha = 0.0008;
|
||||
int ret = sgd_solver(evaluate, progress, &fx, &x[0], 2, 1, &my_para, nullptr, SGD_MOMENTUM);
|
||||
std::clog << "Momentum return: " << sgd_error_str(ret);
|
||||
if (ret > 0) std::clog << "\033[1m\033[32m Successed! \033[0m" << std::endl;
|
||||
else std::clog << "\033[1m\033[31m Failed! \033[0m" << std::endl;
|
||||
std::clog << "fx = " << fx << " ,initial step = " << my_para.alpha << std::endl;
|
||||
std::clog << "model: " << x[0] << " " << x[1] << std::endl << std::endl;
|
||||
|
||||
my_para.alpha = 0.0005;
|
||||
x[0] = 0.25; x[1] = 0.20;
|
||||
ret = sgd_solver(evaluate, progress, &fx, &x[0], 2, 1, &my_para, nullptr, SGD_NAG);
|
||||
std::clog << "NAG return: " << sgd_error_str(ret);
|
||||
if (ret > 0) std::clog << "\033[1m\033[32m Successed! \033[0m" << std::endl;
|
||||
else std::clog << "\033[1m\033[31m Failed! \033[0m" << std::endl;
|
||||
std::clog << "fx = " << fx << " ,initial step = " << my_para.alpha << std::endl;
|
||||
std::clog << "model: " << x[0] << " " << x[1] << std::endl << std::endl;
|
||||
|
||||
my_para.alpha = 0.3;
|
||||
x[0] = 0.25; x[1] = 0.20;
|
||||
ret = sgd_solver(evaluate, progress, &fx, &x[0], 2, 1, &my_para, nullptr, SGD_ADAGRAD);
|
||||
std::clog << "Adagrad return: " << sgd_error_str(ret);
|
||||
if (ret > 0) std::clog << "\033[1m\033[32m Successed! \033[0m" << std::endl;
|
||||
else std::clog << "\033[1m\033[31m Failed! \033[0m" << std::endl;
|
||||
std::clog << "fx = " << fx << " ,initial step = " << my_para.alpha << std::endl;
|
||||
std::clog << "model: " << x[0] << " " << x[1] << std::endl << std::endl;
|
||||
|
||||
my_para.alpha = 0.05;
|
||||
my_para.beta_2 = 0.95;
|
||||
x[0] = 0.25; x[1] = 0.20;
|
||||
ret = sgd_solver(evaluate, progress, &fx, &x[0], 2, 1, &my_para, nullptr, SGD_RMSPROP);
|
||||
std::clog << "RMSProp return: " << sgd_error_str(ret);
|
||||
if (ret > 0) std::clog << "\033[1m\033[32m Successed! \033[0m" << std::endl;
|
||||
else std::clog << "\033[1m\033[31m Failed! \033[0m" << std::endl;
|
||||
std::clog << "fx = " << fx << " ,initial step = " << my_para.alpha << std::endl;
|
||||
std::clog << "model: " << x[0] << " " << x[1] << std::endl << std::endl;
|
||||
|
||||
my_para.alpha = 0.5;
|
||||
my_para.beta_1 = 0.95;
|
||||
x[0] = 0.25; x[1] = 0.20;
|
||||
ret = sgd_solver(evaluate, progress, &fx, &x[0], 2, 1, &my_para, nullptr, SGD_ADAM);
|
||||
std::clog << "Adam return: " << sgd_error_str(ret);
|
||||
if (ret > 0) std::clog << "\033[1m\033[32m Successed! \033[0m" << std::endl;
|
||||
else std::clog << "\033[1m\033[31m Failed! \033[0m" << std::endl;
|
||||
std::clog << "fx = " << fx << " ,initial step = " << my_para.alpha << std::endl;
|
||||
std::clog << "model: " << x[0] << " " << x[1] << std::endl << std::endl;
|
||||
|
||||
my_para.alpha = 0.053;
|
||||
my_para.beta_1 = 0.9;
|
||||
my_para.beta_2 = 0.95;
|
||||
x[0] = 0.25; x[1] = 0.20;
|
||||
ret = sgd_solver(evaluate, progress, &fx, &x[0], 2, 1, &my_para, nullptr, SGD_NADAM);
|
||||
std::clog << "Nadam return: " << sgd_error_str(ret);
|
||||
if (ret > 0) std::clog << "\033[1m\033[32m Successed! \033[0m" << std::endl;
|
||||
else std::clog << "\033[1m\033[31m Failed! \033[0m" << std::endl;
|
||||
std::clog << "fx = " << fx << " ,initial step = " << my_para.alpha << std::endl;
|
||||
std::clog << "model: " << x[0] << " " << x[1] << std::endl << std::endl;
|
||||
|
||||
my_para.alpha = 0.5;
|
||||
my_para.beta_1 = 0.9;
|
||||
my_para.beta_2 = 0.999;
|
||||
x[0] = 0.25; x[1] = 0.20;
|
||||
ret = sgd_solver(evaluate, progress, &fx, &x[0], 2, 1, &my_para, nullptr, SGD_ADAMAX);
|
||||
std::clog << "AdaMax return: " << sgd_error_str(ret) << std::endl;
|
||||
std::clog << "fx = " << fx << std::endl;
|
||||
std::clog << "model: " << x[0] << " " << x[1] << std::endl;
|
||||
std::clog << "AdaMax return: " << sgd_error_str(ret);
|
||||
if (ret > 0) std::clog << "\033[1m\033[32m Successed! \033[0m" << std::endl;
|
||||
else std::clog << "\033[1m\033[31m Failed! \033[0m" << std::endl;
|
||||
std::clog << "fx = " << fx << " ,initial step = " << my_para.alpha << std::endl;
|
||||
std::clog << "model: " << x[0] << " " << x[1] << std::endl << std::endl;
|
||||
|
||||
my_para.alpha = 0.5;
|
||||
my_para.beta_1 = 0.9;
|
||||
my_para.beta_2 = 0.999;
|
||||
x[0] = 0.25; x[1] = 0.20;
|
||||
ret = sgd_solver(evaluate, progress, &fx, &x[0], 2, 1, &my_para, nullptr, SGD_ADABELIEF);
|
||||
std::clog << "AdaBelief return: " << sgd_error_str(ret) << std::endl;
|
||||
std::clog << "fx = " << fx << std::endl;
|
||||
std::clog << "AdaBelief return: " << sgd_error_str(ret);
|
||||
if (ret > 0) std::clog << "\033[1m\033[32m Successed! \033[0m" << std::endl;
|
||||
else std::clog << "\033[1m\033[31m Failed! \033[0m" << std::endl;
|
||||
std::clog << "fx = " << fx << " ,initial step = " << my_para.alpha << std::endl;
|
||||
std::clog << "model: " << x[0] << " " << x[1] << std::endl;
|
||||
return 0;
|
||||
}
|