mlx/mlx/backend/common/conv.cpp
Rifur13 c4a471c99d
Add groups to Conv1d (#948)
* Add conv1d grouped convs on CPU

* Add GPU support

* Parallelize inside metal kernel

* clenaup

* Update mlx/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* New unfold kernel + remove unused code

* Remove copy and refactor

* Update vjp and reuse steel gemm

* Fixed groups on cpu

* Fix metal validation

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-04-27 06:24:57 -07:00

663 lines
21 KiB
C++

// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <numeric>
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <cblas.h>
#endif
#include "mlx/backend/common/copy.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
namespace {
///////////////////////////////////////////////////////////////////////////////
// Naive reference conv
///////////////////////////////////////////////////////////////////////////////
template <typename T>
void slow_conv_1D(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
const T* start_wt_ptr = wt.data<T>();
const T* in_ptr = in.data<T>();
T* out_ptr = out.data<T>();
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
const int C = in.shape(2); // Input channels
const int oH = out.shape(1); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int wH = wt.shape(1); // Weight spatial dim
const int groups = C / wt.shape(2);
const int C_per_group = wt.shape(2);
const int O_per_group = O / groups;
const size_t in_stride_N = in.strides()[0];
const size_t in_stride_H = in.strides()[1];
const size_t in_stride_C = in.strides()[2];
const size_t wt_stride_O = wt.strides()[0];
const size_t wt_stride_H = wt.strides()[1];
const size_t wt_stride_C = wt.strides()[2];
const size_t out_stride_N = out.strides()[0];
const size_t out_stride_H = out.strides()[1];
const size_t out_stride_O = out.strides()[2];
for (int n = 0; n < N; ++n) {
for (int oh = 0; oh < oH; ++oh) {
for (int g = 0; g < groups; ++g) {
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
const T* filter_wt_ptr = start_wt_ptr + o * wt_stride_O;
float r = 0.;
for (int wh = 0; wh < wH; ++wh) {
const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;
int wh_flip = flip ? (wH - wh - 1) : wh;
int ih = oh * wt_strides[0] - padding[0] + wh_flip * wt_dilation[0];
auto ih_div = std::div(ih, in_dilation[0]);
if (ih >= 0 && ih < iH && ih_div.rem == 0) {
for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) {
r += static_cast<float>(
in_ptr[ih_div.quot * in_stride_H + c * in_stride_C]) *
static_cast<float>(wt_ptr[(c % C_per_group) * wt_stride_C]);
} // c
} // ih check
} // wh
out_ptr[oh * out_stride_H + o * out_stride_O] = static_cast<T>(r);
} // o
} // g
} // oh
in_ptr += in_stride_N;
out_ptr += out_stride_N;
} // n
}
template <typename T>
void slow_conv_2D(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
const T* st_wt_ptr = wt.data<T>();
const T* st_in_ptr = in.data<T>();
T* st_out_ptr = out.data<T>();
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
const int iW = 1 + in_dilation[1] * (in.shape(2) - 1); // Input spatial dim
const int oH = out.shape(1); // Output spatial dim
const int oW = out.shape(2); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(3); // In channels
const int wH = wt.shape(1); // Weight spatial dim
const int wW = wt.shape(2); // Weight spatial dim
const size_t in_stride_N = in.strides()[0];
const size_t in_stride_H = in.strides()[1];
const size_t in_stride_W = in.strides()[2];
const size_t in_stride_C = in.strides()[3];
const size_t wt_stride_O = wt.strides()[0];
const size_t wt_stride_H = wt.strides()[1];
const size_t wt_stride_W = wt.strides()[2];
const size_t wt_stride_C = wt.strides()[3];
const size_t out_stride_N = out.strides()[0];
const size_t out_stride_H = out.strides()[1];
const size_t out_stride_W = out.strides()[2];
const size_t out_stride_O = out.strides()[3];
bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1;
auto pt_conv_no_checks =
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
out_ptr += oh * out_stride_H + ow * out_stride_W;
int ih_base = oh * wt_strides[0] - padding[0];
int iw_base = ow * wt_strides[1] - padding[1];
for (int o = 0; o < O; ++o) {
float r = 0.;
for (int wh = 0; wh < wH; ++wh) {
for (int ww = 0; ww < wW; ++ww) {
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int ih = ih_base + wh_flip * wt_dilation[0];
int iw = iw_base + ww_flip * wt_dilation[1];
const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W;
for (int c = 0; c < C; ++c) {
r += static_cast<float>(in_ptr_pt[0]) *
static_cast<float>(wt_ptr_pt[0]);
in_ptr_pt += in_stride_C;
wt_ptr_pt += wt_stride_C;
} // c
} // ww
} // wh
out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O;
wt_ptr += wt_stride_O;
} // o
};
int jump_h = flip ? -wt_dilation[0] : wt_dilation[0];
int jump_w = flip ? -wt_dilation[1] : wt_dilation[1];
int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0);
int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0);
int f_wgt_jump_h = std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0];
int f_wgt_jump_w = std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];
int f_out_jump_h = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];
int f_out_jump_w = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];
std::vector<int> base_h(f_out_jump_h);
std::vector<int> base_w(f_out_jump_w);
for (int i = 0; i < f_out_jump_h; ++i) {
int ih_loop = i * wt_strides[0] - padding[0] + init_h;
int wh_base = 0;
while (wh_base < wH && ih_loop % in_dilation[0] != 0) {
wh_base++;
ih_loop += jump_h;
}
base_h[i] = wh_base;
}
for (int j = 0; j < f_out_jump_w; ++j) {
int iw_loop = j * wt_strides[1] - padding[1] + init_w;
int ww_base = 0;
while (ww_base < wW && iw_loop % in_dilation[1] != 0) {
ww_base++;
iw_loop += jump_w;
}
base_w[j] = ww_base;
}
auto pt_conv_all_checks =
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
out_ptr += oh * out_stride_H + ow * out_stride_W;
int ih_base = oh * wt_strides[0] - padding[0];
int iw_base = ow * wt_strides[1] - padding[1];
int wh_base = base_h[oh % f_out_jump_h];
int ww_base = base_w[ow % f_out_jump_w];
for (int o = 0; o < O; ++o) {
float r = 0.;
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int ih = ih_base + wh_flip * wt_dilation[0];
int iw = iw_base + ww_flip * wt_dilation[1];
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
const T* wt_ptr_pt =
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
const T* in_ptr_pt =
in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W;
for (int c = 0; c < C; ++c) {
r += static_cast<float>(in_ptr_pt[0]) *
static_cast<float>(wt_ptr_pt[0]);
in_ptr_pt += in_stride_C;
wt_ptr_pt += wt_stride_C;
} // c
} // ih, iw check
} // ww
} // wh
out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O;
wt_ptr += wt_stride_O;
} // o
};
int oH_border_0 = 0;
int oH_border_1 =
is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oH;
int oH_border_2 = std::max(
oH_border_1, (iH + padding[0] - wH * wt_dilation[0]) / wt_strides[0]);
int oH_border_3 = oH;
int oW_border_0 = 0;
int oW_border_1 =
is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oW;
int oW_border_2 = std::max(
oW_border_1, (iW + padding[1] - wW * wt_dilation[1]) / wt_strides[1]);
int oW_border_3 = oW;
for (int n = 0; n < N; ++n) {
// Case 1: oh might put us out of bounds
for (int oh = oH_border_0; oh < oH_border_1; ++oh) {
for (int ow = 0; ow < oW; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow
} // oh
// Case 2: oh in bounds
for (int oh = oH_border_1; oh < oH_border_2; ++oh) {
// Case a: ow might put us out of bounds
for (int ow = oW_border_0; ow < oW_border_1; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow
// Case b: ow in bounds
for (int ow = oW_border_1; ow < oW_border_2; ++ow) {
pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow
// Case c: ow might put us out of bounds
for (int ow = oW_border_2; ow < oW_border_3; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow
} // oh
// Case 3: oh might put us out of bounds
for (int oh = oH_border_2; oh < oH_border_3; ++oh) {
for (int ow = 0; ow < oW; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow
} // oh
st_in_ptr += in_stride_N;
st_out_ptr += out_stride_N;
} // n
}
void dispatch_slow_conv_1D(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
if (in.dtype() == float32) {
return slow_conv_1D<float>(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
} else if (in.dtype() == float16) {
return slow_conv_1D<float16_t>(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
} else if (in.dtype() == bfloat16) {
return slow_conv_1D<bfloat16_t>(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
} else {
throw std::invalid_argument(
"[Convolution::eval] got unsupported data type.");
}
}
void dispatch_slow_conv_2D(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
if (in.dtype() == float32) {
return slow_conv_2D<float>(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
} else if (in.dtype() == float16) {
return slow_conv_2D<float16_t>(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
} else if (in.dtype() == bfloat16) {
return slow_conv_2D<bfloat16_t>(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
} else {
throw std::invalid_argument(
"[Convolution::eval] got unsupported data type.");
}
}
///////////////////////////////////////////////////////////////////////////////
// Explicit gemm conv
///////////////////////////////////////////////////////////////////////////////
void explicit_gemm_conv_1D_cpu(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation) {
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iH = in.shape(1); // Input spatial dim
const int C = in.shape(2); // Input channels
const int oH = out.shape(1); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int wH = wt.shape(1); // Weight spatial dim
const int groups = C / wt.shape(2);
const int C_per_group = wt.shape(2);
const int O_per_group = O / groups;
auto conv_dtype = float32;
// Pad input
std::vector<int> padded_shape = {N, iH + 2 * padding[0], C};
array in_padded(padded_shape, conv_dtype, nullptr, {});
// Fill with zeros
copy(array(0, conv_dtype), in_padded, CopyType::Scalar);
// Pick input slice from padded
size_t data_offset = padding[0] * in_padded.strides()[1];
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer(
in_padded,
in_padded.strides(),
in_padded.flags(),
in_padded_slice.size(),
data_offset);
// Copy input values into the slice
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral);
// Make strided view
std::vector<int> strided_shape = {N, oH, wH, C};
std::vector<size_t> strided_strides = {
in_padded.strides()[0],
in_padded.strides()[1] * wt_strides[0],
in_padded.strides()[1],
in_padded.strides()[2]};
auto flags = in_padded.flags();
if (groups > 1) {
// Transpose the last two dimensions for grouped convolutions
std::swap(strided_shape[2], strided_shape[3]);
std::swap(strided_strides[2], strided_strides[3]);
}
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
in_strided_view.copy_shared_buffer(
in_padded, strided_strides, flags, in_strided_view.size(), 0);
// Materialize strided view
std::vector<int> strided_reshape = {N * oH, wH * C};
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy(in_strided_view, in_strided, CopyType::General);
// Check wt dtype and prepare
auto gemm_wt = wt;
auto gemm_out = out;
if (groups > 1) {
// Transpose the last two dimensions for grouped convolutions
array wt_transpose(
{wt.shape(0), wt.shape(2), wt.shape(1)}, wt.dtype(), nullptr, {});
wt_transpose.copy_shared_buffer(
wt,
{wt.strides(0), wt.strides(2), wt.strides(1)},
wt.flags(),
wt.size(),
0);
gemm_wt = array(wt_transpose.shape(), float32, nullptr, {});
copy(wt_transpose, gemm_wt, CopyType::General);
} else if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
auto ctype =
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
gemm_wt = array(wt.shape(), float32, nullptr, {});
copy(wt, gemm_wt, ctype);
}
if (out.dtype() != float32) {
gemm_out = array(out.shape(), float32, nullptr, {});
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
}
for (int g = 0; g < groups; ++g) {
// Perform gemm
cblas_sgemm(
CblasRowMajor,
CblasNoTrans, // no trans A
CblasTrans, // transB
strided_reshape[0], // M
O_per_group, // N
C_per_group * wH, // K
1.0f, // alpha
in_strided.data<float>() + g * C_per_group * wH, // A
wH * C, // lda
gemm_wt.data<float>() + g * O_per_group * C_per_group * wH, // B
wH * C_per_group, // ldb
0.0f, // beta
gemm_out.data<float>() + g * O_per_group, // C
O // ldc
);
// Copy results if needed
if (out.dtype() != float32) {
copy(gemm_out, out, CopyType::Vector);
}
}
}
void explicit_gemm_conv_2D_cpu(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation) {
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iH = in.shape(1); // Input spatial dim
const int iW = in.shape(2); // Input spatial dim
const int oH = out.shape(1); // Output spatial dim
const int oW = out.shape(2); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(3); // In channels
const int wH = wt.shape(1); // Weight spatial dim
const int wW = wt.shape(2); // Weight spatial dim
auto conv_dtype = out.dtype();
// Pad input
std::vector<int> padded_shape = {
N, iH + 2 * padding[0], iW + 2 * padding[1], C};
array in_padded(padded_shape, conv_dtype, nullptr, {});
// Fill with zeros
copy(array(0, conv_dtype), in_padded, CopyType::Scalar);
// Pick input slice from padded
size_t data_offset =
padding[0] * in_padded.strides()[1] + padding[1] * in_padded.strides()[2];
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer(
in_padded,
in_padded.strides(),
in_padded.flags(),
in_padded_slice.size(),
data_offset);
// Copy input values into the slice
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral);
// Make strided view
std::vector<int> strided_shape = {N, oH, oW, wH, wW, C};
std::vector<size_t> strided_strides = {
in_padded.strides()[0],
in_padded.strides()[1] * wt_strides[0],
in_padded.strides()[2] * wt_strides[1],
in_padded.strides()[1],
in_padded.strides()[2],
in_padded.strides()[3]};
auto flags = in_padded.flags();
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
in_strided_view.copy_shared_buffer(
in_padded, strided_strides, flags, in_strided_view.size(), 0);
// Materialize strided view
std::vector<int> strided_reshape = {N * oH * oW, wH * wW * C};
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy(in_strided_view, in_strided, CopyType::General);
// Check wt dtype and prepare
auto gemm_wt = wt;
auto gemm_out = out;
if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
auto ctype =
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
gemm_wt = array(wt.shape(), float32, nullptr, {});
copy(wt, gemm_wt, ctype);
}
if (out.dtype() != float32) {
gemm_out = array(out.shape(), float32, nullptr, {});
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
}
// Perform gemm
cblas_sgemm(
CblasRowMajor,
CblasNoTrans, // no trans A
CblasTrans, // transB
strided_reshape[0], // M
O, // N
strided_reshape[1], // K
1.0f, // alpha
in_strided.data<float>(),
strided_reshape[1], // lda
gemm_wt.data<float>(),
strided_reshape[1], // ldb
0.0f, // beta
gemm_out.data<float>(),
O // ldc
);
// Copy results if needed
if (out.dtype() != float32) {
copy(gemm_out, out, CopyType::Vector);
}
}
///////////////////////////////////////////////////////////////////////////////
// Conv routing
///////////////////////////////////////////////////////////////////////////////
void conv_1D_cpu(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) {
return explicit_gemm_conv_1D_cpu(
in, wt, out, padding, wt_strides, wt_dilation);
}
return dispatch_slow_conv_1D(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
}
void conv_2D_cpu(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
return dispatch_slow_conv_2D(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
}
} // namespace
void Convolution::eval(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& in = inputs[0];
auto& wt = inputs[1];
// 2D convolution
if (in.ndim() == (2 + 2)) {
return conv_2D_cpu(
in,
wt,
out,
padding_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
flip_);
}
// 1D convolution
else if (in.ndim() == (1 + 2)) {
return conv_1D_cpu(
in,
wt,
out,
padding_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
flip_);
}
// Throw error
else {
std::ostringstream msg;
msg << "[Convolution::eval] Convolution currently only supports"
<< " 1D and 2D convolutions. Got inputs with " << in.ndim() - 2
<< " spatial dimensions";
throw std::invalid_argument(msg.str());
}
}
} // namespace mlx::core