/********************************************************
* ██████╗ ██████╗████████╗██╗
* ██╔════╝ ██╔════╝╚══██╔══╝██║
* ██║ ███╗██║ ██║ ██║
* ██║ ██║██║ ██║ ██║
* ╚██████╔╝╚██████╗ ██║ ███████╗
* ╚═════╝ ╚═════╝ ╚═╝ ╚══════╝
* Geophysical Computational Tools & Library (GCTL)
*
* Copyright (c) 2023 Yi Zhang (yizhang-geo@zju.edu.cn)
*
* GCTL is distributed under a dual licensing scheme. You can redistribute
* it and/or modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation, either version 2
* of the License, or (at your option) any later version. You should have
* received a copy of the GNU Lesser General Public License along with this
* program. If not, see .
*
* If the terms and conditions of the LGPL v.2. would prevent you from using
* the GCTL, please consider the option to obtain a commercial license for a
* fee. These licenses are offered by the GCTL's original author. As a rule,
* licenses are provided "as-is", unlimited in time for a one time fee. Please
* send corresponding requests to: yizhang-geo@zju.edu.cn. Please do not forget
* to include some description of your company and the realm of its activities.
* Also add information on how to contact you by electronic and paper mail.
******************************************************/
#include "sinkhorn.h"
#include "../io/netcdf_io.h"
gctl::sinkhorn1d::sinkhorn1d(){}
gctl::sinkhorn1d::~sinkhorn1d(){}
gctl::sinkhorn1d::sinkhorn1d(const array &tar, double tmin, double tmax, double eta, double eps, norm_type_e nt)
{
init(tar, tmin, tmax, eta, eps, nt);
}
void gctl::sinkhorn1d::init(const array &tar, double tmin, double tmax, double eta, double eps, norm_type_e nt)
{
if (eta <= 0 || eps <= 0 || tar.size() <= 1 || tmin >= tmax || (nt != L1 && nt != L2))
{
throw std::invalid_argument("[GCTL] Invalid initiating parameters for gctl::sinkhorn1d");
}
nt_ = nt;
eta_ = eta;
eps_ = eps;
ymin_ = tmin;
ymax_ = tmax;
ynum_ = tar.size();
dy_ = (ymax_ - ymin_)/(ynum_ - 1);
double ysum = 0;
py_.resize(ynum_);
for (size_t i = 0; i < ynum_; i++)
{
if (tar[i] < 0 || std::isnan(tar[i])) throw std::invalid_argument("[GCTL] Invalid targeting distribution for gctl::sinkhorn1d");
py_[i] = tar[i];
ysum += tar[i];
}
if (fabs(ysum - 1) > 1e-12)
{
throw std::invalid_argument("[GCTL] The total amount of distribution must be one for gctl::sinkhorn1d");
}
return;
}
double gctl::sinkhorn1d::get_distance()
{
double dist = 0;
double x, y;
if (nt_ == L2)
{
for (int i = 0; i < ynum_; i++)
{
y = ymin_ + dy_*i;
for (int j = 0; j < xnum_; j++)
{
x = xmin_ + dx_*j;
dist += P_[i][j]*L2_distance(x, y);
}
}
}
else // nt_ == L1
{
for (int i = 0; i < ynum_; i++)
{
y = ymin_ + dy_*i;
for (int j = 0; j < xnum_; j++)
{
x = xmin_ + dx_*j;
dist += P_[i][j]*L1_distance(x, y);
}
}
}
return dist;
}
double gctl::sinkhorn1d::get_distance(array &grad)
{
double x, y, s;
gctl::difference_1d(px_, px_grad_, dx_, 1);
grad.resize(xnum_, 0);
if (nt_ == L2)
{
for (int j = 0; j < xnum_; j++)
{
x = xmin_ + dx_*j;
s = 0;
for (int i = 0; i < ynum_; i++)
{
s += K_[i][j]*u_[i];
}
for (int i = 0; i < ynum_; i++)
{
y = ymin_ + dy_*i;
grad[j] += L2_distance(x, y)*K_[i][j]*u_[i]/s;
}
grad[j] *= px_grad_[j];
}
for (int j = 0; j < xnum_; j++)
{
x = xmin_ + dx_*j;
for (int i = 0; i < ynum_; i++)
{
y = ymin_ + dy_*i;
grad[j] += -2*(y-x)*P_[i][j];
}
}
}
else // nt_ == L1
{
for (int j = 0; j < xnum_; j++)
{
x = xmin_ + dx_*j;
s = 0;
for (int i = 0; i < ynum_; i++)
{
s += K_[i][j]*u_[i];
}
for (int i = 0; i < ynum_; i++)
{
y = ymin_ + dy_*i;
grad[j] += L1_distance(x, y)*K_[i][j]*u_[i]/s;
}
grad[j] *= px_grad_[j];
}
for (int j = 0; j < xnum_; j++)
{
x = xmin_ + dx_*j;
for (int i = 0; i < ynum_; i++)
{
y = ymin_ + dy_*i;
if (y > x) grad[j] -= P_[i][j];
else if (y < x) grad[j] += P_[i][j];
}
}
}
return get_distance();
}
void gctl::sinkhorn1d::sampling_to_target(array &in_out)
{
px_maxi_.resize(xnum_);
for (size_t i = 0; i < xnum_; i++)
{
px_maxi_[i] = P_[0][i];
}
for (int j = 0; j < xnum_; j++)
{
for (int i = 1; i < ynum_; i++)
{
if (P_[i][j] > px_maxi_[j]) px_maxi_[j] = P_[i][j];
}
}
unsigned int seed = std::chrono::system_clock::now().time_since_epoch().count();
std::default_random_engine generator(seed);
std::uniform_real_distribution dist(0, 1);
int x_id, y_id;
double loc, s_val, l_val;
for (size_t i = 0; i < in_out.size(); i++)
{
x_id = floor((in_out[i] - xmin_)/dx_);
if (x_id < 0) x_id = 0;
if (x_id > xnum_-1) x_id = xnum_-1;
if (in_out[i] >= ymin_ && in_out[i] <= ymax_)
{
y_id = floor((in_out[i] - ymin_)/dy_);
if (y_id < 0) y_id = 0;
if (y_id > ynum_-2) y_id = ynum_-2;
l_val = line_interpolate(ymin_ + y_id*dy_, ymin_ + (y_id+1)*dy_, P_[y_id][x_id], P_[y_id+1][x_id], in_out[i]);
s_val = px_maxi_[x_id]*dist(generator);
if (s_val <= l_val) continue;
}
do
{
loc = (ymax_ - ymin_)*dist(generator) + ymin_;
y_id = floor((loc - ymin_)/dy_);
if (y_id < 0) y_id = 0;
if (y_id > ynum_-2) y_id = ynum_-2;
l_val = line_interpolate(ymin_ + y_id*dy_, ymin_ + (y_id+1)*dy_, P_[y_id][x_id], P_[y_id+1][x_id], loc);
s_val = px_maxi_[x_id]*dist(generator);
} while (s_val > l_val);
in_out[i] = loc;
}
return;
}
gctl::matrix &gctl::sinkhorn1d::get_plan()
{
return P_;
}
void gctl::sinkhorn1d::save_plan(std::string filename)
{
save_netcdf_grid(filename, P_, xmin_, dx_, ymin_, dy_, "x", "y", "transports");
return;
}
void gctl::sinkhorn1d::make_plan_from(const array &inp, double imin, double imax, bool verbose)
{
if (inp.size() <= 1 || imin >= imax)
{
throw std::invalid_argument("[GCTL] Invalid parameters for gctl::sinkhorn1d::make_plan_from(...)");
}
xmin_ = imin;
xmax_ = imax;
xnum_ = inp.size();
dx_ = (xmax_ - xmin_)/(xnum_ - 1);
double xsum = 0;
px_.resize(xnum_);
for (size_t i = 0; i < xnum_; i++)
{
if (inp[i] < 0 || std::isnan(inp[i])) throw std::invalid_argument("[GCTL] Invalid input distribution for gctl::sinkhorn1d::make_plan_from(...)");
px_[i] = inp[i];
xsum += inp[i];
}
if (fabs(xsum - 1) > 1e-12)
{
throw std::invalid_argument("[GCTL] The total amount of distribution must be one for gctl::sinkhorn1d::make_plan_from(...)");
}
u_.resize(ynum_, 0.0);
v_.resize(xnum_, 1.0);
K_.resize(ynum_, xnum_);
P_.resize(ynum_, xnum_);
double x, y;
if (nt_ == L2)
{
for (int i = 0; i < ynum_; i++)
{
y = ymin_ + dy_*i;
for (int j = 0; j < xnum_; j++)
{
x = xmin_ + dx_*j;
K_[i][j] = exp(-1.0*L2_distance(x, y)/eta_);
}
}
}
else // nt_ == L1
{
for (int i = 0; i < ynum_; i++)
{
y = ymin_ + dy_*i;
for (int j = 0; j < xnum_; j++)
{
x = xmin_ + dx_*j;
K_[i][j] = exp(-1.0*L1_distance(x, y)/eta_);
}
}
}
int t = 0;
double rms_py, tmp_d, sum;
do
{
// 迭代更新 u 与 v 同时计算 py 分布的均方根误差
for (int i = 0; i < ynum_; i++)
{
tmp_d = 0;
for (int j = 0; j < xnum_; j++)
{
tmp_d += K_[i][j]*v_[j];
}
u_[i] = py_[i]/tmp_d;
}
for (int j = 0; j < xnum_; j++)
{
tmp_d = 0;
for (int i = 0; i < ynum_; i++)
{
tmp_d += K_[i][j]*u_[i];
}
v_[j] = px_[j]/tmp_d;
}
sum = 0;
for (int i = 0; i < ynum_; i++)
{
tmp_d = 0;
for (int j = 0; j < xnum_; j++)
{
tmp_d += K_[i][j]*v_[j];
}
sum += (u_[i]*tmp_d - py_[i])*(u_[i]*tmp_d - py_[i]);
}
rms_py = sqrt(sum/ynum_);
t++;
}
while (rms_py > eps_);
for (int i = 0; i < ynum_; i++)
{
for (int j = 0; j < xnum_; j++)
{
P_[i][j] = K_[i][j]*u_[i]*v_[j];
}
}
if (verbose)
{
std::clog << "[GCTL Sinkhorn] Iterations: " << t << ", Eps. = " << rms_py;
if (nt_ == L1) std::clog << ", W_1 = " << get_distance() << "\n";
else std::clog << ", W_2 = " << get_distance() << "\n"; // nt_ == L2
}
return;
}
double gctl::sinkhorn1d::L1_distance(double x, double y)
{
return fabs(y - x);
}
double gctl::sinkhorn1d::L2_distance(double x, double y)
{
return power2(y - x);
}
/**
* Here starts definitions for gctl::sinkhorn2d
*/
gctl::sinkhorn2d::sinkhorn2d(){}
gctl::sinkhorn2d::~sinkhorn2d(){}
gctl::sinkhorn2d::sinkhorn2d(const matrix &tar, double xmin, double xmax,
double ymin, double ymax, double eta, double eps, norm_type_e nt)
{
init(tar, xmin, xmax, ymin, ymax, eta, eps, nt);
}
void gctl::sinkhorn2d::init(const matrix &tar, double xmin, double xmax,
double ymin, double ymax, double eta, double eps, norm_type_e nt)
{
if (eta <= 0 || eps <= 0 || tar.row_size() <= 1 || tar.col_size() <= 1 ||
xmin >= xmax || ymin >= ymax || (nt != L1 && nt != L2))
{
throw std::invalid_argument("[GCTL] Invalid initiating parameters for gctl::sinkhorn2d");
}
nt_ = nt;
eta_ = eta;
eps_ = eps;
t_xmin_ = xmin; t_ymin_ = ymin;
t_xmax_ = xmax; t_ymax_ = ymax;
t_xnum_ = tar.col_size(); t_ynum_ = tar.row_size();
t_dx_ = (t_xmax_ - t_xmin_)/(t_xnum_ - 1);
t_dy_ = (t_ymax_ - t_ymin_)/(t_ynum_ - 1);
double tsum = 0;
py_num_ = t_ynum_*t_xnum_;
py_.resize(t_ynum_*t_xnum_);
for (size_t i = 0; i < t_ynum_; i++)
{
for (size_t j = 0; j < t_xnum_; j++)
{
if (tar[i][j] < 0 || std::isnan(tar[i][j])) throw std::invalid_argument("[GCTL] Invalid targeting distribution for gctl::sinkhorn2d");
py_[i*t_xnum_ + j] = tar[i][j];
tsum += tar[i][j];
}
}
if (fabs(tsum - 1) > 1e-12)
{
throw std::invalid_argument("[GCTL] The total amount of distribution must be one for gctl::sinkhorn2d");
}
return;
}
void gctl::sinkhorn2d::make_plan_from(const matrix &inp, double xmin, double xmax,
double ymin, double ymax, bool verbose)
{
if (inp.row_size() <= 1 || inp.col_size() <= 1 || xmin >= xmax || ymin >= ymax)
{
throw std::invalid_argument("[GCTL] Invalid parameters for gctl::sinkhorn2d::make_plan_from(...)");
}
i_xmin_ = xmin; i_ymin_ = ymin;
i_xmax_ = xmax; i_ymax_ = ymax;
i_xnum_ = inp.col_size(); i_ynum_ = inp.row_size();
i_dx_ = (i_xmax_ - i_xmin_)/(i_xnum_ - 1);
i_dy_ = (i_ymax_ - i_ymin_)/(i_ynum_ - 1);
double tsum = 0;
px_num_ = i_ynum_*i_xnum_;
px_.resize(i_ynum_*i_xnum_);
for (size_t i = 0; i < i_ynum_; i++)
{
for (size_t j = 0; j < i_xnum_; j++)
{
if (inp[i][j] < 0 || std::isnan(inp[i][j])) throw std::invalid_argument("[GCTL] Invalid targeting distribution for gctl::sinkhorn2d");
px_[i*i_xnum_ + j] = inp[i][j];
tsum += inp[i][j];
}
}
if (fabs(tsum - 1) > 1e-12)
{
throw std::invalid_argument("[GCTL] The total amount of distribution must be one for gctl::sinkhorn2d");
}
u_.resize(py_num_, 0.0);
v_.resize(px_num_, 1.0);
K_.resize(py_num_, px_num_);
P_.resize(py_num_, px_num_);
RP_.resize(py_num_, px_num_);
double x, y, x2, y2;
if (nt_ == L2)
{
for (int i = 0; i < py_num_; i++)
{
x = t_xmin_ + t_dx_*(i%t_xnum_);
y = t_ymin_ + t_dy_*(i/t_xnum_);
for (int j = 0; j < px_num_; j++)
{
x2 = i_xmin_ + i_dx_*(j%i_xnum_);
y2 = i_ymin_ + i_dy_*(j/i_xnum_);
K_[i][j] = exp(-1.0*L2_distance(x, y, x2, y2)/eta_);
}
}
}
else // nt_ == L1
{
for (int i = 0; i < py_num_; i++)
{
x = t_xmin_ + t_dx_*(i%t_xnum_);
y = t_ymin_ + t_dy_*(i/t_xnum_);
for (int j = 0; j < px_num_; j++)
{
x2 = i_xmin_ + i_dx_*(j%i_xnum_);
y2 = i_ymin_ + i_dy_*(j/i_xnum_);
K_[i][j] = exp(-1.0*L1_distance(x, y, x2, y2)/eta_);
}
}
}
int t = 0;
double rms_py, tmp_d, sum;
do
{
// 迭代更新 u 与 v 同时计算 py 分布的均方根误差
for (int i = 0; i < py_num_; i++)
{
tmp_d = 0;
for (int j = 0; j < px_num_; j++)
{
tmp_d += K_[i][j]*v_[j];
}
u_[i] = py_[i]/tmp_d;
}
for (int j = 0; j < px_num_; j++)
{
tmp_d = 0;
for (int i = 0; i < py_num_; i++)
{
tmp_d += K_[i][j]*u_[i];
}
v_[j] = px_[j]/tmp_d;
}
sum = 0;
for (int i = 0; i < py_num_; i++)
{
tmp_d = 0;
for (int j = 0; j < px_num_; j++)
{
tmp_d += K_[i][j]*v_[j];
}
sum += (u_[i]*tmp_d - py_[i])*(u_[i]*tmp_d - py_[i]);
}
rms_py = sqrt(sum/py_num_);
t++;
}
while (rms_py > eps_);
for (int i = 0; i < py_num_; i++)
{
for (int j = 0; j < px_num_; j++)
{
P_[i][j] = K_[i][j]*u_[i]*v_[j];
}
}
for (size_t i = 0; i < i_ynum_; i++)
{
for (size_t j = 0; j < i_xnum_; j++)
{
for (size_t p = 0; p < t_ynum_; p++)
{
for (size_t q = 0; q < t_xnum_; q++)
{
RP_[i*t_ynum_+p][j*t_xnum_+q] = P_[p*t_xnum_+q][i*i_xnum_+j];
}
}
}
}
if (verbose)
{
std::clog << "[GCTL Sinkhorn] Iterations: " << t << ", Eps. = " << rms_py;
if (nt_ == L1) std::clog << ", W_1 = " << get_distance() << "\n";
else std::clog << ", W_2 = " << get_distance() << "\n"; // nt_ == L2
}
return;
}
double gctl::sinkhorn2d::get_distance()
{
double dist = 0;
double x, y, x2, y2;
if (nt_ == L2)
{
for (int i = 0; i < py_num_; i++)
{
x = t_xmin_ + t_dx_*(i%t_xnum_);
y = t_ymin_ + t_dy_*(i/t_xnum_);
for (int j = 0; j < px_num_; j++)
{
x2 = i_xmin_ + i_dx_*(j%i_xnum_);
y2 = i_ymin_ + i_dy_*(j/i_xnum_);
dist += P_[i][j]*L2_distance(x, y, x2, y2);
}
}
}
else // nt_ == L1
{
for (int i = 0; i < py_num_; i++)
{
x = t_xmin_ + t_dx_*(i%t_xnum_);
y = t_ymin_ + t_dy_*(i/t_xnum_);
for (int j = 0; j < px_num_; j++)
{
x2 = i_xmin_ + i_dx_*(j%i_xnum_);
y2 = i_ymin_ + i_dy_*(j/i_xnum_);
dist += P_[i][j]*L1_distance(x, y, x2, y2);
}
}
}
return dist;
}
void gctl::sinkhorn2d::sampling_to_target(array &inx, array &iny)
{
if (inx.size() != iny.size()) throw std::invalid_argument("[GCTL Sinkhorn] Invalid inquring coordinates for gctl::sinkhorn2d::sampling_to_target(...)");
rp_maxi_.resize(i_ynum_, i_xnum_);
for (size_t i = 0; i < i_ynum_; i++)
{
for (size_t j = 0; j < i_xnum_; j++)
{
rp_maxi_[i][j] = RP_[i*t_ynum_][j*t_xnum_];
for (size_t p = 1; p < t_ynum_; p++)
{
for (size_t q = 1; q < t_xnum_; q++)
{
if (RP_[i*t_ynum_+p][j*t_xnum_+q] > rp_maxi_[i][j]) rp_maxi_[i][j] = RP_[i*t_ynum_+p][j*t_xnum_+q];
}
}
}
}
unsigned int seed = std::chrono::system_clock::now().time_since_epoch().count();
std::default_random_engine generator(seed);
std::uniform_real_distribution dist(0, 1);
int idx, idy, xid, yid;
double loc_x, loc_y, s_val, l_val;
for (size_t i = 0; i < inx.size(); i++)
{
idx = floor((inx[i] - i_xmin_)/i_dx_);
idy = floor((iny[i] - i_ymin_)/i_dy_);
if (idx < 0) idx = 0;
if (idy < 0) idy = 0;
if (idx > i_xnum_-1) idx = i_xnum_-1;
if (idy > i_ynum_-1) idy = i_ynum_-1;
if (inx[i] >= t_xmin_ && inx[i] <= t_xmax_ &&
iny[i] >= t_ymin_ && iny[i] <= t_ymax_)
{
xid = floor((inx[i] - t_xmin_)/t_dx_);
yid = floor((iny[i] - t_ymin_)/t_dy_);
if (xid < 0) xid = 0;
if (xid > t_xnum_-2) xid = t_xnum_-2;
if (yid < 0) yid = 0;
if (yid > t_ynum_-2) yid = t_ynum_-2;
l_val = rect_interpolate(xid*t_dx_+t_xmin_, yid*t_dy_+t_ymin_, t_dx_, t_dy_, inx[i], iny[i],
RP_[idy*t_ynum_+yid][idx*t_xnum_+xid], RP_[idy*t_ynum_+yid][idx*t_xnum_+xid+1],
RP_[idy*t_ynum_+yid+1][idx*t_xnum_+xid+1], RP_[idy*t_ynum_+yid+1][idx*t_xnum_+xid]);
s_val = rp_maxi_[idy][idx]*dist(generator);
if (s_val <= l_val) continue;
}
do
{
loc_x = (t_xmax_ - t_xmin_)*dist(generator) + t_xmin_;
loc_y = (t_ymax_ - t_ymin_)*dist(generator) + t_ymin_;
xid = floor((loc_x - t_xmin_)/t_dx_);
yid = floor((loc_y - t_ymin_)/t_dy_);
if (xid < 0) xid = 0;
if (xid > t_xnum_-2) xid = t_xnum_-2;
if (yid < 0) yid = 0;
if (yid > t_ynum_-2) yid = t_ynum_-2;
l_val = rect_interpolate(xid*t_dx_+t_xmin_, yid*t_dy_+t_ymin_, t_dx_, t_dy_, loc_x, loc_y,
RP_[idy*t_ynum_+yid][idx*t_xnum_+xid], RP_[idy*t_ynum_+yid][idx*t_xnum_+xid+1],
RP_[idy*t_ynum_+yid+1][idx*t_xnum_+xid+1], RP_[idy*t_ynum_+yid+1][idx*t_xnum_+xid]);
s_val = rp_maxi_[idy][idx]*dist(generator);
} while (s_val > l_val);
inx[i] = loc_x;
iny[i] = loc_y;
}
return;
}
gctl::matrix &gctl::sinkhorn2d::get_plan()
{
return P_;
}
void gctl::sinkhorn2d::save_plan(std::string filename, int idx, int idy)
{
if (idx >= 0 && idx < i_xnum_ && idy >= 0 && idy < i_ynum_)
{
matrix loc_RP(t_ynum_, t_xnum_);
for (size_t p = 0; p < t_ynum_; p++)
{
for (size_t q = 0; q < t_xnum_; q++)
{
loc_RP[p][q] = RP_[idy*t_ynum_+p][idx*t_xnum_+q];
}
}
save_netcdf_grid(filename, loc_RP, t_xmin_, t_dx_, t_ymin_, t_dy_, "x", "y", "transports");
}
else if (idx == -1 && idy == -1)
{
save_netcdf_grid(filename, RP_, t_xmin_, t_dx_, t_ymin_, t_dy_, "x", "y", "transports");
}
else throw std::runtime_error("[GCTL Sinkhorn] Invalid inquiring indices for gctl::sinkhorn2d::save_plan(...)");
return;
}
double gctl::sinkhorn2d::L1_distance(double x, double y, double x2, double y2)
{
return fabs(x2 - x) + fabs(y2 - y);
}
double gctl::sinkhorn2d::L2_distance(double x, double y, double x2, double y2)
{
return power2(x2 - x) + power2(y2 - y);
}