From 8590c0941e5c034b56dea3b33efa108668de540c Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 10 Jun 2025 20:58:16 -0700 Subject: [PATCH] Add load_safe to the general conv loaders (#2258) --- benchmarks/python/conv_unaligned_bench.py | 107 ++++++++++++++++++ mlx/backend/metal/conv.cpp | 36 ++++-- mlx/backend/metal/jit_kernels.cpp | 4 +- mlx/backend/metal/kernels.h | 2 + .../steel/conv/kernels/steel_conv_general.h | 63 ++++++++--- .../steel/conv/loaders/loader_general.h | 95 ++++++++++++++++ mlx/backend/metal/nojit_kernels.cpp | 4 +- python/tests/test_conv.py | 13 +++ 8 files changed, 302 insertions(+), 22 deletions(-) create mode 100644 benchmarks/python/conv_unaligned_bench.py diff --git a/benchmarks/python/conv_unaligned_bench.py b/benchmarks/python/conv_unaligned_bench.py new file mode 100644 index 000000000..981d7b48b --- /dev/null +++ b/benchmarks/python/conv_unaligned_bench.py @@ -0,0 +1,107 @@ +import math +import time + +import mlx.core as mx +import numpy as np +import torch + +N_warmup = 10 +N_iter_bench = 100 +N_iter_func = 5 + + +def bench(f, a, b): + for i in range(N_warmup): + f(a, b) + torch.mps.synchronize() + + s = time.perf_counter_ns() + for i in range(N_iter_bench): + f(a, b) + e = time.perf_counter_ns() + return (e - s) * 1e-9 + + +def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1): + def mx_conv_2D(a, b): + ys = [] + for i in range(N_iter_func): + y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups) + ys.append(y) + mx.eval(ys) + return ys + + return mx_conv_2D + + +def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1): + @torch.no_grad() + def pt_conv_2D(a, b): + ys = [] + for i in range(N_iter_func): + y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups) + ys.append(y) + torch.mps.synchronize() + return ys + + return pt_conv_2D + + +def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype): + scale = 1.0 / math.sqrt(kH * kH * C) + a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype) + b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype( + np_dtype + ) + + a_mx = mx.array(a_np) + b_mx = mx.array(b_np) + + a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps") + b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps") + + torch.mps.synchronize() + + f_mx = make_mx_conv_2D(strides, padding, groups) + f_pt = make_pt_conv_2D(strides, padding, groups) + + time_torch = bench(f_pt, a_pt, b_pt) + time_mlx = bench(f_mx, a_mx, b_mx) + + out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups) + out_pt = torch.conv2d( + a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups + ) + out_pt = torch.permute(out_pt, (0, 2, 3, 1)) + out_pt = out_pt.numpy(force=True) + + atol = 2e-5 if np_dtype == np.float32 else 1e-4 + + if not np.allclose(out_pt, out_mx, atol=atol): + print( + f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}" + ) + + return time_mlx, time_torch + + +if __name__ == "__main__": + dtype = "float32" + shapes = ( + (4, 32, 32, 21, 3, 3, 128), + (4, 32, 32, 21, 3, 3, 37), + (4, 32, 32, 370, 3, 3, 370), + (4, 32, 32, 370, 7, 7, 128), + (2, 320, 640, 21, 7, 7, 21), + ) + for N, H, W, C, kh, kw, O in shapes: + time_mlx, time_torch = bench_shape( + N, H, W, C, kh, kw, O, (1, 1), (0, 0), 1, dtype + ) + diff = time_torch / time_mlx - 1.0 + + print( + f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kh:2d}, {kw:2d}, {C:3d}), {dtype}, {100. * diff:+5.2f}%" + ) + if time_mlx >= 2.0 * time_torch: + print("ATTENTION ^^^^^^^") diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 593b79384..697afa6a1 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -391,6 +391,7 @@ void implicit_gemm_conv_2D_general_gpu( // Get channel iteration info int channel_k_iters = ((conv_params.C + bk - 1) / bk); int gemm_k_iters = channel_k_iters; + bool align_C = conv_params.C % bk == 0; // Fix host side helper params int sign = (conv_params.flip ? -1 : 1); @@ -419,14 +420,33 @@ void implicit_gemm_conv_2D_general_gpu( /* const int swizzle_log = */ swizzle_log}; // Determine kernel - std::ostringstream kname; - kname << "implicit_gemm_conv_2d_general_" << type_to_name(out) << "_bm" << bm - << "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn; + std::string kname; + kname.reserve(64); + concatenate( + kname, + "implicit_gemm_conv_2d_general_", + type_to_name(out), + "_bm", + bm, + "_bn", + bn, + "_bk", + bk, + "_wm", + wm, + "_wn", + wn); + std::string hash_name; + hash_name.reserve(64); + concatenate(hash_name, kname, "_alC_", align_C); + metal::MTLFCList func_consts = { + {&align_C, MTL::DataType::DataTypeBool, 200}, + }; // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = - get_steel_conv_general_kernel(d, kname.str(), out, bm, bn, bk, wm, wn); + auto kernel = get_steel_conv_general_kernel( + d, kname, hash_name, func_consts, out, bm, bn, bk, wm, wn); compute_encoder.set_compute_pipeline_state(kernel); // Deduce grid launch dimensions @@ -728,8 +748,10 @@ void dispatch_conv_2D_gpu( // Direct to winograd conv bool inp_large = - (conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12; + (conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 4096; bool channels_large = (conv_params.C + conv_params.O) >= 256; + bool out_large = + (conv_params.N * conv_params.oS[0] * conv_params.oS[1]) >= 256; if (!conv_params.flip && is_stride_one && is_kdil_one && is_idil_one && conv_params.wS[0] == 3 && conv_params.wS[1] == 3 && conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large && @@ -743,7 +765,7 @@ void dispatch_conv_2D_gpu( return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params); } - else if (conv_params.C % 16 == 0 && conv_params.O % 16 == 0) { + else if ((conv_params.C % 16 == 0 && conv_params.O % 16 == 0) || out_large) { return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params); } diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index 15e21af6c..467380c3a 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -727,6 +727,8 @@ MTL::ComputePipelineState* get_steel_conv_kernel( MTL::ComputePipelineState* get_steel_conv_general_kernel( metal::Device& d, const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, const array& out, int bm, int bn, @@ -749,7 +751,7 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel( wn); return kernel_source.str(); }); - return d.get_kernel(kernel_name, lib); + return d.get_kernel(kernel_name, lib, hash_name, func_consts); } MTL::ComputePipelineState* get_fft_kernel( diff --git a/mlx/backend/metal/kernels.h b/mlx/backend/metal/kernels.h index 6d8864385..1de5fa47c 100644 --- a/mlx/backend/metal/kernels.h +++ b/mlx/backend/metal/kernels.h @@ -205,6 +205,8 @@ MTL::ComputePipelineState* get_gemv_masked_kernel( MTL::ComputePipelineState* get_steel_conv_general_kernel( metal::Device& d, const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, const array& out, int bm, int bn, diff --git a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h index 8253638f1..9afebd307 100644 --- a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +++ b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h @@ -2,6 +2,8 @@ #include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h" +constant bool align_C [[function_constant(200)]]; + template < typename T, int BM, @@ -118,23 +120,58 @@ implicit_gemm_conv_2d_general( // Prepare threadgroup mma operation mma_t mma_op(simd_gid, simd_lid); - int gemm_k_iterations = - base_wh_size * base_ww_size * gemm_params->gemm_k_iterations; + if (align_C) { + int gemm_k_iterations = + base_wh_size * base_ww_size * gemm_params->gemm_k_iterations; - for (int k = 0; k < gemm_k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - // Load elements into threadgroup - loader_a.load_unsafe(); - loader_b.load_unsafe(); + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); - threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_barrier(mem_flags::mem_threadgroup); - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); - // Prepare for next iteration - loader_a.next(); - loader_b.next(); + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + } + + else { + for (int k = 1; k < gemm_params->gemm_k_iterations; k++) { + for (int j = 0; j < base_wh_size * base_ww_size; j++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + } + const short remaining_k = params->C % BK; + for (int j = 0; j < base_wh_size * base_ww_size; j++) { + // Load elements into threadgroup + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(remaining_k); + loader_b.load_safe(remaining_k); + threadgroup_barrier(mem_flags::mem_threadgroup); + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } } threadgroup_barrier(mem_flags::mem_none); diff --git a/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h b/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h index 72335e698..9b7ddc2ee 100644 --- a/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +++ b/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h @@ -137,6 +137,52 @@ struct Conv2DInputBlockLoaderGeneral { } } + METAL_FUNC void load_safe(const short remaining_k) const { + STEEL_PRAGMA_UNROLL + for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) { + // Find bounds + int n = read_n[i]; + + int h_flip = params->flip ? params->wS[0] - weight_h - 1 : weight_h; + int w_flip = params->flip ? params->wS[1] - weight_w - 1 : weight_w; + + int ih_dil = read_ih[i] + h_flip * params->kdil[0]; + int iw_dil = read_iw[i] + w_flip * params->kdil[1]; + + int ih = ih_dil / params->idil[0]; + int iw = iw_dil / params->idil[1]; + + size_t offset = ih * params->in_strides[1] + iw * params->in_strides[2]; + + // Read from input if in bounds + if ((n < params->N) && (ih_dil >= 0 && ih < params->iS[0]) && + (iw_dil >= 0 && iw < params->iS[1])) { + if (bj + vec_size <= remaining_k) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = (src[i])[offset + j]; + } + } else { + for (short j = 0; j < vec_size; ++j) { + if (bj + j < remaining_k) { + dst[is * dst_ld + j] = (src[i])[offset + j]; + } else { + dst[is * dst_ld + j] = T(0); + } + } + } + } + + // Zero pad otherwise + else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = T(0); + } + } + } + } + /* Iteration helper */ METAL_FUNC void next() { weight_w += jump_params->f_wgt_jump_w; @@ -262,6 +308,55 @@ struct Conv2DWeightBlockLoaderGeneral { } } + METAL_FUNC void load_safe(const short remaining_k) const { + const device T* curr_src = src + weight_h * params->wt_strides[1] + + weight_w * params->wt_strides[2]; + + if ((start_row + BN <= params->O)) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BN; i += TROWS) { + if (bj + vec_size <= remaining_k) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } + } else { + for (short j = 0; j < vec_size; j++) { + if (bj + j < remaining_k) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } else { + dst[i * dst_ld + j] = T(0); + } + } + } + } + } else { + for (short i = 0; i < BN; i += TROWS) { + if ((start_row + i) < params->O) { + if (bj + vec_size <= remaining_k) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } + } else { + for (short j = 0; j < vec_size; j++) { + if (bj + j < remaining_k) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } else { + dst[i * dst_ld + j] = T(0); + } + } + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + } + } + } + /* Iteration helper */ METAL_FUNC void next() { weight_w += jump_params->f_wgt_jump_w; diff --git a/mlx/backend/metal/nojit_kernels.cpp b/mlx/backend/metal/nojit_kernels.cpp index b1478d33b..b0375e37f 100644 --- a/mlx/backend/metal/nojit_kernels.cpp +++ b/mlx/backend/metal/nojit_kernels.cpp @@ -244,13 +244,15 @@ MTL::ComputePipelineState* get_steel_conv_kernel( MTL::ComputePipelineState* get_steel_conv_general_kernel( metal::Device& d, const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, const array&, int, int, int, int, int) { - return d.get_kernel(kernel_name); + return d.get_kernel(kernel_name, hash_name, func_consts); } MTL::ComputePipelineState* get_fft_kernel( diff --git a/python/tests/test_conv.py b/python/tests/test_conv.py index 9fe11286d..c68315a5d 100644 --- a/python/tests/test_conv.py +++ b/python/tests/test_conv.py @@ -1173,6 +1173,19 @@ class TestConv(mlx_tests.MLXTestCase): self.assertTrue(mx.allclose(out, out_2d.squeeze(2))) + def test_conv2d_unaligned_channels(self): + x = mx.random.uniform(shape=(2, 16, 16, 21)) + w = mx.random.uniform(shape=(32, 3, 3, 21)) + y = mx.conv2d(x, w, stream=mx.cpu) + y_hat = mx.conv2d(x, w) + self.assertTrue(mx.allclose(y, y_hat)) + + x = mx.random.uniform(shape=(2, 16, 16, 21)) + w = mx.random.uniform(shape=(21, 3, 3, 21)) + y = mx.conv2d(x, w, stream=mx.cpu) + y_hat = mx.conv2d(x, w) + self.assertTrue(mx.allclose(y, y_hat)) + if __name__ == "__main__": unittest.main()