mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Convolution update (#651)
* Init steel conv and update Conv primitive * Update slow CPU implementation to support flipping and input dilation winograd conv routing Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#include <Accelerate/Accelerate.h>
|
||||
@@ -27,14 +28,16 @@ void slow_conv_1D(
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation) {
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
bool flip) {
|
||||
const T* start_wt_ptr = wt.data<T>();
|
||||
|
||||
const T* in_ptr = in.data<T>();
|
||||
T* out_ptr = out.data<T>();
|
||||
|
||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||
const int iH = in.shape(1); // Input spatial dim
|
||||
const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
|
||||
const int oH = out.shape(1); // Output spatial dim
|
||||
const int O = wt.shape(0); // Out channels
|
||||
const int C = wt.shape(2); // In channels
|
||||
@@ -61,12 +64,15 @@ void slow_conv_1D(
|
||||
for (int wh = 0; wh < wH; ++wh) {
|
||||
const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;
|
||||
|
||||
int ih = oh * wt_strides[0] - padding[0] + wh * wt_dilation[0];
|
||||
int wh_flip = flip ? (wH - wh - 1) : wh;
|
||||
int ih = oh * wt_strides[0] - padding[0] + wh_flip * wt_dilation[0];
|
||||
|
||||
if (ih >= 0 && ih < iH) {
|
||||
auto ih_div = std::div(ih, in_dilation[0]);
|
||||
|
||||
if (ih >= 0 && ih < iH && ih_div.rem == 0) {
|
||||
for (int c = 0; c < C; ++c) {
|
||||
r += static_cast<float>(
|
||||
in_ptr[ih * in_stride_H + c * in_stride_C]) *
|
||||
in_ptr[ih_div.quot * in_stride_H + c * in_stride_C]) *
|
||||
static_cast<float>(wt_ptr[c * wt_stride_C]);
|
||||
} // c
|
||||
|
||||
@@ -90,14 +96,16 @@ void slow_conv_2D(
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation) {
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
bool flip) {
|
||||
const T* st_wt_ptr = wt.data<T>();
|
||||
const T* st_in_ptr = in.data<T>();
|
||||
T* st_out_ptr = out.data<T>();
|
||||
|
||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||
const int iH = in.shape(1); // Input spatial dim
|
||||
const int iW = in.shape(2); // Input spatial dim
|
||||
const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
|
||||
const int iW = 1 + in_dilation[1] * (in.shape(2) - 1); // Input spatial dim
|
||||
const int oH = out.shape(1); // Output spatial dim
|
||||
const int oW = out.shape(2); // Output spatial dim
|
||||
const int O = wt.shape(0); // Out channels
|
||||
@@ -120,6 +128,8 @@ void slow_conv_2D(
|
||||
const size_t out_stride_W = out.strides()[2];
|
||||
const size_t out_stride_O = out.strides()[3];
|
||||
|
||||
bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1;
|
||||
|
||||
auto pt_conv_no_checks =
|
||||
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
|
||||
out_ptr += oh * out_stride_H + ow * out_stride_W;
|
||||
@@ -131,8 +141,10 @@ void slow_conv_2D(
|
||||
|
||||
for (int wh = 0; wh < wH; ++wh) {
|
||||
for (int ww = 0; ww < wW; ++ww) {
|
||||
int ih = ih_base + wh * wt_dilation[0];
|
||||
int iw = iw_base + ww * wt_dilation[1];
|
||||
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||
int ih = ih_base + wh_flip * wt_dilation[0];
|
||||
int iw = iw_base + ww_flip * wt_dilation[1];
|
||||
|
||||
const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
||||
const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W;
|
||||
@@ -153,25 +165,74 @@ void slow_conv_2D(
|
||||
} // o
|
||||
};
|
||||
|
||||
int jump_h = flip ? -wt_dilation[0] : wt_dilation[0];
|
||||
int jump_w = flip ? -wt_dilation[1] : wt_dilation[1];
|
||||
|
||||
int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0);
|
||||
int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0);
|
||||
|
||||
int f_wgt_jump_h = std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0];
|
||||
int f_wgt_jump_w = std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];
|
||||
|
||||
int f_out_jump_h = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];
|
||||
int f_out_jump_w = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];
|
||||
|
||||
std::vector<int> base_h(f_out_jump_h);
|
||||
std::vector<int> base_w(f_out_jump_w);
|
||||
|
||||
for (int i = 0; i < f_out_jump_h; ++i) {
|
||||
int ih_loop = i * wt_strides[0] - padding[0] + init_h;
|
||||
|
||||
int wh_base = 0;
|
||||
while (wh_base < wH && ih_loop % in_dilation[0] != 0) {
|
||||
wh_base++;
|
||||
ih_loop += jump_h;
|
||||
}
|
||||
|
||||
base_h[i] = wh_base;
|
||||
}
|
||||
|
||||
for (int j = 0; j < f_out_jump_w; ++j) {
|
||||
int iw_loop = j * wt_strides[1] - padding[1] + init_w;
|
||||
|
||||
int ww_base = 0;
|
||||
while (ww_base < wW && iw_loop % in_dilation[1] != 0) {
|
||||
ww_base++;
|
||||
iw_loop += jump_w;
|
||||
}
|
||||
|
||||
base_w[j] = ww_base;
|
||||
}
|
||||
|
||||
auto pt_conv_all_checks =
|
||||
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
|
||||
out_ptr += oh * out_stride_H + ow * out_stride_W;
|
||||
|
||||
int ih_base = oh * wt_strides[0] - padding[0];
|
||||
int iw_base = ow * wt_strides[1] - padding[1];
|
||||
|
||||
int wh_base = base_h[oh % f_out_jump_h];
|
||||
int ww_base = base_w[ow % f_out_jump_w];
|
||||
|
||||
for (int o = 0; o < O; ++o) {
|
||||
float r = 0.;
|
||||
|
||||
for (int wh = 0; wh < wH; ++wh) {
|
||||
for (int ww = 0; ww < wW; ++ww) {
|
||||
int ih = ih_base + wh * wt_dilation[0];
|
||||
int iw = iw_base + ww * wt_dilation[1];
|
||||
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
|
||||
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
|
||||
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||
int ih = ih_base + wh_flip * wt_dilation[0];
|
||||
int iw = iw_base + ww_flip * wt_dilation[1];
|
||||
|
||||
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
|
||||
const T* wt_ptr_pt =
|
||||
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
||||
|
||||
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
|
||||
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
|
||||
|
||||
const T* in_ptr_pt =
|
||||
in_ptr + ih * in_stride_H + iw * in_stride_W;
|
||||
in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W;
|
||||
|
||||
for (int c = 0; c < C; ++c) {
|
||||
r += static_cast<float>(in_ptr_pt[0]) *
|
||||
@@ -191,13 +252,17 @@ void slow_conv_2D(
|
||||
};
|
||||
|
||||
int oH_border_0 = 0;
|
||||
int oH_border_1 = (padding[0] + wt_strides[0] + 1) / wt_strides[0];
|
||||
int oH_border_2 = (iH + padding[0] - wH * wt_dilation[0]) / wt_strides[0];
|
||||
int oH_border_1 =
|
||||
is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oH;
|
||||
int oH_border_2 = std::max(
|
||||
oH_border_1, (iH + padding[0] - wH * wt_dilation[0]) / wt_strides[0]);
|
||||
int oH_border_3 = oH;
|
||||
|
||||
int oW_border_0 = 0;
|
||||
int oW_border_1 = (padding[1] + wt_strides[0] + 1) / wt_strides[1];
|
||||
int oW_border_2 = (iW + padding[1] - wW * wt_dilation[1]) / wt_strides[1];
|
||||
int oW_border_1 =
|
||||
is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oW;
|
||||
int oW_border_2 = std::max(
|
||||
oW_border_1, (iW + padding[1] - wW * wt_dilation[1]) / wt_strides[1]);
|
||||
int oW_border_3 = oW;
|
||||
|
||||
for (int n = 0; n < N; ++n) {
|
||||
@@ -246,15 +311,18 @@ void dispatch_slow_conv_1D(
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation) {
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
bool flip) {
|
||||
if (in.dtype() == float32) {
|
||||
return slow_conv_1D<float>(in, wt, out, padding, wt_strides, wt_dilation);
|
||||
return slow_conv_1D<float>(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||
} else if (in.dtype() == float16) {
|
||||
return slow_conv_1D<float16_t>(
|
||||
in, wt, out, padding, wt_strides, wt_dilation);
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||
} else if (in.dtype() == bfloat16) {
|
||||
return slow_conv_1D<bfloat16_t>(
|
||||
in, wt, out, padding, wt_strides, wt_dilation);
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[Convolution::eval] got unsupported data type.");
|
||||
@@ -267,15 +335,18 @@ void dispatch_slow_conv_2D(
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation) {
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
bool flip) {
|
||||
if (in.dtype() == float32) {
|
||||
return slow_conv_2D<float>(in, wt, out, padding, wt_strides, wt_dilation);
|
||||
return slow_conv_2D<float>(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||
} else if (in.dtype() == float16) {
|
||||
return slow_conv_2D<float16_t>(
|
||||
in, wt, out, padding, wt_strides, wt_dilation);
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||
} else if (in.dtype() == bfloat16) {
|
||||
return slow_conv_2D<bfloat16_t>(
|
||||
in, wt, out, padding, wt_strides, wt_dilation);
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[Convolution::eval] got unsupported data type.");
|
||||
@@ -493,13 +564,16 @@ void conv_1D_cpu(
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation) {
|
||||
if (wt_dilation[0] == 1) {
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
bool flip) {
|
||||
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) {
|
||||
return explicit_gemm_conv_1D_cpu(
|
||||
in, wt, out, padding, wt_strides, wt_dilation);
|
||||
}
|
||||
|
||||
return dispatch_slow_conv_1D(in, wt, out, padding, wt_strides, wt_dilation);
|
||||
return dispatch_slow_conv_1D(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||
}
|
||||
|
||||
void conv_2D_cpu(
|
||||
@@ -508,8 +582,11 @@ void conv_2D_cpu(
|
||||
array out,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation) {
|
||||
return dispatch_slow_conv_2D(in, wt, out, padding, wt_strides, wt_dilation);
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
bool flip) {
|
||||
return dispatch_slow_conv_2D(
|
||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -523,12 +600,26 @@ void Convolution::eval(const std::vector<array>& inputs, array& out) {
|
||||
// 2D convolution
|
||||
if (in.ndim() == (2 + 2)) {
|
||||
return conv_2D_cpu(
|
||||
in, wt, out, padding_, kernel_strides_, kernel_dilation_);
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding_,
|
||||
kernel_strides_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
flip_);
|
||||
}
|
||||
// 1D convolution
|
||||
else if (in.ndim() == (1 + 2)) {
|
||||
return conv_1D_cpu(
|
||||
in, wt, out, padding_, kernel_strides_, kernel_dilation_);
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
padding_,
|
||||
kernel_strides_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
flip_);
|
||||
}
|
||||
// Throw error
|
||||
else {
|
||||
|
||||
Reference in New Issue
Block a user