From 4abb218d21e36f0d3dd9fc47e51387637e96c1ee Mon Sep 17 00:00:00 2001 From: Cheng Date: Sat, 16 Aug 2025 07:57:30 +0900 Subject: [PATCH] The naive_conv_2d is no longer used (#2496) --- mlx/backend/metal/kernels/conv.metal | 109 --------------------------- 1 file changed, 109 deletions(-) diff --git a/mlx/backend/metal/kernels/conv.metal b/mlx/backend/metal/kernels/conv.metal index 620352144..e169ade71 100644 --- a/mlx/backend/metal/kernels/conv.metal +++ b/mlx/backend/metal/kernels/conv.metal @@ -166,115 +166,6 @@ instantiate_naive_unfold_nd_dims(float32, float); instantiate_naive_unfold_nd_dims(float16, half); instantiate_naive_unfold_nd_dims(bfloat16, bfloat16_t); -/////////////////////////////////////////////////////////////////////////////// -/// Slow and naive conv2d kernels -/////////////////////////////////////////////////////////////////////////////// - -template < - typename T, - const int BM, /* Threadgroup rows (in threads) */ - const int BN, /* Threadgroup cols (in threads) */ - const int TM, /* Thread rows (in elements) */ - const int TN, /* Thread cols (in elements) */ - const int BC = 16> -[[kernel]] void naive_conv_2d( - const device T* in [[buffer(0)]], - const device T* wt [[buffer(1)]], - device T* out [[buffer(2)]], - const constant MLXConvParams<2>& params [[buffer(3)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - (void)simd_gid; - (void)simd_lid; - - out += tid.z * params.out_strides[0]; - in += tid.z * params.in_strides[0]; - - int out_o = tid.y * BN * TN + lid.y * TN; - int out_hw = tid.x * BM * TM + lid.x * TM; - - int out_h[TM]; - int out_w[TN]; - - for (int m = 0; m < TM; ++m) { - int mm = (out_hw + m); - out_h[m] = mm / params.oS[1]; - out_w[m] = mm % params.oS[1]; - } - - T in_local[TM]; - T wt_local[TN]; - T out_local[TM * TN] = {T(0)}; - - for (int h = 0; h < params.wS[0]; ++h) { - for (int w = 0; w < params.wS[1]; ++w) { - for (int c = 0; c < params.C; ++c) { - // Local in - for (int m = 0; m < TM; m++) { - int i = out_h[m] * params.str[0] - params.pad[0] + h * params.kdil[0]; - int j = out_w[m] * params.str[1] - params.pad[1] + w * params.kdil[1]; - - bool valid = i >= 0 && i < params.iS[0] && j >= 0 && j < params.iS[1]; - in_local[m] = valid - ? in[i * params.in_strides[1] + j * params.in_strides[2] + c] - : T(0); - } - - // Load weight - for (int n = 0; n < TN; ++n) { - int o = out_o + n; - wt_local[n] = o < params.O - ? wt[o * params.wt_strides[0] + h * params.wt_strides[1] + - w * params.wt_strides[2] + c] - : T(0); - } - - // Accumulate - for (int m = 0; m < TM; ++m) { - for (int n = 0; n < TN; ++n) { - out_local[m * TN + n] += in_local[m] * wt_local[n]; - } - } - } - } - } - - for (int m = 0; m < TM; ++m) { - for (int n = 0; n < TN; ++n) { - if (out_h[m] < params.oS[0] && out_w[m] < params.oS[1] && - (out_o + n) < params.O) - out[out_h[m] * params.out_strides[1] + - out_w[m] * params.out_strides[2] + out_o + n] = - out_local[m * TN + n]; - } - } -} - -// Instantiations - -#define instantiate_naive_conv_2d(name, itype, bm, bn, tm, tn) \ - template [[host_name("naive_conv_2d_" #name "_bm" #bm "_bn" #bn "_tm" #tm \ - "_tn" #tn)]] [[kernel]] void \ - naive_conv_2d( \ - const device itype* in [[buffer(0)]], \ - const device itype* wt [[buffer(1)]], \ - device itype* out [[buffer(2)]], \ - const constant MLXConvParams<2>& params [[buffer(3)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]], \ - uint simd_gid [[simdgroup_index_in_threadgroup]], \ - uint simd_lid [[thread_index_in_simdgroup]]); - -#define instantiate_naive_conv_2d_blocks(name, itype) \ - instantiate_naive_conv_2d(name, itype, 16, 8, 4, 4) \ - instantiate_naive_conv_2d(name, itype, 16, 8, 2, 4) - -instantiate_naive_conv_2d_blocks(float32, float); -instantiate_naive_conv_2d_blocks(float16, half); -instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t); - /////////////////////////////////////////////////////////////////////////////// /// Depthwise convolution kernels ///////////////////////////////////////////////////////////////////////////////