diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 426f6aefe..b954eb7e3 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -28,10 +28,12 @@ void explicit_gemm_conv_ND_gpu( const array& wt, array out, const MLXConvParams& conv_params) { + // Get gemm shapes + int implicit_M = out.size() / conv_params.O; + int implicit_K = wt.size() / conv_params.O; + int implicit_N = conv_params.O; // Prepare unfolding array - std::vector unfolded_shape = { - static_cast(out.size() / conv_params.O), - static_cast(wt.size() / conv_params.O)}; + std::vector unfolded_shape{implicit_M, implicit_K}; array in_unfolded(unfolded_shape, in.dtype(), nullptr, {}); in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes())); @@ -59,20 +61,29 @@ void explicit_gemm_conv_ND_gpu( compute_encoder->dispatchThreads(grid_dims, group_dims); + // Reshape weight + std::vector wt_reshape{implicit_K, implicit_N}; + std::vector wt_restride{1, static_cast(implicit_K)}; + array wt_reshaped(wt_reshape, wt.dtype(), nullptr, {}); + auto wt_flags = wt.flags(); + wt_flags.row_contiguous = false; + wt_flags.col_contiguous = true; + wt_reshaped.copy_shared_buffer(wt, wt_restride, wt_flags, wt.data_size()); + // Perform gemm - std::vector copies; + std::vector copies = {in_unfolded, wt_reshaped}; return steel_matmul( s, d, /*a = */ in_unfolded, - /*b = */ wt, + /*b = */ wt_reshaped, /*c = */ out, - /*M = */ unfolded_shape[0], - /*N = */ conv_params.O, - /*K = */ unfolded_shape[1], + /*M = */ implicit_M, + /*N = */ implicit_N, + /*K = */ implicit_K, /*batch_size_out = */ 1, - /*a_cols = */ unfolded_shape[1], - /*b_cols = */ unfolded_shape[1], + /*a_cols = */ implicit_K, + /*b_cols = */ implicit_K, /*a_transposed = */ false, /*b_transposed = */ true, /*copies = */ copies); diff --git a/mlx/backend/metal/kernels/gemv.metal b/mlx/backend/metal/kernels/gemv.metal index e0325a1f1..8b629ca6a 100644 --- a/mlx/backend/metal/kernels/gemv.metal +++ b/mlx/backend/metal/kernels/gemv.metal @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include #include @@ -22,7 +22,8 @@ template < const int BM, /* Threadgroup rows (in threads) */ const int BN, /* Threadgroup cols (in threads) */ const int TM, /* Thread rows (in elements) */ - const int TN > /* Thread cols (in elements) */ + const int TN , /* Thread cols (in elements) */ + const bool kDoAxpby> /* Do out = alpha * out + beta * bias */ struct GEMVKernel { static_assert(BN == SIMD_SIZE, "gemv block must have a width of SIMD_SIZE"); @@ -48,11 +49,16 @@ struct GEMVKernel { MLX_MTL_CONST short tgp_mem_size = BN * TN * 2; static METAL_FUNC void run( - const device T* mat, - const device T* in_vec, - device T* out_vec, - const constant int& in_vec_size [[buffer(3)]], - const constant int& out_vec_size [[buffer(4)]], + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + const device T* bias [[buffer(2)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& marix_ld [[buffer(6)]], + const constant float& alpha [[buffer(7)]], + const constant float& beta [[buffer(8)]], + const constant int& bias_stride [[buffer(14)]], threadgroup T* tgp_memory [[threadgroup(0)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], @@ -81,7 +87,7 @@ struct GEMVKernel { out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM; // Advance matrix - mat += out_row * in_vec_size; + mat += out_row * marix_ld; // Loop over in_vec in blocks of BN * TN for(int bn = simd_lid * TN; bn < in_vec_size; bn += BN * TN) { @@ -124,14 +130,14 @@ struct GEMVKernel { if(bn + TN <= in_vec_size) { #pragma clang loop unroll(full) for(int tn = 0; tn < TN; tn++) { - inter[tn] = mat[tm * in_vec_size + bn + tn]; + inter[tn] = mat[tm * marix_ld + bn + tn]; } } else { // Edgecase #pragma clang loop unroll(full) for(int tn = 0; tn < TN; tn++) { int col_idx = (bn + tn) < in_vec_size ? (bn + tn) : (in_vec_size - 1); - inter[tn] = mat[tm * in_vec_size + col_idx]; + inter[tn] = mat[tm * marix_ld + col_idx]; } } @@ -154,7 +160,13 @@ struct GEMVKernel { #pragma clang loop unroll(full) for(int tm = 0; tm < TM; tm++) { - out_vec[out_row + tm] = result[tm]; + if(kDoAxpby) { + out_vec[out_row + tm] = + static_cast(alpha) * result[tm] + + static_cast(beta) * bias[(out_row + tm) * bias_stride]; + } else { + out_vec[out_row + tm] = result[tm]; + } } } @@ -172,7 +184,8 @@ template < const int BM, /* Threadgroup rows (in threads) */ const int BN, /* Threadgroup cols (in threads) */ const int TM, /* Thread rows (in elements) */ - const int TN > /* Thread cols (in elements) */ + const int TN, /* Thread cols (in elements) */ + const bool kDoAxpby> /* Do out = alpha * out + beta * bias */ struct GEMVTKernel { // - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up @@ -197,11 +210,16 @@ struct GEMVTKernel { MLX_MTL_CONST short tgp_mem_size = BN * BM * TN; static METAL_FUNC void run( - const device T* mat, - const device T* in_vec, - device T* out_vec, - const constant int& in_vec_size [[buffer(3)]], - const constant int& out_vec_size [[buffer(4)]], + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + const device T* bias [[buffer(2)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& marix_ld [[buffer(6)]], + const constant float& alpha [[buffer(7)]], + const constant float& beta [[buffer(8)]], + const constant int& bias_stride [[buffer(14)]], threadgroup T* tgp_memory [[threadgroup(0)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], @@ -245,7 +263,7 @@ struct GEMVTKernel { #pragma clang loop unroll(full) for(int tm = 0; tm < TM; tm++) { for(int tn = 0; tn < TN; tn++) { - inter[tn] = mat[(bm + tm) * out_vec_size + out_col + tn]; + inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; } for(int tn = 0; tn < TN; tn++) { result[tn] += v_coeff[tm] * inter[tn]; @@ -257,7 +275,7 @@ struct GEMVTKernel { v_coeff[tm] = in_vec[bm + tm]; for(int tn = 0; tn < TN; tn++) { - inter[tn] = mat[(bm + tm) * out_vec_size + out_col + tn]; + inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; } for(int tn = 0; tn < TN; tn++) { result[tn] += v_coeff[tm] * inter[tn]; @@ -292,13 +310,17 @@ struct GEMVTKernel { #pragma clang loop unroll(full) for(int j = 0; j < TN; j++) { - out_vec[out_col + j] = result[j]; + + if(kDoAxpby) { + out_vec[out_col + j] = + static_cast(alpha) * result[j] + + static_cast(beta) * bias[(out_col + j) * bias_stride]; + } else { + out_vec[out_col + j] = result[j]; + } } } - } - - }; /////////////////////////////////////////////////////////////////////////////// @@ -310,78 +332,64 @@ template < const int BM, /* Threadgroup rows (in threads) */ const int BN, /* Threadgroup cols (in threads) */ const int TM, /* Thread rows (in elements) */ - const int TN> /* Thread cols (in elements) */ + const int TN, /* Thread cols (in elements) */ + const bool kDoNCBatch, /* Batch ndim > 1 */ + const bool kDoAxpby> /* Do out = alpha * out + beta * bias */ [[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv( const device T* mat [[buffer(0)]], const device T* in_vec [[buffer(1)]], - device T* out_vec [[buffer(2)]], - const constant int& in_vec_size [[buffer(3)]], - const constant int& out_vec_size [[buffer(4)]], - const constant int& vector_batch_stride [[buffer(5)]], - const constant int& matrix_batch_stride [[buffer(6)]], + const device T* bias [[buffer(2)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& marix_ld [[buffer(6)]], + const constant float& alpha [[buffer(7)]], + const constant float& beta [[buffer(8)]], + const constant int& batch_ndim [[buffer(9)]], + const constant int* batch_shape [[buffer(10)]], + const constant size_t* vector_batch_stride [[buffer(11)]], + const constant size_t* matrix_batch_stride [[buffer(12)]], + const constant size_t* bias_batch_stride [[buffer(13)]], + const constant int& bias_stride [[buffer(14)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { - using gemv_kernel = GEMVKernel; + using gemv_kernel = GEMVKernel; threadgroup T tgp_memory[gemv_kernel::tgp_mem_size]; // Update batch offsets - in_vec += tid.z * vector_batch_stride; - mat += tid.z * matrix_batch_stride; - out_vec += tid.z * out_vec_size; - - gemv_kernel::run( - mat, - in_vec, - out_vec, - in_vec_size, - out_vec_size, - tgp_memory, - tid, - lid, - simd_gid, - simd_lid - ); - -} - -template < - typename T, - const int BM, /* Threadgroup rows (in threads) */ - const int BN, /* Threadgroup cols (in threads) */ - const int TM, /* Thread rows (in elements) */ - const int TN> /* Thread cols (in elements) */ -[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv_nc( - const device T* mat [[buffer(0)]], - const device T* in_vec [[buffer(1)]], - device T* out_vec [[buffer(2)]], - const constant int& in_vec_size [[buffer(3)]], - const constant int& out_vec_size [[buffer(4)]], - const constant int& nc_dim [[buffer(5)]], - const device int* nc_shape [[buffer(6)]], - const device size_t* nc_strides_vec [[buffer(7)]], - const device size_t* nc_strides_mat [[buffer(8)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - - using gemv_kernel = GEMVKernel; - threadgroup T tgp_memory[gemv_kernel::tgp_mem_size]; - - // Update batch offsets - in_vec += elem_to_loc(tid.z, nc_shape, nc_strides_vec, nc_dim); - mat += elem_to_loc(tid.z, nc_shape, nc_strides_mat, nc_dim); + if(kDoNCBatch) { + in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim); + mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim); + + if(kDoAxpby) { + bias += elem_to_loc(tid.z, batch_shape, bias_batch_stride, batch_ndim); + } + + } else { + in_vec += tid.z * vector_batch_stride[0]; + mat += tid.z * matrix_batch_stride[0]; + + if(kDoAxpby) { + bias += tid.z * bias_batch_stride[0]; + } + } + out_vec += tid.z * out_vec_size; gemv_kernel::run( mat, in_vec, + bias, out_vec, in_vec_size, out_vec_size, + marix_ld, + alpha, + beta, + bias_stride, tgp_memory, tid, lid, @@ -392,41 +400,34 @@ template < } -#define instantiate_gemv_c(name, itype, bm, bn, tm, tn) \ - template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \ - [[kernel]] void gemv( \ +#define instantiate_gemv_helper(name, itype, bm, bn, tm, tn, nc, axpby) \ + template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc" #nc "_axpby" #axpby)]] \ + [[kernel]] void gemv( \ const device itype* mat [[buffer(0)]], \ - const device itype* vec [[buffer(1)]], \ - device itype* out [[buffer(2)]], \ - const constant int& in_vec_size [[buffer(3)]], \ - const constant int& out_vec_size [[buffer(4)]], \ - const constant int& vector_batch_stride [[buffer(5)]], \ - const constant int& matrix_batch_stride [[buffer(6)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]], \ - uint simd_gid [[simdgroup_index_in_threadgroup]], \ - uint simd_lid [[thread_index_in_simdgroup]]); - -#define instantiate_gemv_nc(name, itype, bm, bn, tm, tn) \ - template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc")]] \ - [[kernel]] void gemv_nc( \ - const device itype* mat [[buffer(0)]], \ - const device itype* vec [[buffer(1)]], \ - device itype* out [[buffer(2)]], \ - const constant int& in_vec_size [[buffer(3)]], \ - const constant int& out_vec_size [[buffer(4)]], \ - const constant int& nc_dim [[buffer(5)]], \ - const device int* nc_shape [[buffer(6)]], \ - const device size_t* nc_strides_vec [[buffer(7)]], \ - const device size_t* nc_strides_mat [[buffer(8)]], \ + const device itype* in_vec [[buffer(1)]], \ + const device itype* bias [[buffer(2)]], \ + device itype* out_vec [[buffer(3)]], \ + const constant int& in_vec_size [[buffer(4)]], \ + const constant int& out_vec_size [[buffer(5)]], \ + const constant int& marix_ld [[buffer(6)]], \ + const constant float& alpha [[buffer(7)]], \ + const constant float& beta [[buffer(8)]], \ + const constant int& batch_ndim [[buffer(9)]], \ + const constant int* batch_shape [[buffer(10)]], \ + const constant size_t* vector_batch_stride [[buffer(11)]], \ + const constant size_t* matrix_batch_stride [[buffer(12)]], \ + const constant size_t* bias_batch_stride [[buffer(13)]], \ + const constant int& bias_stride [[buffer(14)]], \ uint3 tid [[threadgroup_position_in_grid]], \ uint3 lid [[thread_position_in_threadgroup]], \ uint simd_gid [[simdgroup_index_in_threadgroup]], \ uint simd_lid [[thread_index_in_simdgroup]]); #define instantiate_gemv(name, itype, bm, bn, tm, tn) \ - instantiate_gemv_c(name, itype, bm, bn, tm, tn) \ - instantiate_gemv_nc(name, itype, bm, bn, tm, tn) + instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 0) \ + instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 1) \ + instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 0) \ + instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 1) #define instantiate_gemv_blocks(name, itype) \ instantiate_gemv(name, itype, 4, 32, 1, 4) \ @@ -446,77 +447,64 @@ template < const int BM, /* Threadgroup rows (in threads) */ const int BN, /* Threadgroup cols (in threads) */ const int TM, /* Thread rows (in elements) */ - const int TN> /* Thread cols (in elements) */ + const int TN, /* Thread cols (in elements) */ + const bool kDoNCBatch, /* Batch ndim > 1 */ + const bool kDoAxpby> /* Do out = alpha * out + beta * bias */ [[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv_t( const device T* mat [[buffer(0)]], const device T* in_vec [[buffer(1)]], - device T* out_vec [[buffer(2)]], - const constant int& in_vec_size [[buffer(3)]], - const constant int& out_vec_size [[buffer(4)]], - const constant int& vector_batch_stride [[buffer(5)]], - const constant int& matrix_batch_stride [[buffer(6)]], + const device T* bias [[buffer(2)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& marix_ld [[buffer(6)]], + const constant float& alpha [[buffer(7)]], + const constant float& beta [[buffer(8)]], + const constant int& batch_ndim [[buffer(9)]], + const constant int* batch_shape [[buffer(10)]], + const constant size_t* vector_batch_stride [[buffer(11)]], + const constant size_t* matrix_batch_stride [[buffer(12)]], + const constant size_t* bias_batch_stride [[buffer(13)]], + const constant int& bias_stride [[buffer(14)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { - using gemv_kernel = GEMVTKernel; + using gemv_kernel = GEMVTKernel; threadgroup T tgp_memory[gemv_kernel::tgp_mem_size]; // Update batch offsets - in_vec += tid.z * vector_batch_stride; - mat += tid.z * matrix_batch_stride; - out_vec += tid.z * out_vec_size; - - gemv_kernel::run( - mat, - in_vec, - out_vec, - in_vec_size, - out_vec_size, - tgp_memory, - tid, - lid, - simd_gid, - simd_lid - ); -} - -template < - typename T, - const int BM, /* Threadgroup rows (in threads) */ - const int BN, /* Threadgroup cols (in threads) */ - const int TM, /* Thread rows (in elements) */ - const int TN> /* Thread cols (in elements) */ -[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv_t_nc( - const device T* mat [[buffer(0)]], - const device T* in_vec [[buffer(1)]], - device T* out_vec [[buffer(2)]], - const constant int& in_vec_size [[buffer(3)]], - const constant int& out_vec_size [[buffer(4)]], - const constant int& nc_dim [[buffer(5)]], - const device int* nc_shape [[buffer(6)]], - const device size_t* nc_strides_vec [[buffer(7)]], - const device size_t* nc_strides_mat [[buffer(8)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - - using gemv_kernel = GEMVTKernel; - threadgroup T tgp_memory[gemv_kernel::tgp_mem_size]; - - // Update batch offsets - in_vec += elem_to_loc(tid.z, nc_shape, nc_strides_vec, nc_dim); - mat += elem_to_loc(tid.z, nc_shape, nc_strides_mat, nc_dim); + if(kDoNCBatch) { + in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim); + mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim); + + if(kDoAxpby) { + bias += elem_to_loc(tid.z, batch_shape, bias_batch_stride, batch_ndim); + } + + } else { + in_vec += tid.z * vector_batch_stride[0]; + mat += tid.z * matrix_batch_stride[0]; + + if(kDoAxpby) { + bias += tid.z * bias_batch_stride[0]; + } + } + out_vec += tid.z * out_vec_size; gemv_kernel::run( mat, in_vec, + bias, out_vec, in_vec_size, out_vec_size, + marix_ld, + alpha, + beta, + bias_stride, tgp_memory, tid, lid, @@ -526,41 +514,34 @@ template < } -#define instantiate_gemv_t_c(name, itype, bm, bn, tm, tn) \ - template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \ - [[kernel]] void gemv_t( \ +#define instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, nc, axpby) \ + template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc" #nc "_axpby" #axpby)]] \ + [[kernel]] void gemv_t( \ const device itype* mat [[buffer(0)]], \ - const device itype* vec [[buffer(1)]], \ - device itype* out [[buffer(2)]], \ - const constant int& in_vec_size [[buffer(3)]], \ - const constant int& out_vec_size [[buffer(4)]], \ - const constant int& vector_batch_stride [[buffer(5)]], \ - const constant int& matrix_batch_stride [[buffer(6)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]], \ - uint simd_gid [[simdgroup_index_in_threadgroup]], \ - uint simd_lid [[thread_index_in_simdgroup]]); - -#define instantiate_gemv_t_nc(name, itype, bm, bn, tm, tn) \ - template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc")]] \ - [[kernel]] void gemv_t_nc( \ - const device itype* mat [[buffer(0)]], \ - const device itype* vec [[buffer(1)]], \ - device itype* out [[buffer(2)]], \ - const constant int& in_vec_size [[buffer(3)]], \ - const constant int& out_vec_size [[buffer(4)]], \ - const constant int& nc_dim [[buffer(5)]], \ - const device int* nc_shape [[buffer(6)]], \ - const device size_t* nc_strides_vec [[buffer(7)]], \ - const device size_t* nc_strides_mat [[buffer(8)]], \ + const device itype* in_vec [[buffer(1)]], \ + const device itype* bias [[buffer(2)]], \ + device itype* out_vec [[buffer(3)]], \ + const constant int& in_vec_size [[buffer(4)]], \ + const constant int& out_vec_size [[buffer(5)]], \ + const constant int& marix_ld [[buffer(6)]], \ + const constant float& alpha [[buffer(7)]], \ + const constant float& beta [[buffer(8)]], \ + const constant int& batch_ndim [[buffer(9)]], \ + const constant int* batch_shape [[buffer(10)]], \ + const constant size_t* vector_batch_stride [[buffer(11)]], \ + const constant size_t* matrix_batch_stride [[buffer(12)]], \ + const constant size_t* bias_batch_stride [[buffer(13)]], \ + const constant int& bias_stride [[buffer(14)]], \ uint3 tid [[threadgroup_position_in_grid]], \ uint3 lid [[thread_position_in_threadgroup]], \ uint simd_gid [[simdgroup_index_in_threadgroup]], \ uint simd_lid [[thread_index_in_simdgroup]]); #define instantiate_gemv_t(name, itype, bm, bn, tm, tn) \ - instantiate_gemv_t_c(name, itype, bm, bn, tm, tn) \ - instantiate_gemv_t_nc(name, itype, bm, bn, tm, tn) + instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 0, 0) \ + instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 0, 1) \ + instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 1, 0) \ + instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 1, 1) #define instantiate_gemv_t_blocks(name, itype) \ instantiate_gemv_t(name, itype, 8, 8, 4, 1) \ diff --git a/mlx/backend/metal/kernels/steel/gemm/gemm.h b/mlx/backend/metal/kernels/steel/gemm/gemm.h index 2e2b0f838..3ce44f941 100644 --- a/mlx/backend/metal/kernels/steel/gemm/gemm.h +++ b/mlx/backend/metal/kernels/steel/gemm/gemm.h @@ -140,7 +140,7 @@ struct GEMMKernel { static METAL_FUNC void run( const device T* A [[buffer(0)]], const device T* B [[buffer(1)]], - device U* C [[buffer(2)]], + device U* D [[buffer(2)]], const constant GEMMParams* params [[buffer(3)]], threadgroup T* As [[threadgroup(0)]], threadgroup T* Bs [[threadgroup(1)]], @@ -167,7 +167,7 @@ struct GEMMKernel { A += transpose_a ? c_row : c_row * params->lda; B += transpose_b ? c_col * params->ldb : c_col; - C += c_row * params->ldc + c_col; + D += c_row * params->ldd + c_col; // Prepare threadgroup loading operations thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); @@ -214,7 +214,7 @@ struct GEMMKernel { } // Store results to device memory - mma_op.store_result(C, params->ldc); + mma_op.store_result(D, params->ldd); return; } @@ -237,7 +237,7 @@ struct GEMMKernel { tgp_bn, leftover_bk); - mma_op.store_result(C, params->ldc); + mma_op.store_result(D, params->ldd); return; } else if (tgp_bn == BN) { @@ -252,7 +252,7 @@ struct GEMMKernel { tgp_bn, leftover_bk); - mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm)); + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); return; } else if (tgp_bm == BM) { @@ -267,7 +267,7 @@ struct GEMMKernel { tgp_bn, leftover_bk); - mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm)); + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); return; } else { @@ -282,7 +282,7 @@ struct GEMMKernel { tgp_bn, leftover_bk); - mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm)); + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); return; } } diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm.metal index fb051131c..189042fbe 100644 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm.metal +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm.metal @@ -1,6 +1,7 @@ // Copyright © 2024 Apple Inc. #include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm.h" using namespace metal; @@ -23,8 +24,10 @@ template batch_stride_a * tid.z; - B += params->batch_stride_b * tid.z; - C += params->batch_stride_c * tid.z; + if(params->batch_ndim > 1) { + const constant size_t* A_bstrides = batch_strides; + const constant size_t* B_bstrides = batch_strides + params->batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); + + A += batch_offsets.x; + B += batch_offsets.y; + + } else { + A += params->batch_stride_a * tid.z; + B += params->batch_stride_b * tid.z; + } + + D += params->batch_stride_d * tid.z; gemm_kernel::run( - A, B, C, + A, B, D, params, As, Bs, simd_lane_id, simd_group_id, tid, lid @@ -57,8 +73,10 @@ template ( \ const device itype *A [[buffer(0)]], \ const device itype *B [[buffer(1)]], \ - device itype *C [[buffer(2)]], \ - const constant GEMMParams* params [[buffer(3)]], \ + device itype *D [[buffer(3)]], \ + const constant GEMMParams* params [[buffer(4)]], \ + const constant int* batch_shape [[buffer(6)]], \ + const constant size_t* batch_strides [[buffer(7)]], \ uint simd_lane_id [[thread_index_in_simdgroup]], \ uint simd_group_id [[simdgroup_index_in_threadgroup]], \ uint3 tid [[threadgroup_position_in_grid]], \ diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_addmm.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_addmm.metal index b8e131f0e..ec6efa10e 100644 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_addmm.metal +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_addmm.metal @@ -27,7 +27,10 @@ template batch_stride_a * tid.z; - B += params->batch_stride_b * tid.z; - C += params->batch_stride_c * tid.z; + if(params->batch_ndim > 1) { + const constant size_t* A_bstrides = batch_strides; + const constant size_t* B_bstrides = batch_strides + params->batch_ndim; + const constant size_t* C_bstrides = B_bstrides + params->batch_ndim; + + ulong3 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, A_bstrides, B_bstrides, C_bstrides, params->batch_ndim); + + A += batch_offsets.x; + B += batch_offsets.y; + C += batch_offsets.z; + + } else { + A += params->batch_stride_a * tid.z; + B += params->batch_stride_b * tid.z; + C += addmm_params->batch_stride_c * tid.z; + } + D += params->batch_stride_d * tid.z; const int tid_y = ((tid.y) << params->swizzle_log) + @@ -71,9 +89,10 @@ template lda; B += transpose_b ? c_col * params->ldb : c_col; - C += c_row * params->ldc + c_col * params->fdc; D += c_row * params->ldd + c_col; + C += c_row * addmm_params->ldc + c_col * addmm_params->fdc; + // Prepare threadgroup loading operations thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); @@ -83,7 +102,7 @@ template gemm_k_iterations_aligned; - const Epilogue epilogue_op(params->alpha, params->beta); + const Epilogue epilogue_op(addmm_params->alpha, addmm_params->beta); /////////////////////////////////////////////////////////////////////////////// // MNK aligned loop @@ -121,7 +140,7 @@ template ldd, C, params->ldc, params->fdc, epilogue_op); + mma_op.store_result(D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op); return; } @@ -145,7 +164,7 @@ template {}); - mma_op.store_result(D, params->ldd, C, params->ldc, params->fdc, epilogue_op); + mma_op.store_result(D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op); return; } else if (tgp_bn == BN) { @@ -163,7 +182,7 @@ template ldd, - C, params->ldc, params->fdc, + C, addmm_params->ldc, addmm_params->fdc, short2(tgp_bn, tgp_bm), epilogue_op); @@ -182,7 +201,7 @@ template ldd, - C, params->ldc, params->fdc, + C, addmm_params->ldc, addmm_params->fdc, short2(tgp_bn, tgp_bm), epilogue_op); @@ -201,7 +220,7 @@ template ldd, - C, params->ldc, params->fdc, + C, addmm_params->ldc, addmm_params->fdc, short2(tgp_bn, tgp_bm), epilogue_op); } @@ -219,7 +238,10 @@ template #define STEEL_CONST static constant constexpr const -#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") \ No newline at end of file +#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") + +METAL_FUNC ulong2 elem_to_loc_broadcast( + uint elem, + constant const int* shape, + constant const size_t* a_strides, + constant const size_t* b_strides, + int ndim) { + ulong loc_a{0}; + ulong loc_b{0}; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + int pos_in_dim = (elem % shape[i]); + elem /= shape[i]; + loc_a += pos_in_dim * a_strides[i]; + loc_b += pos_in_dim * b_strides[i]; + } + return ulong2(loc_a, loc_b); +} + +METAL_FUNC ulong3 elem_to_loc_broadcast( + uint elem, + constant const int* shape, + constant const size_t* a_strides, + constant const size_t* b_strides, + constant const size_t* c_strides, + int ndim) { + ulong loc_a{0}; + ulong loc_b{0}; + ulong loc_c{0}; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + int pos_in_dim = (elem % shape[i]); + elem /= shape[i]; + loc_a += pos_in_dim * a_strides[i]; + loc_b += pos_in_dim * b_strides[i]; + loc_c += pos_in_dim * c_strides[i]; + } + return ulong3(loc_a, loc_b, loc_c); +} \ No newline at end of file diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index d8de3d832..2c127ed8e 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -191,6 +191,70 @@ inline void mps_matmul( }); } +inline auto collapse_batches(const array& a, const array& b) { + // Get and check the shape for the batched dims + std::vector A_bshape{a.shape().begin(), a.shape().end() - 2}; + std::vector B_bshape{b.shape().begin(), b.shape().end() - 2}; + if (A_bshape != B_bshape) { + std::ostringstream msg; + msg << "[matmul] Got matrices with incorrectly broadcasted shapes: " + << "A " << a.shape() << ", B " << b.shape() << "."; + throw std::runtime_error(msg.str()); + } + + std::vector A_bstride{a.strides().begin(), a.strides().end() - 2}; + std::vector B_bstride{b.strides().begin(), b.strides().end() - 2}; + + auto [batch_shape, batch_strides] = + collapse_contiguous_dims(A_bshape, {A_bstride, B_bstride}); + + auto A_batch_stride = batch_strides[0]; + auto B_batch_stride = batch_strides[1]; + + if (batch_shape.empty()) { + batch_shape.push_back(1); + A_batch_stride.push_back(0); + B_batch_stride.push_back(0); + } + + return std::make_tuple(batch_shape, A_batch_stride, B_batch_stride); +} + +inline auto collapse_batches(const array& a, const array& b, const array& c) { + // Get and check the shape for the batched dims + std::vector A_bshape{a.shape().begin(), a.shape().end() - 2}; + std::vector B_bshape{b.shape().begin(), b.shape().end() - 2}; + std::vector C_bshape{c.shape().begin(), c.shape().end() - 2}; + if (A_bshape != B_bshape || A_bshape != C_bshape) { + std::ostringstream msg; + msg << "[addmm] Got matrices with incorrectly broadcasted shapes: " + << "A " << a.shape() << ", B " << b.shape() << ", B " << c.shape() + << "."; + throw std::runtime_error(msg.str()); + } + + std::vector A_bstride{a.strides().begin(), a.strides().end() - 2}; + std::vector B_bstride{b.strides().begin(), b.strides().end() - 2}; + std::vector C_bstride{c.strides().begin(), c.strides().end() - 2}; + + auto [batch_shape, batch_strides] = + collapse_contiguous_dims(A_bshape, {A_bstride, B_bstride, C_bstride}); + + auto A_batch_stride = batch_strides[0]; + auto B_batch_stride = batch_strides[1]; + auto C_batch_stride = batch_strides[2]; + + if (batch_shape.empty()) { + batch_shape.push_back(1); + A_batch_stride.push_back(0); + B_batch_stride.push_back(0); + C_batch_stride.push_back(0); + } + + return std::make_tuple( + batch_shape, A_batch_stride, B_batch_stride, C_batch_stride); +} + } // namespace /////////////////////////////////////////////////////////////////////////////// @@ -211,22 +275,33 @@ void steel_matmul( int ldb, bool transpose_a, bool transpose_b, - std::vector& copies) { + std::vector& copies, + std::vector batch_shape /* = {} */, + std::vector A_batch_stride /* = {} */, + std::vector B_batch_stride /* = {} */) { using namespace mlx::steel; - // Coalesce (B, M, K) X (K, N) to (B*M, K) X (K, N) - if (batch_size_out > 1 && !transpose_a && - a.data_size() == batch_size_out * M * K && b.size() == K * N) { - M = M * batch_size_out; - batch_size_out = 1; + if (batch_shape.empty()) { + ///////////////////////////////////////////////////////////////////////////// + // Check and collapse batch dimensions + auto [batch_shape_, A_bstride_, B_bstride_] = collapse_batches(a, b); + + batch_shape = batch_shape_; + A_batch_stride = A_bstride_; + B_batch_stride = B_bstride_; + // Collapse batches into M if needed + if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 && + a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K && + B_batch_stride.back() == 0) { + M *= batch_shape.back(); + batch_size_out = 1; + + A_batch_stride = {0}; + B_batch_stride = {0}; + batch_shape = {1}; + } } - // Account for batch sizes and basic broadcasting - int batch_size_a = a.data_size() / (M * K); - int batch_size_b = b.data_size() / (K * N); - - int matrix_stride_a = (batch_size_a == 1) ? 0 : M * K; - int matrix_stride_b = (batch_size_b == 1) ? 0 : K * N; int matrix_stride_out = M * N; ///////////////////////////////////////////////////////////////////////////// @@ -269,18 +344,18 @@ void steel_matmul( int tm = (M + bm - 1) / bm; GEMMSpiltKParams params{ - M, - N, - K, - lda, - ldb, - N, - tn, - tm, - split_k_partitions, - split_k_partition_stride, - split_k_partition_size, - gemm_k_iterations}; + /* const int M = */ M, + /* const int N = */ N, + /* const int K = */ K, + /* const int lda = */ lda, + /* const int ldb = */ ldb, + /* const int ldc = */ N, + /* const int tiles_n = */ tn, + /* const int tiles_m = */ tm, + /* const int split_k_partitions = */ split_k_partitions, + /* const int split_k_partition_stride = */ split_k_partition_stride, + /* const int split_k_partition_size = */ split_k_partition_size, + /* const int gemm_k_iterations_aligned = */ gemm_k_iterations}; MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions); @@ -364,19 +439,20 @@ void steel_matmul( // Prepare steel matmul params GEMMParams params{ - M, - N, - K, - lda, - ldb, - N, - tn, - tm, - matrix_stride_a, - matrix_stride_b, - matrix_stride_out, - swizzle_log, - (K / bk)}; + /* const int M = */ M, + /* const int N = */ N, + /* const int K = */ K, + /* const int lda = */ lda, + /* const int ldb = */ ldb, + /* const int ldd = */ N, + /* const int tiles_n = */ tn, + /* const int tiles_m = */ tm, + /* const int batch_stride_a = */ int(A_batch_stride.back()), + /* const int batch_stride_b = */ int(B_batch_stride.back()), + /* const int batch_stride_d = */ matrix_stride_out, + /* const int swizzle_log = */ swizzle_log, + /* const int gemm_k_iterations_aligned = */ (K / bk), + /* const int batch_ndim = */ int(batch_shape.size())}; // Prepare launch grid params int tile = 1 << swizzle_log; @@ -386,37 +462,25 @@ void steel_matmul( MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out); - // Launch only 1 kernel in the case of simple batching / broadcasting - if (batch_size_out == std::max(batch_size_a, batch_size_b) && - (batch_size_a == batch_size_b || - std::min(batch_size_a, batch_size_b) == 1)) { - set_array_buffer(compute_encoder, a, 0); - set_array_buffer(compute_encoder, b, 1); - set_array_buffer(compute_encoder, out, 2); + std::vector batch_strides = A_batch_stride; + batch_strides.insert( + batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end()); - compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 3); - compute_encoder->dispatchThreadgroups(grid_dims, group_dims); - } else { // Otherwise launch kernels with set offsets + // Launch kernel + set_array_buffer(compute_encoder, a, 0); + set_array_buffer(compute_encoder, b, 1); + set_array_buffer(compute_encoder, out, 3); - MTL::Size grid_dims_single = MTL::Size(tn, tm, 1); + compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 4); - for (int i = 0; i < batch_size_out; ++i) { - auto a_off = elem_to_loc(M * K * i, a.shape(), a.strides()); - auto b_off = elem_to_loc(K * N * i, b.shape(), b.strides()); + compute_encoder->setBytes( + batch_shape.data(), sizeof(int) * batch_shape.size(), 6); + compute_encoder->setBytes( + batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7); - auto a_buf = static_cast(a.buffer().ptr()); - auto b_buf = static_cast(b.buffer().ptr()); - auto out_buf = static_cast(out.buffer().ptr()); - - compute_encoder->setBuffer(a_buf, a_off * a.itemsize(), 0); - compute_encoder->setBuffer(b_buf, b_off * b.itemsize(), 1); - compute_encoder->setBuffer(out_buf, i * M * N * out.itemsize(), 2); - - compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 3); - compute_encoder->dispatchThreadgroups(grid_dims_single, group_dims); - } - } + compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + // Clear copies d.get_command_buffer(s.index)->addCompletedHandler( [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); return; @@ -453,9 +517,9 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { auto check_transpose = [&copies, &s](const array& arr) { auto stx = arr.strides()[arr.ndim() - 2]; auto sty = arr.strides()[arr.ndim() - 1]; - if (stx == arr.shape(-1) && sty == 1) { + if (sty == 1) { return std::make_tuple(false, stx, arr); - } else if (stx == 1 && sty == arr.shape(-2)) { + } else if (stx == 1) { return std::make_tuple(true, sty, arr); } else { array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); @@ -473,8 +537,25 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { int N = b.shape(-1); int K = a.shape(-1); + ///////////////////////////////////////////////////////////////////////////// + // Check and collapse batch dimensions + + auto [batch_shape, A_batch_stride, B_batch_stride] = collapse_batches(a, b); + auto batch_size_out = out.size() / (M * N); + // Collapse batches into M if needed + if (batch_size_out > 1 && !a_transposed && batch_shape.size() == 1 && + a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K && + B_batch_stride.back() == 0) { + M *= batch_shape.back(); + batch_size_out = 1; + + A_batch_stride = {0}; + B_batch_stride = {0}; + batch_shape = {1}; + } + ///////////////////////////////////////////////////////////////////////////// // Gemv specialization @@ -491,20 +572,18 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { int mat_cols = transpose_mat ? out_vector_len : in_vector_len; int mat_rows = transpose_mat ? in_vector_len : out_vector_len; + int mat_ld = is_b_matrix ? b_cols : a_cols; - int batch_size_mat = mat.data_size() / (mat_cols * mat_rows); - int stride_mat = batch_size_mat == 1 ? 0 : mat_cols * mat_rows; + auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride; + auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride; - int batch_size_vec = vec.data_size() / in_vector_len; - int stride_vec = batch_size_vec == 1 ? 0 : in_vector_len; + int stride_mat = batch_strides_mat.back(); + int stride_vec = batch_strides_vec.back(); // Determine if inputs have simple batching / broadcasting - bool contiguous_kernel = - (batch_size_out == std::max(batch_size_mat, batch_size_vec) && - (batch_size_mat == batch_size_vec || - std::min(batch_size_mat, batch_size_vec) == 1)); + bool contiguous_kernel = (batch_shape.size() == 1); - int nc_dim = out.ndim() - 2; + int batch_ndim = batch_shape.size(); // Determine dispatch kernel int tm = 4, tn = 4; @@ -540,10 +619,7 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { } kname << "_bm" << bm << "_bn" << bn << "_tm" << tm << "_tn" << tn; - - if (!contiguous_kernel) { - kname << "_nc"; - } + kname << "_nc" << !contiguous_kernel << "_axpby0"; // Encode and dispatch kernel auto compute_encoder = d.get_command_encoder(s.index); @@ -556,25 +632,18 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { set_array_buffer(compute_encoder, mat, 0); set_array_buffer(compute_encoder, vec, 1); - set_array_buffer(compute_encoder, out, 2); + set_array_buffer(compute_encoder, out, 3); - compute_encoder->setBytes(&in_vector_len, sizeof(int), 3); - compute_encoder->setBytes(&out_vector_len, sizeof(int), 4); + compute_encoder->setBytes(&in_vector_len, sizeof(int), 4); + compute_encoder->setBytes(&out_vector_len, sizeof(int), 5); + compute_encoder->setBytes(&mat_ld, sizeof(int), 6); - if (contiguous_kernel) { - compute_encoder->setBytes(&stride_vec, sizeof(int), 5); - compute_encoder->setBytes(&stride_mat, sizeof(int), 6); - } else { - // In case of complex broadcasting, we consider the shape[:-2] and - // strides [:-2] to determine the location of a batch - // nc_dim = out.ndim() - 2 - compute_encoder->setBytes(&nc_dim, sizeof(int), 5); - compute_encoder->setBytes(out.shape().data(), nc_dim * sizeof(int), 6); - compute_encoder->setBytes( - vec.strides().data(), nc_dim * sizeof(size_t), 7); - compute_encoder->setBytes( - mat.strides().data(), nc_dim * sizeof(size_t), 8); - } + compute_encoder->setBytes(&batch_ndim, sizeof(int), 9); + compute_encoder->setBytes(batch_shape.data(), batch_ndim * sizeof(int), 10); + compute_encoder->setBytes( + batch_strides_vec.data(), batch_ndim * sizeof(size_t), 11); + compute_encoder->setBytes( + batch_strides_mat.data(), batch_ndim * sizeof(size_t), 12); compute_encoder->dispatchThreadgroups(grid_dims, group_dims); @@ -606,20 +675,23 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { } return steel_matmul( - s, - d, - a, - b, - out, - M, - N, - K, - batch_size_out, - a_cols, - b_cols, - a_transposed, - b_transposed, - copies); + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ a_cols, + /* int ldb = */ b_cols, + /* bool transpose_a = */ a_transposed, + /* bool transpose_b = */ b_transposed, + /* std::vector& = */ copies, + /* std::vector batch_shape = */ batch_shape, + /* std::vector A_batch_stride = */ A_batch_stride, + /* std::vector B_batch_stride = */ B_batch_stride); } void AddMM::eval_gpu(const std::vector& inputs, array& out) { @@ -645,9 +717,9 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { auto check_transpose = [&copies, &s](const array& arr) { auto stx = arr.strides()[arr.ndim() - 2]; auto sty = arr.strides()[arr.ndim() - 1]; - if (stx == arr.shape(-1) && sty == 1) { + if (sty == 1) { return std::make_tuple(false, stx, arr); - } else if (stx == 1 && sty == arr.shape(-2)) { + } else if (stx == 1) { return std::make_tuple(true, sty, arr); } else { array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); @@ -665,33 +737,151 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { int N = b.shape(-1); int K = a.shape(-1); - auto batch_size_out = out.size() / (M * N); - array c = c_pre; int ldc = c.strides()[c.ndim() - 2]; int fdc = c.strides()[c.ndim() - 1]; - int matrix_stride_c = c.ndim() <= 2 ? 0 : c.strides()[c.ndim() - 3]; int lda = a_cols; int ldb = b_cols; + int ldd = N; + + ///////////////////////////////////////////////////////////////////////////// + // Check and collapse batch dimensions + auto [batch_shape, A_batch_stride, B_batch_stride, C_batch_stride] = + collapse_batches(a, b, c); + + auto batch_size_out = out.size() / (M * N); + + // Collapse batches into M if needed + if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 && + a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K && + C_batch_stride.back() == M * c.strides()[c.ndim() - 2] && + B_batch_stride.back() == 0) { + M *= batch_shape.back(); + batch_size_out = 1; + + A_batch_stride = {0}; + B_batch_stride = {0}; + C_batch_stride = {0}; + batch_shape = {1}; + } + + int matrix_stride_out = M * N; + + ///////////////////////////////////////////////////////////////////////////// + // Gemv specialization + + // Route to gemv if needed + if (std::min(M, N) == 1) { + // Collect problem info + bool is_b_matrix = N != 1; + + auto& mat = is_b_matrix ? b : a; + auto& vec = is_b_matrix ? a : b; + bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a; + int in_vector_len = K; + int out_vector_len = is_b_matrix ? N : M; + + int mat_cols = transpose_mat ? out_vector_len : in_vector_len; + int mat_rows = transpose_mat ? in_vector_len : out_vector_len; + int mat_ld = is_b_matrix ? b_cols : a_cols; + + auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride; + auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride; + + int stride_mat = batch_strides_mat.back(); + int stride_vec = batch_strides_vec.back(); + + // Determine if inputs have simple batching / broadcasting + bool contiguous_kernel = (batch_shape.size() == 1); + + int batch_ndim = batch_shape.size(); + + // Determine dispatch kernel + int tm = 4, tn = 4; + int bm, bn, n_out_per_tgp; + std::ostringstream kname; + + if (transpose_mat) { + bm = 8; + bn = 8; + if (out_vector_len >= 24576) { + bn = 128; + } else if (out_vector_len >= 16384) { + bn = 64; + } else if (out_vector_len >= 8192) { + bn = 16; + } + + // Specialized kernel for very small outputs + tn = out_vector_len < tn ? 1 : tn; + + n_out_per_tgp = bn * tn; + kname << "gemv_t_" << type_to_name(out); + + } else { + bm = out_vector_len >= 4096 ? 8 : 4; + bn = 32; + + // Specialized kernel for very small outputs + tm = out_vector_len < tm ? 1 : tm; + + n_out_per_tgp = bm * tm; + kname << "gemv_" << type_to_name(out); + } + + kname << "_bm" << bm << "_bn" << bn << "_tm" << tm << "_tn" << tn; + kname << "_nc" << !contiguous_kernel << "_axpby1"; + + // Encode and dispatch kernel + auto compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + compute_encoder->setComputePipelineState(kernel); + + int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; + MTL::Size group_dims = MTL::Size(bn, bm, 1); + MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out); + + set_array_buffer(compute_encoder, mat, 0); + set_array_buffer(compute_encoder, vec, 1); + set_array_buffer(compute_encoder, c, 2); + set_array_buffer(compute_encoder, out, 3); + + compute_encoder->setBytes(&in_vector_len, sizeof(int), 4); + compute_encoder->setBytes(&out_vector_len, sizeof(int), 5); + compute_encoder->setBytes(&mat_ld, sizeof(int), 6); + + compute_encoder->setBytes(&alpha_, sizeof(float), 7); + compute_encoder->setBytes(&beta_, sizeof(float), 8); + + compute_encoder->setBytes(&batch_ndim, sizeof(int), 9); + compute_encoder->setBytes(batch_shape.data(), batch_ndim * sizeof(int), 10); + compute_encoder->setBytes( + batch_strides_vec.data(), batch_ndim * sizeof(size_t), 11); + compute_encoder->setBytes( + batch_strides_mat.data(), batch_ndim * sizeof(size_t), 12); + compute_encoder->setBytes( + C_batch_stride.data(), batch_ndim * sizeof(size_t), 13); + + int bias_stride = c.strides()[c.ndim() - 1]; + compute_encoder->setBytes(&bias_stride, sizeof(int), 14); + + compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + + d.get_command_buffer(s.index)->addCompletedHandler( + [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); + return; + } using namespace mlx::steel; - // Account for batch sizes and basic broadcasting - int batch_size_a = a.data_size() / (M * K); - int batch_size_b = b.data_size() / (K * N); - - int matrix_stride_a = (batch_size_a == 1) ? 0 : M * K; - int matrix_stride_b = (batch_size_b == 1) ? 0 : K * N; - int matrix_stride_out = M * N; + ///////////////////////////////////////////////////////////////////////////// + // Split K specialization int _tm = M / 16; int _tn = N / 16; int _tk = K / 16; - ///////////////////////////////////////////////////////////////////////////// - // Split K specialization - if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) { int bm = M < 40 ? 16 : 32; int bn = N < 40 ? 16 : 32; @@ -817,25 +1007,29 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { // TODO: Explore device-based tuning for swizzle int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2); + // Prepare steel matmul params + GEMMParams gemm_params{ + /* const int M = */ M, + /* const int N = */ N, + /* const int K = */ K, + /* const int lda = */ lda, + /* const int ldb = */ ldb, + /* const int ldd = */ N, + /* const int tiles_n = */ tn, + /* const int tiles_m = */ tm, + /* const int batch_stride_a = */ int(A_batch_stride.back()), + /* const int batch_stride_b = */ int(B_batch_stride.back()), + /* const int batch_stride_d = */ matrix_stride_out, + /* const int swizzle_log = */ swizzle_log, + /* const int gemm_k_iterations_aligned = */ (K / bk), + /* const int batch_ndim = */ int(batch_shape.size())}; + GEMMAddMMParams params{ - M, - N, - K, - lda, - ldb, - ldc, - N, - tn, - tm, - matrix_stride_a, - matrix_stride_b, - matrix_stride_c, - matrix_stride_out, - swizzle_log, - (K / bk), - alpha_, - beta_, - fdc}; + /* const int ldc = */ ldc, + /* const int fdc = */ fdc, + /* const int batch_stride_c = */ int(C_batch_stride.back()), + /* const float alpha = */ alpha_, + /* const float beta = */ beta_}; int tile = 1 << swizzle_log; tm = (tm + tile - 1) / tile; @@ -844,40 +1038,27 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out); - // Launch only 1 kernel in the case of simple batching / broadcasting - if (batch_size_out == std::max(batch_size_a, batch_size_b) && - (batch_size_a == batch_size_b || - std::min(batch_size_a, batch_size_b) == 1)) { - set_array_buffer(compute_encoder, a, 0); - set_array_buffer(compute_encoder, b, 1); - set_array_buffer(compute_encoder, c, 2); - set_array_buffer(compute_encoder, out, 3); + std::vector batch_strides = A_batch_stride; + batch_strides.insert( + batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end()); + batch_strides.insert( + batch_strides.end(), C_batch_stride.begin(), C_batch_stride.end()); - compute_encoder->setBytes(¶ms, sizeof(GEMMAddMMParams), 4); - compute_encoder->dispatchThreadgroups(grid_dims, group_dims); - } else { // Otherwise launch kernels with set offsets + // Launch kernel + set_array_buffer(compute_encoder, a, 0); + set_array_buffer(compute_encoder, b, 1); + set_array_buffer(compute_encoder, c, 2); + set_array_buffer(compute_encoder, out, 3); - MTL::Size grid_dims_single = MTL::Size(tn, tm, 1); + compute_encoder->setBytes(&gemm_params, sizeof(GEMMParams), 4); + compute_encoder->setBytes(¶ms, sizeof(GEMMAddMMParams), 5); - for (int i = 0; i < batch_size_out; ++i) { - auto a_off = elem_to_loc(M * K * i, a.shape(), a.strides()); - auto b_off = elem_to_loc(K * N * i, b.shape(), b.strides()); - auto c_off = elem_to_loc(M * N * i, c.shape(), c.strides()); + compute_encoder->setBytes( + batch_shape.data(), sizeof(int) * batch_shape.size(), 6); + compute_encoder->setBytes( + batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7); - auto a_buf = static_cast(a.buffer().ptr()); - auto b_buf = static_cast(b.buffer().ptr()); - auto c_buf = static_cast(c.buffer().ptr()); - auto out_buf = static_cast(out.buffer().ptr()); - - compute_encoder->setBuffer(a_buf, a_off * a.itemsize(), 0); - compute_encoder->setBuffer(b_buf, b_off * b.itemsize(), 1); - compute_encoder->setBuffer(c_buf, c_off * c.itemsize(), 2); - compute_encoder->setBuffer(out_buf, i * M * N * out.itemsize(), 3); - - compute_encoder->setBytes(¶ms, sizeof(GEMMAddMMParams), 4); - compute_encoder->dispatchThreadgroups(grid_dims_single, group_dims); - } - } + compute_encoder->dispatchThreadgroups(grid_dims, group_dims); d.get_command_buffer(s.index)->addCompletedHandler( [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); diff --git a/mlx/backend/metal/matmul.h b/mlx/backend/metal/matmul.h index 1ebccf0e1..a9c872235 100644 --- a/mlx/backend/metal/matmul.h +++ b/mlx/backend/metal/matmul.h @@ -26,6 +26,9 @@ void steel_matmul( int ldb, bool transpose_a, bool transpose_b, - std::vector& copies); + std::vector& copies, + std::vector batch_shape = {}, + std::vector A_batch_stride = {}, + std::vector B_batch_stride = {}); } // namespace mlx::core \ No newline at end of file diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 61a44c812..ad3efab71 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3315,17 +3315,8 @@ array addmm( const float& alpha /* = 1.f */, const float& beta /* = 1.f */, StreamOrDevice s /* = {} */) { - // Divert in the case of vector-matrix multiplication - // TODO: Add the needed specializtion - if (a.ndim() == 1 || b.ndim() == 1) { - array X = matmul(a, b, s); - array alpha_arr = array(alpha, X.dtype()); - array aX = multiply(alpha_arr, X, s); - - array beta_arr = array(beta, c.dtype()); - array bY = multiply(beta_arr, c, s); - return add(aX, bY, s); - } + int in_a_ndim = a.ndim(); + int in_b_ndim = b.ndim(); if (a.ndim() == 0 || b.ndim() == 0) { throw std::invalid_argument( @@ -3333,6 +3324,15 @@ array addmm( "have at least one dimension."); } + if (a.ndim() == 1) { + // Insert a singleton dim in the beginning + a = reshape(a, {1, -1}, s); + } + if (b.ndim() == 1) { + // Insert a singleton dim at the end + b = reshape(b, {-1, 1}, s); + } + if (a.shape(-1) != b.shape(-2)) { std::ostringstream msg; msg << "[addmm] Last dimension of first input with shape " << a.shape() @@ -3361,7 +3361,13 @@ array addmm( std::vector out_shape = a.shape(); a = reshape(a, {-1, out_shape.back()}, s); out_shape.back() = b.shape(-1); + + if (in_b_ndim == 1) { + out_shape.pop_back(); + } + c = broadcast_to(c, {a.shape(0), b.shape(1)}, s); + auto out = array( {a.shape(0), b.shape(1)}, out_type, @@ -3389,15 +3395,42 @@ array addmm( auto out_shape = a.shape(); out_shape.back() = b.shape(-1); - auto c_broadcast_shape = broadcast_shapes(c.shape(), out_shape); + auto out_shape_adjusted = out_shape; + + if (in_a_ndim == 1 || in_b_ndim == 1) { + out_shape_adjusted.erase( + out_shape_adjusted.end() - ((in_a_ndim == 1) ? 2 : 1), + out_shape_adjusted.end() - ((in_b_ndim == 1) ? 0 : 1)); + } + + auto c_broadcast_shape = broadcast_shapes(c.shape(), out_shape_adjusted); c = broadcast_to(c, c_broadcast_shape, s); + if (in_a_ndim == 1 || in_b_ndim == 1) { + auto c_reshape = c.shape(); + if (in_b_ndim == 1) { + c_reshape.push_back(1); + } + + if (in_a_ndim == 1) { + c_reshape.push_back(c_reshape.back()); + c_reshape[c_reshape.size() - 2] = 1; + } + + c = reshape(c, c_reshape, s); + } + auto out = array( out_shape, out_type, std::make_unique(to_stream(s), alpha, beta), {a, b, c}); + // Remove the possibly inserted singleton dimensions + if (in_a_ndim == 1 || in_b_ndim == 1) { + out = reshape(out, out_shape_adjusted, s); + } + return out; } diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index c2c1cc2a2..0d3417dc1 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -393,6 +393,77 @@ class TestBlas(mlx_tests.MLXTestCase): mlx_vec_f=lambda vec_mlx: mx.broadcast_to(vec_mlx, shape_vec), ) + def test_matrix_vector_attn(self): + # Multi-query style attention check + for dtype in self.dtypes: + # fmt: off + for (B, D, n_kv_heads, factor, qsl, ksl) in ( + (1, 16, 8, 4, 1, 256), + (1, 16, 8, 4, 32, 256), + (1, 16, 8, 4, 256, 1), + (4, 16, 8, 4, 1, 256), + (4, 16, 8, 4, 256, 1), + ): + # fmt: on + with self.subTest( + B=B, # Batch size + D=D, # Dimension of mm + n_kv_heads=n_kv_heads, # key-value heads + factor=factor, # factor to get query heads + qsl=qsl, # Query sequence length + ksl=ksl, # Key sequence length + dtype=dtype # Data type + ): + + np_dtype = getattr(np, dtype) + + # Fix shapes for kqv + n_q_heads = n_kv_heads * factor + Dk = D * n_kv_heads + Dq = D * n_q_heads + scale = 1. / math.sqrt(Dk) + + shape_queries = (B, qsl, Dq) + shape_keys = (B, ksl, Dk) + shape_values = (B, ksl, Dk) + + # Prepare numpy arrays + q_np = np.random.uniform(-scale, scale, size=shape_queries).astype(np_dtype) + k_np = np.random.uniform(-scale, scale, size=shape_keys).astype(np_dtype) + v_np = np.random.uniform(-scale, scale, size=shape_values).astype(np_dtype) + + # Rearrange to move heads up + q_np_reshape = q_np.reshape(B, qsl, n_kv_heads, factor, -1).transpose(0, 2, 3, 1, 4) + k_np_reshape = k_np.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 4, 1) + v_np_reshape = v_np.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 1, 4) + + # Do attn style matmul + s_np = q_np_reshape @ k_np_reshape + o_np = s_np @ v_np_reshape + o_np = o_np.transpose(0, 3, 1, 2, 4).reshape(B, qsl, -1) + + # Test mlx + q_mx = mx.array(q_np) + k_mx = mx.array(k_np) + v_mx = mx.array(v_np) + + # Rearrange to move heads up + q_mx_reshape = q_mx.reshape(B, qsl, n_kv_heads, factor, -1).transpose(0, 2, 3, 1, 4) + k_mx_reshape = k_mx.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 4, 1) + v_mx_reshape = v_mx.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 1, 4) + + # Do attn style matmul + s_mx = q_mx_reshape @ k_mx_reshape + o_mx = (s_mx @ v_mx_reshape) + o_mx = o_mx.transpose(0, 3, 1, 2, 4).reshape(B, qsl, -1) + + # Check against np + self.assertListEqual(list(s_np.shape), list(s_mx.shape)) + self.assertTrue(np.allclose(s_np, s_mx, atol=1e-4)) + + self.assertListEqual(list(o_np.shape), list(o_mx.shape)) + self.assertTrue(np.allclose(o_np, o_mx, atol=1e-4)) + def test_matrix_vector_edgecases(self): for dtype in self.dtypes: with self.subTest(dtype=dtype): @@ -503,16 +574,29 @@ class TestBlas(mlx_tests.MLXTestCase): self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) + # Matmul with vector + a_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32) + b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 128)).astype(np.float32) + a_mlx = mx.array(a_npy) + b_mlx = mx.array(b_npy) + + for c_shape in ((1,), (128,), (32, 128)): + c_npy = np.ones(c_shape).astype(np.float32) + c_mlx = mx.array(c_npy) + + d_npy = alpha * (a_npy @ b_npy) + beta * c_npy + d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta) + + self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) + self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) + # Matmul with vector a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) b_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32) a_mlx = mx.array(a_npy) b_mlx = mx.array(b_npy) - for c_shape in ( - (1,), - (32, 128), - ): + for c_shape in ((1,), (32, 128)): c_npy = np.ones(c_shape).astype(np.float32) c_mlx = mx.array(c_npy) @@ -564,16 +648,12 @@ class TestBlas(mlx_tests.MLXTestCase): out_ref, dout_ref = mx.vjp( f_ref, [c, a, b], - [ - cotan, - ], + [cotan], ) out_test, dout_test = mx.vjp( f_test, [c, a, b], - [ - cotan, - ], + [cotan], ) self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-4).item())