From d518b3b6a5ec5b2963f462519d8fa9d7cebf81b1 Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Tue, 5 Dec 2023 14:15:43 -0800 Subject: [PATCH] Fix gemv broadcasting bug (#6) * Fix broadcasting bug in gemv * Add relevant tests in test_blas.py --- mlx/backend/metal/kernels/gemv.metal | 657 +++++++++++++++++++-------- mlx/backend/metal/kernels/sort.metal | 2 +- mlx/backend/metal/matmul.cpp | 34 +- python/tests/test_blas.py | 2 + 4 files changed, 492 insertions(+), 203 deletions(-) diff --git a/mlx/backend/metal/kernels/gemv.metal b/mlx/backend/metal/kernels/gemv.metal index ea52ab57f..3b4c0a30a 100644 --- a/mlx/backend/metal/kernels/gemv.metal +++ b/mlx/backend/metal/kernels/gemv.metal @@ -3,8 +3,9 @@ #include #include -#include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/defines.h" +#include "mlx/backend/metal/kernels/utils.h" using namespace metal; @@ -12,52 +13,57 @@ using namespace metal; /// Matrix vector multiplication /////////////////////////////////////////////////////////////////////////////// -static constant constexpr int SIMD_SIZE = 32; +#define MLX_MTL_CONST static constant constexpr const -template /* Thread cols (in elements) */ -[[kernel]] void gemv( - const device T* mat [[buffer(0)]], - const device T* in_vec [[buffer(1)]], - device T* out_vec [[buffer(2)]], - const constant int& in_vec_size [[buffer(3)]], - const constant int& out_vec_size [[buffer(4)]], - const constant int& vector_batch_stride [[buffer(5)]], - const constant int& matrix_batch_stride [[buffer(6)]], - uint3 tid [[threadgroup_position_in_grid]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { +MLX_MTL_CONST int SIMD_SIZE = 32; - static_assert(BN == SIMD_SIZE, "gemv block must have a width of SIMD_SIZE"); +template < + typename T, + const int BM, /* Threadgroup rows (in threads) */ + const int BN, /* Threadgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN > /* Thread cols (in elements) */ +struct GEMVKernel { - // - The matrix of size (M = out_vec_size, N = in_vec_size) is divided up - // into blocks of (BM * TM, BN * TN) divided amoung threadgroups - // - Every thread works on a block of (TM, TN) - // - We assume each thead group is launched with (BN, BM, 1) threads - // - // 1. A thread loads TN elements each from mat along TM contiguous rows - // and the corresponding scalar from the vector - // 2. The thread then multiplies and adds to accumulate its local result for the block - // 3. At the end, each thread has accumulated results over all blocks across the rows - // These are then summed up across the threadgroup - // 4. Each threadgroup writes its accumulated BN * TN outputs - // - // Edge case handling: - // - The threadgroup with the largest tid will have blocks that exceed the matrix - // * The blocks that start outside the matrix are never read (thread results remain zero) - // * The last thread that partialy overlaps with the matrix is shifted inwards - // such that the thread block fits exactly in the matrix + static_assert(BN == SIMD_SIZE, "gemv block must have a width of SIMD_SIZE"); + + // - The matrix of size (M = out_vec_size, N = in_vec_size) is divided up + // into blocks of (BM * TM, BN * TN) divided amoung threadgroups + // - Every thread works on a block of (TM, TN) + // - We assume each thead group is launched with (BN, BM, 1) threads + // + // 1. A thread loads TN elements each from mat along TM contiguous rows + // and the corresponding scalar from the vector + // 2. The thread then multiplies and adds to accumulate its local result for the block + // 3. At the end, each thread has accumulated results over all blocks across the rows + // These are then summed up across the threadgroup + // 4. Each threadgroup writes its accumulated BN * TN outputs + // + // Edge case handling: + // - The threadgroup with the largest tid will have blocks that exceed the matrix + // * The blocks that start outside the matrix are never read (thread results remain zero) + // * The last thread that partialy overlaps with the matrix is shifted inwards + // such that the thread block fits exactly in the matrix + + MLX_MTL_CONST short tgp_mem_size = BN * TN * 2; + + static METAL_FUNC void run( + const device T* mat, + const device T* in_vec, + device T* out_vec, + const constant int& in_vec_size [[buffer(3)]], + const constant int& out_vec_size [[buffer(4)]], + threadgroup T* tgp_memory [[threadgroup(0)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + + // Appease compiler + (void)lid; - // Update batch offsets - in_vec += tid.z * vector_batch_stride; - mat += tid.z * matrix_batch_stride; - out_vec += tid.z * out_vec_size; - // Threadgroup in_vec cache - threadgroup T in_vec_block[BN][TN * 2]; + threadgroup T* in_vec_block = tgp_memory + simd_lid * TN * 2; // Thread local accumulation results thread T result[TM] = {0}; @@ -69,7 +75,7 @@ template = out_vec_size) - return; + return; // Adjust tail simdgroup to ensure in bound reads out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM; @@ -79,62 +85,304 @@ template /* Thread cols (in elements) */ +struct GEMVTKernel { + + // - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up + // into blocks of (BM * TM, BN * TN) divided amoung threadgroups + // - Every thread works on a block of (TM, TN) + // - We assume each thead group is launched with (BN, BM, 1) threads + // + // 1. A thread loads TN elements each from mat along TM contiguous rows + // and the corresponding scalar from the vector + // 2. The thread then multiplies and adds to accumulate its local result for the block + // 3. At the end, each thread has accumulated results over all blocks across the rows + // These are then summed up across the threadgroup + // 4. Each threadgroup writes its accumulated BN * TN outputs + // + // Edge case handling: + // - The threadgroup with the largest tid will have blocks that exceed the matrix + // * The blocks that start outside the matrix are never read (thread results remain zero) + // * The last thread that partialy overlaps with the matrix is shifted inwards + // such that the thread block fits exactly in the matrix + + + MLX_MTL_CONST short tgp_mem_size = BN * BM * TN; + + static METAL_FUNC void run( + const device T* mat, + const device T* in_vec, + device T* out_vec, + const constant int& in_vec_size [[buffer(3)]], + const constant int& out_vec_size [[buffer(4)]], + threadgroup T* tgp_memory [[threadgroup(0)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + + // Appease compiler + (void)simd_gid; + (void)simd_lid; + + // Thread local accumulation results + T result[TN] = {0}; + T inter[TN]; + T v_coeff[TM]; + + // Threadgroup accumulation results + threadgroup T* tgp_results = tgp_memory + lid.x * BM * TN; + + int out_col = (tid.x * BN + lid.x) * TN; + int in_row = lid.y * TM; + + // Edgecase handling + if (out_col < out_vec_size) { + + out_col = out_col + TN < out_vec_size ? out_col : out_vec_size - TN; + + // Per thread accumulation main loop + int bm = in_row; + for(; bm < in_vec_size; bm += BM * TM) { + // Adding a threadgroup_barrier improves performance slightly + // This is possibly it may help exploit cache better + threadgroup_barrier(mem_flags::mem_none); + + if(bm + TM <= in_vec_size) { + + #pragma clang loop unroll(full) + for(int tm = 0; tm < TM; tm++) { + v_coeff[tm] = in_vec[bm + tm]; + } + + #pragma clang loop unroll(full) + for(int tm = 0; tm < TM; tm++) { + for(int tn = 0; tn < TN; tn++) { + inter[tn] = mat[(bm + tm) * out_vec_size + out_col + tn]; + } + for(int tn = 0; tn < TN; tn++) { + result[tn] += v_coeff[tm] * inter[tn]; + } + } + + } else { // Edgecase handling + for(int tm = 0; bm + tm < in_vec_size; tm++) { + v_coeff[tm] = in_vec[bm + tm]; + + for(int tn = 0; tn < TN; tn++) { + inter[tn] = mat[(bm + tm) * out_vec_size + out_col + tn]; + } + for(int tn = 0; tn < TN; tn++) { + result[tn] += v_coeff[tm] * inter[tn]; + } + + } + } + } + + } + + // Threadgroup collection + + #pragma clang loop unroll(full) + for(int i = 0; i < TN; i++) { + tgp_results[lid.y * TN + i] = result[i]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Threadgroup accumulation and writing out results + if(lid.y == 0 && out_col < out_vec_size) { + + #pragma clang loop unroll(full) + for(int i = 1; i < BM; i++) { + + #pragma clang loop unroll(full) + for(int j = 0; j < TN; j++) { + result[j] += tgp_results[i * TN + j]; + } + } + + #pragma clang loop unroll(full) + for(int j = 0; j < TN; j++) { + out_vec[out_col + j] = result[j]; + } + } + + } + + +}; + +/////////////////////////////////////////////////////////////////////////////// +/// Matrix vector multiplication +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + const int BM, /* Threadgroup rows (in threads) */ + const int BN, /* Threadgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN> /* Thread cols (in elements) */ +[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + device T* out_vec [[buffer(2)]], + const constant int& in_vec_size [[buffer(3)]], + const constant int& out_vec_size [[buffer(4)]], + const constant int& vector_batch_stride [[buffer(5)]], + const constant int& matrix_batch_stride [[buffer(6)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + + using gemv_kernel = GEMVKernel; + threadgroup T tgp_memory[gemv_kernel::tgp_mem_size]; + + // Update batch offsets + in_vec += tid.z * vector_batch_stride; + mat += tid.z * matrix_batch_stride; + out_vec += tid.z * out_vec_size; + + gemv_kernel::run( + mat, + in_vec, + out_vec, + in_vec_size, + out_vec_size, + tgp_memory, + tid, + lid, + simd_gid, + simd_lid + ); + } -#define instantiate_gemv(name, itype, bm, bn, tm, tn) \ +template < + typename T, + const int BM, /* Threadgroup rows (in threads) */ + const int BN, /* Threadgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN> /* Thread cols (in elements) */ +[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv_nc( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + device T* out_vec [[buffer(2)]], + const constant int& in_vec_size [[buffer(3)]], + const constant int& out_vec_size [[buffer(4)]], + const constant int& nc_dim [[buffer(5)]], + const device int* nc_shape [[buffer(6)]], + const device size_t* nc_strides_vec [[buffer(7)]], + const device size_t* nc_strides_mat [[buffer(8)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + + using gemv_kernel = GEMVKernel; + threadgroup T tgp_memory[gemv_kernel::tgp_mem_size]; + + // Update batch offsets + in_vec += elem_to_loc(tid.z, nc_shape, nc_strides_vec, nc_dim); + mat += elem_to_loc(tid.z, nc_shape, nc_strides_mat, nc_dim); + out_vec += tid.z * out_vec_size; + + gemv_kernel::run( + mat, + in_vec, + out_vec, + in_vec_size, + out_vec_size, + tgp_memory, + tid, + lid, + simd_gid, + simd_lid + ); + +} + + +#define instantiate_gemv_c(name, itype, bm, bn, tm, tn) \ template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \ [[kernel]] void gemv( \ const device itype* mat [[buffer(0)]], \ @@ -145,28 +393,51 @@ template ( \ + const device itype* mat [[buffer(0)]], \ + const device itype* vec [[buffer(1)]], \ + device itype* out [[buffer(2)]], \ + const constant int& in_vec_size [[buffer(3)]], \ + const constant int& out_vec_size [[buffer(4)]], \ + const constant int& nc_dim [[buffer(5)]], \ + const device int* nc_shape [[buffer(6)]], \ + const device size_t* nc_strides_vec [[buffer(7)]], \ + const device size_t* nc_strides_mat [[buffer(8)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); -instantiate_gemv_blocks(float32, float) -instantiate_gemv_blocks(float16, half) -instantiate_gemv_blocks(bfloat16, bfloat16_t) +#define instantiate_gemv(name, itype, bm, bn, tm, tn) \ + instantiate_gemv_c(name, itype, bm, bn, tm, tn) \ + instantiate_gemv_nc(name, itype, bm, bn, tm, tn) + +#define instantiate_gemv_blocks(name, itype) \ + instantiate_gemv(name, itype, 4, 32, 1, 4) \ + instantiate_gemv(name, itype, 4, 32, 4, 4) \ + instantiate_gemv(name, itype, 8, 32, 4, 4) + +instantiate_gemv_blocks(float32, float); +instantiate_gemv_blocks(float16, half); +instantiate_gemv_blocks(bfloat16, bfloat16_t); /////////////////////////////////////////////////////////////////////////////// /// Vector matrix multiplication /////////////////////////////////////////////////////////////////////////////// -template /* Thread cols (in elements) */ -[[kernel]] void gemv_t( +template < + typename T, + const int BM, /* Threadgroup rows (in threads) */ + const int BN, /* Threadgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN> /* Thread cols (in elements) */ +[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv_t( const device T* mat [[buffer(0)]], const device T* in_vec [[buffer(1)]], device T* out_vec [[buffer(2)]], @@ -175,110 +446,77 @@ template ; + threadgroup T tgp_memory[gemv_kernel::tgp_mem_size]; - // Update batch offsets - in_vec += tid.z * vector_batch_stride; - mat += tid.z * matrix_batch_stride; - out_vec += tid.z * out_vec_size; - - // Thread local accumulation results - T result[TN] = {0}; - T inter[TN]; - T v_coeff[TM]; + // Update batch offsets + in_vec += tid.z * vector_batch_stride; + mat += tid.z * matrix_batch_stride; + out_vec += tid.z * out_vec_size; - // Threadgroup accumulation results - threadgroup T tgp_results[BN][BM][TM]; + gemv_kernel::run( + mat, + in_vec, + out_vec, + in_vec_size, + out_vec_size, + tgp_memory, + tid, + lid, + simd_gid, + simd_lid + ); +} - int out_col = (tid.x * BN + lid.x) * TN; - int in_row = lid.y * TM; +template < + typename T, + const int BM, /* Threadgroup rows (in threads) */ + const int BN, /* Threadgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN> /* Thread cols (in elements) */ +[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv_t_nc( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + device T* out_vec [[buffer(2)]], + const constant int& in_vec_size [[buffer(3)]], + const constant int& out_vec_size [[buffer(4)]], + const constant int& nc_dim [[buffer(5)]], + const device int* nc_shape [[buffer(6)]], + const device size_t* nc_strides_vec [[buffer(7)]], + const device size_t* nc_strides_mat [[buffer(8)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { - // Edgecase handling - if (out_col < out_vec_size) { - // Edgecase handling - out_col = out_col + TN < out_vec_size ? out_col : out_vec_size - TN; + using gemv_kernel = GEMVTKernel; + threadgroup T tgp_memory[gemv_kernel::tgp_mem_size]; - // Per thread accumulation main loop - int bm = in_row; - for(; bm < in_vec_size; bm += BM * TM) { - // Adding a threadgroup_barrier improves performance slightly - // This is possibly it may help exploit cache better - threadgroup_barrier(mem_flags::mem_none); + // Update batch offsets + in_vec += elem_to_loc(tid.z, nc_shape, nc_strides_vec, nc_dim); + mat += elem_to_loc(tid.z, nc_shape, nc_strides_mat, nc_dim); + out_vec += tid.z * out_vec_size; - if(bm + TM <= in_vec_size) { - - #pragma clang loop unroll(full) - for(int tm = 0; tm < TM; tm++) { - v_coeff[tm] = in_vec[bm + tm]; - } - - #pragma clang loop unroll(full) - for(int tm = 0; tm < TM; tm++) { - for(int tn = 0; tn < TN; tn++) { - inter[tn] = mat[(bm + tm) * out_vec_size + out_col + tn]; - } - for(int tn = 0; tn < TN; tn++) { - result[tn] += v_coeff[tm] * inter[tn]; - } - } - - } else { // Edgecase handling - for(int tm = 0; bm + tm < in_vec_size; tm++) { - v_coeff[tm] = in_vec[bm + tm]; - - for(int tn = 0; tn < TN; tn++) { - inter[tn] = mat[(bm + tm) * out_vec_size + out_col + tn]; - } - for(int tn = 0; tn < TN; tn++) { - result[tn] += v_coeff[tm] * inter[tn]; - } - } - } - } - - } - - // Threadgroup collection - for(int i = 0; i < TN; i++) { - tgp_results[lid.x][lid.y][i] = result[i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - if(lid.y == 0 && out_col < out_vec_size) { - // Threadgroup accumulation - #pragma clang loop unroll(full) - for(int i = 1; i < BM; i++) { - for(int j = 0; j < TN; j++) { - result[j] += tgp_results[lid.x][i][j]; - } - } - - for(int j = 0; j < TN; j++) { - out_vec[out_col + j] = result[j]; - } - } + gemv_kernel::run( + mat, + in_vec, + out_vec, + in_vec_size, + out_vec_size, + tgp_memory, + tid, + lid, + simd_gid, + simd_lid + ); } -#define instantiate_gemv_t(name, itype, bm, bn, tm, tn) \ +#define instantiate_gemv_t_c(name, itype, bm, bn, tm, tn) \ template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \ [[kernel]] void gemv_t( \ const device itype* mat [[buffer(0)]], \ @@ -289,16 +527,39 @@ template ( \ + const device itype* mat [[buffer(0)]], \ + const device itype* vec [[buffer(1)]], \ + device itype* out [[buffer(2)]], \ + const constant int& in_vec_size [[buffer(3)]], \ + const constant int& out_vec_size [[buffer(4)]], \ + const constant int& nc_dim [[buffer(5)]], \ + const device int* nc_shape [[buffer(6)]], \ + const device size_t* nc_strides_vec [[buffer(7)]], \ + const device size_t* nc_strides_mat [[buffer(8)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); + +#define instantiate_gemv_t(name, itype, bm, bn, tm, tn) \ + instantiate_gemv_t_c(name, itype, bm, bn, tm, tn) \ + instantiate_gemv_t_nc(name, itype, bm, bn, tm, tn) #define instantiate_gemv_t_blocks(name, itype) \ - instantiate_gemv_t(name, itype, 8, 8, 4, 1) \ - instantiate_gemv_t(name, itype, 8, 8, 4, 4) \ - instantiate_gemv_t(name, itype, 8, 16, 4, 4) \ - instantiate_gemv_t(name, itype, 8, 32, 4, 4) \ - instantiate_gemv_t(name, itype, 8, 64, 4, 4) \ - instantiate_gemv_t(name, itype, 8, 128, 4, 4) + instantiate_gemv_t(name, itype, 8, 8, 4, 1) \ + instantiate_gemv_t(name, itype, 8, 8, 4, 4) \ + instantiate_gemv_t(name, itype, 8, 16, 4, 4) \ + instantiate_gemv_t(name, itype, 8, 32, 4, 4) \ + instantiate_gemv_t(name, itype, 8, 64, 4, 4) \ + instantiate_gemv_t(name, itype, 8, 128, 4, 4) -instantiate_gemv_t_blocks(float32, float) -instantiate_gemv_t_blocks(float16, half) -instantiate_gemv_t_blocks(bfloat16, bfloat16_t) +instantiate_gemv_t_blocks(float32, float); +instantiate_gemv_t_blocks(float16, half); +instantiate_gemv_t_blocks(bfloat16, bfloat16_t); diff --git a/mlx/backend/metal/kernels/sort.metal b/mlx/backend/metal/kernels/sort.metal index 7e2c0442a..3aa54de3e 100644 --- a/mlx/backend/metal/kernels/sort.metal +++ b/mlx/backend/metal/kernels/sort.metal @@ -9,7 +9,7 @@ #define MLX_MTL_CONST static constant constexpr const #define MLX_MTL_LOOP_UNROLL _Pragma("clang loop unroll(full)") -using namespace metal;\ +using namespace metal; // Based on GPU merge sort algorithm at https://github.com/NVIDIA/cccl/tree/main/cub/cub diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 22472c638..864181da9 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -343,10 +343,18 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { int mat_rows = transpose_mat ? in_vector_len : out_vector_len; int batch_size_mat = mat.data_size() / (mat_cols * mat_rows); - int stride_mat = batch_size_mat == batch_size_out ? mat_cols * mat_rows : 0; + int stride_mat = batch_size_mat == 1 ? 0 : mat_cols * mat_rows; int batch_size_vec = vec.data_size() / in_vector_len; - int stride_vec = batch_size_vec == batch_size_out ? in_vector_len : 0; + int stride_vec = batch_size_vec == 1 ? 0 : in_vector_len; + + // Determine if inputs have simple batching / broadcasting + bool contiguous_kernel = + (batch_size_out == std::max(batch_size_mat, batch_size_vec) && + (batch_size_mat == batch_size_vec || + std::min(batch_size_mat, batch_size_vec) == 1)); + + int nc_dim = out.ndim() - 2; // Determine dispatch kernel int tm = 4, tn = 4; @@ -383,6 +391,10 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { kname << "_bm" << bm << "_bn" << bn << "_tm" << tm << "_tn" << tn; + if (!contiguous_kernel) { + kname << "_nc"; + } + // Encode and dispatch kernel auto compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); @@ -398,8 +410,22 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { compute_encoder->setBytes(&in_vector_len, sizeof(int), 3); compute_encoder->setBytes(&out_vector_len, sizeof(int), 4); - compute_encoder->setBytes(&stride_vec, sizeof(int), 5); - compute_encoder->setBytes(&stride_mat, sizeof(int), 6); + + if (contiguous_kernel) { + compute_encoder->setBytes(&stride_vec, sizeof(int), 5); + compute_encoder->setBytes(&stride_mat, sizeof(int), 6); + } else { + // In case of complex broadcasting, we consider the shape[:-2] and + // strides [:-2] to determine the location of a batch + // nc_dim = out.ndim() - 2 + compute_encoder->setBytes(&nc_dim, sizeof(int), 5); + compute_encoder->setBytes(out.shape().data(), nc_dim * sizeof(int), 6); + compute_encoder->setBytes( + vec.strides().data(), nc_dim * sizeof(size_t), 7); + compute_encoder->setBytes( + mat.strides().data(), nc_dim * sizeof(size_t), 8); + } + compute_encoder->dispatchThreadgroups(grid_dims, group_dims); d.get_command_buffer(s.index)->addCompletedHandler( diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index 4b310ed33..7a7133e3d 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -340,6 +340,7 @@ class TestBlas(mlx_tests.MLXTestCase): ((32, 128, 64), (32, 64, 1)), ((128, 64), (32, 64, 1)), ((32, 128, 64), (64, 1)), + ((2, 1, 8, 1, 6, 128), (2, 1, 8, 4, 128, 1)), ): self.__gemv_test( shape_mat, shape_vec, mat_first=True, np_dtype=np_dtype @@ -350,6 +351,7 @@ class TestBlas(mlx_tests.MLXTestCase): ((32, 1, 128), (32, 128, 64)), ((32, 1, 128), (128, 64)), ((1, 128), (32, 128, 64)), + ((1, 8, 4, 1, 128), (1, 8, 1, 128, 6)), ): self.__gemv_test( shape_mat, shape_vec, mat_first=False, np_dtype=np_dtype