mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
fix: conv_general differences between gpu, cpu (#2070)
* fix general_conv padding * fix bugs * add test --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
0cae0bdac8
commit
a7fae8a176
@ -22,7 +22,8 @@ void slow_conv_1D(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding,
|
const std::vector<int>& padding_lo,
|
||||||
|
const std::vector<int>& padding_hi,
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const std::vector<int>& in_dilation,
|
const std::vector<int>& in_dilation,
|
||||||
@ -60,7 +61,8 @@ void slow_conv_1D(
|
|||||||
out_stride_O = out.strides()[2],
|
out_stride_O = out.strides()[2],
|
||||||
|
|
||||||
flip,
|
flip,
|
||||||
padding = padding[0],
|
padding_lo = padding_lo[0],
|
||||||
|
padding_hi = padding_hi[0],
|
||||||
wt_stride = wt_strides[0],
|
wt_stride = wt_strides[0],
|
||||||
wt_dilation = wt_dilation[0],
|
wt_dilation = wt_dilation[0],
|
||||||
in_dilation = in_dilation[0]]() mutable {
|
in_dilation = in_dilation[0]]() mutable {
|
||||||
@ -77,7 +79,7 @@ void slow_conv_1D(
|
|||||||
const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;
|
const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;
|
||||||
|
|
||||||
int wh_flip = flip ? (wH - wh - 1) : wh;
|
int wh_flip = flip ? (wH - wh - 1) : wh;
|
||||||
int ih = oh * wt_stride - padding + wh_flip * wt_dilation;
|
int ih = oh * wt_stride - padding_lo + wh_flip * wt_dilation;
|
||||||
|
|
||||||
auto ih_div = std::div(ih, in_dilation);
|
auto ih_div = std::div(ih, in_dilation);
|
||||||
|
|
||||||
@ -109,7 +111,8 @@ void slow_conv_2D(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding,
|
const std::vector<int>& padding_lo,
|
||||||
|
const std::vector<int>& padding_hi,
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const std::vector<int>& in_dilation,
|
const std::vector<int>& in_dilation,
|
||||||
@ -120,230 +123,235 @@ void slow_conv_2D(
|
|||||||
encoder.set_input_array(wt);
|
encoder.set_input_array(wt);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
encoder.dispatch([st_wt_ptr = wt.data<T>(),
|
encoder.dispatch(
|
||||||
st_in_ptr = in.data<T>(),
|
[st_wt_ptr = wt.data<T>(),
|
||||||
st_out_ptr = out.data<T>(),
|
st_in_ptr = in.data<T>(),
|
||||||
|
st_out_ptr = out.data<T>(),
|
||||||
|
|
||||||
N = in.shape(
|
N = in.shape(0), // Batch size, should be the same as out.shape(0)
|
||||||
0), // Batch size, should be the same as out.shape(0)
|
iH = 1 + in_dilation[0] * (in.shape(1) - 1), // Input spatial dim
|
||||||
iH = 1 +
|
iW = 1 + in_dilation[1] * (in.shape(2) - 1), // Input spatial dim
|
||||||
in_dilation[0] * (in.shape(1) - 1), // Input spatial dim
|
C = in.shape(3), // In channels
|
||||||
iW = 1 +
|
oH = out.shape(1), // Output spatial dim
|
||||||
in_dilation[1] * (in.shape(2) - 1), // Input spatial dim
|
oW = out.shape(2), // Output spatial dim
|
||||||
C = in.shape(3), // In channels
|
O = wt.shape(0), // Out channels
|
||||||
oH = out.shape(1), // Output spatial dim
|
wH = wt.shape(1), // Weight spatial dim
|
||||||
oW = out.shape(2), // Output spatial dim
|
wW = wt.shape(2), // Weight spatial dim
|
||||||
O = wt.shape(0), // Out channels
|
|
||||||
wH = wt.shape(1), // Weight spatial dim
|
|
||||||
wW = wt.shape(2), // Weight spatial dim
|
|
||||||
|
|
||||||
groups = in.shape(3) / wt.shape(3),
|
groups = in.shape(3) / wt.shape(3),
|
||||||
C_per_group = wt.shape(3),
|
C_per_group = wt.shape(3),
|
||||||
|
|
||||||
in_stride_N = in.strides()[0],
|
in_stride_N = in.strides()[0],
|
||||||
in_stride_H = in.strides()[1],
|
in_stride_H = in.strides()[1],
|
||||||
in_stride_W = in.strides()[2],
|
in_stride_W = in.strides()[2],
|
||||||
in_stride_C = in.strides()[3],
|
in_stride_C = in.strides()[3],
|
||||||
|
|
||||||
wt_stride_O = wt.strides()[0],
|
wt_stride_O = wt.strides()[0],
|
||||||
wt_stride_H = wt.strides()[1],
|
wt_stride_H = wt.strides()[1],
|
||||||
wt_stride_W = wt.strides()[2],
|
wt_stride_W = wt.strides()[2],
|
||||||
wt_stride_C = wt.strides()[3],
|
wt_stride_C = wt.strides()[3],
|
||||||
|
|
||||||
out_stride_N = out.strides()[0],
|
out_stride_N = out.strides()[0],
|
||||||
out_stride_H = out.strides()[1],
|
out_stride_H = out.strides()[1],
|
||||||
out_stride_W = out.strides()[2],
|
out_stride_W = out.strides()[2],
|
||||||
out_stride_O = out.strides()[3],
|
out_stride_O = out.strides()[3],
|
||||||
|
|
||||||
padding,
|
padding_lo,
|
||||||
wt_strides,
|
padding_hi,
|
||||||
wt_dilation,
|
wt_strides,
|
||||||
in_dilation,
|
wt_dilation,
|
||||||
flip]() mutable {
|
in_dilation,
|
||||||
bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1;
|
flip]() mutable {
|
||||||
|
bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1;
|
||||||
|
|
||||||
const int O_per_group = O / groups;
|
const int O_per_group = O / groups;
|
||||||
auto pt_conv_no_checks = [&](const T* in_ptr,
|
auto pt_conv_no_checks =
|
||||||
const T* wt_ptr,
|
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
|
||||||
T* out_ptr,
|
out_ptr += oh * out_stride_H + ow * out_stride_W;
|
||||||
int oh,
|
int ih_base = oh * wt_strides[0] - padding_lo[0];
|
||||||
int ow) {
|
int iw_base = ow * wt_strides[1] - padding_lo[1];
|
||||||
out_ptr += oh * out_stride_H + ow * out_stride_W;
|
|
||||||
int ih_base = oh * wt_strides[0] - padding[0];
|
|
||||||
int iw_base = ow * wt_strides[1] - padding[1];
|
|
||||||
|
|
||||||
for (int g = 0; g < groups; ++g) {
|
for (int g = 0; g < groups; ++g) {
|
||||||
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
|
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
|
||||||
float r = 0.;
|
float r = 0.;
|
||||||
|
|
||||||
for (int wh = 0; wh < wH; ++wh) {
|
for (int wh = 0; wh < wH; ++wh) {
|
||||||
for (int ww = 0; ww < wW; ++ww) {
|
for (int ww = 0; ww < wW; ++ww) {
|
||||||
int wh_flip = flip ? wH - wh - 1 : wh;
|
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||||
int ww_flip = flip ? wW - ww - 1 : ww;
|
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||||
int ih = ih_base + wh_flip * wt_dilation[0];
|
int ih = ih_base + wh_flip * wt_dilation[0];
|
||||||
int iw = iw_base + ww_flip * wt_dilation[1];
|
int iw = iw_base + ww_flip * wt_dilation[1];
|
||||||
|
|
||||||
const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
const T* wt_ptr_pt =
|
||||||
const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W;
|
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
||||||
|
const T* in_ptr_pt =
|
||||||
|
in_ptr + ih * in_stride_H + iw * in_stride_W;
|
||||||
|
|
||||||
for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) {
|
for (int c = g * C_per_group; c < (g + 1) * C_per_group;
|
||||||
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
|
++c) {
|
||||||
static_cast<float>(
|
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
|
||||||
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
|
static_cast<float>(
|
||||||
} // c
|
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
|
||||||
} // ww
|
} // c
|
||||||
} // wh
|
} // ww
|
||||||
|
} // wh
|
||||||
|
|
||||||
out_ptr[0] = static_cast<T>(r);
|
out_ptr[0] = static_cast<T>(r);
|
||||||
out_ptr += out_stride_O;
|
out_ptr += out_stride_O;
|
||||||
wt_ptr += wt_stride_O;
|
wt_ptr += wt_stride_O;
|
||||||
} // o
|
} // o
|
||||||
} // g
|
} // g
|
||||||
};
|
};
|
||||||
|
|
||||||
int jump_h = flip ? -wt_dilation[0] : wt_dilation[0];
|
int jump_h = flip ? -wt_dilation[0] : wt_dilation[0];
|
||||||
int jump_w = flip ? -wt_dilation[1] : wt_dilation[1];
|
int jump_w = flip ? -wt_dilation[1] : wt_dilation[1];
|
||||||
|
|
||||||
int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0);
|
int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0);
|
||||||
int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0);
|
int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0);
|
||||||
|
|
||||||
int f_wgt_jump_h =
|
int f_wgt_jump_h =
|
||||||
std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0];
|
std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0];
|
||||||
int f_wgt_jump_w =
|
int f_wgt_jump_w =
|
||||||
std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];
|
std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];
|
||||||
|
|
||||||
int f_out_jump_h = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];
|
int f_out_jump_h =
|
||||||
int f_out_jump_w = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];
|
std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];
|
||||||
|
int f_out_jump_w =
|
||||||
|
std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];
|
||||||
|
|
||||||
std::vector<int> base_h(f_out_jump_h);
|
std::vector<int> base_h(f_out_jump_h);
|
||||||
std::vector<int> base_w(f_out_jump_w);
|
std::vector<int> base_w(f_out_jump_w);
|
||||||
|
|
||||||
for (int i = 0; i < f_out_jump_h; ++i) {
|
for (int i = 0; i < f_out_jump_h; ++i) {
|
||||||
int ih_loop = i * wt_strides[0] - padding[0] + init_h;
|
int ih_loop = i * wt_strides[0] - padding_lo[0] + init_h;
|
||||||
|
|
||||||
int wh_base = 0;
|
int wh_base = 0;
|
||||||
while (wh_base < wH && ih_loop % in_dilation[0] != 0) {
|
while (wh_base < wH && ih_loop % in_dilation[0] != 0) {
|
||||||
wh_base++;
|
wh_base++;
|
||||||
ih_loop += jump_h;
|
ih_loop += jump_h;
|
||||||
}
|
}
|
||||||
|
|
||||||
base_h[i] = wh_base;
|
base_h[i] = wh_base;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int j = 0; j < f_out_jump_w; ++j) {
|
for (int j = 0; j < f_out_jump_w; ++j) {
|
||||||
int iw_loop = j * wt_strides[1] - padding[1] + init_w;
|
int iw_loop = j * wt_strides[1] - padding_lo[1] + init_w;
|
||||||
|
|
||||||
int ww_base = 0;
|
int ww_base = 0;
|
||||||
while (ww_base < wW && iw_loop % in_dilation[1] != 0) {
|
while (ww_base < wW && iw_loop % in_dilation[1] != 0) {
|
||||||
ww_base++;
|
ww_base++;
|
||||||
iw_loop += jump_w;
|
iw_loop += jump_w;
|
||||||
}
|
}
|
||||||
|
|
||||||
base_w[j] = ww_base;
|
base_w[j] = ww_base;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto pt_conv_all_checks =
|
auto pt_conv_all_checks =
|
||||||
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
|
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
|
||||||
out_ptr += oh * out_stride_H + ow * out_stride_W;
|
out_ptr += oh * out_stride_H + ow * out_stride_W;
|
||||||
|
|
||||||
int ih_base = oh * wt_strides[0] - padding[0];
|
int ih_base = oh * wt_strides[0] - padding_lo[0];
|
||||||
int iw_base = ow * wt_strides[1] - padding[1];
|
int iw_base = ow * wt_strides[1] - padding_lo[1];
|
||||||
|
|
||||||
int wh_base = base_h[oh % f_out_jump_h];
|
int wh_base = base_h[oh % f_out_jump_h];
|
||||||
int ww_base = base_w[ow % f_out_jump_w];
|
int ww_base = base_w[ow % f_out_jump_w];
|
||||||
|
|
||||||
for (int g = 0; g < groups; ++g) {
|
for (int g = 0; g < groups; ++g) {
|
||||||
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
|
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
|
||||||
float r = 0.;
|
float r = 0.;
|
||||||
|
|
||||||
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
|
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
|
||||||
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
|
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
|
||||||
int wh_flip = flip ? wH - wh - 1 : wh;
|
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||||
int ww_flip = flip ? wW - ww - 1 : ww;
|
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||||
int ih = ih_base + wh_flip * wt_dilation[0];
|
int ih = ih_base + wh_flip * wt_dilation[0];
|
||||||
int iw = iw_base + ww_flip * wt_dilation[1];
|
int iw = iw_base + ww_flip * wt_dilation[1];
|
||||||
|
|
||||||
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
|
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
|
||||||
const T* wt_ptr_pt =
|
const T* wt_ptr_pt =
|
||||||
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
||||||
|
|
||||||
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
|
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
|
||||||
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
|
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
|
||||||
|
|
||||||
const T* in_ptr_pt =
|
const T* in_ptr_pt = in_ptr + ih_dil * in_stride_H +
|
||||||
in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W;
|
iw_dil * in_stride_W;
|
||||||
|
|
||||||
for (int c = g * C_per_group; c < (g + 1) * C_per_group;
|
for (int c = g * C_per_group; c < (g + 1) * C_per_group;
|
||||||
++c) {
|
++c) {
|
||||||
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
|
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
|
||||||
static_cast<float>(
|
static_cast<float>(
|
||||||
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
|
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
|
||||||
} // c
|
} // c
|
||||||
|
|
||||||
} // ih, iw check
|
} // ih, iw check
|
||||||
} // ww
|
} // ww
|
||||||
} // wh
|
} // wh
|
||||||
|
|
||||||
out_ptr[0] = static_cast<T>(r);
|
out_ptr[0] = static_cast<T>(r);
|
||||||
out_ptr += out_stride_O;
|
out_ptr += out_stride_O;
|
||||||
wt_ptr += wt_stride_O;
|
wt_ptr += wt_stride_O;
|
||||||
} // o
|
} // o
|
||||||
} // g
|
} // g
|
||||||
};
|
};
|
||||||
|
|
||||||
int oH_border_0 = 0;
|
int oH_border_0 = 0;
|
||||||
int oH_border_1 =
|
int oH_border_1 = is_idil_one
|
||||||
is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oH;
|
? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0])
|
||||||
int oH_border_2 = std::max(
|
: oH;
|
||||||
oH_border_1, (iH + padding[0] - wH * wt_dilation[0]) / wt_strides[0]);
|
int oH_border_2 = std::max(
|
||||||
int oH_border_3 = oH;
|
oH_border_1,
|
||||||
|
(iH + padding_lo[0] - wH * wt_dilation[0]) / wt_strides[0]);
|
||||||
|
int oH_border_3 = oH;
|
||||||
|
|
||||||
int oW_border_0 = 0;
|
int oW_border_0 = 0;
|
||||||
int oW_border_1 =
|
int oW_border_1 = is_idil_one
|
||||||
is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oW;
|
? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1])
|
||||||
int oW_border_2 = std::max(
|
: oW;
|
||||||
oW_border_1, (iW + padding[1] - wW * wt_dilation[1]) / wt_strides[1]);
|
int oW_border_2 = std::max(
|
||||||
int oW_border_3 = oW;
|
oW_border_1,
|
||||||
|
(iW + padding_lo[1] - wW * wt_dilation[1]) / wt_strides[1]);
|
||||||
|
int oW_border_3 = oW;
|
||||||
|
|
||||||
for (int n = 0; n < N; ++n) {
|
for (int n = 0; n < N; ++n) {
|
||||||
// Case 1: oh might put us out of bounds
|
// Case 1: oh might put us out of bounds
|
||||||
for (int oh = oH_border_0; oh < oH_border_1; ++oh) {
|
for (int oh = oH_border_0; oh < oH_border_1; ++oh) {
|
||||||
for (int ow = 0; ow < oW; ++ow) {
|
for (int ow = 0; ow < oW; ++ow) {
|
||||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
||||||
} // ow
|
} // ow
|
||||||
} // oh
|
} // oh
|
||||||
|
|
||||||
// Case 2: oh in bounds
|
// Case 2: oh in bounds
|
||||||
for (int oh = oH_border_1; oh < oH_border_2; ++oh) {
|
for (int oh = oH_border_1; oh < oH_border_2; ++oh) {
|
||||||
// Case a: ow might put us out of bounds
|
// Case a: ow might put us out of bounds
|
||||||
for (int ow = oW_border_0; ow < oW_border_1; ++ow) {
|
for (int ow = oW_border_0; ow < oW_border_1; ++ow) {
|
||||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
||||||
} // ow
|
} // ow
|
||||||
|
|
||||||
// Case b: ow in bounds
|
// Case b: ow in bounds
|
||||||
for (int ow = oW_border_1; ow < oW_border_2; ++ow) {
|
for (int ow = oW_border_1; ow < oW_border_2; ++ow) {
|
||||||
pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
||||||
} // ow
|
} // ow
|
||||||
|
|
||||||
// Case c: ow might put us out of bounds
|
// Case c: ow might put us out of bounds
|
||||||
for (int ow = oW_border_2; ow < oW_border_3; ++ow) {
|
for (int ow = oW_border_2; ow < oW_border_3; ++ow) {
|
||||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
||||||
} // ow
|
} // ow
|
||||||
|
|
||||||
} // oh
|
} // oh
|
||||||
|
|
||||||
// Case 3: oh might put us out of bounds
|
// Case 3: oh might put us out of bounds
|
||||||
for (int oh = oH_border_2; oh < oH_border_3; ++oh) {
|
for (int oh = oH_border_2; oh < oH_border_3; ++oh) {
|
||||||
for (int ow = 0; ow < oW; ++ow) {
|
for (int ow = 0; ow < oW; ++ow) {
|
||||||
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
|
||||||
} // ow
|
} // ow
|
||||||
} // oh
|
} // oh
|
||||||
|
|
||||||
st_in_ptr += in_stride_N;
|
st_in_ptr += in_stride_N;
|
||||||
st_out_ptr += out_stride_N;
|
st_out_ptr += out_stride_N;
|
||||||
|
|
||||||
} // n
|
} // n
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -351,7 +359,8 @@ void slow_conv_3D(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding,
|
const std::vector<int>& padding_lo,
|
||||||
|
const std::vector<int>& padding_hi,
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const std::vector<int>& in_dilation,
|
const std::vector<int>& in_dilation,
|
||||||
@ -400,7 +409,8 @@ void slow_conv_3D(
|
|||||||
out_stride_H = out.strides()[2],
|
out_stride_H = out.strides()[2],
|
||||||
out_stride_W = out.strides()[3],
|
out_stride_W = out.strides()[3],
|
||||||
out_stride_O = out.strides()[4],
|
out_stride_O = out.strides()[4],
|
||||||
padding,
|
padding_lo,
|
||||||
|
padding_hi,
|
||||||
wt_strides,
|
wt_strides,
|
||||||
wt_dilation,
|
wt_dilation,
|
||||||
in_dilation,
|
in_dilation,
|
||||||
@ -415,9 +425,9 @@ void slow_conv_3D(
|
|||||||
int oh,
|
int oh,
|
||||||
int ow) {
|
int ow) {
|
||||||
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
|
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
|
||||||
int id_base = od * wt_strides[0] - padding[0];
|
int id_base = od * wt_strides[0] - padding_lo[0];
|
||||||
int ih_base = oh * wt_strides[1] - padding[1];
|
int ih_base = oh * wt_strides[1] - padding_lo[1];
|
||||||
int iw_base = ow * wt_strides[2] - padding[2];
|
int iw_base = ow * wt_strides[2] - padding_lo[2];
|
||||||
|
|
||||||
for (int o = 0; o < O; ++o) {
|
for (int o = 0; o < O; ++o) {
|
||||||
float r = 0.;
|
float r = 0.;
|
||||||
@ -478,7 +488,7 @@ void slow_conv_3D(
|
|||||||
std::vector<int> base_w(f_out_jump_w);
|
std::vector<int> base_w(f_out_jump_w);
|
||||||
|
|
||||||
for (int i = 0; i < f_out_jump_d; ++i) {
|
for (int i = 0; i < f_out_jump_d; ++i) {
|
||||||
int id_loop = i * wt_strides[0] - padding[0] + init_d;
|
int id_loop = i * wt_strides[0] - padding_lo[0] + init_d;
|
||||||
|
|
||||||
int wd_base = 0;
|
int wd_base = 0;
|
||||||
while (wd_base < wD && id_loop % in_dilation[0] != 0) {
|
while (wd_base < wD && id_loop % in_dilation[0] != 0) {
|
||||||
@ -490,7 +500,7 @@ void slow_conv_3D(
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < f_out_jump_h; ++i) {
|
for (int i = 0; i < f_out_jump_h; ++i) {
|
||||||
int ih_loop = i * wt_strides[1] - padding[1] + init_h;
|
int ih_loop = i * wt_strides[1] - padding_lo[1] + init_h;
|
||||||
|
|
||||||
int wh_base = 0;
|
int wh_base = 0;
|
||||||
while (wh_base < wH && ih_loop % in_dilation[1] != 0) {
|
while (wh_base < wH && ih_loop % in_dilation[1] != 0) {
|
||||||
@ -502,7 +512,7 @@ void slow_conv_3D(
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (int j = 0; j < f_out_jump_w; ++j) {
|
for (int j = 0; j < f_out_jump_w; ++j) {
|
||||||
int iw_loop = j * wt_strides[2] - padding[2] + init_w;
|
int iw_loop = j * wt_strides[2] - padding_lo[2] + init_w;
|
||||||
|
|
||||||
int ww_base = 0;
|
int ww_base = 0;
|
||||||
while (ww_base < wW && iw_loop % in_dilation[2] != 0) {
|
while (ww_base < wW && iw_loop % in_dilation[2] != 0) {
|
||||||
@ -521,9 +531,9 @@ void slow_conv_3D(
|
|||||||
int ow) {
|
int ow) {
|
||||||
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
|
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
|
||||||
|
|
||||||
int id_base = od * wt_strides[0] - padding[0];
|
int id_base = od * wt_strides[0] - padding_lo[0];
|
||||||
int ih_base = oh * wt_strides[1] - padding[1];
|
int ih_base = oh * wt_strides[1] - padding_lo[1];
|
||||||
int iw_base = ow * wt_strides[2] - padding[2];
|
int iw_base = ow * wt_strides[2] - padding_lo[2];
|
||||||
|
|
||||||
int wd_base = base_d[od % f_out_jump_d];
|
int wd_base = base_d[od % f_out_jump_d];
|
||||||
int wh_base = base_h[oh % f_out_jump_h];
|
int wh_base = base_h[oh % f_out_jump_h];
|
||||||
@ -573,24 +583,30 @@ void slow_conv_3D(
|
|||||||
};
|
};
|
||||||
|
|
||||||
int oD_border_0 = 0;
|
int oD_border_0 = 0;
|
||||||
int oD_border_1 =
|
int oD_border_1 = is_idil_one
|
||||||
is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oD;
|
? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0])
|
||||||
|
: oD;
|
||||||
int oD_border_2 = std::max(
|
int oD_border_2 = std::max(
|
||||||
oD_border_1, (iD + padding[0] - wD * wt_dilation[0]) / wt_strides[0]);
|
oD_border_1,
|
||||||
|
(iD + padding_lo[0] - wD * wt_dilation[0]) / wt_strides[0]);
|
||||||
int oD_border_3 = oD;
|
int oD_border_3 = oD;
|
||||||
|
|
||||||
int oH_border_0 = 0;
|
int oH_border_0 = 0;
|
||||||
int oH_border_1 =
|
int oH_border_1 = is_idil_one
|
||||||
is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oH;
|
? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1])
|
||||||
|
: oH;
|
||||||
int oH_border_2 = std::max(
|
int oH_border_2 = std::max(
|
||||||
oH_border_1, (iH + padding[1] - wH * wt_dilation[1]) / wt_strides[1]);
|
oH_border_1,
|
||||||
|
(iH + padding_lo[1] - wH * wt_dilation[1]) / wt_strides[1]);
|
||||||
int oH_border_3 = oH;
|
int oH_border_3 = oH;
|
||||||
|
|
||||||
int oW_border_0 = 0;
|
int oW_border_0 = 0;
|
||||||
int oW_border_1 =
|
int oW_border_1 = is_idil_one
|
||||||
is_idil_one ? ((padding[2] + wt_strides[2] - 1) / wt_strides[2]) : oW;
|
? ((padding_lo[2] + wt_strides[2] - 1) / wt_strides[2])
|
||||||
|
: oW;
|
||||||
int oW_border_2 = std::max(
|
int oW_border_2 = std::max(
|
||||||
oW_border_1, (iW + padding[2] - wW * wt_dilation[2]) / wt_strides[2]);
|
oW_border_1,
|
||||||
|
(iW + padding_lo[2] - wW * wt_dilation[2]) / wt_strides[2]);
|
||||||
int oW_border_3 = oW;
|
int oW_border_3 = oW;
|
||||||
|
|
||||||
for (int n = 0; n < N; ++n) {
|
for (int n = 0; n < N; ++n) {
|
||||||
@ -658,7 +674,8 @@ void dispatch_slow_conv_1D(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding,
|
const std::vector<int>& padding_lo,
|
||||||
|
const std::vector<int>& padding_hi,
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const std::vector<int>& in_dilation,
|
const std::vector<int>& in_dilation,
|
||||||
@ -669,7 +686,8 @@ void dispatch_slow_conv_1D(
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding,
|
padding_lo,
|
||||||
|
padding_hi,
|
||||||
wt_strides,
|
wt_strides,
|
||||||
wt_dilation,
|
wt_dilation,
|
||||||
in_dilation,
|
in_dilation,
|
||||||
@ -680,7 +698,8 @@ void dispatch_slow_conv_1D(
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding,
|
padding_lo,
|
||||||
|
padding_hi,
|
||||||
wt_strides,
|
wt_strides,
|
||||||
wt_dilation,
|
wt_dilation,
|
||||||
in_dilation,
|
in_dilation,
|
||||||
@ -691,7 +710,8 @@ void dispatch_slow_conv_1D(
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding,
|
padding_lo,
|
||||||
|
padding_hi,
|
||||||
wt_strides,
|
wt_strides,
|
||||||
wt_dilation,
|
wt_dilation,
|
||||||
in_dilation,
|
in_dilation,
|
||||||
@ -707,7 +727,8 @@ void dispatch_slow_conv_2D(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding,
|
const std::vector<int>& padding_lo,
|
||||||
|
const std::vector<int>& padding_hi,
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const std::vector<int>& in_dilation,
|
const std::vector<int>& in_dilation,
|
||||||
@ -718,7 +739,8 @@ void dispatch_slow_conv_2D(
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding,
|
padding_lo,
|
||||||
|
padding_hi,
|
||||||
wt_strides,
|
wt_strides,
|
||||||
wt_dilation,
|
wt_dilation,
|
||||||
in_dilation,
|
in_dilation,
|
||||||
@ -729,7 +751,8 @@ void dispatch_slow_conv_2D(
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding,
|
padding_lo,
|
||||||
|
padding_hi,
|
||||||
wt_strides,
|
wt_strides,
|
||||||
wt_dilation,
|
wt_dilation,
|
||||||
in_dilation,
|
in_dilation,
|
||||||
@ -740,7 +763,8 @@ void dispatch_slow_conv_2D(
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding,
|
padding_lo,
|
||||||
|
padding_hi,
|
||||||
wt_strides,
|
wt_strides,
|
||||||
wt_dilation,
|
wt_dilation,
|
||||||
in_dilation,
|
in_dilation,
|
||||||
@ -756,7 +780,8 @@ void dispatch_slow_conv_3D(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding,
|
const std::vector<int>& padding_lo,
|
||||||
|
const std::vector<int>& padding_hi,
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const std::vector<int>& in_dilation,
|
const std::vector<int>& in_dilation,
|
||||||
@ -767,7 +792,8 @@ void dispatch_slow_conv_3D(
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding,
|
padding_lo,
|
||||||
|
padding_hi,
|
||||||
wt_strides,
|
wt_strides,
|
||||||
wt_dilation,
|
wt_dilation,
|
||||||
in_dilation,
|
in_dilation,
|
||||||
@ -778,7 +804,8 @@ void dispatch_slow_conv_3D(
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding,
|
padding_lo,
|
||||||
|
padding_hi,
|
||||||
wt_strides,
|
wt_strides,
|
||||||
wt_dilation,
|
wt_dilation,
|
||||||
in_dilation,
|
in_dilation,
|
||||||
@ -789,7 +816,8 @@ void dispatch_slow_conv_3D(
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding,
|
padding_lo,
|
||||||
|
padding_hi,
|
||||||
wt_strides,
|
wt_strides,
|
||||||
wt_dilation,
|
wt_dilation,
|
||||||
in_dilation,
|
in_dilation,
|
||||||
@ -829,7 +857,8 @@ void explicit_gemm_conv_1D_cpu(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding,
|
const std::vector<int>& padding_lo,
|
||||||
|
const std::vector<int>& padding_hi,
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
Stream stream) {
|
Stream stream) {
|
||||||
@ -848,7 +877,7 @@ void explicit_gemm_conv_1D_cpu(
|
|||||||
auto& encoder = cpu::get_command_encoder(stream);
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
|
|
||||||
// Pad input
|
// Pad input
|
||||||
Shape padded_shape = {N, iH + 2 * padding[0], C};
|
Shape padded_shape = {N, iH + padding_lo[0] + padding_hi[0], C};
|
||||||
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
||||||
|
|
||||||
// Fill with zeros
|
// Fill with zeros
|
||||||
@ -857,7 +886,7 @@ void explicit_gemm_conv_1D_cpu(
|
|||||||
copy(temps.back(), in_padded, CopyType::Scalar, stream);
|
copy(temps.back(), in_padded, CopyType::Scalar, stream);
|
||||||
|
|
||||||
// Pick input slice from padded
|
// Pick input slice from padded
|
||||||
size_t data_offset = padding[0] * in_padded.strides()[1];
|
size_t data_offset = padding_lo[0] * in_padded.strides()[1];
|
||||||
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||||
in_padded_slice.copy_shared_buffer(
|
in_padded_slice.copy_shared_buffer(
|
||||||
in_padded,
|
in_padded,
|
||||||
@ -971,7 +1000,8 @@ void explicit_gemm_conv_2D_cpu(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding,
|
const std::vector<int>& padding_lo,
|
||||||
|
const std::vector<int>& padding_hi,
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
Stream stream) {
|
Stream stream) {
|
||||||
@ -989,7 +1019,11 @@ void explicit_gemm_conv_2D_cpu(
|
|||||||
auto& encoder = cpu::get_command_encoder(stream);
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
|
|
||||||
// Pad input
|
// Pad input
|
||||||
Shape padded_shape = {N, iH + 2 * padding[0], iW + 2 * padding[1], C};
|
Shape padded_shape = {
|
||||||
|
N,
|
||||||
|
iH + padding_lo[0] + padding_hi[0],
|
||||||
|
iW + padding_lo[1] + padding_hi[1],
|
||||||
|
C};
|
||||||
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
||||||
|
|
||||||
// Fill with zeros
|
// Fill with zeros
|
||||||
@ -998,8 +1032,8 @@ void explicit_gemm_conv_2D_cpu(
|
|||||||
copy(temps.back(), in_padded, CopyType::Scalar, stream);
|
copy(temps.back(), in_padded, CopyType::Scalar, stream);
|
||||||
|
|
||||||
// Pick input slice from padded
|
// Pick input slice from padded
|
||||||
size_t data_offset =
|
size_t data_offset = padding_lo[0] * in_padded.strides()[1] +
|
||||||
padding[0] * in_padded.strides()[1] + padding[1] * in_padded.strides()[2];
|
padding_lo[1] * in_padded.strides()[2];
|
||||||
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||||
in_padded_slice.copy_shared_buffer(
|
in_padded_slice.copy_shared_buffer(
|
||||||
in_padded,
|
in_padded,
|
||||||
@ -1091,7 +1125,8 @@ void explicit_gemm_conv_ND_cpu(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding,
|
const std::vector<int>& padding_lo,
|
||||||
|
const std::vector<int>& padding_hi,
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const bool flip,
|
const bool flip,
|
||||||
@ -1114,7 +1149,7 @@ void explicit_gemm_conv_ND_cpu(
|
|||||||
Shape padded_shape(in.shape().size());
|
Shape padded_shape(in.shape().size());
|
||||||
padded_shape.front() = N;
|
padded_shape.front() = N;
|
||||||
for (size_t i = 0; i < iDim.size(); i++) {
|
for (size_t i = 0; i < iDim.size(); i++) {
|
||||||
padded_shape[i + 1] = iDim[i] + 2 * padding[i];
|
padded_shape[i + 1] = iDim[i] + padding_lo[i] + padding_hi[i];
|
||||||
}
|
}
|
||||||
padded_shape.back() = C;
|
padded_shape.back() = C;
|
||||||
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
array in_padded(padded_shape, conv_dtype, nullptr, {});
|
||||||
@ -1125,9 +1160,10 @@ void explicit_gemm_conv_ND_cpu(
|
|||||||
|
|
||||||
// Pick input slice from padded
|
// Pick input slice from padded
|
||||||
size_t data_offset = 0;
|
size_t data_offset = 0;
|
||||||
for (size_t i = 0; i < padding.size(); i++) {
|
for (size_t i = 0; i < padding_lo.size(); i++) {
|
||||||
data_offset += padding[i] * in_padded.strides()[i + 1];
|
data_offset += padding_lo[i] * in_padded.strides()[i + 1];
|
||||||
}
|
}
|
||||||
|
|
||||||
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
||||||
in_padded_slice.copy_shared_buffer(
|
in_padded_slice.copy_shared_buffer(
|
||||||
in_padded,
|
in_padded,
|
||||||
@ -1261,7 +1297,8 @@ void conv_1D_cpu(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding,
|
const std::vector<int>& padding_lo,
|
||||||
|
const std::vector<int>& padding_hi,
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const std::vector<int>& in_dilation,
|
const std::vector<int>& in_dilation,
|
||||||
@ -1270,22 +1307,40 @@ void conv_1D_cpu(
|
|||||||
const int groups = in.shape().back() / wt.shape().back();
|
const int groups = in.shape().back() / wt.shape().back();
|
||||||
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) {
|
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) {
|
||||||
return explicit_gemm_conv_1D_cpu(
|
return explicit_gemm_conv_1D_cpu(
|
||||||
in, wt, out, padding, wt_strides, wt_dilation, stream);
|
in, wt, out, padding_lo, padding_hi, wt_strides, wt_dilation, stream);
|
||||||
}
|
}
|
||||||
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && groups == 1) {
|
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && groups == 1) {
|
||||||
return explicit_gemm_conv_ND_cpu(
|
return explicit_gemm_conv_ND_cpu(
|
||||||
in, wt, out, padding, wt_strides, wt_dilation, flip, stream);
|
in,
|
||||||
|
wt,
|
||||||
|
out,
|
||||||
|
padding_lo,
|
||||||
|
padding_hi,
|
||||||
|
wt_strides,
|
||||||
|
wt_dilation,
|
||||||
|
flip,
|
||||||
|
stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
return dispatch_slow_conv_1D(
|
return dispatch_slow_conv_1D(
|
||||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream);
|
in,
|
||||||
|
wt,
|
||||||
|
out,
|
||||||
|
padding_lo,
|
||||||
|
padding_hi,
|
||||||
|
wt_strides,
|
||||||
|
wt_dilation,
|
||||||
|
in_dilation,
|
||||||
|
flip,
|
||||||
|
stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
void conv_2D_cpu(
|
void conv_2D_cpu(
|
||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding,
|
const std::vector<int>& padding_lo,
|
||||||
|
const std::vector<int>& padding_hi,
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const std::vector<int>& in_dilation,
|
const std::vector<int>& in_dilation,
|
||||||
@ -1295,18 +1350,35 @@ void conv_2D_cpu(
|
|||||||
if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && in_dilation[0] == 1 &&
|
if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && in_dilation[0] == 1 &&
|
||||||
in_dilation[1] == 1 && groups == 1) {
|
in_dilation[1] == 1 && groups == 1) {
|
||||||
return explicit_gemm_conv_ND_cpu(
|
return explicit_gemm_conv_ND_cpu(
|
||||||
in, wt, out, padding, wt_strides, wt_dilation, flip, stream);
|
in,
|
||||||
|
wt,
|
||||||
|
out,
|
||||||
|
padding_lo,
|
||||||
|
padding_hi,
|
||||||
|
wt_strides,
|
||||||
|
wt_dilation,
|
||||||
|
flip,
|
||||||
|
stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
return dispatch_slow_conv_2D(
|
return dispatch_slow_conv_2D(
|
||||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream);
|
in,
|
||||||
|
wt,
|
||||||
|
out,
|
||||||
|
padding_lo,
|
||||||
|
padding_hi,
|
||||||
|
wt_strides,
|
||||||
|
wt_dilation,
|
||||||
|
in_dilation,
|
||||||
|
flip,
|
||||||
|
stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
void conv_3D_cpu(
|
void conv_3D_cpu(
|
||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding,
|
const std::vector<int>& padding_lo,
|
||||||
|
const std::vector<int>& padding_hi,
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const std::vector<int>& in_dilation,
|
const std::vector<int>& in_dilation,
|
||||||
@ -1317,11 +1389,28 @@ void conv_3D_cpu(
|
|||||||
in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1 &&
|
in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1 &&
|
||||||
groups == 1) {
|
groups == 1) {
|
||||||
return explicit_gemm_conv_ND_cpu(
|
return explicit_gemm_conv_ND_cpu(
|
||||||
in, wt, out, padding, wt_strides, wt_dilation, flip, stream);
|
in,
|
||||||
|
wt,
|
||||||
|
out,
|
||||||
|
padding_lo,
|
||||||
|
padding_hi,
|
||||||
|
wt_strides,
|
||||||
|
wt_dilation,
|
||||||
|
flip,
|
||||||
|
stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
return dispatch_slow_conv_3D(
|
return dispatch_slow_conv_3D(
|
||||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream);
|
in,
|
||||||
|
wt,
|
||||||
|
out,
|
||||||
|
padding_lo,
|
||||||
|
padding_hi,
|
||||||
|
wt_strides,
|
||||||
|
wt_dilation,
|
||||||
|
in_dilation,
|
||||||
|
flip,
|
||||||
|
stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@ -1338,7 +1427,8 @@ void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding_,
|
padding_lo_,
|
||||||
|
padding_hi_,
|
||||||
kernel_strides_,
|
kernel_strides_,
|
||||||
kernel_dilation_,
|
kernel_dilation_,
|
||||||
input_dilation_,
|
input_dilation_,
|
||||||
@ -1351,7 +1441,8 @@ void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding_,
|
padding_lo_,
|
||||||
|
padding_hi_,
|
||||||
kernel_strides_,
|
kernel_strides_,
|
||||||
kernel_dilation_,
|
kernel_dilation_,
|
||||||
input_dilation_,
|
input_dilation_,
|
||||||
@ -1364,7 +1455,8 @@ void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding_,
|
padding_lo_,
|
||||||
|
padding_hi_,
|
||||||
kernel_strides_,
|
kernel_strides_,
|
||||||
kernel_dilation_,
|
kernel_dilation_,
|
||||||
input_dilation_,
|
input_dilation_,
|
||||||
|
@ -952,7 +952,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding_,
|
padding_lo_,
|
||||||
kernel_strides_,
|
kernel_strides_,
|
||||||
kernel_dilation_,
|
kernel_dilation_,
|
||||||
input_dilation_,
|
input_dilation_,
|
||||||
@ -967,7 +967,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding_,
|
padding_lo_,
|
||||||
kernel_strides_,
|
kernel_strides_,
|
||||||
kernel_dilation_,
|
kernel_dilation_,
|
||||||
input_dilation_,
|
input_dilation_,
|
||||||
@ -983,7 +983,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
in,
|
in,
|
||||||
wt,
|
wt,
|
||||||
out,
|
out,
|
||||||
padding_,
|
padding_lo_,
|
||||||
kernel_strides_,
|
kernel_strides_,
|
||||||
kernel_dilation_,
|
kernel_dilation_,
|
||||||
input_dilation_,
|
input_dilation_,
|
||||||
|
@ -3974,6 +3974,7 @@ array conv_general(
|
|||||||
to_stream(s),
|
to_stream(s),
|
||||||
stride,
|
stride,
|
||||||
padding_lo,
|
padding_lo,
|
||||||
|
padding_hi,
|
||||||
kernel_dilation,
|
kernel_dilation,
|
||||||
input_dilation,
|
input_dilation,
|
||||||
groups,
|
groups,
|
||||||
|
@ -1055,7 +1055,8 @@ array conv_weight_backward_patches(
|
|||||||
const array& wt,
|
const array& wt,
|
||||||
const array& cotan,
|
const array& cotan,
|
||||||
const std::vector<int>& kernel_strides,
|
const std::vector<int>& kernel_strides,
|
||||||
const std::vector<int>& padding,
|
const std::vector<int>& padding_lo,
|
||||||
|
const std::vector<int>& padding_hi,
|
||||||
StreamOrDevice s) {
|
StreamOrDevice s) {
|
||||||
// Resolve Padded input shapes and strides
|
// Resolve Padded input shapes and strides
|
||||||
Shape padding_starts(in.ndim(), 0);
|
Shape padding_starts(in.ndim(), 0);
|
||||||
@ -1064,9 +1065,9 @@ array conv_weight_backward_patches(
|
|||||||
|
|
||||||
// padded shape
|
// padded shape
|
||||||
for (int i = 1; i < in.ndim() - 1; i++) {
|
for (int i = 1; i < in.ndim() - 1; i++) {
|
||||||
in_padded_shape[i] += 2 * padding[i - 1];
|
in_padded_shape[i] += padding_lo[i - 1] + padding_hi[i - 1];
|
||||||
padding_ends[i] += padding[i - 1];
|
padding_ends[i] += padding_lo[i - 1];
|
||||||
padding_starts[i] += padding[i - 1];
|
padding_starts[i] += padding_lo[i - 1];
|
||||||
}
|
}
|
||||||
|
|
||||||
// padded strides (contiguous)
|
// padded strides (contiguous)
|
||||||
@ -1078,9 +1079,16 @@ array conv_weight_backward_patches(
|
|||||||
// Pad input
|
// Pad input
|
||||||
std::vector<int> padded_axes(in.ndim() - 2, 0);
|
std::vector<int> padded_axes(in.ndim() - 2, 0);
|
||||||
std::iota(padded_axes.begin(), padded_axes.end(), 1);
|
std::iota(padded_axes.begin(), padded_axes.end(), 1);
|
||||||
Shape padding_(padding.begin(), padding.end());
|
Shape padding_lo_(padding_lo.begin(), padding_lo.end());
|
||||||
auto in_padded = pad(
|
Shape padding_hi_(padding_hi.begin(), padding_hi.end());
|
||||||
in, padded_axes, padding_, padding_, array(0, in.dtype()), "constant", s);
|
auto in_padded =
|
||||||
|
pad(in,
|
||||||
|
padded_axes,
|
||||||
|
padding_lo_,
|
||||||
|
padding_hi_,
|
||||||
|
array(0, in.dtype()),
|
||||||
|
"constant",
|
||||||
|
s);
|
||||||
|
|
||||||
// Resolve strided patches
|
// Resolve strided patches
|
||||||
|
|
||||||
@ -1147,16 +1155,16 @@ std::vector<array> Convolution::vjp(
|
|||||||
for (int a : argnums) {
|
for (int a : argnums) {
|
||||||
// Grads for input
|
// Grads for input
|
||||||
if (a == 0) {
|
if (a == 0) {
|
||||||
std::vector<int> padding_lo = padding_;
|
std::vector<int> padding_lo = padding_lo_;
|
||||||
std::vector<int> padding_hi = padding_;
|
std::vector<int> padding_hi = padding_hi_;
|
||||||
|
|
||||||
for (int i = 0; i < padding_lo.size(); ++i) {
|
for (int i = 0; i < padding_lo.size(); ++i) {
|
||||||
int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);
|
int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);
|
||||||
padding_lo[i] = wt_size - padding_[i] - 1;
|
padding_lo[i] = wt_size - padding_lo_[i] - 1;
|
||||||
|
|
||||||
int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1);
|
int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1);
|
||||||
int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1);
|
int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1);
|
||||||
padding_hi[i] = in_size - out_size + padding_[i];
|
padding_hi[i] = in_size - out_size + padding_hi_[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for negative padding
|
// Check for negative padding
|
||||||
@ -1226,18 +1234,12 @@ std::vector<array> Convolution::vjp(
|
|||||||
|
|
||||||
if (no_dilation && !flip_ && groups_ == 1) {
|
if (no_dilation && !flip_ && groups_ == 1) {
|
||||||
auto grad = conv_weight_backward_patches(
|
auto grad = conv_weight_backward_patches(
|
||||||
in, wt, cotan, kernel_strides_, padding_, stream());
|
in, wt, cotan, kernel_strides_, padding_lo_, padding_hi_, stream());
|
||||||
grads.push_back(grad);
|
grads.push_back(grad);
|
||||||
} else {
|
} else {
|
||||||
std::vector<int> padding_lo = padding_;
|
std::vector<int> padding_lo = padding_lo_;
|
||||||
std::vector<int> padding_hi = padding_;
|
std::vector<int> padding_hi = padding_hi_;
|
||||||
|
|
||||||
for (int i = 0; i < padding_hi.size(); ++i) {
|
|
||||||
int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1);
|
|
||||||
int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1);
|
|
||||||
int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);
|
|
||||||
padding_hi[i] = out_size - in_size + wt_size - padding_[i] - 1;
|
|
||||||
}
|
|
||||||
auto cotan_trans = swapaxes(cotan, 0, -1, stream());
|
auto cotan_trans = swapaxes(cotan, 0, -1, stream());
|
||||||
auto in_trans = group_transpose(in, -1, 0, -1);
|
auto in_trans = group_transpose(in, -1, 0, -1);
|
||||||
|
|
||||||
@ -1283,7 +1285,8 @@ std::pair<std::vector<array>, std::vector<int>> Convolution::vmap(
|
|||||||
in,
|
in,
|
||||||
w,
|
w,
|
||||||
kernel_strides_,
|
kernel_strides_,
|
||||||
padding_,
|
padding_lo_,
|
||||||
|
padding_hi_,
|
||||||
kernel_dilation_,
|
kernel_dilation_,
|
||||||
input_dilation_,
|
input_dilation_,
|
||||||
groups,
|
groups,
|
||||||
@ -1332,7 +1335,8 @@ std::pair<std::vector<array>, std::vector<int>> Convolution::vmap(
|
|||||||
|
|
||||||
bool Convolution::is_equivalent(const Primitive& other) const {
|
bool Convolution::is_equivalent(const Primitive& other) const {
|
||||||
const Convolution& c_other = static_cast<const Convolution&>(other);
|
const Convolution& c_other = static_cast<const Convolution&>(other);
|
||||||
return padding_ == c_other.padding_ &&
|
return padding_lo_ == c_other.padding_lo_ &&
|
||||||
|
padding_hi_ == c_other.padding_hi_ &&
|
||||||
kernel_strides_ == c_other.kernel_strides_ &&
|
kernel_strides_ == c_other.kernel_strides_ &&
|
||||||
kernel_dilation_ == c_other.kernel_dilation_ &&
|
kernel_dilation_ == c_other.kernel_dilation_ &&
|
||||||
input_dilation_ == c_other.input_dilation_ &&
|
input_dilation_ == c_other.input_dilation_ &&
|
||||||
|
@ -689,13 +689,15 @@ class Convolution : public UnaryPrimitive {
|
|||||||
explicit Convolution(
|
explicit Convolution(
|
||||||
Stream stream,
|
Stream stream,
|
||||||
const std::vector<int>& kernel_strides,
|
const std::vector<int>& kernel_strides,
|
||||||
const std::vector<int>& padding,
|
const std::vector<int>& padding_lo,
|
||||||
|
const std::vector<int>& padding_hi,
|
||||||
const std::vector<int>& kernel_dilation,
|
const std::vector<int>& kernel_dilation,
|
||||||
const std::vector<int>& input_dilation,
|
const std::vector<int>& input_dilation,
|
||||||
const int groups = 1,
|
const int groups = 1,
|
||||||
const bool flip = false)
|
const bool flip = false)
|
||||||
: UnaryPrimitive(stream),
|
: UnaryPrimitive(stream),
|
||||||
padding_(padding),
|
padding_lo_(padding_lo),
|
||||||
|
padding_hi_(padding_hi),
|
||||||
kernel_strides_(kernel_strides),
|
kernel_strides_(kernel_strides),
|
||||||
kernel_dilation_(kernel_dilation),
|
kernel_dilation_(kernel_dilation),
|
||||||
input_dilation_(input_dilation),
|
input_dilation_(input_dilation),
|
||||||
@ -716,7 +718,8 @@ class Convolution : public UnaryPrimitive {
|
|||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
auto state() const {
|
auto state() const {
|
||||||
return std::make_tuple(
|
return std::make_tuple(
|
||||||
padding_,
|
padding_lo_,
|
||||||
|
padding_hi_,
|
||||||
kernel_strides_,
|
kernel_strides_,
|
||||||
kernel_dilation_,
|
kernel_dilation_,
|
||||||
input_dilation_,
|
input_dilation_,
|
||||||
@ -725,7 +728,8 @@ class Convolution : public UnaryPrimitive {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<int> padding_;
|
std::vector<int> padding_lo_;
|
||||||
|
std::vector<int> padding_hi_;
|
||||||
std::vector<int> kernel_strides_;
|
std::vector<int> kernel_strides_;
|
||||||
std::vector<int> kernel_dilation_;
|
std::vector<int> kernel_dilation_;
|
||||||
std::vector<int> input_dilation_;
|
std::vector<int> input_dilation_;
|
||||||
|
@ -1088,6 +1088,48 @@ class TestConv(mlx_tests.MLXTestCase):
|
|||||||
atol=2e-5 if dtype == np.float32 else 5e-4,
|
atol=2e-5 if dtype == np.float32 else 5e-4,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@unittest.skipIf(not has_torch, "requires Torch")
|
||||||
|
def test_asymmetric_padding(self):
|
||||||
|
inputs = np.random.normal(size=(2, 8, 8, 8, 3)).astype(np.float32)
|
||||||
|
kernel = np.random.normal(size=(2, 3, 3, 3, 3)).astype(np.float32)
|
||||||
|
strides = (2, 2, 2)
|
||||||
|
|
||||||
|
pt_out = torch.conv3d(
|
||||||
|
torch.permute(torch.tensor(inputs), (0, 4, 1, 2, 3)),
|
||||||
|
torch.permute(torch.tensor(kernel), (0, 4, 1, 2, 3)),
|
||||||
|
stride=strides,
|
||||||
|
padding=2,
|
||||||
|
)
|
||||||
|
pt_out = torch.permute(pt_out, (0, 2, 3, 4, 1))[:, 1:, 1:, 1:, :].numpy()
|
||||||
|
|
||||||
|
mx_out = mx.conv_general(
|
||||||
|
mx.array(inputs),
|
||||||
|
mx.array(kernel),
|
||||||
|
stride=strides,
|
||||||
|
padding=([0, 0, 0], [1, 1, 1]),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(mx.allclose(mx_out, mx.array(pt_out), atol=1e-3, rtol=1e-3))
|
||||||
|
|
||||||
|
inputs = np.random.normal(size=(2, 10, 10, 3)).astype(np.float32)
|
||||||
|
kernel = np.random.normal(size=(2, 2, 2, 3)).astype(np.float32)
|
||||||
|
|
||||||
|
pt_out = torch.conv2d(
|
||||||
|
torch.permute(torch.tensor(inputs), (0, 3, 1, 2)),
|
||||||
|
torch.permute(torch.tensor(kernel), (0, 3, 1, 2)),
|
||||||
|
stride=1,
|
||||||
|
padding=(1, 0),
|
||||||
|
)
|
||||||
|
pt_out = torch.permute(pt_out, (0, 2, 3, 1))[:, 1:].numpy()
|
||||||
|
|
||||||
|
mx_out = mx.conv_general(
|
||||||
|
mx.array(inputs),
|
||||||
|
mx.array(kernel),
|
||||||
|
stride=1,
|
||||||
|
padding=([0, 0], [1, 0]),
|
||||||
|
)
|
||||||
|
self.assertTrue(mx.allclose(mx_out, mx.array(pt_out), atol=1e-3, rtol=1e-3))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user