diff --git a/benchmarks/python/conv1d_bench.py b/benchmarks/python/conv1d_bench.py new file mode 100644 index 000000000..0306a3e08 --- /dev/null +++ b/benchmarks/python/conv1d_bench.py @@ -0,0 +1,123 @@ +import argparse +import math +import os +import subprocess +import time + +import mlx.core as mx +import numpy as np +import torch + +device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"]) +device_name = device_name.decode("utf-8").strip("\n") + +N_warmup = 10 +N_iter_bench = 100 +N_iter_func = 5 + + +def bench(f, a, b): + for i in range(N_warmup): + f(a, b) + torch.mps.synchronize() + + s = time.perf_counter_ns() + for i in range(N_iter_bench): + f(a, b) + e = time.perf_counter_ns() + return (e - s) * 1e-9 + + +def make_mx_conv_1D(strides=1, padding=0, groups=1): + def mx_conv_1D(a, b): + ys = [] + for _ in range(N_iter_func): + y = mx.conv1d(a, b, stride=strides, padding=padding, groups=groups) + ys.append(y) + mx.eval(ys) + return ys + + return mx_conv_1D + + +def make_pt_conv_1D(strides=1, padding=0, groups=1): + @torch.no_grad() + def pt_conv_1D(a, b): + ys = [] + for _ in range(N_iter_func): + y = torch.conv1d(a, b, stride=strides, padding=padding, groups=groups) + ys.append(y) + torch.mps.synchronize() + return ys + + return pt_conv_1D + + +def bench_shape(N, iH, C, wH, O, strides, padding, np_dtype, groups): + scale = 1.0 / math.sqrt(wH * C) + a_np = np.random.uniform(0, 0.5, (N, iH, C)).astype(np_dtype) + b_np = np.random.uniform(-scale, scale, (O, wH, int(C / groups))).astype(np_dtype) + + a_mx = mx.array(a_np) + b_mx = mx.array(b_np) + + a_pt = torch.from_numpy(a_np.transpose((0, 2, 1))).to("mps") + b_pt = torch.from_numpy(b_np.transpose((0, 2, 1))).to("mps") + + torch.mps.synchronize() + + f_mx = make_mx_conv_1D(strides, padding, groups) + f_pt = make_pt_conv_1D(strides, padding, groups) + + time_torch = bench(f_pt, a_pt, b_pt) + time_mlx = bench(f_mx, a_mx, b_mx) + + out_mx = mx.conv1d(a_mx, b_mx, stride=strides, padding=padding, groups=groups) + out_pt = torch.conv1d( + a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups + ) + out_pt = torch.permute(out_pt, (0, 2, 1)) + out_pt = out_pt.numpy(force=True) + + atol = 2e-5 if np_dtype == np.float32 else 1e-4 + + if not np.allclose(out_pt, out_mx, atol=atol): + print( + f"Failed at {(N, iH, C)}, {(O, wH, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}" + ) + + return time_mlx, time_torch + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run conv benchmarks") + + dtypes = ("float32",) + shapes = ( + (4, 32, 32, 5, 32, 1, 2, 1), + (4, 32, 32, 5, 32, 1, 2, 2), + (4, 32, 32, 5, 32, 1, 2, 4), + (4, 32, 32, 5, 32, 1, 2, 8), + (4, 32, 32, 5, 32, 1, 2, 8), + (4, 32, 32, 5, 32, 1, 2, 16), + (4, 32, 32, 5, 32, 1, 2, 32), + (4, 32, 256, 5, 512, 1, 2, 2), + (4, 32, 256, 5, 512, 1, 2, 128), + (4, 32, 256, 5, 512, 1, 2, 256), + ) + + for dtype in dtypes: + print("(N, iH, C), (O, wH, C), dtype, stride, pads, groups, diff%") + for N, iH, C, wH, O, strides, padding, groups in shapes: + np_dtype = getattr(np, dtype) + time_mlx, time_torch = bench_shape( + N, iH, C, wH, O, strides, padding, np_dtype, groups + ) + diff = time_torch / time_mlx - 1.0 + + print( + f"({N}, {iH:3d}, {C:3d}), ({O:3d}, {wH:2d}, {C:3d}), {dtype}, {strides:5d}, {padding:4d}, {groups:6d}, {100. * diff:+5.2f}%" + ) + + if time_mlx >= 2.0 * time_torch: + print("ATTENTION ^^^^^^^") diff --git a/mlx/array.h b/mlx/array.h index aeb76d9c8..42ebca35d 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -114,6 +114,15 @@ class array { return array_desc_->strides; }; + /** + * Get the stride of the corresponding dimension. + * + * This function supports negative indexing and provides + * bounds checking. */ + size_t strides(int dim) const { + return strides().at(dim < 0 ? dim + ndim() : dim); + }; + /** Get the arrays data type. */ Dtype dtype() const { return array_desc_->dtype; diff --git a/mlx/backend/common/conv.cpp b/mlx/backend/common/conv.cpp index 5a8495040..f3162c056 100644 --- a/mlx/backend/common/conv.cpp +++ b/mlx/backend/common/conv.cpp @@ -38,11 +38,15 @@ void slow_conv_1D( const int N = in.shape(0); // Batch size, should be the same as out.shape(0) const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim + const int C = in.shape(2); // Input channels const int oH = out.shape(1); // Output spatial dim const int O = wt.shape(0); // Out channels - const int C = wt.shape(2); // In channels const int wH = wt.shape(1); // Weight spatial dim + const int groups = C / wt.shape(2); + const int C_per_group = wt.shape(2); + const int O_per_group = O / groups; + const size_t in_stride_N = in.strides()[0]; const size_t in_stride_H = in.strides()[1]; const size_t in_stride_C = in.strides()[2]; @@ -57,35 +61,36 @@ void slow_conv_1D( for (int n = 0; n < N; ++n) { for (int oh = 0; oh < oH; ++oh) { - for (int o = 0; o < O; ++o) { - const T* filter_wt_ptr = start_wt_ptr + o * wt_stride_O; - float r = 0.; + for (int g = 0; g < groups; ++g) { + for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { + const T* filter_wt_ptr = start_wt_ptr + o * wt_stride_O; + float r = 0.; - for (int wh = 0; wh < wH; ++wh) { - const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H; + for (int wh = 0; wh < wH; ++wh) { + const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H; - int wh_flip = flip ? (wH - wh - 1) : wh; - int ih = oh * wt_strides[0] - padding[0] + wh_flip * wt_dilation[0]; + int wh_flip = flip ? (wH - wh - 1) : wh; + int ih = oh * wt_strides[0] - padding[0] + wh_flip * wt_dilation[0]; - auto ih_div = std::div(ih, in_dilation[0]); + auto ih_div = std::div(ih, in_dilation[0]); - if (ih >= 0 && ih < iH && ih_div.rem == 0) { - for (int c = 0; c < C; ++c) { - r += static_cast( - in_ptr[ih_div.quot * in_stride_H + c * in_stride_C]) * - static_cast(wt_ptr[c * wt_stride_C]); - } // c + if (ih >= 0 && ih < iH && ih_div.rem == 0) { + for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) { + r += static_cast( + in_ptr[ih_div.quot * in_stride_H + c * in_stride_C]) * + static_cast(wt_ptr[(c % C_per_group) * wt_stride_C]); + } // c - } // ih check - } // wh + } // ih check + } // wh - out_ptr[oh * out_stride_H + o * out_stride_O] = static_cast(r); - } // o + out_ptr[oh * out_stride_H + o * out_stride_O] = static_cast(r); + } // o + } // g } // oh in_ptr += in_stride_N; out_ptr += out_stride_N; - } // n } @@ -366,11 +371,15 @@ void explicit_gemm_conv_1D_cpu( const std::vector& wt_dilation) { const int N = in.shape(0); // Batch size, should be the same as out.shape(0) const int iH = in.shape(1); // Input spatial dim + const int C = in.shape(2); // Input channels const int oH = out.shape(1); // Output spatial dim const int O = wt.shape(0); // Out channels - const int C = wt.shape(2); // In channels const int wH = wt.shape(1); // Weight spatial dim + const int groups = C / wt.shape(2); + const int C_per_group = wt.shape(2); + const int O_per_group = O / groups; + auto conv_dtype = float32; // Pad input @@ -402,6 +411,11 @@ void explicit_gemm_conv_1D_cpu( in_padded.strides()[1], in_padded.strides()[2]}; auto flags = in_padded.flags(); + if (groups > 1) { + // Transpose the last two dimensions for grouped convolutions + std::swap(strided_shape[2], strided_shape[3]); + std::swap(strided_strides[2], strided_strides[3]); + } array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {}); in_strided_view.copy_shared_buffer( @@ -416,7 +430,19 @@ void explicit_gemm_conv_1D_cpu( auto gemm_wt = wt; auto gemm_out = out; - if (wt.dtype() != float32 || !wt.flags().row_contiguous) { + if (groups > 1) { + // Transpose the last two dimensions for grouped convolutions + array wt_transpose( + {wt.shape(0), wt.shape(2), wt.shape(1)}, wt.dtype(), nullptr, {}); + wt_transpose.copy_shared_buffer( + wt, + {wt.strides(0), wt.strides(2), wt.strides(1)}, + wt.flags(), + wt.size(), + 0); + gemm_wt = array(wt_transpose.shape(), float32, nullptr, {}); + copy(wt_transpose, gemm_wt, CopyType::General); + } else if (wt.dtype() != float32 || !wt.flags().row_contiguous) { auto ctype = wt.flags().row_contiguous ? CopyType::Vector : CopyType::General; gemm_wt = array(wt.shape(), float32, nullptr, {}); @@ -428,27 +454,29 @@ void explicit_gemm_conv_1D_cpu( gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes())); } - // Perform gemm - cblas_sgemm( - CblasRowMajor, - CblasNoTrans, // no trans A - CblasTrans, // transB - strided_reshape[0], // M - O, // N - strided_reshape[1], // K - 1.0f, // alpha - in_strided.data(), - strided_reshape[1], // lda - gemm_wt.data(), - strided_reshape[1], // ldb - 0.0f, // beta - gemm_out.data(), - O // ldc - ); + for (int g = 0; g < groups; ++g) { + // Perform gemm + cblas_sgemm( + CblasRowMajor, + CblasNoTrans, // no trans A + CblasTrans, // transB + strided_reshape[0], // M + O_per_group, // N + C_per_group * wH, // K + 1.0f, // alpha + in_strided.data() + g * C_per_group * wH, // A + wH * C, // lda + gemm_wt.data() + g * O_per_group * C_per_group * wH, // B + wH * C_per_group, // ldb + 0.0f, // beta + gemm_out.data() + g * O_per_group, // C + O // ldc + ); - // Copy results if needed - if (out.dtype() != float32) { - copy(gemm_out, out, CopyType::Vector); + // Copy results if needed + if (out.dtype() != float32) { + copy(gemm_out, out, CopyType::Vector); + } } } diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index f2b3553f7..c8fd95c1a 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -89,6 +89,90 @@ void explicit_gemm_conv_ND_gpu( /*copies = */ copies); } +template +void explicit_gemm_conv_group_ND_gpu( + const Stream& s, + metal::Device& d, + const array& in, + const array& wt, + array out, + const MLXConvParams& conv_params) { + const int groups = conv_params.groups; + const int C_per_group = conv_params.C / conv_params.groups; + const int O_per_group = conv_params.O / conv_params.groups; + // Get gemm shapes + const int implicit_M = out.size() / conv_params.O; + const int implicit_K = wt.size() / conv_params.O; + const int implicit_N = O_per_group; + + int kernel_size = 1; + for (int i = 0; i < N; ++i) { + kernel_size *= conv_params.wS[i]; + } + + // Prepare unfolding array + std::vector unfolded_shape{implicit_M, implicit_K * groups}; + array in_unfolded(unfolded_shape, in.dtype(), nullptr, {}); + in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes())); + + // Prepare unfolding kernel + std::ostringstream kname; + kname << "naive_unfold_transpose_nd_" << type_to_name(in_unfolded) << "_" + << N; + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + compute_encoder->setComputePipelineState(kernel); + + compute_encoder.set_input_array(in, 0); + compute_encoder.set_output_array(in_unfolded, 1); + + compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2); + + // Launch unfolding kernel + int tgp_x = std::min(conv_params.C, 64); + tgp_x = 32 * ((tgp_x + 32 - 1) / 32); + int tgp_y = 256 / tgp_x; + + MTL::Size group_dims = MTL::Size(tgp_x, tgp_y, 1); + MTL::Size grid_dims = MTL::Size( + conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]); + + compute_encoder->dispatchThreads(grid_dims, group_dims); + + // Transpose kernel weights so that we can slice them by contiguous chunks + // of channel groups. + array wt_view( + {wt.shape(0), C_per_group, kernel_size}, wt.dtype(), nullptr, {}); + wt_view.copy_shared_buffer( + wt, + {wt.strides(0), 1, static_cast(C_per_group)}, + wt.flags(), + wt.size()); + + // Materialize + auto wt_transpose = array(wt_view.shape(), wt_view.dtype(), nullptr, {}); + copy_gpu(wt_view, wt_transpose, CopyType::General, s); + + // Perform gemm + std::vector copies = {in_unfolded, wt_view, wt_transpose}; + return steel_matmul_conv_groups( + s, + d, + /*a = */ in_unfolded, + /*b = */ wt_transpose, + /*c = */ out, + /*M = */ implicit_M, + /*N = */ implicit_N, + /*K = */ implicit_K, + /*a_cols = */ implicit_K * groups, + /*b_cols = */ implicit_K, + /*out_cols = */ implicit_N * groups, + /*a_transposed = */ false, + /*b_transposed = */ true, + /* groups = */ groups, + /*copies = */ copies); +} + void conv_1D_gpu( const Stream& s, metal::Device& d, @@ -99,6 +183,7 @@ void conv_1D_gpu( const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, + int groups, bool flip) { // Make conv params MLXConvParams<1> conv_params{ @@ -118,11 +203,15 @@ void conv_1D_gpu( {wt.strides()[0], wt.strides()[1], wt.strides()[2]}, /* const size_t out_strides[NDIM + 2] = */ {out.strides()[0], out.strides()[1], out.strides()[2]}, - /* const int groups = */ 1, + /* const int groups = */ groups, /* const bool flip = */ flip}; // Direct to explicit gemm conv - return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params); + if (groups > 1) { + return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params); + } else { + return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params); + } } void slow_conv_2D_gpu( @@ -721,6 +810,7 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { kernel_strides_, kernel_dilation_, input_dilation_, + groups_, flip_); } // Throw error diff --git a/mlx/backend/metal/kernels/conv.metal b/mlx/backend/metal/kernels/conv.metal index b977876ff..563002997 100644 --- a/mlx/backend/metal/kernels/conv.metal +++ b/mlx/backend/metal/kernels/conv.metal @@ -33,7 +33,7 @@ template // Set out out += gid.z * filter_size + gid.y * (params->C); - // Corrdinates in input + // Coordinates in input int is[N] = {0}; // gid.z: N oS (Batch and row in unfolded output) @@ -75,12 +75,81 @@ template } else { out[gid.x] = T(0); } +} +// This kernel unfolds the input array of size (N, *spatial_dims, C) +// into an array of size (N x *spatial_dims, C x *kernel_dims). +template +[[kernel]] void naive_unfold_transpose_Nd( + const device T* in [[buffer(0)]], + device T* out [[buffer(1)]], + const constant MLXConvParams* params [[buffer(2)]], + uint3 gid [[thread_position_in_grid]]) { + + int filter_size = params->C; + for(short i = 0; i < N; i++) filter_size *= params->wS[i]; + + int out_pixels = 1; + for(short i = 0; i < N; i++) out_pixels *= params->oS[i]; + + // Set out + out += gid.z * filter_size + gid.x * (filter_size / params->C); + + // Coordinates in input + int is[N] = {0}; + + // gid.z: N oS (Batch and row in unfolded output) + // gid.y: wS (Filter location to unfold input) + // gid.x: C (channel) + + int n = (gid.z) / out_pixels; + int oS = (gid.z) % out_pixels; + int wS = gid.y; + + bool valid = n < params->N; + + // Unroll dimensions + for (int i = N - 1; i >= 0; --i) { + int os_ = (oS % params->oS[i]); + int ws_ = (wS % params->wS[i]); + + ws_ = params->flip ? params->wS[i] - ws_ - 1 : ws_; + + int is_ = os_ * params->str[i] - params->pad[i] + ws_ * params->kdil[i]; + int is_max = 1 + params->idil[i] * (params->iS[i] - 1); + + valid &= is_ >= 0 && is_ < is_max && (is_ % params->idil[i] == 0); + + is[i] = is_ / params->idil[i]; + + oS /= params->oS[i]; + wS /= params->wS[i]; + + out += ws_ * params->str[i]; + } + + if(valid) { + size_t in_offset = n * params->in_strides[0]; + + for(int i = 0; i < N; ++i) { + in_offset += is[i] * params->in_strides[i + 1]; + } + + out[0] = in[in_offset + gid.x]; + } else { + out[0] = T(0); + } } #define instantiate_naive_unfold_nd(name, itype, n) \ template [[host_name("naive_unfold_nd_" #name "_" #n)]] \ [[kernel]] void naive_unfold_Nd( \ + const device itype* in [[buffer(0)]], \ + device itype* out [[buffer(1)]], \ + const constant MLXConvParams* params [[buffer(2)]], \ + uint3 gid [[thread_position_in_grid]]); \ + template [[host_name("naive_unfold_transpose_nd_" #name "_" #n)]] \ + [[kernel]] void naive_unfold_transpose_Nd( \ const device itype* in [[buffer(0)]], \ device itype* out [[buffer(1)]], \ const constant MLXConvParams* params [[buffer(2)]], \ diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index fb72b8bcd..3e550cd3d 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -260,6 +260,110 @@ inline auto collapse_batches(const array& a, const array& b, const array& c) { // Steel matmul fallback /////////////////////////////////////////////////////////////////////////////// +void steel_matmul_conv_groups( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + array& out, + int M, + int N, + int K, + int lda, + int ldb, + int ldd, + bool transpose_a, + bool transpose_b, + int groups, + std::vector& copies) { + using namespace mlx::steel; + + ///////////////////////////////////////////////////////////////////////////// + // Regular kernel dispatch + + // Determine dispatch kernel + int bm = 32, bn = 32, bk = 16; + int wm = 2, wn = 2; + + if ((size_t)M * N >= 1ul << 20) { + if (!transpose_a && transpose_b) { + bm = 64; + bn = (out.dtype() == float32) ? 64 : 32; + bk = (out.dtype() == float32) ? 16 : 32; + } else { + bm = 64; + bn = 64; + } + } + + // Prepare kernel name + std::ostringstream kname; + kname << "steel_gemm_" << (transpose_a ? 't' : 'n') + << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" + << type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk + << "_wm" << wm << "_wn" << wn << "_MN_" + << ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_" + << ((K % bk == 0) ? "t" : "n") << "aligned"; + + // Encode and dispatch kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + compute_encoder->setComputePipelineState(kernel); + + // Use problem size to determine threadblock swizzle + int tn = (N + bn - 1) / bn; + int tm = (M + bm - 1) / bm; + + // TODO: Explore device-based tuning for swizzle + int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2); + + // Prepare steel matmul params + GEMMParams params{ + /* const int M = */ M, + /* const int N = */ N, + /* const int K = */ K, + /* const int lda = */ lda, + /* const int ldb = */ ldb, + /* const int ldd = */ ldd, + /* const int tiles_n = */ tn, + /* const int tiles_m = */ tm, + /* const int batch_stride_a = */ K, + /* const int batch_stride_b = */ N * K, + /* const int batch_stride_d = */ N, + /* const int swizzle_log = */ swizzle_log, + /* const int gemm_k_iterations_aligned = */ (K / bk), + /* const int batch_ndim = */ 1}; + + // Prepare launch grid params + int tile = 1 << swizzle_log; + tm = (tm + tile - 1) / tile; + tn = tn * tile; + + MTL::Size group_dims = MTL::Size(32, wn, wm); + MTL::Size grid_dims = MTL::Size(tn, tm, groups); + + std::vector batch_shape = {1}; + std::vector batch_strides = {0}; + + // Launch kernel + compute_encoder.set_input_array(a, 0); + compute_encoder.set_input_array(b, 1); + compute_encoder.set_output_array(out, 3); + compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 4); + + compute_encoder->setBytes( + batch_shape.data(), sizeof(int) * batch_shape.size(), 6); + compute_encoder->setBytes( + batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7); + + compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + + // Clear copies + d.get_command_buffer(s.index)->addCompletedHandler( + [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); + return; +} + void steel_matmul( const Stream& s, metal::Device& d, diff --git a/mlx/backend/metal/matmul.h b/mlx/backend/metal/matmul.h index a9c872235..bbe41ea8d 100644 --- a/mlx/backend/metal/matmul.h +++ b/mlx/backend/metal/matmul.h @@ -12,6 +12,23 @@ namespace mlx::core { +void steel_matmul_conv_groups( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + array& out, + int M, + int N, + int K, + int lda, + int ldb, + int ldd, + bool transpose_a, + bool transpose_b, + int groups, + std::vector& copies); + void steel_matmul( const Stream& s, metal::Device& d, diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 1ec08fa56..5ef3f4724 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -320,7 +320,7 @@ array reshape( "[reshape] Cannot infer the shape of an empty array"); } - // Check the the reshaping is valid + // Check that the reshaping is valid if (a.size() != size) { std::ostringstream msg; msg << "[reshape] Cannot reshape array of size " << a.size() @@ -2947,7 +2947,8 @@ inline std::vector conv_out_shape( return out_shape; } -inline void run_conv_checks(const array& in, const array& wt, int n_dim) { +inline void +run_conv_checks(const array& in, const array& wt, int n_dim, int groups) { if (!issubdtype(in.dtype(), floating)) { std::ostringstream msg; msg << "[conv] Invalid input array with type " << in.dtype() << "." @@ -2972,11 +2973,35 @@ inline void run_conv_checks(const array& in, const array& wt, int n_dim) { throw std::invalid_argument(msg.str()); } - if (in.shape(n_dim + 1) != wt.shape(n_dim + 1)) { + if (in.shape(n_dim + 1) % groups != 0) { std::ostringstream msg; - msg << "[conv] Expect the input channels in the input" - << " and weight array to match but got shapes -" - << " input: " << in.shape() << " and weight: " << wt.shape(); + msg << "[conv] The input channels must be divisible by the number" + << " of groups. Got input with shape " << in.shape() << " and " + << groups << " groups."; + throw std::invalid_argument(msg.str()); + } + + if (groups > 1 && wt.shape(0) % groups != 0) { + std::ostringstream msg; + msg << "[conv] If groups > 1, the output channels must be divisible by the number" + << " of groups. Got " << wt.shape(0) << " output channels and " + << groups << " groups."; + throw std::invalid_argument(msg.str()); + } + + if (in.shape(n_dim + 1) != (groups * wt.shape(n_dim + 1))) { + std::ostringstream msg; + if (groups == 1) { + msg << "[conv] Expect the input channels in the input" + << " and weight array to match but got shapes -" + << " input: " << in.shape() << " and weight: " << wt.shape(); + + } else { + msg << "Given groups=" << groups << " and weights of shape " << wt.shape() + << ", expected to have " << (groups * wt.shape(n_dim + 1)) + << " input channels but got " << in.shape(n_dim + 1) + << " input channels instead."; + } throw std::invalid_argument(msg.str()); } } @@ -3039,8 +3064,9 @@ array conv_general( bool flip /* = false */, StreamOrDevice s /* = {} */) { // Run checks - if (groups != 1) { - throw std::invalid_argument("[conv] Cannot handle groups != 1 yet"); + if (groups != 1 && in.ndim() != 3) { + throw std::invalid_argument( + "[conv] Can only handle groups != 1 in 1D convolutions."); } int spatial_dims = in.ndim() - 2; @@ -3052,7 +3078,7 @@ array conv_general( } // Run checks - run_conv_checks(in, wt, spatial_dims); + run_conv_checks(in, wt, spatial_dims, groups); // Type promotion auto out_type = promote_types(in.dtype(), wt.dtype()); diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 8543daa22..85e56ac54 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -831,6 +831,11 @@ std::vector Convolution::vjp( assert(primals.size() == 2); std::vector grads; + if (groups_ != 1) { + throw std::invalid_argument( + "[Convolution] Backward pass not implemented for groups > 1."); + } + // Collect info auto& in = primals[0]; auto& wt = primals[1]; diff --git a/python/tests/test_conv.py b/python/tests/test_conv.py index ef1800638..8c5585126 100644 --- a/python/tests/test_conv.py +++ b/python/tests/test_conv.py @@ -77,7 +77,9 @@ class TestConv(mlx_tests.MLXTestCase): np_dtype = getattr(np, dtype) np.random.seed(0) in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype) - wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype) + wt_np = np.random.normal(0, 1.0 / C, (O, kH, int(C / groups))).astype( + np_dtype + ) in_mx, wt_mx = map(mx.array, (in_np, wt_np)) in_pt, wt_pt = map( @@ -119,6 +121,12 @@ class TestConv(mlx_tests.MLXTestCase): ): run_conv1D(N, C, O, iH, kH, stride, padding, dtype=dtype) + # Groups tests + N, C, O = (4, 32, 64) + iH, kH, stride, padding = (31, 5, 1, 2) + for group in (1, 2, 4, 8, 16, 32): + run_conv1D(N, C, O, iH, kH, stride=1, padding=1, groups=group, dtype=dtype) + # Strided inputs tests for tpose_in, tpose_wt in ( ((0, 2, 1), (0, 1, 2)), diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 26abfcfb4..eb49cbc46 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -3228,3 +3228,102 @@ TEST_CASE("test meshgrid") { CHECK(array_equal(out[0], expected_zero).item()); CHECK(array_equal(out[1], expected_one).item()); } + +TEST_CASE("test conv1d") { + auto in = astype( + array( + {0.5488135, + 0.71518937, + 0.60276338, + 0.54488318, + 0.4236548, + 0.64589411}, + {1, 3, 2}), + float16); + + int kernel = 3; + int stride = 1; + int padding = 1; + + { + int groups = 1; + auto wt = astype( + array( + { + + 0.43758721, 0.891773, 0.96366276, 0.38344152, + 0.79172504, 0.52889492, + + 0.56804456, 0.92559664, 0.07103606, 0.0871293, + 0.0202184, 0.83261985, + + 0.77815675, 0.87001215, 0.97861834, 0.79915856, + 0.46147936, 0.78052918, + + 0.11827443, 0.63992102, 0.14335329, 0.94466892, + 0.52184832, 0.41466194 + + }, + {4, 3, 2}), + float16); + + auto expected = array( + {1.5685, + 0.5672, + 1.8121, + 1.2948, + 2.3448, + 1.6104, + 2.7743, + 1.6126, + 1.4056, + 0.9331, + 1.8739, + 1.0909}, + {1, 3, 4}); + + auto out = conv1d(in, wt, stride, padding, /* dilation= */ 1, groups); + CHECK(allclose(out, expected, /* rtol = */ 1.0e-3).item()); + } + + { + int groups = 2; + auto wt = array( + {0.43758721, + 0.891773, + 0.96366276, + + 0.38344152, + 0.79172504, + 0.52889492, + + 0.56804456, + 0.92559664, + 0.07103606, + + 0.0871293, + 0.0202184, + 0.83261985 + + }, + {4, 3, 1}); + + auto expected = array( + {1.0703, + 0.7533, + 0.7007, + 0.4681, + 1.1859, + 0.9117, + 0.9565, + 0.6111, + 0.6416, + 0.5665, + 0.9074, + 0.0605}, + {1, 3, 4}); + + auto out = conv1d(in, wt, stride, padding, /* dilation= */ 1, groups); + CHECK(allclose(out, expected, /* rtol = */ 1.0e-3).item()); + } +}