diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 5106dfa210..8074547cac 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -181,6 +181,7 @@ Device::Device() { auto pool = new_scoped_memory_pool(); device_ = load_device(); library_map_ = {{"mlx", load_library(device_)}}; + arch_ = std::string(device_->architecture()->name()->utf8String()); } Device::~Device() { diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 8737b479a3..fe32cc7381 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -136,6 +136,10 @@ class Device { return device_; }; + const std::string& get_architecture() { + return arch_; + } + void new_queue(int index); MTL::CommandBuffer* get_command_buffer(int index); int get_command_buffer_ops(int index); @@ -228,6 +232,7 @@ class Device { std::shared_mutex library_mtx_; std::unordered_map library_map_; const MTL::ResidencySet* residency_set_{nullptr}; + std::string arch_; }; Device& device(mlx::core::Device); diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 28f34535c9..9a8b9b7b4e 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -50,7 +50,9 @@ set(STEEL_HEADERS steel/gemm/transforms.h steel/gemm/kernels/steel_gemm_fused.h steel/gemm/kernels/steel_gemm_masked.h - steel/gemm/kernels/steel_gemm_splitk.h) + steel/gemm/kernels/steel_gemm_splitk.h + steel/utils/type_traits.h + steel/utils/integral_constant.h) if(NOT MLX_METAL_JIT) build_kernel(arange arange.h) diff --git a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h index f5430d5903..e4b662cd34 100644 --- a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +++ b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h @@ -142,8 +142,8 @@ implicit_gemm_conv_2d_general( // Store results to device memory { // Adjust for simdgroup and thread locatio - int offset_m = c_row + mma_op.sm + mma_op.tm; - int offset_n = c_col + mma_op.sn + mma_op.tn; + int offset_m = c_row + mma_op.sm; + int offset_n = c_col + mma_op.sn; C += offset_n; if (offset_n >= gemm_params->N) @@ -169,17 +169,17 @@ implicit_gemm_conv_2d_general( STEEL_PRAGMA_UNROLL for (int j = 0; j < mma_t::TN; j++) { // Get accumulated result and associated offset in C - thread const auto& accum = - mma_op.results[i * mma_t::TN + j].thread_elements(); + thread const auto& accum = mma_op.Ctile.frag_at(i, j); int offset = offset_cm + (j * mma_t::TN_stride); - // Apply epilogue and output C - if (j * mma_t::TN_stride < diff) { - C[offset] = Epilogue::apply(accum[0]); - } + constexpr short kelems = decltype(mma_op.Ctile)::kElemsPerFrag; - if (j * mma_t::TN_stride + 1 < diff) { - C[offset + 1] = Epilogue::apply(accum[1]); + // Apply epilogue and output C + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + if ((j * mma_t::TN_stride + k) < diff) { + C[offset + k] = Epilogue::apply(accum[k]); + } } } } diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal index 0665cb6f30..9f33a2bc60 100644 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal @@ -36,11 +36,11 @@ instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) #define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \ - instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \ - instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \ - instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) instantiate_gemm_shapes_helper(float16, half, float16, half); instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t); diff --git a/mlx/backend/metal/kernels/steel/gemm/mma.h b/mlx/backend/metal/kernels/steel/gemm/mma.h index dbd425ef03..7ad6c0fe65 100644 --- a/mlx/backend/metal/kernels/steel/gemm/mma.h +++ b/mlx/backend/metal/kernels/steel/gemm/mma.h @@ -8,6 +8,7 @@ #include "mlx/backend/metal/kernels/steel/defines.h" #include "mlx/backend/metal/kernels/steel/gemm/transforms.h" +#include "mlx/backend/metal/kernels/steel/utils/integral_constant.h" using namespace metal; @@ -18,6 +19,347 @@ using namespace metal; namespace mlx { namespace steel { +template +struct BaseMMAFrag { + static_assert( + kFragRows_ == 8, + "Only 8 x 8 fragment matrices are currently supported"); + static_assert( + kFragCols_ == 8, + "Only 8 x 8 fragment matrices are currently supported"); +}; + +template +struct BaseMMAFrag { + STEEL_CONST int kFragRows = 8; + STEEL_CONST int kFragCols = 8; + + STEEL_CONST int kElemsPerFrag = (kFragRows * kFragCols) / 32; + + STEEL_CONST int kElemRows = 1; + STEEL_CONST int kElemCols = 2; + + static_assert( + kElemRows * kElemCols == kElemsPerFrag, + "MMAFrag shape is not consistent with MMAFrag size"); + + typedef metal::simdgroup_matrix mat_type; + typedef metal::vec frag_type; + + METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id + [[thread_index_in_simdgroup]]) { + const short qid = simd_lane_id / 4; + const short fm = (qid & 4) + ((simd_lane_id / 2) % 4); + const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; + return short2{fn, fm}; + } + + template + METAL_FUNC static constexpr void + load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = static_cast(src[i * str_x + j * str_y]); + } + } + } + + template < + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX, + typename OffY> + METAL_FUNC static constexpr void load_safe( + thread frag_type& dst, + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if ((off_x + i) < lim_x && (off_y + j) < lim_y) { + dst[i * kElemCols + j] = + static_cast(src[(off_x + i) * str_x + (off_x + j) * str_y]); + } else { + dst[i * kElemCols + j] = T(0); + } + } + } + } + + template + METAL_FUNC static constexpr void + store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) { + using U = pointer_element_t; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * str_x + j * str_y] = static_cast(src[i * kElemCols + j]); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX, + typename OffY> + METAL_FUNC static constexpr void store_safe( + const thread frag_type& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) { + using U = pointer_element_t; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if ((off_x + i) < lim_x && (off_y + j) < lim_y) { + dst[(off_x + i) * str_x + (off_y + j) * str_y] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + + METAL_FUNC static constexpr void mma( + thread frag_type& D, + thread frag_type& A, + thread frag_type& B, + thread frag_type& C) { + mat_type D_mat; + mat_type A_mat; + mat_type B_mat; + mat_type C_mat; + + reinterpret_cast(A_mat.thread_elements()) = A; + reinterpret_cast(B_mat.thread_elements()) = B; + reinterpret_cast(C_mat.thread_elements()) = C; + + mma(D_mat, A_mat, B_mat, C_mat); + + D = reinterpret_cast(D_mat.thread_elements()); + } + + METAL_FUNC static constexpr void mma( + thread mat_type& D, + thread mat_type& A, + thread mat_type& B, + thread mat_type& C) { + simdgroup_multiply_accumulate(D, A, B, C); + } +}; + +template < + typename T, + int kTileRows_, + int kTileCols_, + class MMAFrag_ = BaseMMAFrag> +struct MMATile { + using MMAFrag_t = MMAFrag_; + using elem_type = T; + STEEL_CONST int kFragRows = MMAFrag_t::kFragRows; + STEEL_CONST int kFragCols = MMAFrag_t::kFragCols; + STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag; + + STEEL_CONST int kTileRows = kTileRows_; + STEEL_CONST int kTileCols = kTileCols_; + + STEEL_CONST int kRows = kTileRows * kFragRows; + STEEL_CONST int kCols = kTileCols * kFragCols; + + STEEL_CONST int kNumFrags = kTileRows * kTileCols; + STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag; + + typedef typename MMAFrag_t::mat_type mat_type; + typedef typename MMAFrag_t::frag_type frag_type; + + frag_type val_frags[kNumFrags] = {frag_type(0)}; + + METAL_FUNC MMATile() thread {} + + METAL_FUNC constexpr void clear() { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kNumFrags; ++i) { + val_frags[i] = frag_type(0); + } + } + + METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) { + return val_frags[i * kTileCols + j]; + } + + METAL_FUNC constexpr const thread frag_type& frag_at( + const short i, + const short j) const { + return val_frags[i * kTileCols + j]; + } + + METAL_FUNC mat_type mat_at(const short i, const short j) { + mat_type val_mat; + STEEL_PRAGMA_UNROLL + for (short ii = 0; ii < kElemsPerFrag; ++ii) { + val_mat.thread_elements()[ii] = frag_at(i, j)[ii]; + } + return val_mat; + } + + METAL_FUNC thread elem_type* elems() { + return reinterpret_cast(val_frags); + } + + METAL_FUNC const thread elem_type* elems() const { + return reinterpret_cast(val_frags); + } + + template + METAL_FUNC void load(const threadgroup U* src) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::load( + frag_at(i, j), + &( + src[(i * kFragRows) * w_x * str_x + + (j * kFragCols) * w_y * str_y]), + Int{}, + Int{}); + } + } + } + + template + METAL_FUNC void store(threadgroup U* dst) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::store( + frag_at(i, j), + &( + dst[(i * kFragRows) * w_x * str_x + + (j * kFragCols) * w_y * str_y]), + Int{}, + Int{}); + } + } + } + + template + METAL_FUNC void load(const device U* src, const int ld) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::load( + frag_at(i, j), + &(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), + ld, + Int<1>{}); + } + } + } + + template + METAL_FUNC void store(device U* dst, const int ld) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::store( + frag_at(i, j), + &(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), + ld, + Int<1>{}); + } + } + } + + template + METAL_FUNC void + load_safe(const device U* src, const int ld, const short2 src_tile_dims) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + MMAFrag_t::load_safe( + frag_at(i, j), + src, + ld, + Int<1>{}, + src_tile_dims.y, + src_tile_dims.x, + (i * kFragRows) * w_x, + (j * kFragCols) * w_y); + } + } + } + + template + METAL_FUNC void + store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + MMAFrag_t::store_safe( + frag_at(i, j), + dst, + ld, + Int<1>{}, + dst_tile_dims.y, + dst_tile_dims.x, + (i * kFragRows) * w_x, + (j * kFragCols) * w_y); + } + } + } +}; + +template +METAL_FUNC void tile_matmad( + thread MMATile& D, + thread MMATile& A, + thread MMATile& B, + thread MMATile& C) { + STEEL_PRAGMA_UNROLL + for (short m = 0; m < M; ++m) { + STEEL_PRAGMA_UNROLL + for (short n = 0; n < N; ++n) { + short n_serp = (m % 2) ? (N - 1 - n) : n; + STEEL_PRAGMA_UNROLL + for (short k = 0; k < K; ++k) { + MMATile::MMAFrag_t::mma( + D.frag_at(m, n_serp), + A.frag_at(m, k), + B.frag_at(k, n_serp), + C.frag_at(m, n_serp)); + } + } + } +} + template < typename T, typename U, @@ -33,39 +375,38 @@ template < typename AccumType = float, typename Epilogue = TransformNone> struct BlockMMA { + // MMAFrag size + STEEL_CONST short kFragSize = 8; + using MMAFrag_acc_t = BaseMMAFrag; + // Warp tile simdgroup matrix strides along M - STEEL_CONST short TM_stride = 8 * WM; + STEEL_CONST short TM_stride = kFragSize * WM; // Warp tile simdgroup matrix strides along M - STEEL_CONST short TN_stride = 8 * WN; + STEEL_CONST short TN_stride = kFragSize * WN; // Warp tile size along M STEEL_CONST short TM = BM / TM_stride; // Warp tile size along N STEEL_CONST short TN = BN / TN_stride; - // Strides of A, B along reduction axis - STEEL_CONST short simd_stride_a = { - transpose_a ? TM_stride : TM_stride * lda_tgp}; - STEEL_CONST short simd_stride_b = { - transpose_b ? TN_stride * ldb_tgp : TN_stride}; + // Threadgroup A strides + STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M + STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K - // Jump between elements - STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1}; - STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1}; + // Threadgroup B strides + STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K + STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N - STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8}; - STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp}; + // Threadgroup strides along K + STEEL_CONST short tile_stride_a = kFragSize * A_str_k; + STEEL_CONST short tile_stride_b = kFragSize * B_str_k; // Simdgroup matrices - simdgroup_matrix Asimd[TM]; - simdgroup_matrix Bsimd[TN]; - simdgroup_matrix results[TM * TN] = { - simdgroup_matrix(0)}; + MMATile Atile; + MMATile Btile; + MMATile Ctile; // Offsets within threadgroup - const short tm; - const short tn; - short sm; short sn; @@ -75,18 +416,21 @@ struct BlockMMA { /* Constructor */ METAL_FUNC BlockMMA( ushort simd_group_id [[simdgroup_index_in_threadgroup]], - ushort simd_lane_id [[thread_index_in_simdgroup]]) - : tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) { + ushort simd_lane_id [[thread_index_in_simdgroup]]) { // Determine thread position in simdgroup matrix - short qid = simd_lane_id / 4; - sm = (qid & 4) + (simd_lane_id / 2) % 4; - sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; + short tm = kFragSize * (simd_group_id / WN); + short tn = kFragSize * (simd_group_id % WN); + + short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); + sm = simd_coord.y; + sn = simd_coord.x; // Determine thread and simdgroup offset - As_offset = - transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp); - Bs_offset = - transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn)); + As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // M, K + Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // K, N + + sm += tm; + sn += tn; } /* (BM, BK) X (BK, BN) multiply accumulate function */ @@ -95,47 +439,20 @@ struct BlockMMA { As += As_offset; Bs += Bs_offset; - // Iterate over BK in blocks of 8 + // Iterate over BK in blocks of kFragSize STEEL_PRAGMA_UNROLL - for (short kk = 0; kk < BK; kk += 8) { + for (short kk = 0; kk < BK; kk += kFragSize) { simdgroup_barrier(mem_flags::mem_none); - // Load elements from threadgroup A as simdgroup matrices - STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - Asimd[i].thread_elements()[0] = - static_cast(As[i * simd_stride_a + 0]); - Asimd[i].thread_elements()[1] = - static_cast(As[i * simd_stride_a + jump_a]); - } + Atile.template load(As); simdgroup_barrier(mem_flags::mem_none); - // Load elements from threadgroup B as simdgroup matrices - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - Bsimd[j].thread_elements()[0] = - static_cast(Bs[j * simd_stride_b + 0]); - Bsimd[j].thread_elements()[1] = - static_cast(Bs[j * simd_stride_b + jump_b]); - } + Btile.template load(Bs); simdgroup_barrier(mem_flags::mem_none); - // Multiply and accumulate into result simdgroup matrices - STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - short j_serp = (i % 2) ? (TN - 1 - j) : j; - - simdgroup_multiply_accumulate( - results[i * TN + j_serp], - Asimd[i], - Bsimd[j_serp], - results[i * TN + j_serp]); - } - } + tile_matmad(Ctile, Atile, Btile, Ctile); // Progress to next simdgroup tile As += tile_stride_a; @@ -144,58 +461,35 @@ struct BlockMMA { } /* Store results from simdgroup_matrix results into device memory */ - METAL_FUNC void store_result(device U* D, const int ldd) const { - // Adjust for simdgroup and thread location - D += (sm + tm) * ldd + tn + sn; - - // Loop over all simdgroup tiles + METAL_FUNC void store_result(device U* D, const int ldd) { + // Apply epilogue STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread const auto& accum = results[i * TN + j].thread_elements(); - int offset = (i * TM_stride) * ldd + (j * TN_stride); - - // Apply epilogue - U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])}; - - // Write out D - D[offset] = outs[0]; - D[offset + 1] = outs[1]; - } + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); } + + // Adjust for simdgroup and thread location + D += sm * ldd + sn; + + Ctile.template store(D, ldd); } METAL_FUNC void - store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) const { + store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) { + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); + } + // Adjust for simdgroup and thread location - D += (sm + tm) * ldd + (tn + sn); - dst_tile_dims -= short2(tn + sn, sm + tm); + D += sm * ldd + sn; + dst_tile_dims -= short2(sn, sm); if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) return; - STEEL_PRAGMA_UNROLL - for (int i = 0; i < TM; i++) { - if (i * TM_stride < dst_tile_dims.y) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread const auto& accum = results[i * TN + j].thread_elements(); - int offset = (i * TM_stride) * ldd + (j * TN_stride); - - // Apply epilogue and output C - if (j * TN_stride < dst_tile_dims.x) { - D[offset] = Epilogue::apply(accum[0]); - } - - if (j * TN_stride + 1 < dst_tile_dims.x) { - D[offset + 1] = Epilogue::apply(accum[1]); - } - } - } - } + Ctile.template store_safe(D, ldd, dst_tile_dims); } /* Apply epilogue */ @@ -203,16 +497,8 @@ struct BlockMMA { METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) { // Loop over all simdgroup tiles STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread auto& accum = results[i * TN + j].thread_elements(); - - // Apply epilogue - accum[0] = epilogue_op.apply(accum[0]); - accum[1] = epilogue_op.apply(accum[1]); - } + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]); } } @@ -224,7 +510,7 @@ struct BlockMMA { const int fdc, thread const BinaryEpilogue& epilogue_op) { // Adjust for simdgroup and thread location - C += (sm + tm) * ldc + (tn + sn) * fdc; + C += (sm)*ldc + (sn)*fdc; // Loop over all simdgroup tiles STEEL_PRAGMA_UNROLL @@ -232,12 +518,14 @@ struct BlockMMA { STEEL_PRAGMA_UNROLL for (short j = 0; j < TN; j++) { // Get accumulated result and associated offset in C - thread auto& accum = results[i * TN + j].thread_elements(); + thread auto& accum = Ctile.frag_at(i, j); int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; // Apply epilogue - accum[0] = epilogue_op.apply(accum[0], C[offset_c]); - accum[1] = epilogue_op.apply(accum[1], C[offset_c + fdc]); + STEEL_PRAGMA_UNROLL + for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) { + accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); + } } } } @@ -251,8 +539,8 @@ struct BlockMMA { short2 dst_tile_dims, thread const BinaryEpilogue& epilogue_op) { // Adjust for simdgroup and thread location - C += (sm + tm) * ldc + (tn + sn) * fdc; - dst_tile_dims -= short2(tn + sn, sm + tm); + C += (sm)*ldc + (sn)*fdc; + dst_tile_dims -= short2(sn, sm); if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) return; @@ -263,22 +551,26 @@ struct BlockMMA { STEEL_PRAGMA_UNROLL for (short j = 0; j < TN; j++) { // Get accumulated result and associated offset in C - thread auto& accum = results[i * TN + j].thread_elements(); + thread auto& accum = Ctile.frag_at(i, j); int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; - // Read C - U c_elems[2] = {0}; + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; - if ((j * TN_stride + 1) < dst_tile_dims.x) { - c_elems[0] = C[offset_c]; - c_elems[1] = C[offset_c + fdc]; - } else if ((j * TN_stride) < dst_tile_dims.x) { - c_elems[0] = C[offset_c]; + // Read C + U c_elems[kelems] = {0}; + + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + if ((j * TN_stride + k) < dst_tile_dims.x) { + c_elems[k] = C[offset_c + k * fdc]; + } } // Apply epilogue - accum[0] = epilogue_op.apply(accum[0], c_elems[0]); - accum[1] = epilogue_op.apply(accum[1], c_elems[1]); + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + accum[k] = epilogue_op.apply(accum[k], c_elems[k]); + } } } } @@ -292,8 +584,10 @@ struct BlockMMA { const int fdc, thread const Epilogue& epilogue_op) const { // Adjust for simdgroup and thread location - C += (sm + tm) * ldc + (tn + sn) * fdc; - D += (sm + tm) * ldd + tn + sn; + C += (sm)*ldc + (sn)*fdc; + D += (sm)*ldd + sn; + + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; // Loop over all simdgroup tiles STEEL_PRAGMA_UNROLL @@ -301,18 +595,15 @@ struct BlockMMA { STEEL_PRAGMA_UNROLL for (short j = 0; j < TN; j++) { // Get accumulated result and associated offset in C - thread const auto& accum = results[i * TN + j].thread_elements(); + thread const auto& accum = Ctile.frag_at(i, j); int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; int offset_d = (i * TM_stride) * ldd + (j * TN_stride); // Apply epilogue - U outs[2] = { - epilogue_op.apply(accum[0], C[offset_c]), - epilogue_op.apply(accum[1], C[offset_c + fdc])}; - - // Write out D - D[offset_d] = outs[0]; - D[offset_d + 1] = outs[1]; + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); + } } } } @@ -326,30 +617,32 @@ struct BlockMMA { short2 dst_tile_dims, thread const Epilogue& epilogue_op) const { // Adjust for simdgroup and thread location - C += (sm + tm) * ldc + (tn + sn) * fdc; - D += (sm + tm) * ldd + tn + sn; - dst_tile_dims -= short2(tn + sn, sm + tm); + C += (sm)*ldc + (sn)*fdc; + D += (sm)*ldd + sn; + dst_tile_dims -= short2(sn, sm); if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) return; + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; + STEEL_PRAGMA_UNROLL for (int i = 0; i < TM; i++) { if (i * TM_stride < dst_tile_dims.y) { STEEL_PRAGMA_UNROLL for (int j = 0; j < TN; j++) { // Get accumulated result and associated offset in C - thread const auto& accum = results[i * TN + j].thread_elements(); + thread const auto& accum = Ctile.frag_at(i, j); int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; int offset_d = (i * TM_stride) * ldd + (j * TN_stride); - // Apply epilogue and output C - if (j * TN_stride < dst_tile_dims.x) { - D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]); - } - - if (j * TN_stride + 1 < dst_tile_dims.x) { - D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]); + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + if ((j * TN_stride + k) < dst_tile_dims.x) { + D[offset_d + k] = + epilogue_op.apply(accum[k], C[offset_c + k * fdc]); + } } } } diff --git a/mlx/backend/metal/kernels/steel/utils/integral_constant.h b/mlx/backend/metal/kernels/steel/utils/integral_constant.h new file mode 100644 index 0000000000..b616acc676 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/utils/integral_constant.h @@ -0,0 +1,96 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include +#include "mlx/backend/metal/kernels/steel/utils/type_traits.h" + +#pragma METAL internals : enable + +namespace mlx { +namespace steel { + +/////////////////////////////////////////////////////////////////////////////// +// Integral constant with casting +/////////////////////////////////////////////////////////////////////////////// + +template +struct integral_constant { + static constexpr constant T value = v; + using value_type = T; + using type = integral_constant; + + METAL_FUNC constexpr operator value_type() const noexcept { + return value; + } + + // METAL_FUNC constexpr value_type operator()() const noexcept { + // return value; + // } +}; + +template +using bool_constant = integral_constant; +using true_type = bool_constant; +using false_type = bool_constant; + +template +struct is_integral : bool_constant::value> {}; + +template +struct is_integral> + : bool_constant::value> {}; + +template +constexpr constant bool is_integral_v = is_integral::value; + +template +using Int = integral_constant; + +/////////////////////////////////////////////////////////////////////////////// +// Binary Operators on Integral constants +/////////////////////////////////////////////////////////////////////////////// + +#define integral_const_binop(__op__, __operator__) \ + template \ + METAL_FUNC constexpr auto __operator__( \ + integral_constant, integral_constant) { \ + constexpr auto res = tv __op__ uv; \ + return integral_constant{}; \ + } + +integral_const_binop(+, operator+); +integral_const_binop(-, operator-); +integral_const_binop(*, operator*); +integral_const_binop(/, operator/); + +integral_const_binop(==, operator==); +integral_const_binop(!=, operator!=); +integral_const_binop(<, operator<); +integral_const_binop(>, operator>); +integral_const_binop(<=, operator<=); +integral_const_binop(>=, operator>=); + +integral_const_binop(&&, operator&&); +integral_const_binop(||, operator||); + +#undef integral_const_binop + +/////////////////////////////////////////////////////////////////////////////// +// Reduction operators +/////////////////////////////////////////////////////////////////////////////// + +template +METAL_FUNC constexpr T sum(T x) { + return x; +} + +template +METAL_FUNC constexpr auto sum(T x, Us... us) { + return x + sum(us...); +} + +} // namespace steel +} // namespace mlx + +#pragma METAL internals : disable \ No newline at end of file diff --git a/mlx/backend/metal/kernels/steel/utils/type_traits.h b/mlx/backend/metal/kernels/steel/utils/type_traits.h new file mode 100644 index 0000000000..f004dc836a --- /dev/null +++ b/mlx/backend/metal/kernels/steel/utils/type_traits.h @@ -0,0 +1,55 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include + +#pragma METAL internals : enable + +namespace metal { + +template +struct is_empty : metal::bool_constant<__is_empty(T)> {}; + +#ifdef __cpp_variable_templates +template +constexpr constant bool is_empty_v = is_empty::value; +#endif + +template +struct make_void { + typedef void type; +}; + +template +using void_t = typename make_void::type; + +template +struct is_static : metal::bool_constant>::value> {}; + +template +struct pointer_element {}; + +template +struct pointer_element { + using type = remove_cv_t; +}; +template +struct pointer_element { + using type = remove_cv_t; +}; +template +struct pointer_element { + using type = remove_cv_t; +}; +template +struct pointer_element { + using type = remove_cv_t; +}; + +template +using pointer_element_t = typename pointer_element>::type; + +} // namespace metal + +#pragma METAL internals : disable \ No newline at end of file diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 3d8973fae1..0614fadc7b 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -88,6 +88,83 @@ inline auto collapse_batches(const array& a, const array& b, const array& c) { // Steel matmul fallback /////////////////////////////////////////////////////////////////////////////// +#define GEMM_TPARAM_MACRO(devc) \ + if (devc == 'g') { /* Small device */ \ + if (!transpose_a && transpose_b) { /* nt */ \ + bm = 64; \ + bn = 32; \ + bk = 32; \ + wm = 2; \ + wn = 2; \ + } else if (out.dtype() != float32) { /* half and bfloat */ \ + bm = 64; \ + bn = 64; \ + bk = 16; \ + wm = 1; \ + wn = 2; \ + } \ + } else if (devc == 'd') { /* Large device */ \ + if ((size_t)batch_size_out * M * N >= 1ul << 20) { /* large matmul */ \ + if (out.dtype() != float32) { /* half and bfloat */ \ + if (2 * std::max(M, N) > K) { /* Reasonable K */ \ + bm = 64; \ + bn = 64; \ + bk = 16; \ + wm = 1; \ + wn = 2; \ + } else if (!transpose_a && transpose_b) { /* nt with large k */ \ + bm = 64; \ + bn = 32; \ + bk = 32; \ + wm = 2; \ + wn = 2; \ + } else { /* nn with large K */ \ + bm = 32; \ + bn = 64; \ + bk = 16; \ + wm = 1; \ + wn = 2; \ + } \ + } /* float takes default */ \ + } else { /* smaller matmul */ \ + if (out.dtype() != float32) { /* half and bfloat */ \ + if (!transpose_a && transpose_b) { /* nt */ \ + bm = 64; \ + bn = 32; \ + bk = 32; \ + wm = 2; \ + wn = 2; \ + } else { /* nn */ \ + bm = 64; \ + bn = 64; \ + bk = 16; \ + wm = 1; \ + wn = 2; \ + } \ + } else { /* floats */ \ + if (!transpose_a && transpose_b) { /* nt */ \ + bm = 32; \ + bn = 64; \ + bk = 16; \ + wm = 1; \ + wn = 2; \ + } else { /* nn */ \ + bm = 64; \ + bn = 32; \ + bk = 32; \ + wm = 2; \ + wn = 2; \ + } \ + } \ + } \ + } else { /* Medium device */ \ + bm = 64; \ + bn = 64; \ + bk = 16; \ + wm = 2; \ + wn = 2; \ + } + void steel_matmul_regular( const Stream& s, metal::Device& d, @@ -112,19 +189,11 @@ void steel_matmul_regular( using namespace mlx::steel; // Determine dispatch kernel - int bm = 32, bn = 32, bk = 16; + int bm = 64, bn = 64, bk = 16; int wm = 2, wn = 2; - if ((size_t)batch_size_out * M * N >= 1ul << 20) { - if (!transpose_a && transpose_b) { - bm = 64; - bn = (out.dtype() == float32) ? 64 : 32; - bk = (out.dtype() == float32) ? 16 : 32; - } else { - bm = 64; - bn = 64; - } - } + char devc = d.get_architecture().back(); + GEMM_TPARAM_MACRO(devc) // Prepare kernel name std::ostringstream kname; @@ -903,19 +972,11 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { // Regular addmm dispatch // Determine dispatch kernel - int bm = 32, bn = 32, bk = 16; + int bm = 64, bn = 64, bk = 16; int wm = 2, wn = 2; - if ((size_t)batch_size_out * M * N >= 1ul << 20) { - if (!transpose_a && transpose_b) { - bm = 64; - bn = (out.dtype() == float32) ? 64 : 32; - bk = (out.dtype() == float32) ? 16 : 32; - } else { - bm = 64; - bn = 64; - } - } + char devc = d.get_architecture().back(); + GEMM_TPARAM_MACRO(devc) // Prepare kernel name std::ostringstream kname; @@ -1667,19 +1728,11 @@ void GatherMM::eval_gpu(const std::vector& inputs, array& out) { // Regular kernel dispatch // Determine dispatch kernel - int bm = 32, bn = 32, bk = 16; + int bm = 64, bn = 64, bk = 16; int wm = 2, wn = 2; - if ((size_t)batch_size_out * M * N >= 1ul << 20) { - if (!transpose_a && transpose_b) { - bm = 64; - bn = (out.dtype() == float32) ? 64 : 32; - bk = (out.dtype() == float32) ? 16 : 32; - } else { - bm = 64; - bn = 64; - } - } + char devc = d.get_architecture().back(); + GEMM_TPARAM_MACRO(devc) // Prepare kernel name std::ostringstream kname;