diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 5158d5812..4d36c6538 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -37,6 +37,7 @@ endfunction(build_kernel) build_kernel(arg_reduce) build_kernel(conv steel/conv/params.h) build_kernel(gemv steel/utils.h) +build_kernel(gemv_masked steel/utils.h) build_kernel(layer_norm) build_kernel(random) build_kernel(rms_norm) diff --git a/mlx/backend/metal/kernels/gemv.metal b/mlx/backend/metal/kernels/gemv.metal index 0398c2cd1..de63dbff6 100644 --- a/mlx/backend/metal/kernels/gemv.metal +++ b/mlx/backend/metal/kernels/gemv.metal @@ -17,29 +17,250 @@ using namespace metal; #define MLX_MTL_CONST static constant constexpr const -MLX_MTL_CONST int SIMD_SIZE = 32; - template < typename T, - const int BM, /* Threadgroup rows (in threads) */ - const int BN, /* Threadgroup cols (in threads) */ + const int BM, /* Threadgroup rows (in simdgroups) */ + const int BN, /* Threadgroup cols (in simdgroups) */ + const int SM, /* Simdgroup rows (in threads) */ + const int SN, /* Simdgroup cols (in threads) */ const int TM, /* Thread rows (in elements) */ const int TN, /* Thread cols (in elements) */ const bool kDoAxpby> /* Do out = alpha * out + beta * bias */ struct GEMVKernel { - static_assert(BN == SIMD_SIZE, "gemv block must have a width of SIMD_SIZE"); + MLX_MTL_CONST int threadsM = BM * SM; + MLX_MTL_CONST int threadsN = BN * SN; - // - The matrix of size (M = out_vec_size, N = in_vec_size) is divided up - // into blocks of (BM * TM, BN * TN) divided among threadgroups + MLX_MTL_CONST int blockM = threadsM * TM; + MLX_MTL_CONST int blockN = threadsN * TN; + + static_assert(SM * SN == 32, "simdgroup can only have 32 threads"); + + static_assert( + SN == 8 || SN == 16 || SN == 32, + "gemv block must have a width of 8, 16, or 32"); + + // - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up + // into blocks of (blockM, blockN) divided among threadgroups // - Every thread works on a block of (TM, TN) - // - We assume each thead group is launched with (BN, BM, 1) threads + // - We assume each threadgroup has (threadsN, threadsM, 1) threads // - // 1. A thread loads TN elements each from mat along TM contiguous rows + // 1. A thread loads TN elements each from mat along TM rows // and the corresponding scalar from the vector // 2. The thread then multiplies and adds to accumulate its local result for // the block // 3. At the end, each thread has accumulated results over all blocks across // the rows. These are then summed up across the threadgroup + // 4. Each threadgroup writes its accumulated blockM outputs + // + // Edge case handling: + // - The threadgroup with the largest tid has blocks that exceed the matrix + // * The blocks that start outside the matrix are never read (thread results + // remain zero) + // * The last thread that partially overlaps with the matrix is shifted + // inwards such that the thread block fits exactly in the matrix + + MLX_MTL_CONST short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0; + MLX_MTL_CONST bool needs_tgp_reduction = BN > 1; + + static METAL_FUNC void + load_unsafe(const device T* src, thread T dst[TN], const int src_offset = 0) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + dst[tn] = src[src_offset + tn]; + } + } + + static METAL_FUNC void load_safe( + const device T* src, + thread T dst[TN], + const int src_offset = 0, + const int src_size = TN) { + if (src_offset + TN <= src_size) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + dst[tn] = src[src_offset + tn]; + } + } else { // Edgecase + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + dst[tn] = src_offset + tn < src_size ? src[src_offset + tn] : 0; + } + } + } + + static METAL_FUNC void run( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + const device T* bias [[buffer(2)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& matrix_ld [[buffer(6)]], + const constant float& alpha [[buffer(7)]], + const constant float& beta [[buffer(8)]], + const constant int& bias_stride [[buffer(14)]], + threadgroup T* tgp_memory [[threadgroup(0)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + // Appease compiler + (void)lid; + + // Thread local accumulation results + thread T result[TM] = {0}; + thread T inter[TN]; + thread T v_coeff[TN]; + + const int thrM = SN != 32 ? simd_lid / SN : 0; + const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid); + + const int sgN = BN != 1 ? (simd_gid % BN) : 0; + + const int simdM = BN != 1 ? SM * (simd_gid / BN) : int(SM * simd_gid); + const int simdN = BN != 1 ? SN * (simd_gid % BN) : 0; + + int bm = (simdM + thrM) * TM; + int bn = (simdN + thrN) * TN; + + // Block position + int out_row = tid.x * blockM + bm; + + // Exit simdgroup if rows out of bound + if (out_row >= out_vec_size) + return; + + // Adjust tail simdgroup to ensure in bound reads + out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM; + + // Advance matrix + mat += out_row * matrix_ld; + + constexpr const uniform loop_stride = make_uniform(blockN); + const uniform in_size = make_uniform(in_vec_size); + const uniform n_iter = in_size / loop_stride; + const uniform last_iter = loop_stride * n_iter; + const uniform leftover = in_size - last_iter; + + // Loop over in_vec in blocks of blockN + for (int i = 0; i < n_iter; ++i) { + load_unsafe(in_vec, v_coeff, bn); + + // Per thread work loop + int mat_offset = 0; + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + // Load for the row + load_unsafe(mat, inter, mat_offset + bn); + + // Accumulate results + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tm] += inter[tn] * v_coeff[tn]; + } + + mat_offset += matrix_ld; + } + + bn += blockN; + } + + if (leftover > 0) { + load_safe(in_vec, v_coeff, bn, in_size); + + // Per thread work loop + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + // Load for the row + load_safe(&mat[tm * matrix_ld], inter, bn, in_size); + + // Accumulate results + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tm] += inter[tn] * v_coeff[tn]; + } + } + } + + // Simdgroup accumulations + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + MLX_MTL_PRAGMA_UNROLL + for (ushort sn = (SN / 2); sn >= 1; sn >>= 1) { + result[tm] += simd_shuffle_down(result[tm], sn); + } + } + + // Threadgroup accumulation results + if (needs_tgp_reduction) { + threadgroup T* tgp_results = tgp_memory + sgN * (blockM + TM) + bm; + if (thrN == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + tgp_results[tm] = result[tm]; + } + + threadgroup_barrier(mem_flags::mem_none); + + if (sgN == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int sgn = 1; sgn < BN; sgn++) { + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + result[tm] += tgp_results[sgn * (blockM + TM) + tm]; + } + } + } + } + } + + // Write outputs + if (simdN == 0 && thrN == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + if (kDoAxpby) { + out_vec[out_row + tm] = static_cast(alpha) * result[tm] + + static_cast(beta) * bias[(out_row + tm) * bias_stride]; + } else { + out_vec[out_row + tm] = result[tm]; + } + } + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// +/// Vector matrix multiplication +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + const int BM, /* Threadgroup rows (in simdgroups) */ + const int BN, /* Threadgroup cols (in simdgroups) */ + const int SM, /* Simdgroup rows (in threads) */ + const int SN, /* Simdgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN, /* Thread cols (in elements) */ + const bool kDoAxpby> /* Do out = alpha * out + beta * bias */ +struct GEMVTKernel { + MLX_MTL_CONST int threadsM = BM * SM; + MLX_MTL_CONST int threadsN = BN * SN; + + MLX_MTL_CONST int blockM = threadsM * TM; + MLX_MTL_CONST int blockN = threadsN * TN; + + static_assert(SM * SN == 32, "simdgroup can only have 32 threads"); + + // - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up + // into blocks of (blockM, blockN) divided among threadgroups + // - Every thread works on a block of (TM, TN) + // - We assume each threadgroup has (threadsN, threadsM, 1) threads + // + // 1. A thread loads TN elements each from mat along TM contiguous rows + // and the corresponding scalar from the vector + // 2. The thread then accumulates its local result for the block + // 3. At the end, each thread has accumulated results over all blocks across + // the rows. These are then summed up across the threadgroup // 4. Each threadgroup writes its accumulated BN * TN outputs // // Edge case handling: @@ -49,7 +270,8 @@ struct GEMVKernel { // * The last thread that partially overlaps with the matrix is shifted // inwards such that the thread block fits exactly in the matrix - MLX_MTL_CONST short tgp_mem_size = BN * TN * 2; + MLX_MTL_CONST short tgp_mem_size = BM > 1 ? BM*(blockN + TN) : 0; + MLX_MTL_CONST bool needs_tgp_reduction = BM > 1; static METAL_FUNC void run( const device T* mat [[buffer(0)]], @@ -70,230 +292,113 @@ struct GEMVKernel { // Appease compiler (void)lid; - // Threadgroup in_vec cache - threadgroup T* in_vec_block = tgp_memory + simd_lid * TN * 2; - - // Thread local accumulation results - thread T result[TM] = {0}; - thread T inter[TN]; - thread T v_coeff[TN]; - - // Block position - int out_row = (tid.x * BM + simd_gid) * TM; - - // Exit simdgroup if rows out of bound - if (out_row >= out_vec_size) - return; - - // Adjust tail simdgroup to ensure in bound reads - out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM; - - // Advance matrix - mat += out_row * marix_ld; - - // Loop over in_vec in blocks of BN * TN - for (int bn = simd_lid * TN; bn < in_vec_size; bn += BN * TN) { - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Prefetch in_vector for threadgroup use - if (simd_gid == 0) { - // Main load loop - if (bn + TN <= in_vec_size) { - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - in_vec_block[tn] = in_vec[bn + tn]; - } - - } else { // Edgecase - - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - in_vec_block[tn] = bn + tn < in_vec_size ? in_vec[bn + tn] : 0; - } - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Load for all rows - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - v_coeff[tn] = in_vec_block[tn]; - } - - // Per thread work loop - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - // Load for the row - if (bn + TN <= in_vec_size) { - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - inter[tn] = mat[tm * marix_ld + bn + tn]; - } - - } else { // Edgecase - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - int col_idx = - (bn + tn) < in_vec_size ? (bn + tn) : (in_vec_size - 1); - inter[tn] = mat[tm * marix_ld + col_idx]; - } - } - - // Accumulate results - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - result[tm] += inter[tn] * v_coeff[tn]; - } - } - } - - // Simdgroup accumulations - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - result[tm] = simd_sum(result[tm]); - } - - // Write outputs - if (simd_lid == 0) { - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - if (kDoAxpby) { - out_vec[out_row + tm] = static_cast(alpha) * result[tm] + - static_cast(beta) * bias[(out_row + tm) * bias_stride]; - } else { - out_vec[out_row + tm] = result[tm]; - } - } - } - } -}; - -/////////////////////////////////////////////////////////////////////////////// -/// Vector matrix multiplication -/////////////////////////////////////////////////////////////////////////////// - -template < - typename T, - const int BM, /* Threadgroup rows (in threads) */ - const int BN, /* Threadgroup cols (in threads) */ - const int TM, /* Thread rows (in elements) */ - const int TN, /* Thread cols (in elements) */ - const bool kDoAxpby> /* Do out = alpha * out + beta * bias */ -struct GEMVTKernel { - // - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up - // into blocks of (BM * TM, BN * TN) divided among threadgroups - // - Every thread works on a block of (TM, TN) - // - We assume each thead group is launched with (BN, BM, 1) threads - // - // 1. A thread loads TN elements each from mat along TM contiguous rows - // and the corresponding scalar from the vector - // 2. The thread then accumulates its local result for the block - // 3. At the end, each thread has accumulated results over all blocks across - // the rows. These are then summed up across the threadgroup - // 4. Each threadgroup writes its accumulated BN * TN outputs - // - // Edge case handling: - // - The threadgroup with the largest tid has blocks that exceed the matrix - // * The blocks that start outside the matrix are never read (thread results - // remain zero) - // * The last thread that partially overlaps with the matrix is shifted - // inwards such that the thread block fits exactly in the matrix - - MLX_MTL_CONST short tgp_mem_size = BN * BM * TN; - - static METAL_FUNC void run( - const device T* mat [[buffer(0)]], - const device T* in_vec [[buffer(1)]], - const device T* bias [[buffer(2)]], - device T* out_vec [[buffer(3)]], - const constant int& in_vec_size [[buffer(4)]], - const constant int& out_vec_size [[buffer(5)]], - const constant int& marix_ld [[buffer(6)]], - const constant float& alpha [[buffer(7)]], - const constant float& beta [[buffer(8)]], - const constant int& bias_stride [[buffer(14)]], - threadgroup T* tgp_memory [[threadgroup(0)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - // Appease compiler - (void)simd_gid; - (void)simd_lid; - // Thread local accumulation results T result[TN] = {0}; T inter[TN]; T v_coeff[TM]; - // Threadgroup accumulation results - threadgroup T* tgp_results = tgp_memory + lid.x * BM * TN; + const int thrM = SN != 32 ? simd_lid / SN : 0; + const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid); - int out_col = (tid.x * BN + lid.x) * TN; - int in_row = lid.y * TM; + const int sgM = BN != 1 ? (simd_gid / BN) : int(simd_gid); + const int sgN = BN != 1 ? (simd_gid % BN) : 0; + + const int simdM = SM * sgM; + const int simdN = SN * sgN; + + int cm = (simdM + thrM); + int cn = (simdN + thrN); + + int bm = cm * TM; + int bn = cn * TN; + + int out_col = tid.x * blockN + bn; + + constexpr const uniform loop_stride = make_uniform(blockM); + const uniform in_size = make_uniform(in_vec_size); + const uniform n_iter = in_size / loop_stride; + const uniform last_iter = loop_stride * n_iter; + const uniform leftover = in_size - last_iter; // Edgecase handling if (out_col < out_vec_size) { out_col = out_col + TN < out_vec_size ? out_col : out_vec_size - TN; // Per thread accumulation main loop - int bm = in_row; - for (; bm < in_vec_size; bm += BM * TM) { + for (int i = 0; i < n_iter; ++i) { // Adding a threadgroup_barrier improves performance slightly // This is possibly it may help exploit cache better threadgroup_barrier(mem_flags::mem_none); - if (bm + TM <= in_vec_size) { + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + v_coeff[tm] = in_vec[bm + tm]; + } + + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + for (int tn = 0; tn < TN; tn++) { + inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; + } + for (int tn = 0; tn < TN; tn++) { + result[tn] += v_coeff[tm] * inter[tn]; + } + } + + bm += blockM; + } + + if (leftover > 0) { + for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) { + v_coeff[tm] = in_vec[bm + tm]; + MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - v_coeff[tm] = in_vec[bm + tm]; + for (int tn = 0; tn < TN; tn++) { + inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; } MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - for (int tn = 0; tn < TN; tn++) { - inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; - } - for (int tn = 0; tn < TN; tn++) { - result[tn] += v_coeff[tm] * inter[tn]; - } - } - - } else { // Edgecase handling - for (int tm = 0; bm + tm < in_vec_size; tm++) { - v_coeff[tm] = in_vec[bm + tm]; - - for (int tn = 0; tn < TN; tn++) { - inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; - } - for (int tn = 0; tn < TN; tn++) { - result[tn] += v_coeff[tm] * inter[tn]; - } + for (int tn = 0; tn < TN; tn++) { + result[tn] += v_coeff[tm] * inter[tn]; } } } } - // Threadgroup collection - + // Simdgroup accumulations MLX_MTL_PRAGMA_UNROLL - for (int i = 0; i < TN; i++) { - tgp_results[lid.y * TN + i] = result[i]; + for (int tn = 0; tn < TN; tn++) { + MLX_MTL_PRAGMA_UNROLL + for (ushort sm = (SM / 2); sm >= 1; sm >>= 1) { + result[tn] += simd_shuffle_down(result[tn], SN * sm); + } } - threadgroup_barrier(mem_flags::mem_threadgroup); + // Threadgroup accumulation results + if (needs_tgp_reduction) { + threadgroup T* tgp_results = tgp_memory + sgM * (blockN + TN) + bn; + if (thrM == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + tgp_results[tn] = result[tn]; + } + + threadgroup_barrier(mem_flags::mem_none); + + if (sgM == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int sgm = 1; sgm < BM; sgm++) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tn] += tgp_results[sgm * (blockN + TN) + tn]; + } + } + } + } + } // Threadgroup accumulation and writing out results - if (lid.y == 0 && out_col < out_vec_size) { - MLX_MTL_PRAGMA_UNROLL - for (int i = 1; i < BM; i++) { - MLX_MTL_PRAGMA_UNROLL - for (int j = 0; j < TN; j++) { - result[j] += tgp_results[i * TN + j]; - } - } - + if (cm == 0 && out_col < out_vec_size) { MLX_MTL_PRAGMA_UNROLL for (int j = 0; j < TN; j++) { if (kDoAxpby) { @@ -313,13 +418,15 @@ struct GEMVTKernel { template < typename T, - const int BM, /* Threadgroup rows (in threads) */ - const int BN, /* Threadgroup cols (in threads) */ + const int BM, /* Threadgroup rows (in simdgroups) */ + const int BN, /* Threadgroup cols (in simdgroups) */ + const int SM, /* Simdgroup rows (in threads) */ + const int SN, /* Simdgroup cols (in threads) */ const int TM, /* Thread rows (in elements) */ const int TN, /* Thread cols (in elements) */ const bool kDoNCBatch, /* Batch ndim > 1 */ const bool kDoAxpby> /* Do out = alpha * out + beta * bias */ -[[kernel, max_total_threads_per_threadgroup(BM* BN)]] void gemv( +[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv( const device T* mat [[buffer(0)]], const device T* in_vec [[buffer(1)]], const device T* bias [[buffer(2)]], @@ -339,8 +446,9 @@ template < uint3 lid [[thread_position_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { - using gemv_kernel = GEMVKernel; - threadgroup T tgp_memory[gemv_kernel::tgp_mem_size]; + using gemv_kernel = GEMVKernel; + threadgroup T tgp_memory + [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; // Update batch offsets if (kDoNCBatch) { @@ -373,17 +481,19 @@ template < alpha, beta, bias_stride, - tgp_memory, + gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, tid, lid, simd_gid, simd_lid); } -#define instantiate_gemv_helper(name, itype, bm, bn, tm, tn, nc, axpby) \ - template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn \ - "_nc" #nc "_axpby" #axpby)]] [[kernel]] void \ - gemv( \ +#define instantiate_gemv_helper( \ + name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \ + template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn \ + "_tm" #tm "_tn" #tn "_nc" #nc \ + "_axpby" #axpby)]] [[kernel]] void \ + gemv( \ const device itype* mat [[buffer(0)]], \ const device itype* in_vec [[buffer(1)]], \ const device itype* bias [[buffer(2)]], \ @@ -405,11 +515,11 @@ template < uint simd_lid [[thread_index_in_simdgroup]]); // clang-format off -#define instantiate_gemv(name, itype, bm, bn, tm, tn) \ - instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 0) \ - instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 1) \ - instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 0) \ - instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 1) // clang-format on +#define instantiate_gemv(name, itype, bm, bn, tm, tn) \ + instantiate_gemv_helper(name, itype, bm, 1, 1, bn, tm, tn, 0, 0) \ + instantiate_gemv_helper(name, itype, bm, 1, 1, bn, tm, tn, 0, 1) \ + instantiate_gemv_helper(name, itype, bm, 1, 1, bn, tm, tn, 1, 0) \ + instantiate_gemv_helper(name, itype, bm, 1, 1, bn, tm, tn, 1, 1) // clang-format on // clang-format off #define instantiate_gemv_blocks(name, itype) \ @@ -423,11 +533,13 @@ instantiate_gemv_blocks(bfloat16, bfloat16_t); template < typename T, - const int BM, /* Threadgroup rows (in threads) */ - const int BN, /* Threadgroup cols (in threads) */ + const int BM, /* Threadgroup rows (in simdgroups) */ + const int BN, /* Threadgroup cols (in simdgroups) */ + const int SM, /* Simdgroup rows (in threads) */ + const int SN, /* Simdgroup cols (in threads) */ const int TM, /* Thread rows (in elements) */ const int TN> /* Thread cols (in elements) */ -[[kernel, max_total_threads_per_threadgroup(BM* BN)]] void gemv_bs( +[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_gather( const device T* mat [[buffer(0)]], const device T* in_vec [[buffer(1)]], const device T* bias [[buffer(2)]], @@ -452,8 +564,9 @@ template < uint3 lid [[thread_position_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { - using gemv_kernel = GEMVKernel; - threadgroup T tgp_memory[gemv_kernel::tgp_mem_size]; + using gemv_kernel = GEMVKernel; + threadgroup T tgp_memory + [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; uint32_t indx_vec; uint32_t indx_mat; @@ -501,47 +614,47 @@ template < alpha, beta, batch_ndim, // Not used - tgp_memory, + gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, tid, lid, simd_gid, simd_lid); } -#define instantiate_gemv_bs_helper(nm, itype, bm, bn, tm, tn) \ - template [[host_name("gemv_bs_" #nm "_bm" #bm "_bn" #bn "_tm" #tm \ - "_tn" #tn)]] [[kernel]] void \ - gemv_bs( \ - const device itype* mat [[buffer(0)]], \ - const device itype* in_vec [[buffer(1)]], \ - const device itype* bias [[buffer(2)]], \ - device itype* out_vec [[buffer(3)]], \ - const constant int& in_vec_size [[buffer(4)]], \ - const constant int& out_vec_size [[buffer(5)]], \ - const constant int& marix_ld [[buffer(6)]], \ - const constant float& alpha [[buffer(7)]], \ - const constant float& beta [[buffer(8)]], \ - const constant int& batch_ndim [[buffer(9)]], \ - const constant int* batch_shape [[buffer(10)]], \ - const constant size_t* index_batch_strides [[buffer(11)]], \ - const constant int& vector_batch_ndim [[buffer(12)]], \ - const constant int* vector_batch_shape [[buffer(13)]], \ - const constant size_t* vector_batch_stride [[buffer(14)]], \ - const constant int& matrix_batch_ndim [[buffer(15)]], \ - const constant int* matrix_batch_shape [[buffer(16)]], \ - const constant size_t* matrix_batch_stride [[buffer(17)]], \ - const constant uint32_t* vec_indices [[buffer(18)]], \ - const constant uint32_t* mat_indices [[buffer(19)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]], \ - uint simd_gid [[simdgroup_index_in_threadgroup]], \ +#define instantiate_gemv_bs_helper(nm, itype, bm, bn, sm, sn, tm, tn) \ + template [[host_name("gemv_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \ + "_sn" #sn "_tm" #tm "_tn" #tn)]] [[kernel]] void \ + gemv_gather( \ + const device itype* mat [[buffer(0)]], \ + const device itype* in_vec [[buffer(1)]], \ + const device itype* bias [[buffer(2)]], \ + device itype* out_vec [[buffer(3)]], \ + const constant int& in_vec_size [[buffer(4)]], \ + const constant int& out_vec_size [[buffer(5)]], \ + const constant int& marix_ld [[buffer(6)]], \ + const constant float& alpha [[buffer(7)]], \ + const constant float& beta [[buffer(8)]], \ + const constant int& batch_ndim [[buffer(9)]], \ + const constant int* batch_shape [[buffer(10)]], \ + const constant size_t* index_batch_strides [[buffer(11)]], \ + const constant int& vector_batch_ndim [[buffer(12)]], \ + const constant int* vector_batch_shape [[buffer(13)]], \ + const constant size_t* vector_batch_stride [[buffer(14)]], \ + const constant int& matrix_batch_ndim [[buffer(15)]], \ + const constant int* matrix_batch_shape [[buffer(16)]], \ + const constant size_t* matrix_batch_stride [[buffer(17)]], \ + const constant uint32_t* vec_indices [[buffer(18)]], \ + const constant uint32_t* mat_indices [[buffer(19)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ uint simd_lid [[thread_index_in_simdgroup]]); // clang-format off #define instantiate_gemv_bs_blocks(name, itype) \ - instantiate_gemv_bs_helper(name, itype, 4, 32, 1, 4) \ - instantiate_gemv_bs_helper(name, itype, 4, 32, 4, 4) \ - instantiate_gemv_bs_helper(name, itype, 8, 32, 4, 4) // clang-format on + instantiate_gemv_bs_helper(name, itype, 4, 1, 1, 32, 1, 4) \ + instantiate_gemv_bs_helper(name, itype, 4, 1, 1, 32, 4, 4) \ + instantiate_gemv_bs_helper(name, itype, 8, 1, 1, 32, 4, 4) // clang-format on instantiate_gemv_bs_blocks(float32, float); instantiate_gemv_bs_blocks(float16, half); @@ -553,13 +666,15 @@ instantiate_gemv_bs_blocks(bfloat16, bfloat16_t); template < typename T, - const int BM, /* Threadgroup rows (in threads) */ - const int BN, /* Threadgroup cols (in threads) */ + const int BM, /* Threadgroup rows (in simdgroups) */ + const int BN, /* Threadgroup cols (in simdgroups) */ + const int SM, /* Simdgroup rows (in threads) */ + const int SN, /* Simdgroup cols (in threads) */ const int TM, /* Thread rows (in elements) */ const int TN, /* Thread cols (in elements) */ const bool kDoNCBatch, /* Batch ndim > 1 */ const bool kDoAxpby> /* Do out = alpha * out + beta * bias */ -[[kernel, max_total_threads_per_threadgroup(BM* BN)]] void gemv_t( +[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_t( const device T* mat [[buffer(0)]], const device T* in_vec [[buffer(1)]], const device T* bias [[buffer(2)]], @@ -579,8 +694,9 @@ template < uint3 lid [[thread_position_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { - using gemv_kernel = GEMVTKernel; - threadgroup T tgp_memory[gemv_kernel::tgp_mem_size]; + using gemv_kernel = GEMVTKernel; + threadgroup T tgp_memory + [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; // Update batch offsets if (kDoNCBatch) { @@ -613,17 +729,19 @@ template < alpha, beta, bias_stride, - tgp_memory, + gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, tid, lid, simd_gid, simd_lid); } -#define instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, nc, axpby) \ - template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn \ - "_nc" #nc "_axpby" #axpby)]] [[kernel]] void \ - gemv_t( \ +#define instantiate_gemv_t_helper( \ + name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \ + template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn \ + "_tm" #tm "_tn" #tn "_nc" #nc \ + "_axpby" #axpby)]] [[kernel]] void \ + gemv_t( \ const device itype* mat [[buffer(0)]], \ const device itype* in_vec [[buffer(1)]], \ const device itype* bias [[buffer(2)]], \ @@ -645,20 +763,19 @@ template < uint simd_lid [[thread_index_in_simdgroup]]); // clang-format off -#define instantiate_gemv_t(name, itype, bm, bn, tm, tn) \ - instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 0, 0) \ - instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 0, 1) \ - instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 1, 0) \ - instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 1, 1) // clang-format on +#define instantiate_gemv_t(name, itype, bm, bn, sm, sn, tm, tn) \ + instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 0) \ + instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 1) \ + instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 1, 0) \ + instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 1, 1) // clang-format on // clang-format off #define instantiate_gemv_t_blocks(name, itype) \ - instantiate_gemv_t(name, itype, 8, 8, 4, 1) \ - instantiate_gemv_t(name, itype, 8, 8, 4, 4) \ - instantiate_gemv_t(name, itype, 8, 16, 4, 4) \ - instantiate_gemv_t(name, itype, 8, 32, 4, 4) \ - instantiate_gemv_t(name, itype, 8, 64, 4, 4) \ - instantiate_gemv_t(name, itype, 8, 128, 4, 4) // clang-format on + instantiate_gemv_t(name, itype, 1, 2, 8, 4, 4, 1) \ + instantiate_gemv_t(name, itype, 1, 2, 8, 4, 4, 4) \ + instantiate_gemv_t(name, itype, 1, 4, 8, 4, 4, 4) \ + instantiate_gemv_t(name, itype, 1, 16, 8, 4, 4, 4) \ + instantiate_gemv_t(name, itype, 1, 16, 4, 8, 4, 4) // clang-format on // clang-format off instantiate_gemv_t_blocks(float32, float); @@ -667,11 +784,13 @@ instantiate_gemv_t_blocks(bfloat16, bfloat16_t); // clang-format on template < typename T, - const int BM, /* Threadgroup rows (in threads) */ - const int BN, /* Threadgroup cols (in threads) */ + const int BM, /* Threadgroup rows (in simdgroups) */ + const int BN, /* Threadgroup cols (in simdgroups) */ + const int SM, /* Simdgroup rows (in threads) */ + const int SN, /* Simdgroup cols (in threads) */ const int TM, /* Thread rows (in elements) */ const int TN> /* Thread cols (in elements) */ -[[kernel, max_total_threads_per_threadgroup(BM* BN)]] void gemv_t_bs( +[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_t_gather( const device T* mat [[buffer(0)]], const device T* in_vec [[buffer(1)]], const device T* bias [[buffer(2)]], @@ -696,8 +815,9 @@ template < uint3 lid [[thread_position_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { - using gemv_kernel = GEMVTKernel; - threadgroup T tgp_memory[gemv_kernel::tgp_mem_size]; + using gemv_kernel = GEMVTKernel; + threadgroup T tgp_memory + [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; uint32_t indx_vec; uint32_t indx_mat; @@ -745,50 +865,49 @@ template < alpha, beta, batch_ndim, // Not used, - tgp_memory, + gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, tid, lid, simd_gid, simd_lid); } -#define instantiate_gemv_t_bs_helper(nm, itype, bm, bn, tm, tn) \ - template [[host_name("gemv_t_bs_" #nm "_bm" #bm "_bn" #bn "_tm" #tm \ - "_tn" #tn)]] [[kernel]] void \ - gemv_t_bs( \ - const device itype* mat [[buffer(0)]], \ - const device itype* in_vec [[buffer(1)]], \ - const device itype* bias [[buffer(2)]], \ - device itype* out_vec [[buffer(3)]], \ - const constant int& in_vec_size [[buffer(4)]], \ - const constant int& out_vec_size [[buffer(5)]], \ - const constant int& marix_ld [[buffer(6)]], \ - const constant float& alpha [[buffer(7)]], \ - const constant float& beta [[buffer(8)]], \ - const constant int& batch_ndim [[buffer(9)]], \ - const constant int* batch_shape [[buffer(10)]], \ - const constant size_t* index_batch_strides [[buffer(11)]], \ - const constant int& vector_batch_ndim [[buffer(12)]], \ - const constant int* vector_batch_shape [[buffer(13)]], \ - const constant size_t* vector_batch_stride [[buffer(14)]], \ - const constant int& matrix_batch_ndim [[buffer(15)]], \ - const constant int* matrix_batch_shape [[buffer(16)]], \ - const constant size_t* matrix_batch_stride [[buffer(17)]], \ - const constant uint32_t* vec_indices [[buffer(18)]], \ - const constant uint32_t* mat_indices [[buffer(19)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]], \ - uint simd_gid [[simdgroup_index_in_threadgroup]], \ +#define instantiate_gemv_t_bs_helper(nm, itype, bm, bn, sm, sn, tm, tn) \ + template [[host_name("gemv_t_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \ + "_sn" #sn "_tm" #tm "_tn" #tn)]] [[kernel]] void \ + gemv_t_gather( \ + const device itype* mat [[buffer(0)]], \ + const device itype* in_vec [[buffer(1)]], \ + const device itype* bias [[buffer(2)]], \ + device itype* out_vec [[buffer(3)]], \ + const constant int& in_vec_size [[buffer(4)]], \ + const constant int& out_vec_size [[buffer(5)]], \ + const constant int& marix_ld [[buffer(6)]], \ + const constant float& alpha [[buffer(7)]], \ + const constant float& beta [[buffer(8)]], \ + const constant int& batch_ndim [[buffer(9)]], \ + const constant int* batch_shape [[buffer(10)]], \ + const constant size_t* index_batch_strides [[buffer(11)]], \ + const constant int& vector_batch_ndim [[buffer(12)]], \ + const constant int* vector_batch_shape [[buffer(13)]], \ + const constant size_t* vector_batch_stride [[buffer(14)]], \ + const constant int& matrix_batch_ndim [[buffer(15)]], \ + const constant int* matrix_batch_shape [[buffer(16)]], \ + const constant size_t* matrix_batch_stride [[buffer(17)]], \ + const constant uint32_t* vec_indices [[buffer(18)]], \ + const constant uint32_t* mat_indices [[buffer(19)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ uint simd_lid [[thread_index_in_simdgroup]]); // clang-format off -#define instantiate_gemv_t_bs_blocks(name, itype) \ - instantiate_gemv_t_bs_helper(name, itype, 8, 8, 4, 1) \ - instantiate_gemv_t_bs_helper(name, itype, 8, 8, 4, 4) \ - instantiate_gemv_t_bs_helper(name, itype, 8, 16, 4, 4) \ - instantiate_gemv_t_bs_helper(name, itype, 8, 32, 4, 4) \ - instantiate_gemv_t_bs_helper(name, itype, 8, 64, 4, 4) \ - instantiate_gemv_t_bs_helper(name, itype, 8, 128, 4, 4) // clang-format on +#define instantiate_gemv_t_bs_blocks(name, itype) \ + instantiate_gemv_t_bs_helper(name, itype, 1, 2, 8, 4, 4, 1) \ + instantiate_gemv_t_bs_helper(name, itype, 1, 2, 8, 4, 4, 4) \ + instantiate_gemv_t_bs_helper(name, itype, 1, 4, 8, 4, 4, 4) \ + instantiate_gemv_t_bs_helper(name, itype, 1, 16, 8, 4, 4, 4) \ + instantiate_gemv_t_bs_helper(name, itype, 1, 16, 4, 8, 4, 4) // clang-format on // clang-format off instantiate_gemv_t_bs_blocks(float32, float); diff --git a/mlx/backend/metal/kernels/gemv_masked.metal b/mlx/backend/metal/kernels/gemv_masked.metal new file mode 100644 index 000000000..2d63a6e40 --- /dev/null +++ b/mlx/backend/metal/kernels/gemv_masked.metal @@ -0,0 +1,939 @@ +// Copyright © 2023-2024 Apple Inc. + +#include +#include + +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/defines.h" +#include "mlx/backend/metal/kernels/utils.h" + +#include "mlx/backend/metal/kernels/steel/utils.h" + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +/// Matrix vector multiplication +/////////////////////////////////////////////////////////////////////////////// + +#define MLX_MTL_CONST static constant constexpr const + +struct _NoMask { + char x; + + constexpr METAL_FUNC operator bool() { + return true; + } + constexpr METAL_FUNC operator bool() const threadgroup { + return true; + } + constexpr METAL_FUNC operator bool() const device { + return true; + } + constexpr METAL_FUNC operator bool() const constant { + return true; + } +}; + +typedef struct _NoMask nomask_t; + +template +struct ScaleOp { + OutT scale; + + METAL_FUNC OutT apply(InT x) const { + return static_cast(x) * scale; + } +}; + +template < + typename T, + typename out_mask_t, + typename op_mask_t, + const int BM, /* Threadgroup rows (in simdgroups) */ + const int BN, /* Threadgroup cols (in simdgroups) */ + const int SM, /* Simdgroup rows (in threads) */ + const int SN, /* Simdgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN> /* Thread cols (in elements) */ +struct GEMVKernel { + MLX_MTL_CONST int threadsM = BM * SM; + MLX_MTL_CONST int threadsN = BN * SN; + + MLX_MTL_CONST int blockM = threadsM * TM; + MLX_MTL_CONST int blockN = threadsN * TN; + + static_assert(SM * SN == 32, "simdgroup can only have 32 threads"); + + static_assert( + SN == 8 || SN == 16 || SN == 32, + "gemv block must have a width of 8, 16, or 32"); + + static_assert(blockN >= blockM, "Masked gemv must have blockN >= blockM"); + + MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v; + MLX_MTL_CONST bool has_output_mask = !metal::is_same_v; + + MLX_MTL_CONST bool has_mul_operand_mask = + has_operand_mask && !metal::is_same_v; + MLX_MTL_CONST bool has_mul_output_mask = + has_output_mask && !metal::is_same_v; + + // - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up + // into blocks of (blockM, blockN) divided among threadgroups + // - Every thread works on a block of (TM, TN) + // - We assume each threadgroup has (threadsN, threadsM, 1) threads + // + // 1. A thread loads TN elements each from mat along TM rows + // and the corresponding scalar from the vector + // 2. The thread then multiplies and adds to accumulate its local result for + // the block + // 3. At the end, each thread has accumulated results over all blocks across + // the rows. These are then summed up across the threadgroup + // 4. Each threadgroup writes its accumulated blockM outputs + // + // Edge case handling: + // - The threadgroup with the largest tid has blocks that exceed the matrix + // * The blocks that start outside the matrix are never read (thread results + // remain zero) + // * The last thread that partially overlaps with the matrix is shifted + // inwards such that the thread block fits exactly in the matrix + + MLX_MTL_CONST short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0; + MLX_MTL_CONST bool needs_tgp_reduction = BN > 1; + + static METAL_FUNC void + load_unsafe(const device T* src, thread T dst[TN], const int src_offset = 0) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + dst[tn] = src[src_offset + tn]; + } + } + + static METAL_FUNC void load_safe( + const device T* src, + thread T dst[TN], + const int src_offset = 0, + const int src_size = TN) { + if (src_offset + TN <= src_size) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + dst[tn] = src[src_offset + tn]; + } + } else { // Edgecase + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + dst[tn] = src_offset + tn < src_size ? src[src_offset + tn] : 0; + } + } + } + + static METAL_FUNC void run( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& matrix_ld [[buffer(6)]], + const device out_mask_t* out_mask [[buffer(20)]], + const device op_mask_t* mat_mask [[buffer(21)]], + const device op_mask_t* vec_mask [[buffer(22)]], + const constant int* mask_strides [[buffer(23)]], + threadgroup T* tgp_memory [[threadgroup(0)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + // Appease compiler + (void)lid; + + // Thread local accumulation results + thread T result[TM] = {0}; + thread T inter[TN]; + thread T v_coeff[TN]; + + const int thrM = SN != 32 ? simd_lid / SN : 0; + const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid); + + const int sgN = BN != 1 ? (simd_gid % BN) : 0; + + const int simdM = BN != 1 ? SM * (simd_gid / BN) : int(SM * simd_gid); + const int simdN = BN != 1 ? SN * (simd_gid % BN) : 0; + + int bm = (simdM + thrM) * TM; + int bn = (simdN + thrN) * TN; + + // Block position + int out_row = tid.x * blockM + bm; + + // Exit simdgroup if rows out of bound + if (out_row >= out_vec_size) + return; + + // Adjust tail simdgroup to ensure in bound reads + out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM; + + // Prepare mask offsets + const constant int* out_mask_strides = mask_strides; + const constant int* mat_mask_strides = + mask_strides + (has_output_mask ? 2 : 0); + const constant int* vec_mask_strides = + mat_mask_strides + (has_operand_mask ? 2 : 0); + + const int m_block_idx = blockN > blockM ? out_row / blockN : int(tid.x); + + const int out_mask_offset = + !has_output_mask ? 0 : m_block_idx * out_mask_strides[1]; + + int mat_mask_offset = + !has_operand_mask ? 0 : m_block_idx * mat_mask_strides[1]; + int vec_mask_offset = 0; + const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[0]; + const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[1]; + + T out_scale{1}; + + // Check output mask + if (has_output_mask) { + auto mask_out = out_mask[out_mask_offset]; + + // Write zeros and return if mask is 0 + if (!mask_out) { + if (simdN == 0 && thrN == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + out_vec[out_row + tm] = T(0.); + } + } + + return; + } + + // Store scalar if multiplicative mask + if (has_mul_output_mask) { + out_scale = T(mask_out); + } + } + + // Advance matrix + mat += out_row * matrix_ld; + + // Prepare for loop + constexpr const uniform loop_stride = make_uniform(blockN); + const uniform in_size = make_uniform(in_vec_size); + const uniform n_iter = in_size / loop_stride; + const uniform last_iter = loop_stride * n_iter; + const uniform leftover = in_size - last_iter; + + // Loop over in_vec in blocks of blockN + for (int i = 0; i < n_iter; ++i) { + if (!has_operand_mask || + (bool(mat_mask[mat_mask_offset]) && + bool(vec_mask[vec_mask_offset]))) { + T block_scale{1}; + if (has_mul_operand_mask) { + block_scale = + T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); + } + + load_unsafe(in_vec, v_coeff, bn); + + // Apply scale + if (has_mul_operand_mask) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + v_coeff[tn] *= block_scale; + } + } + + // Per thread work loop + int mat_offset = 0; + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + // Load for the row + load_unsafe(mat, inter, mat_offset + bn); + + // Accumulate results + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tm] += inter[tn] * v_coeff[tn]; + } + + mat_offset += matrix_ld; + } + } + + bn += blockN; + mat_mask_offset += mat_mask_step; + vec_mask_offset += vec_mask_step; + } + + if (leftover > 0 && + (!has_operand_mask || + (bool(mat_mask[mat_mask_offset]) && + bool(vec_mask[vec_mask_offset])))) { + T block_scale{1}; + if (has_mul_operand_mask) { + block_scale = + T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); + } + + load_safe(in_vec, v_coeff, bn, in_size); + + // Apply scale + if (has_mul_operand_mask) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + v_coeff[tn] *= block_scale; + } + } + + // Per thread work loop + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + // Load for the row + load_safe(&mat[tm * matrix_ld], inter, bn, in_size); + + // Accumulate results + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tm] += inter[tn] * v_coeff[tn]; + } + } + } + + // Apply out scale + if (has_mul_output_mask) { + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + result[tm] *= out_scale; + } + } + + // Simdgroup accumulations + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + MLX_MTL_PRAGMA_UNROLL + for (ushort sn = (SN / 2); sn >= 1; sn >>= 1) { + result[tm] += simd_shuffle_down(result[tm], sn); + } + } + + // Threadgroup accumulation results + if (needs_tgp_reduction) { + threadgroup T* tgp_results = tgp_memory + sgN * (blockM + TM) + bm; + if (thrN == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + tgp_results[tm] = result[tm]; + } + + threadgroup_barrier(mem_flags::mem_none); + + if (sgN == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int sgn = 1; sgn < BN; sgn++) { + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + result[tm] += tgp_results[sgn * (blockM + TM) + tm]; + } + } + } + } + } + + // Write outputs + if (simdN == 0 && thrN == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + out_vec[out_row + tm] = result[tm]; + } + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// +/// Vector matrix multiplication +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + typename out_mask_t, + typename op_mask_t, + const int BM, /* Threadgroup rows (in simdgroups) */ + const int BN, /* Threadgroup cols (in simdgroups) */ + const int SM, /* Simdgroup rows (in threads) */ + const int SN, /* Simdgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN> /* Thread cols (in elements) */ +struct GEMVTKernel { + MLX_MTL_CONST int threadsM = BM * SM; + MLX_MTL_CONST int threadsN = BN * SN; + + MLX_MTL_CONST int blockM = threadsM * TM; + MLX_MTL_CONST int blockN = threadsN * TN; + + static_assert(SM * SN == 32, "simdgroup can only have 32 threads"); + + MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v; + MLX_MTL_CONST bool has_output_mask = !metal::is_same_v; + + MLX_MTL_CONST bool has_mul_operand_mask = + has_operand_mask && !metal::is_same_v; + MLX_MTL_CONST bool has_mul_output_mask = + has_output_mask && !metal::is_same_v; + + // - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up + // into blocks of (blockM, blockN) divided among threadgroups + // - Every thread works on a block of (TM, TN) + // - We assume each threadgroup has (threadsN, threadsM, 1) threads + // + // 1. A thread loads TN elements each from mat along TM contiguous rows + // and the corresponding scalar from the vector + // 2. The thread then accumulates its local result for the block + // 3. At the end, each thread has accumulated results over all blocks across + // the rows. These are then summed up across the threadgroup + // 4. Each threadgroup writes its accumulated BN * TN outputs + // + // Edge case handling: + // - The threadgroup with the largest tid has blocks that exceed the matrix + // * The blocks that start outside the matrix are never read (thread results + // remain zero) + // * The last thread that partially overlaps with the matrix is shifted + // inwards such that the thread block fits exactly in the matrix + + MLX_MTL_CONST short tgp_mem_size = BM > 1 ? BM*(blockN + TN) : 0; + MLX_MTL_CONST bool needs_tgp_reduction = BM > 1; + + static METAL_FUNC void run( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& marix_ld [[buffer(6)]], + const device out_mask_t* out_mask [[buffer(20)]], + const device op_mask_t* mat_mask [[buffer(21)]], + const device op_mask_t* vec_mask [[buffer(22)]], + const constant int* mask_strides [[buffer(23)]], + threadgroup T* tgp_memory [[threadgroup(0)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + // Appease compiler + (void)lid; + + // Thread local accumulation results + T result[TN] = {0}; + T inter[TN]; + T v_coeff[TM]; + + const int thrM = SN != 32 ? simd_lid / SN : 0; + const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid); + + const int sgM = BN != 1 ? (simd_gid / BN) : int(simd_gid); + const int sgN = BN != 1 ? (simd_gid % BN) : 0; + + const int simdM = SM * sgM; + const int simdN = SN * sgN; + + int cm = (simdM + thrM); + int cn = (simdN + thrN); + + int bm = cm * TM; + int bn = cn * TN; + + int out_col = tid.x * blockN + bn; + + // Prepare mask offsets + const constant int* out_mask_strides = mask_strides; + const constant int* mat_mask_strides = + out_mask_strides + (has_output_mask ? 2 : 0); + const constant int* vec_mask_strides = + mat_mask_strides + (has_operand_mask ? 2 : 0); + + const int n_block_idx = blockM > blockN ? out_col / blockM : int(tid.x); + + const int out_mask_offset = + !has_output_mask ? 0 : n_block_idx; // * out_mask_strides[0]; + + int mat_mask_offset = + !has_operand_mask ? 0 : n_block_idx * mat_mask_strides[0]; + int vec_mask_offset = 0; + const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[1]; + const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[0]; + + T out_scale{1}; + + // Check output mask + if (has_output_mask) { + auto mask_out = out_mask[out_mask_offset]; + + // Write zeros and return if mask is 0 + if (!mask_out) { + if (cm == 0 && out_col < out_vec_size) { + if (out_col + TN <= out_vec_size) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + out_vec[out_col + tn] = T(0.); + } + } else { + for (int tn = 0; tn < TN && (out_col + tn) < out_vec_size; tn++) { + out_vec[out_col + tn] = T(0.); + } + } + } + + return; + } + + // Store scalar if multiplicative mask + if (has_mul_output_mask) { + out_scale = T(mask_out); + } + } + + // Prepare for loop + constexpr const uniform loop_stride = make_uniform(blockM); + const uniform in_size = make_uniform(in_vec_size); + const uniform n_iter = in_size / loop_stride; + const uniform last_iter = loop_stride * n_iter; + const uniform leftover = in_size - last_iter; + + // Edgecase handling + if (out_col < out_vec_size) { + out_col = (out_col + TN) <= out_vec_size ? out_col : out_vec_size - TN; + + // Per thread accumulation main loop + for (int i = 0; i < n_iter; ++i) { + // Adding a threadgroup_barrier improves performance slightly + // This is possibly it may help exploit cache better + threadgroup_barrier(mem_flags::mem_none); + + if (!has_operand_mask || + (bool(mat_mask[mat_mask_offset]) && + bool(vec_mask[vec_mask_offset]))) { + T block_scale{1}; + if (has_mul_operand_mask) { + block_scale = + T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); + } + + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + v_coeff[tm] = in_vec[bm + tm]; + } + + // Apply scale + if (has_mul_operand_mask) { + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + v_coeff[tm] *= block_scale; + } + } + + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + for (int tn = 0; tn < TN; tn++) { + inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; + } + for (int tn = 0; tn < TN; tn++) { + result[tn] += v_coeff[tm] * inter[tn]; + } + } + } + + bm += blockM; + mat_mask_offset += mat_mask_step; + vec_mask_offset += vec_mask_step; + } + + if (leftover > 0 && + (!has_operand_mask || + (bool(mat_mask[mat_mask_offset]) && + bool(vec_mask[vec_mask_offset])))) { + T block_scale{1}; + if (has_mul_operand_mask) { + block_scale = + T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); + } + + for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) { + v_coeff[tm] = in_vec[bm + tm]; + + if (has_mul_operand_mask) { + v_coeff[tm] *= block_scale; + } + + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; + } + + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tn] += v_coeff[tm] * inter[tn]; + } + } + } + } + + // Apply out scale + if (has_mul_output_mask) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tn] *= out_scale; + } + } + + // Simdgroup accumulations + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + MLX_MTL_PRAGMA_UNROLL + for (ushort sm = (SM / 2); sm >= 1; sm >>= 1) { + result[tn] += simd_shuffle_down(result[tn], SN * sm); + } + } + + // Threadgroup accumulation results + if (needs_tgp_reduction) { + threadgroup T* tgp_results = tgp_memory + sgM * (blockN + TN) + bn; + if (thrM == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + tgp_results[tn] = result[tn]; + } + + threadgroup_barrier(mem_flags::mem_none); + + if (sgM == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int sgm = 1; sgm < BM; sgm++) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tn] += tgp_results[sgm * (blockN + TN) + tn]; + } + } + } + } + } + + // Threadgroup accumulation and writing out results + if (cm == 0 && out_col < out_vec_size) { + MLX_MTL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + out_vec[out_col + j] = result[j]; + } + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// +/// Matrix vector multiplication +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + typename out_mask_t, + typename op_mask_t, + const int BM, /* Threadgroup rows (in simdgroups) */ + const int BN, /* Threadgroup cols (in simdgroups) */ + const int SM, /* Simdgroup rows (in threads) */ + const int SN, /* Simdgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN, /* Thread cols (in elements) */ + const bool kDoNCBatch> /* Batch ndim > 1 */ +[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_masked( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& marix_ld [[buffer(6)]], + const constant int& batch_ndim [[buffer(9)]], + const constant int* batch_shape [[buffer(10)]], + const constant size_t* vector_batch_stride [[buffer(11)]], + const constant size_t* matrix_batch_stride [[buffer(12)]], + const device out_mask_t* out_mask [[buffer(20)]], + const device op_mask_t* mat_mask [[buffer(21)]], + const device op_mask_t* vec_mask [[buffer(22)]], + const constant int* mask_strides [[buffer(23)]], + const constant size_t* mask_batch_strides [[buffer(24)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + using gemv_kernel = + GEMVKernel; + threadgroup T tgp_memory + [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; + + constexpr bool has_operand_mask = !metal::is_same_v; + constexpr bool has_output_mask = !metal::is_same_v; + + // Update batch offsets + if (kDoNCBatch) { + in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim); + mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim); + + if (has_output_mask) { + out_mask += + elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim); + mask_batch_strides += batch_ndim; + } + + if (has_operand_mask) { + const constant size_t* mask_strides_mat = mask_batch_strides; + const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim); + + mat_mask += batch_offsets.x; + vec_mask += batch_offsets.y; + } + + } else { + in_vec += tid.z * vector_batch_stride[0]; + mat += tid.z * matrix_batch_stride[0]; + + if (has_output_mask) { + out_mask += tid.z * mask_batch_strides[0]; + mask_batch_strides += batch_ndim; + } + + if (has_operand_mask) { + mat_mask += tid.z * mask_batch_strides[0]; + vec_mask += tid.z * mask_batch_strides[batch_ndim]; + } + } + + out_vec += tid.z * out_vec_size; + + gemv_kernel::run( + mat, + in_vec, + out_vec, + in_vec_size, + out_vec_size, + marix_ld, + out_mask, + mat_mask, + vec_mask, + mask_strides, + gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, + tid, + lid, + simd_gid, + simd_lid); +} + +#define instantiate_gemv_helper( \ + outm_n, outm_t, opm_n, opm_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \ + template [[host_name("gemv_outmask_" #outm_n "_opmask_" #opm_n "_" #name \ + "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \ + "_tn" #tn "_nc" #nc)]] [[kernel]] void \ + gemv_masked( \ + const device itype* mat [[buffer(0)]], \ + const device itype* in_vec [[buffer(1)]], \ + device itype* out_vec [[buffer(3)]], \ + const constant int& in_vec_size [[buffer(4)]], \ + const constant int& out_vec_size [[buffer(5)]], \ + const constant int& marix_ld [[buffer(6)]], \ + const constant int& batch_ndim [[buffer(9)]], \ + const constant int* batch_shape [[buffer(10)]], \ + const constant size_t* vector_batch_stride [[buffer(11)]], \ + const constant size_t* matrix_batch_stride [[buffer(12)]], \ + const device outm_t* out_mask [[buffer(20)]], \ + const device opm_t* mat_mask [[buffer(21)]], \ + const device opm_t* vec_mask [[buffer(22)]], \ + const constant int* mask_strides [[buffer(23)]], \ + const constant size_t* mask_batch_strides [[buffer(24)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); + +// clang-format off +#define instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, nc) \ + instantiate_gemv_helper(bool_, bool, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \ + instantiate_gemv_helper(name, itype, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \ + instantiate_gemv_helper(bool_, bool, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \ + instantiate_gemv_helper(name, itype, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \ + instantiate_gemv_helper(nomask, nomask_t, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \ + instantiate_gemv_helper(nomask, nomask_t, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \ + instantiate_gemv_helper(bool_, bool, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \ + instantiate_gemv_helper(name, itype, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc) // clang-format on + +// clang-format off +#define instantiate_gemv(name, itype, bm, bn, sm, sn, tm, tn) \ + instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, 0) \ + instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, 1) // clang-format on + +// clang-format off +#define instantiate_gemv_blocks(name, itype) \ + instantiate_gemv(name, itype, 2, 1, 4, 8, 1, 4) \ + instantiate_gemv(name, itype, 2, 1, 4, 8, 4, 4) \ + instantiate_gemv(name, itype, 2, 1, 2, 16, 1, 4) \ + instantiate_gemv(name, itype, 2, 1, 2, 16, 4, 4) \ + instantiate_gemv(name, itype, 4, 1, 2, 16, 4, 4) // clang-format on + +instantiate_gemv_blocks(float32, float); +instantiate_gemv_blocks(float16, half); +instantiate_gemv_blocks(bfloat16, bfloat16_t); + +/////////////////////////////////////////////////////////////////////////////// +/// Vector matrix multiplication +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + typename out_mask_t, + typename op_mask_t, + const int BM, /* Threadgroup rows (in simdgroups) */ + const int BN, /* Threadgroup cols (in simdgroups) */ + const int SM, /* Simdgroup rows (in threads) */ + const int SN, /* Simdgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN, /* Thread cols (in elements) */ + const bool kDoNCBatch> /* Batch ndim > 1 */ +[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_t_masked( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& marix_ld [[buffer(6)]], + const constant int& batch_ndim [[buffer(9)]], + const constant int* batch_shape [[buffer(10)]], + const constant size_t* vector_batch_stride [[buffer(11)]], + const constant size_t* matrix_batch_stride [[buffer(12)]], + const device out_mask_t* out_mask [[buffer(20)]], + const device op_mask_t* mat_mask [[buffer(21)]], + const device op_mask_t* vec_mask [[buffer(22)]], + const constant int* mask_strides [[buffer(23)]], + const constant size_t* mask_batch_strides [[buffer(24)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + using gemv_kernel = + GEMVTKernel; + threadgroup T tgp_memory + [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; + + constexpr bool has_operand_mask = !metal::is_same_v; + constexpr bool has_output_mask = !metal::is_same_v; + + // Update batch offsets + if (kDoNCBatch) { + in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim); + mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim); + + if (has_output_mask) { + out_mask += + elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim); + mask_batch_strides += batch_ndim; + } + + if (has_operand_mask) { + const constant size_t* mask_strides_mat = mask_batch_strides; + const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim); + + mat_mask += batch_offsets.x; + vec_mask += batch_offsets.y; + } + + } else { + in_vec += tid.z * vector_batch_stride[0]; + mat += tid.z * matrix_batch_stride[0]; + + if (has_output_mask) { + out_mask += tid.z * mask_batch_strides[0]; + mask_batch_strides += batch_ndim; + } + + if (has_operand_mask) { + mat_mask += tid.z * mask_batch_strides[0]; + vec_mask += tid.z * mask_batch_strides[batch_ndim]; + } + } + + out_vec += tid.z * out_vec_size; + + gemv_kernel::run( + mat, + in_vec, + out_vec, + in_vec_size, + out_vec_size, + marix_ld, + out_mask, + mat_mask, + vec_mask, + mask_strides, + gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, + tid, + lid, + simd_gid, + simd_lid); +} + +#define instantiate_gemv_t_helper( \ + outm_n, outm_t, opm_n, opm_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \ + template [[host_name("gemv_t_outmask_" #outm_n "_opmask_" #opm_n "_" #name \ + "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \ + "_tn" #tn "_nc" #nc)]] [[kernel]] void \ + gemv_t_masked( \ + const device itype* mat [[buffer(0)]], \ + const device itype* in_vec [[buffer(1)]], \ + device itype* out_vec [[buffer(3)]], \ + const constant int& in_vec_size [[buffer(4)]], \ + const constant int& out_vec_size [[buffer(5)]], \ + const constant int& marix_ld [[buffer(6)]], \ + const constant int& batch_ndim [[buffer(9)]], \ + const constant int* batch_shape [[buffer(10)]], \ + const constant size_t* vector_batch_stride [[buffer(11)]], \ + const constant size_t* matrix_batch_stride [[buffer(12)]], \ + const device outm_t* out_mask [[buffer(20)]], \ + const device opm_t* mat_mask [[buffer(21)]], \ + const device opm_t* vec_mask [[buffer(22)]], \ + const constant int* mask_strides [[buffer(23)]], \ + const constant size_t* mask_batch_strides [[buffer(24)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); + +// clang-format off +#define instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, nc) \ + instantiate_gemv_t_helper(bool_, bool, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \ + instantiate_gemv_t_helper(name, itype, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \ + instantiate_gemv_t_helper(bool_, bool, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \ + instantiate_gemv_t_helper(name, itype, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \ + instantiate_gemv_t_helper(nomask, nomask_t, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \ + instantiate_gemv_t_helper(nomask, nomask_t, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \ + instantiate_gemv_t_helper(bool_, bool, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \ + instantiate_gemv_t_helper(name, itype, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc) // clang-format on + +// clang-format off +#define instantiate_gemv_t(name, itype, bm, bn, sm, sn, tm, tn) \ + instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, 0) \ + instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, 1) // clang-format on + +// clang-format off +#define instantiate_gemv_t_blocks(name, itype) \ + instantiate_gemv_t(name, itype, 1, 1, 8, 4, 4, 1) \ + instantiate_gemv_t(name, itype, 1, 2, 8, 4, 4, 4) \ + instantiate_gemv_t(name, itype, 1, 1, 8, 4, 8, 1) \ + instantiate_gemv_t(name, itype, 1, 1, 8, 4, 8, 4) \ + instantiate_gemv_t(name, itype, 1, 2, 8, 4, 8, 4) \ + instantiate_gemv_t(name, itype, 1, 4, 8, 4, 8, 4) // clang-format on + +// clang-format off +instantiate_gemv_t_blocks(float32, float); +instantiate_gemv_t_blocks(float16, half); +instantiate_gemv_t_blocks(bfloat16, bfloat16_t); // clang-format on \ No newline at end of file diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 1858715c2..4dd6cd715 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -786,38 +786,47 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { // Determine dispatch kernel int tm = 4, tn = 4; - int bm, bn, n_out_per_tgp; + int sm = 1, sn = 32; + int bm = 1, bn = 1; + int n_out_per_tgp; std::ostringstream kname; if (transpose_mat) { - bm = 8; - bn = 8; - if (out_vector_len >= 24576) { - bn = 128; - } else if (out_vector_len >= 16384) { - bn = 64; - } else if (out_vector_len >= 8192) { + if (in_vector_len >= 8192 && out_vector_len >= 2048) { + sm = 4; + sn = 8; + } else { + sm = 8; + sn = 4; + } + + if (out_vector_len >= 2048) { bn = 16; + } else if (out_vector_len >= 512) { + bn = 4; + } else { + bn = 2; } // Specialized kernel for very small outputs tn = out_vector_len < tn ? 1 : tn; - n_out_per_tgp = bn * tn; + n_out_per_tgp = bn * sn * tn; kname << "gemv_t_" << type_to_name(out); } else { bm = out_vector_len >= 4096 ? 8 : 4; - bn = 32; + sn = 32; // Specialized kernel for very small outputs tm = out_vector_len < tm ? 1 : tm; - n_out_per_tgp = bm * tm; + n_out_per_tgp = bm * sm * tm; kname << "gemv_" << type_to_name(out); } - kname << "_bm" << bm << "_bn" << bn << "_tm" << tm << "_tn" << tn; + kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm" + << tm << "_tn" << tn; kname << "_nc" << !contiguous_kernel << "_axpby0"; // Encode and dispatch kernel @@ -826,7 +835,7 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { compute_encoder->setComputePipelineState(kernel); int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; - MTL::Size group_dims = MTL::Size(bn, bm, 1); + MTL::Size group_dims = MTL::Size(32, bn, bm); MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out); compute_encoder.set_input_array(mat, 0); @@ -838,11 +847,9 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { compute_encoder->setBytes(&mat_ld, sizeof(int), 6); compute_encoder->setBytes(&batch_ndim, sizeof(int), 9); - compute_encoder->setBytes(batch_shape.data(), batch_ndim * sizeof(int), 10); - compute_encoder->setBytes( - batch_strides_vec.data(), batch_ndim * sizeof(size_t), 11); - compute_encoder->setBytes( - batch_strides_mat.data(), batch_ndim * sizeof(size_t), 12); + set_vector_bytes(compute_encoder, batch_shape, 10); + set_vector_bytes(compute_encoder, batch_strides_vec, 11); + set_vector_bytes(compute_encoder, batch_strides_mat, 12); compute_encoder.dispatchThreadgroups(grid_dims, group_dims); @@ -910,15 +917,19 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { ///////////////////////////////////////////////////////////////////////////// // Init checks and prep + int M = a_pre.shape(-2); + int N = b_pre.shape(-1); + int K = a_pre.shape(-1); + // Keep a vector with copies to be cleared in the completed buffer to release // the arrays std::vector copies; - auto check_transpose = [&copies, &s](const array& arr) { + auto check_transpose = [&copies, &s](const array& arr, bool is_vector) { auto stx = arr.strides()[arr.ndim() - 2]; auto sty = arr.strides()[arr.ndim() - 1]; - if (sty == 1) { + if (sty == 1 && (!is_vector || stx == arr.shape(-1))) { return std::make_tuple(false, stx, arr); - } else if (stx == 1) { + } else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) { return std::make_tuple(true, sty, arr); } else { array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); @@ -929,12 +940,8 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { } }; - auto [transpose_a, a_cols, a] = check_transpose(a_pre); - auto [transpose_b, b_cols, b] = check_transpose(b_pre); - - int M = a.shape(-2); - int N = b.shape(-1); - int K = a.shape(-1); + auto [transpose_a, a_cols, a] = check_transpose(a_pre, M == 1); + auto [transpose_b, b_cols, b] = check_transpose(b_pre, N == 1); array c = c_pre; int ldc = c.strides()[c.ndim() - 2]; @@ -997,38 +1004,47 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { // Determine dispatch kernel int tm = 4, tn = 4; - int bm, bn, n_out_per_tgp; + int sm = 1, sn = 32; + int bm = 1, bn = 1; + int n_out_per_tgp; std::ostringstream kname; if (transpose_mat) { - bm = 8; - bn = 8; - if (out_vector_len >= 24576) { - bn = 128; - } else if (out_vector_len >= 16384) { - bn = 64; - } else if (out_vector_len >= 8192) { + if (in_vector_len >= 8192 && out_vector_len >= 2048) { + sm = 4; + sn = 8; + } else { + sm = 8; + sn = 4; + } + + if (out_vector_len >= 2048) { bn = 16; + } else if (out_vector_len >= 512) { + bn = 4; + } else { + bn = 2; } // Specialized kernel for very small outputs tn = out_vector_len < tn ? 1 : tn; - n_out_per_tgp = bn * tn; + n_out_per_tgp = bn * sn * tn; kname << "gemv_t_" << type_to_name(out); } else { bm = out_vector_len >= 4096 ? 8 : 4; - bn = 32; + sn = 32; // Specialized kernel for very small outputs tm = out_vector_len < tm ? 1 : tm; - n_out_per_tgp = bm * tm; + n_out_per_tgp = bm * sm * tm; kname << "gemv_" << type_to_name(out); } - kname << "_bm" << bm << "_bn" << bn << "_tm" << tm << "_tn" << tn; + kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm" + << tm << "_tn" << tn; kname << "_nc" << !contiguous_kernel << "_axpby1"; // Encode and dispatch kernel @@ -1037,7 +1053,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { compute_encoder->setComputePipelineState(kernel); int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; - MTL::Size group_dims = MTL::Size(bn, bm, 1); + MTL::Size group_dims = MTL::Size(32, bn, bm); MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out); compute_encoder.set_input_array(mat, 0); @@ -1344,15 +1360,19 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { ///////////////////////////////////////////////////////////////////////////// // Init checks and prep + int M = a_pre.shape(-2); + int N = b_pre.shape(-1); + int K = a_pre.shape(-1); + // Keep a vector with copies to be cleared in the completed buffer to release // the arrays std::vector copies; - auto check_transpose = [&copies, &s](const array& arr) { + auto check_transpose = [&copies, &s](const array& arr, bool is_vector) { auto stx = arr.strides()[arr.ndim() - 2]; auto sty = arr.strides()[arr.ndim() - 1]; - if (sty == 1) { + if (sty == 1 && (!is_vector || stx == arr.shape(-1))) { return std::make_tuple(false, stx, arr); - } else if (stx == 1) { + } else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) { return std::make_tuple(true, sty, arr); } else { array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); @@ -1363,33 +1383,38 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { } }; - auto [transpose_a, a_cols, a] = check_transpose(a_pre); - auto [transpose_b, b_cols, b] = check_transpose(b_pre); + auto [transpose_a, a_cols, a] = check_transpose(a_pre, M == 1); + auto [transpose_b, b_cols, b] = check_transpose(b_pre, N == 1); int lda = a_cols; int ldb = b_cols; - int M = a.shape(-2); - int N = b.shape(-1); - int K = a.shape(-1); - ///////////////////////////////////////////////////////////////////////////// // Check and collapse batch dimensions bool has_op_mask = inputs.size() > 3; bool has_out_mask = inputs.size() == 3 || inputs.size() == 5; + // Prepare kernel name + std::string out_mask_nm = has_out_mask ? type_to_name(inputs[2]) : "nomask"; + std::string op_mask_nm = has_op_mask ? type_to_name(inputs.back()) : "nomask"; + + auto get_batch_dims = [](const auto& v) { + return decltype(v){v.begin(), v.end() - 2}; + }; + std::vector batch_shape{1}; + std::vector A_batch_stride{0}; + std::vector B_batch_stride{0}; + std::vector outmask_bstride{0}; + std::vector Amask_bstride{0}; + std::vector Bmask_bstride{0}; size_t A_batch_str = 0; size_t B_batch_str = 0; std::vector batch_strides; if (out.ndim() > 2) { - auto get_batch_dims = [](const auto& v) { - return decltype(v){v.begin(), v.end() - 2}; - }; - std::vector bshape{out.shape().begin(), out.shape().end() - 2}; std::vector> bstrides; @@ -1397,14 +1422,26 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { bstrides.emplace_back(arr.strides().begin(), arr.strides().end() - 2); } - auto [bshape_c, bstrides_c] = collapse_contiguous_dims(bshape, bstrides); - batch_shape = bshape_c; - A_batch_str = bstrides_c[0].back(); - B_batch_str = bstrides_c[1].back(); + // auto [bshape_c, bstrides_c] = collapse_contiguous_dims(bshape, bstrides); + batch_shape = bshape; + A_batch_str = bstrides[0].back(); + B_batch_str = bstrides[1].back(); - for (auto& bstr : bstrides_c) { + for (auto& bstr : bstrides) { batch_strides.insert(batch_strides.end(), bstr.begin(), bstr.end()); } + + A_batch_stride = bstrides[0]; + B_batch_stride = bstrides[1]; + + if (has_out_mask) { + outmask_bstride = bstrides[2]; + } + if (has_op_mask) { + Amask_bstride = bstrides[has_out_mask + 2]; + Bmask_bstride = bstrides[has_out_mask + 3]; + } + } else { batch_strides = std::vector(inputs.size(), 0); } @@ -1412,6 +1449,174 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { size_t matrix_stride_out = size_t(M) * N; size_t batch_size_out = out.size() / (matrix_stride_out); + ///////////////////////////////////////////////////////////////////////////// + // Gemv specialization + + // Route to gemv if needed + if (std::min(M, N) == 1) { + // Collect problem info + bool is_b_matrix = N != 1; + + auto& mat = is_b_matrix ? b : a; + auto& vec = is_b_matrix ? a : b; + bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a; + int in_vector_len = K; + int out_vector_len = is_b_matrix ? N : M; + + int mat_cols = transpose_mat ? out_vector_len : in_vector_len; + int mat_rows = transpose_mat ? in_vector_len : out_vector_len; + int mat_ld = is_b_matrix ? b_cols : a_cols; + + auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride; + auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride; + + auto mask_bstrides_mat = is_b_matrix ? Bmask_bstride : Amask_bstride; + auto mask_bstrides_vec = is_b_matrix ? Amask_bstride : Bmask_bstride; + + auto mat_mask_idx = int(has_out_mask) + (is_b_matrix ? 3 : 2); + auto vec_mask_idx = int(has_out_mask) + (is_b_matrix ? 2 : 3); + + // Determine if inputs have simple batching / broadcasting + bool contiguous_kernel = (batch_shape.size() == 1); + + int batch_ndim = batch_shape.size(); + + // Determine dispatch kernel + int tm = 4, tn = 4; + int sm = 1, sn = 32; + int bm = 1, bn = 1; + int n_out_per_tgp; + std::ostringstream kname; + + if (transpose_mat) { + sm = 8; + sn = 4; + bm = 1; + bn = (block_size_ == 64 && out_vector_len >= 2048) ? 4 : 2; + tm = block_size_ == 32 ? 4 : 8; + tn = 4; + + // Specialized kernel for very small outputs + tn = out_vector_len < tn ? 1 : tn; + + n_out_per_tgp = bn * sn * tn; + kname << "gemv_t"; + + } else { + if (block_size_ == 32) { + sm = 4; + sn = 8; + bm = 2; + } else { + sm = 2; + sn = 16; + bm = out_vector_len >= 512 ? 4 : 2; + } + + // Specialized kernel for very small outputs + tm = out_vector_len < tm ? 1 : tm; + + n_out_per_tgp = bm * sm * tm; + kname << "gemv"; + } + + kname << "_outmask_" << out_mask_nm; + kname << "_opmask_" << op_mask_nm; + kname << "_" << type_to_name(out); + kname << "_bm" << bm << "_bn" << bn; + kname << "_sm" << sm << "_sn" << sn; + kname << "_tm" << tm << "_tn" << tn; + kname << "_nc" << !contiguous_kernel; + + // Encode and dispatch kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + compute_encoder->setComputePipelineState(kernel); + + int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; + MTL::Size group_dims = MTL::Size(32, bn, bm); + MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out); + + // Get mask params + std::vector mask_strides; + std::vector mask_batch_strides; + if (has_out_mask) { + auto& out_mask = inputs[2]; + + if (transpose_mat) { + mask_strides.push_back(out_mask.strides(out.shape(-2) == 1 ? -1 : -2)); + mask_strides.push_back(out_mask.strides(out.shape(-2) == 1 ? -2 : -1)); + } else { + mask_strides.push_back(out_mask.strides(out.shape(-1) == 1 ? -1 : -2)); + mask_strides.push_back(out_mask.strides(out.shape(-1) == 1 ? -2 : -1)); + } + + mask_batch_strides.insert( + mask_batch_strides.end(), + outmask_bstride.begin(), + outmask_bstride.end()); + + compute_encoder.set_input_array(out_mask, 20); + } + + if (has_op_mask) { + auto& mat_mask = inputs[mat_mask_idx]; + + if (transpose_mat) { + mask_strides.push_back(mat_mask.strides(!is_b_matrix ? -2 : -1)); + mask_strides.push_back(mat_mask.strides(!is_b_matrix ? -1 : -2)); + } else { + mask_strides.push_back(mat_mask.strides(is_b_matrix ? -2 : -1)); + mask_strides.push_back(mat_mask.strides(is_b_matrix ? -1 : -2)); + } + + mask_batch_strides.insert( + mask_batch_strides.end(), + mask_bstrides_mat.begin(), + mask_bstrides_mat.end()); + + compute_encoder.set_input_array(mat_mask, 21); + + auto& vec_mask = inputs[vec_mask_idx]; + if (transpose_mat) { + mask_strides.push_back(vec_mask.strides(vec.shape(-2) == 1 ? -1 : -2)); + mask_strides.push_back(vec_mask.strides(vec.shape(-2) == 1 ? -2 : -1)); + } else { + mask_strides.push_back(vec_mask.strides(vec.shape(-1) == 1 ? -1 : -2)); + mask_strides.push_back(vec_mask.strides(vec.shape(-1) == 1 ? -2 : -1)); + } + + mask_batch_strides.insert( + mask_batch_strides.end(), + mask_bstrides_vec.begin(), + mask_bstrides_vec.end()); + + compute_encoder.set_input_array(vec_mask, 22); + } + + // Get gemv params + compute_encoder.set_input_array(mat, 0); + compute_encoder.set_input_array(vec, 1); + compute_encoder.set_output_array(out, 3); + + compute_encoder->setBytes(&in_vector_len, sizeof(int), 4); + compute_encoder->setBytes(&out_vector_len, sizeof(int), 5); + compute_encoder->setBytes(&mat_ld, sizeof(int), 6); + compute_encoder->setBytes(&batch_ndim, sizeof(int), 9); + set_vector_bytes(compute_encoder, batch_shape, 10); + set_vector_bytes(compute_encoder, batch_strides_vec, 11); + set_vector_bytes(compute_encoder, batch_strides_mat, 12); + + set_vector_bytes(compute_encoder, mask_strides, 23); + set_vector_bytes(compute_encoder, mask_batch_strides, 24); + + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); + + d.get_command_buffer(s.index)->addCompletedHandler( + [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); + return; + } + ///////////////////////////////////////////////////////////////////////////// // Regular kernel dispatch @@ -1421,10 +1626,6 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { bool mn_aligned = M % bm == 0 && N % bn == 0; bool k_aligned = K % bk == 0; - // Prepare kernel name - std::string out_mask_nm = has_out_mask ? type_to_name(inputs[2]) : "nomask"; - std::string op_mask_nm = has_op_mask ? type_to_name(inputs.back()) : "nomask"; - std::ostringstream kname; kname << "steel_gemm_block_outmask_" << out_mask_nm << "_opmask_" << op_mask_nm << "_" << (transpose_a ? 't' : 'n') @@ -1554,15 +1755,19 @@ void GatherMM::eval_gpu(const std::vector& inputs, array& out) { ///////////////////////////////////////////////////////////////////////////// // Init checks and prep + int M = a_pre.shape(-2); + int N = b_pre.shape(-1); + int K = a_pre.shape(-1); + // Keep a vector with copies to be cleared in the completed buffer to release // the arrays std::vector copies; - auto check_transpose = [&copies, &s](const array& arr) { + auto check_transpose = [&copies, &s](const array& arr, bool is_vector) { auto stx = arr.strides()[arr.ndim() - 2]; auto sty = arr.strides()[arr.ndim() - 1]; - if (sty == 1) { + if (sty == 1 && (!is_vector || stx == arr.shape(-1))) { return std::make_tuple(false, stx, arr); - } else if (stx == 1) { + } else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) { return std::make_tuple(true, sty, arr); } else { array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); @@ -1573,16 +1778,12 @@ void GatherMM::eval_gpu(const std::vector& inputs, array& out) { } }; - auto [transpose_a, a_cols, a] = check_transpose(a_pre); - auto [transpose_b, b_cols, b] = check_transpose(b_pre); + auto [transpose_a, a_cols, a] = check_transpose(a_pre, M == 1); + auto [transpose_b, b_cols, b] = check_transpose(b_pre, N == 1); int lda = a_cols; int ldb = b_cols; - int M = a.shape(-2); - int N = b.shape(-1); - int K = a.shape(-1); - ///////////////////////////////////////////////////////////////////////////// // Check and collapse batch dimensions @@ -1673,38 +1874,47 @@ void GatherMM::eval_gpu(const std::vector& inputs, array& out) { // Determine dispatch kernel int tm = 4, tn = 4; - int bm, bn, n_out_per_tgp; + int sm = 1, sn = 32; + int bm = 1, bn = 1; + int n_out_per_tgp; std::ostringstream kname; if (transpose_mat) { - bm = 8; - bn = 8; - if (out_vector_len >= 24576) { - bn = 128; - } else if (out_vector_len >= 16384) { - bn = 64; - } else if (out_vector_len >= 8192) { + if (in_vector_len >= 8192 && out_vector_len >= 2048) { + sm = 4; + sn = 8; + } else { + sm = 8; + sn = 4; + } + + if (out_vector_len >= 2048) { bn = 16; + } else if (out_vector_len >= 512) { + bn = 4; + } else { + bn = 2; } // Specialized kernel for very small outputs tn = out_vector_len < tn ? 1 : tn; - n_out_per_tgp = bn * tn; - kname << "gemv_t_bs_" << type_to_name(out); + n_out_per_tgp = bn * sn * tn; + kname << "gemv_t_gather_" << type_to_name(out); } else { bm = out_vector_len >= 4096 ? 8 : 4; - bn = 32; + sn = 32; // Specialized kernel for very small outputs tm = out_vector_len < tm ? 1 : tm; - n_out_per_tgp = bm * tm; - kname << "gemv_bs_" << type_to_name(out); + n_out_per_tgp = bm * sm * tm; + kname << "gemv_gather_" << type_to_name(out); } - kname << "_bm" << bm << "_bn" << bn << "_tm" << tm << "_tn" << tn; + kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm" + << tm << "_tn" << tn; // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); @@ -1712,7 +1922,7 @@ void GatherMM::eval_gpu(const std::vector& inputs, array& out) { compute_encoder->setComputePipelineState(kernel); int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; - MTL::Size group_dims = MTL::Size(bn, bm, 1); + MTL::Size group_dims = MTL::Size(32, bn, bm); MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out); compute_encoder.set_input_array(mat, 0); diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index a49e9c312..b5cd9ac25 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3634,16 +3634,19 @@ std::vector BlockMaskedMM::vjp( }; // Prepare for padding if needed - int M = cotan.shape(-2); - int N = cotan.shape(-1); - int K = primals[0].shape(-1); - int align_M = (M % block_size_); - int align_N = (N % block_size_); - int align_K = (K % block_size_); + const int M = cotan.shape(-2); + const int N = cotan.shape(-1); + const int K = primals[0].shape(-1); + const int tm = (M + block_size_ - 1) / block_size_; + const int tn = (N + block_size_ - 1) / block_size_; + const int tk = (K + block_size_ - 1) / block_size_; + const int align_M = tm * block_size_ - M; + const int align_N = tn * block_size_ - N; + const int align_K = tk * block_size_ - K; // Potential intermediates - auto unmasked_lhs_grad = primals[0]; - auto unmasked_rhs_grad = primals[1]; + array unmasked_lhs_grad = primals[0]; + array unmasked_rhs_grad = primals[1]; bool unmasked_lhs_grad_calculated = false; bool unmasked_rhs_grad_calculated = false; diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index b645011b2..fdeaea98a 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -263,7 +263,9 @@ class TestBlas(mlx_tests.MLXTestCase): mlx_mat_f=lambda x: x, mlx_vec_f=lambda x: x, ): - with self.subTest(shape=shape_mat): + with self.subTest( + shape_mat=shape_mat, shape_vec=shape_vec, mat_first=mat_first + ): np.random.seed(42) scale = max(np.sum(shape_mat), 32) mat_npy = np.random.normal(0.0, 1.0 / scale, shape_mat).astype(np_dtype) @@ -794,10 +796,12 @@ class TestBlas(mlx_tests.MLXTestCase): out_ref, dout_ref = mx.vjp(f_ref, [a, b], [cotan]) out_test, dout_test = mx.vjp(f_test, [a, b], [cotan]) - mx.eval((out_ref, dout_ref, out_test, dout_test)) - self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-5).item()) + for r, t in zip(dout_ref, dout_test): + self.assertEqual(r.shape, t.shape) + self.assertTrue(mx.allclose(r, t, atol=1e-4).item()) + def run_test_mask_vjp(a, b, block_size, out_mask, a_mask, b_mask, cotan): def f_ref(a_, b_, a_mask_, b_mask_): return ref_block_masked_mm( @@ -896,6 +900,8 @@ class TestBlas(mlx_tests.MLXTestCase): (64, 64, 16, 32), (128, 128, 128, 32), (256, 256, 128, 64), + (1, 128, 128, 32), + (256, 1, 128, 64), ) for M, N, K, block_size in shapes: @@ -903,21 +909,51 @@ class TestBlas(mlx_tests.MLXTestCase): # Test broadcasting test_shape(64, 64, 64, 32, batch_A=(1, 2), batch_B=(2, 2)) + test_shape(1, 128, 128, 32, batch_A=(1, 2), batch_B=(2, 2)) + test_shape(128, 1, 128, 32, batch_A=(1, 2), batch_B=(2, 2)) - # Test gemv - a_np = np.random.normal(size=(64, 64)).astype(np.float32) - b_np = np.random.normal(size=(64,)).astype(np.float32) - mask_np = np.array([True, False]).astype(np.bool_) + a_np = np.ones((128, 256)).astype(np.float32) + b_np = np.ones((128, 1)).astype(np.float32) + d_np = np.ones((1, 256)).astype(np.float32) + a_mask_np = np.random.normal(size=(4, 8)).astype(np.float32) + b_mask_np = np.ones((4, 1)).astype(np.bool_) + d_mask_np = np.ones((1, 8)).astype(np.bool_) + c_mask_np = np.random.normal(size=(8, 1)).astype(np.float32) + e_mask_np = np.random.normal(size=(1, 4)).astype(np.float32) + + a_mask_np[a_mask_np < 0.0] = 0.0 + e_mask_np[e_mask_np < 0.0] = 0.0 + c_mask_np[c_mask_np < 0.0] = 0.0 a_mx = mx.array(a_np) b_mx = mx.array(b_np) - mask_mx = mx.array(mask_np) + d_mx = mx.array(d_np) + a_mask_mx = mx.array(a_mask_np) + b_mask_mx = mx.array(b_mask_np) + d_mask_mx = mx.array(d_mask_np) + e_mask_mx = mx.array(e_mask_np) + c_mask_mx = mx.array(c_mask_np) - c_mx = mx.block_masked_mm(a_mx, b_mx, 32, mask_mx) - c_np = a_np @ b_np - c_np[32:] = 0.0 + c_mx = mx.block_masked_mm(a_mx.T, b_mx, 32, c_mask_mx, a_mask_mx.T, b_mask_mx) + e_mx = mx.block_masked_mm(d_mx, a_mx.T, 32, e_mask_mx, d_mask_mx, a_mask_mx.T) + + a_mask_np = np.broadcast_to(np.expand_dims(a_mask_np, (-3, -1)), (4, 32, 8, 32)) + a_mask_np = a_mask_np.reshape((128, 256)) + a_np *= a_mask_np + + c_np = a_np.T @ b_np + e_np = d_np @ a_np.T + + c_mask_np = np.broadcast_to(np.expand_dims(c_mask_np, (-2)), (8, 32, 1)) + c_mask_np = c_mask_np.reshape((256, 1)) + c_np *= c_mask_np + + e_mask_np = np.broadcast_to(np.expand_dims(e_mask_np, (-1)), (1, 4, 32)) + e_mask_np = e_mask_np.reshape((1, 128)) + e_np *= e_mask_np self.assertTrue(np.allclose(c_mx, c_np, atol=1e-5)) + self.assertTrue(np.allclose(e_mx, e_np, atol=1e-5)) def test_gather_matmul(self): def np_gather_mm(a, b, lhs_indices=None, rhs_indices=None):