mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add groups to Conv1d (#948)
* Add conv1d grouped convs on CPU * Add GPU support * Parallelize inside metal kernel * clenaup * Update mlx/ops.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * New unfold kernel + remove unused code * Remove copy and refactor * Update vjp and reuse steel gemm * Fixed groups on cpu * Fix metal validation --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
@@ -38,11 +38,15 @@ void slow_conv_1D(
|
||||
|
||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||
const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
|
||||
const int C = in.shape(2); // Input channels
|
||||
const int oH = out.shape(1); // Output spatial dim
|
||||
const int O = wt.shape(0); // Out channels
|
||||
const int C = wt.shape(2); // In channels
|
||||
const int wH = wt.shape(1); // Weight spatial dim
|
||||
|
||||
const int groups = C / wt.shape(2);
|
||||
const int C_per_group = wt.shape(2);
|
||||
const int O_per_group = O / groups;
|
||||
|
||||
const size_t in_stride_N = in.strides()[0];
|
||||
const size_t in_stride_H = in.strides()[1];
|
||||
const size_t in_stride_C = in.strides()[2];
|
||||
@@ -57,35 +61,36 @@ void slow_conv_1D(
|
||||
|
||||
for (int n = 0; n < N; ++n) {
|
||||
for (int oh = 0; oh < oH; ++oh) {
|
||||
for (int o = 0; o < O; ++o) {
|
||||
const T* filter_wt_ptr = start_wt_ptr + o * wt_stride_O;
|
||||
float r = 0.;
|
||||
for (int g = 0; g < groups; ++g) {
|
||||
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
|
||||
const T* filter_wt_ptr = start_wt_ptr + o * wt_stride_O;
|
||||
float r = 0.;
|
||||
|
||||
for (int wh = 0; wh < wH; ++wh) {
|
||||
const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;
|
||||
for (int wh = 0; wh < wH; ++wh) {
|
||||
const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;
|
||||
|
||||
int wh_flip = flip ? (wH - wh - 1) : wh;
|
||||
int ih = oh * wt_strides[0] - padding[0] + wh_flip * wt_dilation[0];
|
||||
int wh_flip = flip ? (wH - wh - 1) : wh;
|
||||
int ih = oh * wt_strides[0] - padding[0] + wh_flip * wt_dilation[0];
|
||||
|
||||
auto ih_div = std::div(ih, in_dilation[0]);
|
||||
auto ih_div = std::div(ih, in_dilation[0]);
|
||||
|
||||
if (ih >= 0 && ih < iH && ih_div.rem == 0) {
|
||||
for (int c = 0; c < C; ++c) {
|
||||
r += static_cast<float>(
|
||||
in_ptr[ih_div.quot * in_stride_H + c * in_stride_C]) *
|
||||
static_cast<float>(wt_ptr[c * wt_stride_C]);
|
||||
} // c
|
||||
if (ih >= 0 && ih < iH && ih_div.rem == 0) {
|
||||
for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) {
|
||||
r += static_cast<float>(
|
||||
in_ptr[ih_div.quot * in_stride_H + c * in_stride_C]) *
|
||||
static_cast<float>(wt_ptr[(c % C_per_group) * wt_stride_C]);
|
||||
} // c
|
||||
|
||||
} // ih check
|
||||
} // wh
|
||||
} // ih check
|
||||
} // wh
|
||||
|
||||
out_ptr[oh * out_stride_H + o * out_stride_O] = static_cast<T>(r);
|
||||
} // o
|
||||
out_ptr[oh * out_stride_H + o * out_stride_O] = static_cast<T>(r);
|
||||
} // o
|
||||
} // g
|
||||
} // oh
|
||||
|
||||
in_ptr += in_stride_N;
|
||||
out_ptr += out_stride_N;
|
||||
|
||||
} // n
|
||||
}
|
||||
|
||||
@@ -366,11 +371,15 @@ void explicit_gemm_conv_1D_cpu(
|
||||
const std::vector<int>& wt_dilation) {
|
||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||
const int iH = in.shape(1); // Input spatial dim
|
||||
const int C = in.shape(2); // Input channels
|
||||
const int oH = out.shape(1); // Output spatial dim
|
||||
const int O = wt.shape(0); // Out channels
|
||||
const int C = wt.shape(2); // In channels
|
||||
const int wH = wt.shape(1); // Weight spatial dim
|
||||
|
||||
const int groups = C / wt.shape(2);
|
||||
const int C_per_group = wt.shape(2);
|
||||
const int O_per_group = O / groups;
|
||||
|
||||
auto conv_dtype = float32;
|
||||
|
||||
// Pad input
|
||||
@@ -402,6 +411,11 @@ void explicit_gemm_conv_1D_cpu(
|
||||
in_padded.strides()[1],
|
||||
in_padded.strides()[2]};
|
||||
auto flags = in_padded.flags();
|
||||
if (groups > 1) {
|
||||
// Transpose the last two dimensions for grouped convolutions
|
||||
std::swap(strided_shape[2], strided_shape[3]);
|
||||
std::swap(strided_strides[2], strided_strides[3]);
|
||||
}
|
||||
|
||||
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
|
||||
in_strided_view.copy_shared_buffer(
|
||||
@@ -416,7 +430,19 @@ void explicit_gemm_conv_1D_cpu(
|
||||
auto gemm_wt = wt;
|
||||
auto gemm_out = out;
|
||||
|
||||
if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
|
||||
if (groups > 1) {
|
||||
// Transpose the last two dimensions for grouped convolutions
|
||||
array wt_transpose(
|
||||
{wt.shape(0), wt.shape(2), wt.shape(1)}, wt.dtype(), nullptr, {});
|
||||
wt_transpose.copy_shared_buffer(
|
||||
wt,
|
||||
{wt.strides(0), wt.strides(2), wt.strides(1)},
|
||||
wt.flags(),
|
||||
wt.size(),
|
||||
0);
|
||||
gemm_wt = array(wt_transpose.shape(), float32, nullptr, {});
|
||||
copy(wt_transpose, gemm_wt, CopyType::General);
|
||||
} else if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
|
||||
auto ctype =
|
||||
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||
gemm_wt = array(wt.shape(), float32, nullptr, {});
|
||||
@@ -428,27 +454,29 @@ void explicit_gemm_conv_1D_cpu(
|
||||
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
|
||||
}
|
||||
|
||||
// Perform gemm
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
CblasNoTrans, // no trans A
|
||||
CblasTrans, // transB
|
||||
strided_reshape[0], // M
|
||||
O, // N
|
||||
strided_reshape[1], // K
|
||||
1.0f, // alpha
|
||||
in_strided.data<float>(),
|
||||
strided_reshape[1], // lda
|
||||
gemm_wt.data<float>(),
|
||||
strided_reshape[1], // ldb
|
||||
0.0f, // beta
|
||||
gemm_out.data<float>(),
|
||||
O // ldc
|
||||
);
|
||||
for (int g = 0; g < groups; ++g) {
|
||||
// Perform gemm
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
CblasNoTrans, // no trans A
|
||||
CblasTrans, // transB
|
||||
strided_reshape[0], // M
|
||||
O_per_group, // N
|
||||
C_per_group * wH, // K
|
||||
1.0f, // alpha
|
||||
in_strided.data<float>() + g * C_per_group * wH, // A
|
||||
wH * C, // lda
|
||||
gemm_wt.data<float>() + g * O_per_group * C_per_group * wH, // B
|
||||
wH * C_per_group, // ldb
|
||||
0.0f, // beta
|
||||
gemm_out.data<float>() + g * O_per_group, // C
|
||||
O // ldc
|
||||
);
|
||||
|
||||
// Copy results if needed
|
||||
if (out.dtype() != float32) {
|
||||
copy(gemm_out, out, CopyType::Vector);
|
||||
// Copy results if needed
|
||||
if (out.dtype() != float32) {
|
||||
copy(gemm_out, out, CopyType::Vector);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user