From e4534dac1742ab29114764e1f779d55c6a90392a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sun, 6 Oct 2024 07:08:53 -0700 Subject: [PATCH] Conv grad with groups + bugfix (#1449) * fix bug in flipped conv with groups, start of grad for groups * fix * fix * fix + test --- mlx/backend/metal/conv.cpp | 35 +++--- mlx/backend/metal/kernels/conv.metal | 2 +- mlx/backend/metal/matmul.cpp | 174 ++++++--------------------- mlx/backend/metal/matmul.h | 9 +- mlx/primitives.cpp | 53 +++++--- python/tests/test_conv.py | 100 +++++++++++++++ 6 files changed, 197 insertions(+), 176 deletions(-) diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 71662c8ae..a7a62644d 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -72,7 +72,7 @@ void explicit_gemm_conv_ND_gpu( wt_reshaped.copy_shared_buffer(wt, wt_restride, wt_flags, wt.data_size()); // Perform gemm - std::vector copies = {in_unfolded, wt_reshaped}; + std::vector copies = {in_unfolded}; return steel_matmul( s, d, @@ -155,22 +155,27 @@ void explicit_gemm_conv_group_ND_gpu( copy_gpu(wt_view, wt_transpose, CopyType::General, s); // Perform gemm - std::vector copies = {in_unfolded, wt_view, wt_transpose}; - return steel_matmul_conv_groups( + std::vector copies = {in_unfolded, wt_transpose}; + return steel_matmul_regular( s, d, - /*a = */ in_unfolded, - /*b = */ wt_transpose, - /*c = */ out, - /*M = */ implicit_M, - /*N = */ implicit_N, - /*K = */ implicit_K, - /*a_cols = */ implicit_K * groups, - /*b_cols = */ implicit_K, - /*out_cols = */ implicit_N * groups, - /*a_transposed = */ false, - /*b_transposed = */ true, - /* groups = */ groups, + /* a = */ in_unfolded, + /* b = */ wt_transpose, + /* c = */ out, + /* M = */ implicit_M, + /* N = */ implicit_N, + /* K = */ implicit_K, + /* batch_size_out = */ groups, + /* a_cols = */ implicit_K * groups, + /* b_cols = */ implicit_K, + /* out_cols = */ implicit_N * groups, + /* a_transposed = */ false, + /* b_transposed = */ true, + /* batch_shape = */ {1}, + /* batch_strides = */ {0}, + /* A_batch_strides = */ size_t(implicit_K), + /* B_batch_strides = */ size_t(implicit_N) * implicit_K, + /* matrix_stride_out = */ size_t(implicit_N), /*copies = */ copies); } diff --git a/mlx/backend/metal/kernels/conv.metal b/mlx/backend/metal/kernels/conv.metal index fd43aa371..4798460df 100644 --- a/mlx/backend/metal/kernels/conv.metal +++ b/mlx/backend/metal/kernels/conv.metal @@ -113,6 +113,7 @@ template for (int i = N - 1; i >= 0; --i) { int os_ = (oS % params->oS[i]); int ws_ = (wS % params->wS[i]); + out += ws_ * kernel_stride; ws_ = params->flip ? params->wS[i] - ws_ - 1 : ws_; @@ -126,7 +127,6 @@ template oS /= params->oS[i]; wS /= params->wS[i]; - out += ws_ * kernel_stride; kernel_stride *= params->wS[i]; } diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index dbcf1b4df..3620a5676 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -88,7 +88,7 @@ inline auto collapse_batches(const array& a, const array& b, const array& c) { // Steel matmul fallback /////////////////////////////////////////////////////////////////////////////// -void steel_matmul_conv_groups( +void steel_matmul_regular( const Stream& s, metal::Device& d, const array& a, @@ -97,23 +97,25 @@ void steel_matmul_conv_groups( int M, int N, int K, + int batch_size_out, int lda, int ldb, int ldd, bool transpose_a, bool transpose_b, - int groups, + std::vector batch_shape, + std::vector batch_strides, + size_t A_batch_stride, + size_t B_batch_stride, + size_t matrix_stride_out, std::vector& copies) { using namespace mlx::steel; - ///////////////////////////////////////////////////////////////////////////// - // Regular kernel dispatch - // Determine dispatch kernel int bm = 32, bn = 32, bk = 16; int wm = 2, wn = 2; - if ((size_t)M * N >= 1ul << 20) { + if ((size_t)batch_size_out * M * N >= 1ul << 20) { if (!transpose_a && transpose_b) { bm = 64; bn = (out.dtype() == float32) ? 64 : 32; @@ -133,7 +135,7 @@ void steel_matmul_conv_groups( std::string base_name = kname.str(); - const bool has_batch = false; + const bool has_batch = (batch_shape.size() > 1); const bool use_out_source = false; const bool do_axpby = false; const bool align_M = (M % bm) == 0; @@ -197,12 +199,12 @@ void steel_matmul_conv_groups( /* const int ldd = */ ldd, /* const int tiles_n = */ tn, /* const int tiles_m = */ tm, - /* const size_t batch_stride_a = */ size_t(K), - /* const size_t batch_stride_b = */ size_t(N) * K, - /* const size_t batch_stride_d = */ size_t(N), + /* const size_t batch_stride_a = */ A_batch_stride, + /* const size_t batch_stride_b = */ B_batch_stride, + /* const size_t batch_stride_d = */ matrix_stride_out, /* const int swizzle_log = */ swizzle_log, /* const int gemm_k_iterations_aligned = */ (K / bk), - /* const int batch_ndim = */ 1}; + /* const int batch_ndim = */ int(batch_shape.size())}; // Prepare launch grid params int tile = 1 << swizzle_log; @@ -210,15 +212,13 @@ void steel_matmul_conv_groups( tn = tn * tile; MTL::Size group_dims = MTL::Size(32, wn, wm); - MTL::Size grid_dims = MTL::Size(tn, tm, groups); - - std::vector batch_shape = {1}; - std::vector batch_strides = {0}; + MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out); // Launch kernel compute_encoder.set_input_array(a, 0); compute_encoder.set_input_array(b, 1); compute_encoder.set_output_array(out, 3); + compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 4); set_vector_bytes(compute_encoder, batch_shape, 6); @@ -393,133 +393,31 @@ void steel_matmul( ///////////////////////////////////////////////////////////////////////////// // Regular kernel dispatch - - // Determine dispatch kernel - int bm = 32, bn = 32, bk = 16; - int wm = 2, wn = 2; - - if ((size_t)batch_size_out * M * N >= 1ul << 20) { - if (!transpose_a && transpose_b) { - bm = 64; - bn = (out.dtype() == float32) ? 64 : 32; - bk = (out.dtype() == float32) ? 16 : 32; - } else { - bm = 64; - bn = 64; - } - } - - // Prepare kernel name - std::ostringstream kname; - kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n') - << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" - << type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk - << "_wm" << wm << "_wn" << wn; - - std::string base_name = kname.str(); - - const bool has_batch = (batch_shape.size() > 1); - const bool use_out_source = false; - const bool do_axpby = false; - const bool align_M = (M % bm) == 0; - const bool align_N = (N % bn) == 0; - const bool align_K = (K % bk) == 0; - const bool do_gather = false; - - metal::MTLFCList func_consts = { - {&has_batch, MTL::DataType::DataTypeBool, 10}, - {&use_out_source, MTL::DataType::DataTypeBool, 100}, - {&do_axpby, MTL::DataType::DataTypeBool, 110}, - {&align_M, MTL::DataType::DataTypeBool, 200}, - {&align_N, MTL::DataType::DataTypeBool, 201}, - {&align_K, MTL::DataType::DataTypeBool, 202}, - {&do_gather, MTL::DataType::DataTypeBool, 300}, - }; - - // clang-format off - kname << "_has_batch_" << (has_batch ? 't' : 'n') - << "_use_out_source_" << (use_out_source ? 't' : 'n') - << "_do_axpby_" << (do_axpby ? 't' : 'n') - << "_align_M_" << (align_M ? 't' : 'n') - << "_align_N_" << (align_N ? 't' : 'n') - << "_align_K_" << (align_K ? 't' : 'n') - << "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on - - std::string hash_name = kname.str(); - - // Encode and dispatch kernel - auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = get_steel_gemm_fused_kernel( - d, - base_name, - hash_name, - func_consts, - out, - transpose_a, - transpose_b, - bm, - bn, - bk, - wm, - wn); - - compute_encoder->setComputePipelineState(kernel); - - // Use problem size to determine threadblock swizzle - int tn = (N + bn - 1) / bn; - int tm = (M + bm - 1) / bm; - - // TODO: Explore device-based tuning for swizzle - int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2); - - // Prepare steel matmul params - GEMMParams params{ - /* const int M = */ M, - /* const int N = */ N, - /* const int K = */ K, - /* const int lda = */ lda, - /* const int ldb = */ ldb, - /* const int ldd = */ N, - /* const int tiles_n = */ tn, - /* const int tiles_m = */ tm, - /* const size_t batch_stride_a = */ A_batch_stride.back(), - /* const size_t batch_stride_b = */ B_batch_stride.back(), - /* const size_t batch_stride_d = */ matrix_stride_out, - /* const int swizzle_log = */ swizzle_log, - /* const int gemm_k_iterations_aligned = */ (K / bk), - /* const int batch_ndim = */ int(batch_shape.size())}; - - // Prepare launch grid params - int tile = 1 << swizzle_log; - tm = (tm + tile - 1) / tile; - tn = tn * tile; - - MTL::Size group_dims = MTL::Size(32, wn, wm); - MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out); - std::vector batch_strides = A_batch_stride; batch_strides.insert( batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end()); - // Launch kernel - compute_encoder.set_input_array(a, 0); - compute_encoder.set_input_array(b, 1); - compute_encoder.set_output_array(out, 3); - - compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 4); - - set_vector_bytes(compute_encoder, batch_shape, 6); - set_vector_bytes(compute_encoder, batch_strides, 7); - - compute_encoder.dispatchThreadgroups(grid_dims, group_dims); - - // Clear copies - if (!copies.empty()) { - d.get_command_buffer(s.index)->addCompletedHandler( - [copies = std::move(copies)](MTL::CommandBuffer*) mutable { - copies.clear(); - }); - } + steel_matmul_regular( + s, + d, + a, + b, + out, + M, + N, + K, + batch_size_out, + lda, + ldb, + N, + transpose_a, + transpose_b, + std::move(batch_shape), + std::move(batch_strides), + A_batch_stride.back(), + B_batch_stride.back(), + matrix_stride_out, + copies); } void Matmul::eval_gpu(const std::vector& inputs, array& out) { diff --git a/mlx/backend/metal/matmul.h b/mlx/backend/metal/matmul.h index 3edf34f66..c771bb8b4 100644 --- a/mlx/backend/metal/matmul.h +++ b/mlx/backend/metal/matmul.h @@ -4,7 +4,7 @@ namespace mlx::core { -void steel_matmul_conv_groups( +void steel_matmul_regular( const Stream& s, metal::Device& d, const array& a, @@ -13,12 +13,17 @@ void steel_matmul_conv_groups( int M, int N, int K, + int batch_size_out, int lda, int ldb, int ldd, bool transpose_a, bool transpose_b, - int groups, + std::vector batch_shape, + std::vector batch_strides, + size_t A_batch_stride, + size_t B_batch_stride, + size_t matrix_stride_out, std::vector& copies); void steel_matmul( diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index c28a945a3..ddcc7d938 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -929,16 +929,28 @@ std::vector Convolution::vjp( assert(primals.size() == 2); std::vector grads; - if (groups_ != 1) { - throw std::invalid_argument( - "[Convolution] Backward pass not implemented for groups > 1."); - } - // Collect info auto& in = primals[0]; auto& wt = primals[1]; auto& cotan = cotangents[0]; + auto group_transpose = + [this](const array& x, int group_dim, int ax_a, int ax_b) { + if (groups_ > 1) { + auto shape = x.shape(); + if (group_dim < 0) { + group_dim += shape.size(); + } + shape.insert(shape.begin() + group_dim, groups_); + shape[group_dim + 1] = shape[group_dim + 1] / groups_; + auto x_trans = swapaxes( + reshape(x, std::move(shape), stream()), ax_a, ax_b, stream()); + return flatten(x_trans, group_dim, group_dim + 1, stream()); + } else { + return swapaxes(x, 0, -1, stream()); + } + }; + for (int a : argnums) { // Grads for input if (a == 0) { @@ -976,8 +988,7 @@ std::vector Convolution::vjp( } } - auto wt_trans = swapaxes(wt, 0, -1, stream()); - + auto wt_trans = group_transpose(wt, 0, 1, -1); auto grad = conv_general( /* const array& input = */ cotan, /* const array& weight = */ wt_trans, @@ -986,7 +997,7 @@ std::vector Convolution::vjp( /* std::vector padding_hi = */ padding_hi, /* std::vector kernel_dilation = */ kernel_dilation_, /* std::vector input_dilation = */ kernel_strides_, - /* int groups = */ 1, + /* int groups = */ groups_, /* bool flip = */ !flip_, stream()); @@ -1020,14 +1031,11 @@ std::vector Convolution::vjp( no_dilation &= (input_dilation_[i] == 1) && (kernel_dilation_[i] == 1); } - if (no_dilation && !flip_) { + if (no_dilation && !flip_ && groups_ == 1) { auto grad = conv_weight_backward_patches( in, wt, cotan, kernel_strides_, padding_, stream()); grads.push_back(grad); } else { - auto cotan_trans = swapaxes(cotan, 0, -1, stream()); - auto in_trans = swapaxes(in, 0, -1, stream()); - if (flip_) { auto padding = padding_; for (int i = 0; i < padding.size(); i++) { @@ -1035,6 +1043,9 @@ std::vector Convolution::vjp( padding[i] = wt_size - padding_[i] - 1; } + auto cotan_trans = group_transpose(cotan, -1, 0, -1); + auto in_trans = swapaxes(in, 0, -1, stream()); + auto grad_trans = conv_general( /* const array& input = */ cotan_trans, /* const array& weight = */ in_trans, @@ -1043,11 +1054,14 @@ std::vector Convolution::vjp( /* std::vector padding_hi = */ padding, /* std::vector kernel_dilation = */ input_dilation_, /* std::vector input_dilation = */ kernel_strides_, - /* int groups = */ 1, + /* int groups = */ groups_, /* bool flip = */ false, stream()); - auto grad = swapaxes(grad_trans, 0, -1, stream()); - grads.push_back(grad_trans); + if (groups_ > 1) { + grads.push_back(group_transpose(grad_trans, -1, 0, -2)); + } else { + grads.push_back(grad_trans); + } } else { std::vector padding_lo = padding_; std::vector padding_hi = padding_; @@ -1058,9 +1072,9 @@ std::vector Convolution::vjp( int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1); padding_hi[i] = out_size - in_size + wt_size - padding_[i] - 1; } - - auto in_trans = swapaxes(in, 0, -1, stream()); auto cotan_trans = swapaxes(cotan, 0, -1, stream()); + auto in_trans = group_transpose(in, -1, 0, -1); + auto grad_trans = conv_general( /* const array& input = */ in_trans, /* const array& weight = */ cotan_trans, @@ -1069,11 +1083,10 @@ std::vector Convolution::vjp( /* std::vector padding_hi = */ padding_hi, /* std::vector kernel_dilation = */ kernel_strides_, /* std::vector input_dilation = */ input_dilation_, - /* int groups = */ 1, + /* int groups = */ groups_, /* bool flip = */ false, stream()); - auto grad = swapaxes(grad_trans, 0, -1, stream()); - grads.push_back(grad); + grads.push_back(swapaxes(grad_trans, 0, -1, stream())); } } } diff --git a/python/tests/test_conv.py b/python/tests/test_conv.py index 46291cf6d..f6bff01cb 100644 --- a/python/tests/test_conv.py +++ b/python/tests/test_conv.py @@ -47,6 +47,13 @@ class TestConv(mlx_tests.MLXTestCase): self.assertEqual(c_mx.shape, c_np.shape) self.assertTrue(np.allclose(c_mx, c_np, atol=atol)) + def test_conv_1d_groups_flipped(self): + x = mx.broadcast_to(mx.arange(5).astype(mx.float32), (2, 5)).T + w = mx.broadcast_to(mx.arange(4).astype(mx.float32), (2, 4)) + out = mx.conv_general(x[None], w[..., None], flip=True, groups=2) + expected = mx.array([4.0, 4.0, 10.0, 10.0]).reshape(1, 2, 2) + self.assertTrue(mx.allclose(out, expected)) + @unittest.skipIf(not has_torch, "requires Torch") def test_torch_conv_1D(self): def run_conv1D( @@ -897,6 +904,99 @@ class TestConv(mlx_tests.MLXTestCase): expected = mx.array([[dw00, dw01], [dw10, dw11]]) self.assertTrue(mx.allclose(dw, expected)) + def test_conv_groups_grad(self): + def fn(x, w): + num_groups = x.shape[-1] // w.shape[-1] + return mx.conv1d(x, w, groups=num_groups) + + def fn_gt(x, w): + num_groups = x.shape[-1] // w.shape[-1] + group_size = w.shape[-1] + ws = w.reshape(num_groups, -1, *w.shape[1:]).split(num_groups) + xs = x.reshape(*x.shape[:-1], num_groups, -1).split(num_groups, axis=-2) + return mx.concatenate( + [mx.conv_general(x.squeeze(-2), w.squeeze(0)) for x, w in zip(xs, ws)], + axis=-1, + ) + + mx.random.seed(3) + + w = mx.random.normal(shape=(2, 3, 1)) + x = mx.random.normal(shape=(1, 5, 2)) + cotans = (mx.ones(shape=(1, 3, 2)),) + grads = mx.vjp(fn, (x, w), cotans)[1] + expected = mx.vjp(fn_gt, (x, w), cotans)[1] + self.assertTrue(mx.allclose(expected[0], grads[0])) + self.assertTrue(mx.allclose(expected[1], grads[1])) + + w = mx.random.normal(shape=(2, 3, 2)) + x = mx.random.normal(shape=(1, 5, 4)) + cotans = (mx.ones(shape=(1, 3, 2)),) + grads = mx.vjp(fn, (x, w), cotans)[1] + expected = mx.vjp(fn_gt, (x, w), cotans)[1] + self.assertTrue(mx.allclose(expected[0], grads[0])) + self.assertTrue(mx.allclose(expected[1], grads[1])) + + w = mx.random.normal(shape=(6, 3, 2)) + x = mx.random.normal(shape=(1, 5, 4)) + cotans = (mx.ones(shape=(1, 3, 6)),) + grads = mx.vjp(fn, (x, w), cotans)[1] + expected = mx.vjp(fn_gt, (x, w), cotans)[1] + self.assertTrue(mx.allclose(expected[0], grads[0])) + self.assertTrue(mx.allclose(expected[1], grads[1])) + + # Test 2D + w = mx.random.normal(shape=(2, 3, 3, 1)) + x = mx.random.normal(shape=(1, 5, 5, 2)) + cotans = (mx.ones(shape=(1, 3, 3, 2)),) + grads = mx.vjp(fn, (x, w), cotans)[1] + expected = mx.vjp(fn_gt, (x, w), cotans)[1] + self.assertTrue(mx.allclose(expected[0], grads[0])) + self.assertTrue(mx.allclose(expected[1], grads[1])) + + # Test with flip + def fn(x, w): + num_groups = x.shape[-1] // w.shape[-1] + return mx.conv_general(x, w, groups=num_groups, flip=True) + + def fn_gt(x, w): + num_groups = x.shape[-1] // w.shape[-1] + group_size = w.shape[-1] + ws = w.reshape(num_groups, -1, *w.shape[1:]).split(num_groups) + xs = x.reshape(*x.shape[:-1], num_groups, -1).split(num_groups, axis=-2) + return mx.concatenate( + [ + mx.conv_general(x.squeeze(-2), w.squeeze(0), flip=True) + for x, w in zip(xs, ws) + ], + axis=-1, + ) + + w = mx.random.normal(shape=(2, 3, 1)) + x = mx.random.normal(shape=(1, 5, 2)) + cotans = (mx.ones(shape=(1, 3, 2)),) + grads = mx.vjp(fn, (x, w), cotans)[1] + expected = mx.vjp(fn_gt, (x, w), cotans)[1] + self.assertTrue(mx.allclose(expected[0], grads[0])) + self.assertTrue(mx.allclose(expected[1], grads[1])) + + w = mx.random.normal(shape=(2, 3, 2)) + x = mx.random.normal(shape=(1, 5, 4)) + cotans = (mx.ones(shape=(1, 3, 2)),) + grads = mx.vjp(fn, (x, w), cotans)[1] + expected = mx.vjp(fn_gt, (x, w), cotans)[1] + self.assertTrue(mx.allclose(expected[0], grads[0])) + self.assertTrue(mx.allclose(expected[1], grads[1])) + + # Test 2D + w = mx.random.normal(shape=(2, 3, 3, 1)) + x = mx.random.normal(shape=(1, 5, 5, 2)) + cotans = (mx.ones(shape=(1, 3, 3, 2)),) + grads = mx.vjp(fn, (x, w), cotans)[1] + expected = mx.vjp(fn_gt, (x, w), cotans)[1] + self.assertTrue(mx.allclose(expected[0], grads[0])) + self.assertTrue(mx.allclose(expected[1], grads[1])) + if __name__ == "__main__": unittest.main()