mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-22 19:28:14 +08:00 
			
		
		
		
	Gemm update (#1518)
This commit is contained in:
		| @@ -181,6 +181,7 @@ Device::Device() { | |||||||
|   auto pool = new_scoped_memory_pool(); |   auto pool = new_scoped_memory_pool(); | ||||||
|   device_ = load_device(); |   device_ = load_device(); | ||||||
|   library_map_ = {{"mlx", load_library(device_)}}; |   library_map_ = {{"mlx", load_library(device_)}}; | ||||||
|  |   arch_ = std::string(device_->architecture()->name()->utf8String()); | ||||||
| } | } | ||||||
|  |  | ||||||
| Device::~Device() { | Device::~Device() { | ||||||
|   | |||||||
| @@ -136,6 +136,10 @@ class Device { | |||||||
|     return device_; |     return device_; | ||||||
|   }; |   }; | ||||||
|  |  | ||||||
|  |   const std::string& get_architecture() { | ||||||
|  |     return arch_; | ||||||
|  |   } | ||||||
|  |  | ||||||
|   void new_queue(int index); |   void new_queue(int index); | ||||||
|   MTL::CommandBuffer* get_command_buffer(int index); |   MTL::CommandBuffer* get_command_buffer(int index); | ||||||
|   int get_command_buffer_ops(int index); |   int get_command_buffer_ops(int index); | ||||||
| @@ -228,6 +232,7 @@ class Device { | |||||||
|   std::shared_mutex library_mtx_; |   std::shared_mutex library_mtx_; | ||||||
|   std::unordered_map<std::string, MTL::Library*> library_map_; |   std::unordered_map<std::string, MTL::Library*> library_map_; | ||||||
|   const MTL::ResidencySet* residency_set_{nullptr}; |   const MTL::ResidencySet* residency_set_{nullptr}; | ||||||
|  |   std::string arch_; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| Device& device(mlx::core::Device); | Device& device(mlx::core::Device); | ||||||
|   | |||||||
| @@ -50,7 +50,9 @@ set(STEEL_HEADERS | |||||||
|     steel/gemm/transforms.h |     steel/gemm/transforms.h | ||||||
|     steel/gemm/kernels/steel_gemm_fused.h |     steel/gemm/kernels/steel_gemm_fused.h | ||||||
|     steel/gemm/kernels/steel_gemm_masked.h |     steel/gemm/kernels/steel_gemm_masked.h | ||||||
|     steel/gemm/kernels/steel_gemm_splitk.h) |     steel/gemm/kernels/steel_gemm_splitk.h | ||||||
|  |     steel/utils/type_traits.h | ||||||
|  |     steel/utils/integral_constant.h) | ||||||
|  |  | ||||||
| if(NOT MLX_METAL_JIT) | if(NOT MLX_METAL_JIT) | ||||||
|   build_kernel(arange arange.h) |   build_kernel(arange arange.h) | ||||||
|   | |||||||
| @@ -142,8 +142,8 @@ implicit_gemm_conv_2d_general( | |||||||
|   // Store results to device memory |   // Store results to device memory | ||||||
|   { |   { | ||||||
|     // Adjust for simdgroup and thread locatio |     // Adjust for simdgroup and thread locatio | ||||||
|     int offset_m = c_row + mma_op.sm + mma_op.tm; |     int offset_m = c_row + mma_op.sm; | ||||||
|     int offset_n = c_col + mma_op.sn + mma_op.tn; |     int offset_n = c_col + mma_op.sn; | ||||||
|     C += offset_n; |     C += offset_n; | ||||||
|  |  | ||||||
|     if (offset_n >= gemm_params->N) |     if (offset_n >= gemm_params->N) | ||||||
| @@ -169,17 +169,17 @@ implicit_gemm_conv_2d_general( | |||||||
|         STEEL_PRAGMA_UNROLL |         STEEL_PRAGMA_UNROLL | ||||||
|         for (int j = 0; j < mma_t::TN; j++) { |         for (int j = 0; j < mma_t::TN; j++) { | ||||||
|           // Get accumulated result and associated offset in C |           // Get accumulated result and associated offset in C | ||||||
|           thread const auto& accum = |           thread const auto& accum = mma_op.Ctile.frag_at(i, j); | ||||||
|               mma_op.results[i * mma_t::TN + j].thread_elements(); |  | ||||||
|           int offset = offset_cm + (j * mma_t::TN_stride); |           int offset = offset_cm + (j * mma_t::TN_stride); | ||||||
|  |  | ||||||
|           // Apply epilogue and output C |           constexpr short kelems = decltype(mma_op.Ctile)::kElemsPerFrag; | ||||||
|           if (j * mma_t::TN_stride < diff) { |  | ||||||
|             C[offset] = Epilogue::apply(accum[0]); |  | ||||||
|           } |  | ||||||
|  |  | ||||||
|           if (j * mma_t::TN_stride + 1 < diff) { |           // Apply epilogue and output C | ||||||
|             C[offset + 1] = Epilogue::apply(accum[1]); |           STEEL_PRAGMA_UNROLL | ||||||
|  |           for (short k = 0; k < kelems; k++) { | ||||||
|  |             if ((j * mma_t::TN_stride + k) < diff) { | ||||||
|  |               C[offset + k] = Epilogue::apply(accum[k]); | ||||||
|  |             } | ||||||
|           } |           } | ||||||
|         } |         } | ||||||
|       } |       } | ||||||
|   | |||||||
| @@ -36,11 +36,11 @@ | |||||||
|     instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) |     instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) | ||||||
|  |  | ||||||
| #define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \ | #define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \ | ||||||
|     instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \ |  | ||||||
|     instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \ |     instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \ | ||||||
|  |     instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \ | ||||||
|     instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \ |     instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \ | ||||||
|     instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \ |     instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \ | ||||||
|     instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) |     instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) | ||||||
|  |  | ||||||
| instantiate_gemm_shapes_helper(float16, half, float16, half); | instantiate_gemm_shapes_helper(float16, half, float16, half); | ||||||
| instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t); | instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t); | ||||||
|   | |||||||
| @@ -8,6 +8,7 @@ | |||||||
|  |  | ||||||
| #include "mlx/backend/metal/kernels/steel/defines.h" | #include "mlx/backend/metal/kernels/steel/defines.h" | ||||||
| #include "mlx/backend/metal/kernels/steel/gemm/transforms.h" | #include "mlx/backend/metal/kernels/steel/gemm/transforms.h" | ||||||
|  | #include "mlx/backend/metal/kernels/steel/utils/integral_constant.h" | ||||||
|  |  | ||||||
| using namespace metal; | using namespace metal; | ||||||
|  |  | ||||||
| @@ -18,6 +19,347 @@ using namespace metal; | |||||||
| namespace mlx { | namespace mlx { | ||||||
| namespace steel { | namespace steel { | ||||||
|  |  | ||||||
|  | template <typename T, int kFragRows_, int kFragCols_> | ||||||
|  | struct BaseMMAFrag { | ||||||
|  |   static_assert( | ||||||
|  |       kFragRows_ == 8, | ||||||
|  |       "Only 8 x 8 fragment matrices are currently supported"); | ||||||
|  |   static_assert( | ||||||
|  |       kFragCols_ == 8, | ||||||
|  |       "Only 8 x 8 fragment matrices are currently supported"); | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | template <typename T> | ||||||
|  | struct BaseMMAFrag<T, 8, 8> { | ||||||
|  |   STEEL_CONST int kFragRows = 8; | ||||||
|  |   STEEL_CONST int kFragCols = 8; | ||||||
|  |  | ||||||
|  |   STEEL_CONST int kElemsPerFrag = (kFragRows * kFragCols) / 32; | ||||||
|  |  | ||||||
|  |   STEEL_CONST int kElemRows = 1; | ||||||
|  |   STEEL_CONST int kElemCols = 2; | ||||||
|  |  | ||||||
|  |   static_assert( | ||||||
|  |       kElemRows * kElemCols == kElemsPerFrag, | ||||||
|  |       "MMAFrag shape is not consistent with MMAFrag size"); | ||||||
|  |  | ||||||
|  |   typedef metal::simdgroup_matrix<T, kFragRows, kFragCols> mat_type; | ||||||
|  |   typedef metal::vec<T, kElemsPerFrag> frag_type; | ||||||
|  |  | ||||||
|  |   METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id | ||||||
|  |                                                [[thread_index_in_simdgroup]]) { | ||||||
|  |     const short qid = simd_lane_id / 4; | ||||||
|  |     const short fm = (qid & 4) + ((simd_lane_id / 2) % 4); | ||||||
|  |     const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; | ||||||
|  |     return short2{fn, fm}; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   template <typename SrcPtrType, typename StrX, typename StrY> | ||||||
|  |   METAL_FUNC static constexpr void | ||||||
|  |   load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) { | ||||||
|  |     STEEL_PRAGMA_UNROLL | ||||||
|  |     for (short i = 0; i < kElemRows; i++) { | ||||||
|  |       STEEL_PRAGMA_UNROLL | ||||||
|  |       for (short j = 0; j < kElemCols; j++) { | ||||||
|  |         dst[i * kElemCols + j] = static_cast<T>(src[i * str_x + j * str_y]); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   template < | ||||||
|  |       typename SrcPtrType, | ||||||
|  |       typename StrX, | ||||||
|  |       typename StrY, | ||||||
|  |       typename LimX, | ||||||
|  |       typename LimY, | ||||||
|  |       typename OffX, | ||||||
|  |       typename OffY> | ||||||
|  |   METAL_FUNC static constexpr void load_safe( | ||||||
|  |       thread frag_type& dst, | ||||||
|  |       SrcPtrType src, | ||||||
|  |       StrX str_x, | ||||||
|  |       StrY str_y, | ||||||
|  |       LimX lim_x, | ||||||
|  |       LimY lim_y, | ||||||
|  |       OffX off_x = Int<0>{}, | ||||||
|  |       OffY off_y = Int<0>{}) { | ||||||
|  |     STEEL_PRAGMA_UNROLL | ||||||
|  |     for (short i = 0; i < kElemRows; i++) { | ||||||
|  |       STEEL_PRAGMA_UNROLL | ||||||
|  |       for (short j = 0; j < kElemCols; j++) { | ||||||
|  |         if ((off_x + i) < lim_x && (off_y + j) < lim_y) { | ||||||
|  |           dst[i * kElemCols + j] = | ||||||
|  |               static_cast<T>(src[(off_x + i) * str_x + (off_x + j) * str_y]); | ||||||
|  |         } else { | ||||||
|  |           dst[i * kElemCols + j] = T(0); | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   template <typename DstPtrType, typename StrX, typename StrY> | ||||||
|  |   METAL_FUNC static constexpr void | ||||||
|  |   store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) { | ||||||
|  |     using U = pointer_element_t<DstPtrType>; | ||||||
|  |  | ||||||
|  |     STEEL_PRAGMA_UNROLL | ||||||
|  |     for (short i = 0; i < kElemRows; i++) { | ||||||
|  |       STEEL_PRAGMA_UNROLL | ||||||
|  |       for (short j = 0; j < kElemCols; j++) { | ||||||
|  |         dst[i * str_x + j * str_y] = static_cast<U>(src[i * kElemCols + j]); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   template < | ||||||
|  |       typename DstPtrType, | ||||||
|  |       typename StrX, | ||||||
|  |       typename StrY, | ||||||
|  |       typename LimX, | ||||||
|  |       typename LimY, | ||||||
|  |       typename OffX, | ||||||
|  |       typename OffY> | ||||||
|  |   METAL_FUNC static constexpr void store_safe( | ||||||
|  |       const thread frag_type& src, | ||||||
|  |       DstPtrType dst, | ||||||
|  |       StrX str_x, | ||||||
|  |       StrY str_y, | ||||||
|  |       LimX lim_x, | ||||||
|  |       LimY lim_y, | ||||||
|  |       OffX off_x = Int<0>{}, | ||||||
|  |       OffY off_y = Int<0>{}) { | ||||||
|  |     using U = pointer_element_t<DstPtrType>; | ||||||
|  |  | ||||||
|  |     STEEL_PRAGMA_UNROLL | ||||||
|  |     for (short i = 0; i < kElemRows; i++) { | ||||||
|  |       STEEL_PRAGMA_UNROLL | ||||||
|  |       for (short j = 0; j < kElemCols; j++) { | ||||||
|  |         if ((off_x + i) < lim_x && (off_y + j) < lim_y) { | ||||||
|  |           dst[(off_x + i) * str_x + (off_y + j) * str_y] = | ||||||
|  |               static_cast<U>(src[i * kElemCols + j]); | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   METAL_FUNC static constexpr void mma( | ||||||
|  |       thread frag_type& D, | ||||||
|  |       thread frag_type& A, | ||||||
|  |       thread frag_type& B, | ||||||
|  |       thread frag_type& C) { | ||||||
|  |     mat_type D_mat; | ||||||
|  |     mat_type A_mat; | ||||||
|  |     mat_type B_mat; | ||||||
|  |     mat_type C_mat; | ||||||
|  |  | ||||||
|  |     reinterpret_cast<thread frag_type&>(A_mat.thread_elements()) = A; | ||||||
|  |     reinterpret_cast<thread frag_type&>(B_mat.thread_elements()) = B; | ||||||
|  |     reinterpret_cast<thread frag_type&>(C_mat.thread_elements()) = C; | ||||||
|  |  | ||||||
|  |     mma(D_mat, A_mat, B_mat, C_mat); | ||||||
|  |  | ||||||
|  |     D = reinterpret_cast<thread frag_type&>(D_mat.thread_elements()); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   METAL_FUNC static constexpr void mma( | ||||||
|  |       thread mat_type& D, | ||||||
|  |       thread mat_type& A, | ||||||
|  |       thread mat_type& B, | ||||||
|  |       thread mat_type& C) { | ||||||
|  |     simdgroup_multiply_accumulate(D, A, B, C); | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | template < | ||||||
|  |     typename T, | ||||||
|  |     int kTileRows_, | ||||||
|  |     int kTileCols_, | ||||||
|  |     class MMAFrag_ = BaseMMAFrag<T, 8, 8>> | ||||||
|  | struct MMATile { | ||||||
|  |   using MMAFrag_t = MMAFrag_; | ||||||
|  |   using elem_type = T; | ||||||
|  |   STEEL_CONST int kFragRows = MMAFrag_t::kFragRows; | ||||||
|  |   STEEL_CONST int kFragCols = MMAFrag_t::kFragCols; | ||||||
|  |   STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag; | ||||||
|  |  | ||||||
|  |   STEEL_CONST int kTileRows = kTileRows_; | ||||||
|  |   STEEL_CONST int kTileCols = kTileCols_; | ||||||
|  |  | ||||||
|  |   STEEL_CONST int kRows = kTileRows * kFragRows; | ||||||
|  |   STEEL_CONST int kCols = kTileCols * kFragCols; | ||||||
|  |  | ||||||
|  |   STEEL_CONST int kNumFrags = kTileRows * kTileCols; | ||||||
|  |   STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag; | ||||||
|  |  | ||||||
|  |   typedef typename MMAFrag_t::mat_type mat_type; | ||||||
|  |   typedef typename MMAFrag_t::frag_type frag_type; | ||||||
|  |  | ||||||
|  |   frag_type val_frags[kNumFrags] = {frag_type(0)}; | ||||||
|  |  | ||||||
|  |   METAL_FUNC MMATile() thread {} | ||||||
|  |  | ||||||
|  |   METAL_FUNC constexpr void clear() { | ||||||
|  |     STEEL_PRAGMA_UNROLL | ||||||
|  |     for (short i = 0; i < kNumFrags; ++i) { | ||||||
|  |       val_frags[i] = frag_type(0); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) { | ||||||
|  |     return val_frags[i * kTileCols + j]; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   METAL_FUNC constexpr const thread frag_type& frag_at( | ||||||
|  |       const short i, | ||||||
|  |       const short j) const { | ||||||
|  |     return val_frags[i * kTileCols + j]; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   METAL_FUNC mat_type mat_at(const short i, const short j) { | ||||||
|  |     mat_type val_mat; | ||||||
|  |     STEEL_PRAGMA_UNROLL | ||||||
|  |     for (short ii = 0; ii < kElemsPerFrag; ++ii) { | ||||||
|  |       val_mat.thread_elements()[ii] = frag_at(i, j)[ii]; | ||||||
|  |     } | ||||||
|  |     return val_mat; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   METAL_FUNC thread elem_type* elems() { | ||||||
|  |     return reinterpret_cast<thread elem_type*>(val_frags); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   METAL_FUNC const thread elem_type* elems() const { | ||||||
|  |     return reinterpret_cast<const thread elem_type*>(val_frags); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   template <typename U, int w_x, int w_y, int str_x, int str_y> | ||||||
|  |   METAL_FUNC void load(const threadgroup U* src) { | ||||||
|  |     STEEL_PRAGMA_UNROLL | ||||||
|  |     for (short i = 0; i < kTileRows; ++i) { | ||||||
|  |       STEEL_PRAGMA_UNROLL | ||||||
|  |       for (short j = 0; j < kTileCols; ++j) { | ||||||
|  |         MMAFrag_t::load( | ||||||
|  |             frag_at(i, j), | ||||||
|  |             &( | ||||||
|  |                 src[(i * kFragRows) * w_x * str_x + | ||||||
|  |                     (j * kFragCols) * w_y * str_y]), | ||||||
|  |             Int<str_x>{}, | ||||||
|  |             Int<str_y>{}); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   template <typename U, int w_x, int w_y, int str_x, int str_y> | ||||||
|  |   METAL_FUNC void store(threadgroup U* dst) const { | ||||||
|  |     STEEL_PRAGMA_UNROLL | ||||||
|  |     for (short i = 0; i < kTileRows; ++i) { | ||||||
|  |       STEEL_PRAGMA_UNROLL | ||||||
|  |       for (short j = 0; j < kTileCols; ++j) { | ||||||
|  |         MMAFrag_t::store( | ||||||
|  |             frag_at(i, j), | ||||||
|  |             &( | ||||||
|  |                 dst[(i * kFragRows) * w_x * str_x + | ||||||
|  |                     (j * kFragCols) * w_y * str_y]), | ||||||
|  |             Int<str_x>{}, | ||||||
|  |             Int<str_y>{}); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   template <typename U, int w_x, int w_y> | ||||||
|  |   METAL_FUNC void load(const device U* src, const int ld) { | ||||||
|  |     STEEL_PRAGMA_UNROLL | ||||||
|  |     for (short i = 0; i < kTileRows; ++i) { | ||||||
|  |       STEEL_PRAGMA_UNROLL | ||||||
|  |       for (short j = 0; j < kTileCols; ++j) { | ||||||
|  |         MMAFrag_t::load( | ||||||
|  |             frag_at(i, j), | ||||||
|  |             &(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), | ||||||
|  |             ld, | ||||||
|  |             Int<1>{}); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   template <typename U, int w_x, int w_y> | ||||||
|  |   METAL_FUNC void store(device U* dst, const int ld) const { | ||||||
|  |     STEEL_PRAGMA_UNROLL | ||||||
|  |     for (short i = 0; i < kTileRows; ++i) { | ||||||
|  |       STEEL_PRAGMA_UNROLL | ||||||
|  |       for (short j = 0; j < kTileCols; ++j) { | ||||||
|  |         MMAFrag_t::store( | ||||||
|  |             frag_at(i, j), | ||||||
|  |             &(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), | ||||||
|  |             ld, | ||||||
|  |             Int<1>{}); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   template <typename U, int w_x, int w_y> | ||||||
|  |   METAL_FUNC void | ||||||
|  |   load_safe(const device U* src, const int ld, const short2 src_tile_dims) { | ||||||
|  |     STEEL_PRAGMA_UNROLL | ||||||
|  |     for (int i = 0; i < kTileRows; ++i) { | ||||||
|  |       STEEL_PRAGMA_UNROLL | ||||||
|  |       for (int j = 0; j < kTileCols; ++j) { | ||||||
|  |         MMAFrag_t::load_safe( | ||||||
|  |             frag_at(i, j), | ||||||
|  |             src, | ||||||
|  |             ld, | ||||||
|  |             Int<1>{}, | ||||||
|  |             src_tile_dims.y, | ||||||
|  |             src_tile_dims.x, | ||||||
|  |             (i * kFragRows) * w_x, | ||||||
|  |             (j * kFragCols) * w_y); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   template <typename U, int w_x, int w_y> | ||||||
|  |   METAL_FUNC void | ||||||
|  |   store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const { | ||||||
|  |     STEEL_PRAGMA_UNROLL | ||||||
|  |     for (int i = 0; i < kTileRows; ++i) { | ||||||
|  |       STEEL_PRAGMA_UNROLL | ||||||
|  |       for (int j = 0; j < kTileCols; ++j) { | ||||||
|  |         MMAFrag_t::store_safe( | ||||||
|  |             frag_at(i, j), | ||||||
|  |             dst, | ||||||
|  |             ld, | ||||||
|  |             Int<1>{}, | ||||||
|  |             dst_tile_dims.y, | ||||||
|  |             dst_tile_dims.x, | ||||||
|  |             (i * kFragRows) * w_x, | ||||||
|  |             (j * kFragCols) * w_y); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | template <typename T, typename U, int M, int N, int K> | ||||||
|  | METAL_FUNC void tile_matmad( | ||||||
|  |     thread MMATile<T, M, N>& D, | ||||||
|  |     thread MMATile<U, M, K>& A, | ||||||
|  |     thread MMATile<U, K, N>& B, | ||||||
|  |     thread MMATile<T, M, N>& C) { | ||||||
|  |   STEEL_PRAGMA_UNROLL | ||||||
|  |   for (short m = 0; m < M; ++m) { | ||||||
|  |     STEEL_PRAGMA_UNROLL | ||||||
|  |     for (short n = 0; n < N; ++n) { | ||||||
|  |       short n_serp = (m % 2) ? (N - 1 - n) : n; | ||||||
|  |       STEEL_PRAGMA_UNROLL | ||||||
|  |       for (short k = 0; k < K; ++k) { | ||||||
|  |         MMATile<T, M, N>::MMAFrag_t::mma( | ||||||
|  |             D.frag_at(m, n_serp), | ||||||
|  |             A.frag_at(m, k), | ||||||
|  |             B.frag_at(k, n_serp), | ||||||
|  |             C.frag_at(m, n_serp)); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
| template < | template < | ||||||
|     typename T, |     typename T, | ||||||
|     typename U, |     typename U, | ||||||
| @@ -33,39 +375,38 @@ template < | |||||||
|     typename AccumType = float, |     typename AccumType = float, | ||||||
|     typename Epilogue = TransformNone<U, AccumType>> |     typename Epilogue = TransformNone<U, AccumType>> | ||||||
| struct BlockMMA { | struct BlockMMA { | ||||||
|  |   // MMAFrag size | ||||||
|  |   STEEL_CONST short kFragSize = 8; | ||||||
|  |   using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>; | ||||||
|  |  | ||||||
|   // Warp tile simdgroup matrix strides along M |   // Warp tile simdgroup matrix strides along M | ||||||
|   STEEL_CONST short TM_stride = 8 * WM; |   STEEL_CONST short TM_stride = kFragSize * WM; | ||||||
|   // Warp tile simdgroup matrix strides along M |   // Warp tile simdgroup matrix strides along M | ||||||
|   STEEL_CONST short TN_stride = 8 * WN; |   STEEL_CONST short TN_stride = kFragSize * WN; | ||||||
|  |  | ||||||
|   // Warp tile size along M |   // Warp tile size along M | ||||||
|   STEEL_CONST short TM = BM / TM_stride; |   STEEL_CONST short TM = BM / TM_stride; | ||||||
|   // Warp tile size along N |   // Warp tile size along N | ||||||
|   STEEL_CONST short TN = BN / TN_stride; |   STEEL_CONST short TN = BN / TN_stride; | ||||||
|  |  | ||||||
|   // Strides of A, B along reduction axis |   // Threadgroup A strides | ||||||
|   STEEL_CONST short simd_stride_a = { |   STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M | ||||||
|       transpose_a ? TM_stride : TM_stride * lda_tgp}; |   STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K | ||||||
|   STEEL_CONST short simd_stride_b = { |  | ||||||
|       transpose_b ? TN_stride * ldb_tgp : TN_stride}; |  | ||||||
|  |  | ||||||
|   // Jump between elements |   // Threadgroup B strides | ||||||
|   STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1}; |   STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K | ||||||
|   STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1}; |   STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N | ||||||
|  |  | ||||||
|   STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8}; |   // Threadgroup strides along K | ||||||
|   STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp}; |   STEEL_CONST short tile_stride_a = kFragSize * A_str_k; | ||||||
|  |   STEEL_CONST short tile_stride_b = kFragSize * B_str_k; | ||||||
|  |  | ||||||
|   // Simdgroup matrices |   // Simdgroup matrices | ||||||
|   simdgroup_matrix<AccumType, 8, 8> Asimd[TM]; |   MMATile<AccumType, TM, 1, MMAFrag_acc_t> Atile; | ||||||
|   simdgroup_matrix<AccumType, 8, 8> Bsimd[TN]; |   MMATile<AccumType, 1, TN, MMAFrag_acc_t> Btile; | ||||||
|   simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = { |   MMATile<AccumType, TM, TN, MMAFrag_acc_t> Ctile; | ||||||
|       simdgroup_matrix<AccumType, 8, 8>(0)}; |  | ||||||
|  |  | ||||||
|   // Offsets within threadgroup |   // Offsets within threadgroup | ||||||
|   const short tm; |  | ||||||
|   const short tn; |  | ||||||
|  |  | ||||||
|   short sm; |   short sm; | ||||||
|   short sn; |   short sn; | ||||||
|  |  | ||||||
| @@ -75,18 +416,21 @@ struct BlockMMA { | |||||||
|   /* Constructor */ |   /* Constructor */ | ||||||
|   METAL_FUNC BlockMMA( |   METAL_FUNC BlockMMA( | ||||||
|       ushort simd_group_id [[simdgroup_index_in_threadgroup]], |       ushort simd_group_id [[simdgroup_index_in_threadgroup]], | ||||||
|       ushort simd_lane_id [[thread_index_in_simdgroup]]) |       ushort simd_lane_id [[thread_index_in_simdgroup]]) { | ||||||
|       : tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) { |  | ||||||
|     // Determine thread position in simdgroup matrix |     // Determine thread position in simdgroup matrix | ||||||
|     short qid = simd_lane_id / 4; |     short tm = kFragSize * (simd_group_id / WN); | ||||||
|     sm = (qid & 4) + (simd_lane_id / 2) % 4; |     short tn = kFragSize * (simd_group_id % WN); | ||||||
|     sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; |  | ||||||
|  |     short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); | ||||||
|  |     sm = simd_coord.y; | ||||||
|  |     sn = simd_coord.x; | ||||||
|  |  | ||||||
|     // Determine thread and simdgroup offset |     // Determine thread and simdgroup offset | ||||||
|     As_offset = |     As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // M, K | ||||||
|         transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp); |     Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // K, N | ||||||
|     Bs_offset = |  | ||||||
|         transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn)); |     sm += tm; | ||||||
|  |     sn += tn; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   /* (BM, BK) X (BK, BN) multiply accumulate function */ |   /* (BM, BK) X (BK, BN) multiply accumulate function */ | ||||||
| @@ -95,47 +439,20 @@ struct BlockMMA { | |||||||
|     As += As_offset; |     As += As_offset; | ||||||
|     Bs += Bs_offset; |     Bs += Bs_offset; | ||||||
|  |  | ||||||
|     // Iterate over BK in blocks of 8 |     // Iterate over BK in blocks of kFragSize | ||||||
|     STEEL_PRAGMA_UNROLL |     STEEL_PRAGMA_UNROLL | ||||||
|     for (short kk = 0; kk < BK; kk += 8) { |     for (short kk = 0; kk < BK; kk += kFragSize) { | ||||||
|       simdgroup_barrier(mem_flags::mem_none); |       simdgroup_barrier(mem_flags::mem_none); | ||||||
|  |  | ||||||
|       // Load elements from threadgroup A as simdgroup matrices |       Atile.template load<T, WM, 1, A_str_m, A_str_k>(As); | ||||||
|       STEEL_PRAGMA_UNROLL |  | ||||||
|       for (short i = 0; i < TM; i++) { |  | ||||||
|         Asimd[i].thread_elements()[0] = |  | ||||||
|             static_cast<AccumType>(As[i * simd_stride_a + 0]); |  | ||||||
|         Asimd[i].thread_elements()[1] = |  | ||||||
|             static_cast<AccumType>(As[i * simd_stride_a + jump_a]); |  | ||||||
|       } |  | ||||||
|  |  | ||||||
|       simdgroup_barrier(mem_flags::mem_none); |       simdgroup_barrier(mem_flags::mem_none); | ||||||
|  |  | ||||||
|       // Load elements from threadgroup B as simdgroup matrices |       Btile.template load<T, 1, WN, B_str_k, B_str_n>(Bs); | ||||||
|       STEEL_PRAGMA_UNROLL |  | ||||||
|       for (short j = 0; j < TN; j++) { |  | ||||||
|         Bsimd[j].thread_elements()[0] = |  | ||||||
|             static_cast<AccumType>(Bs[j * simd_stride_b + 0]); |  | ||||||
|         Bsimd[j].thread_elements()[1] = |  | ||||||
|             static_cast<AccumType>(Bs[j * simd_stride_b + jump_b]); |  | ||||||
|       } |  | ||||||
|  |  | ||||||
|       simdgroup_barrier(mem_flags::mem_none); |       simdgroup_barrier(mem_flags::mem_none); | ||||||
|  |  | ||||||
|       // Multiply and accumulate into result simdgroup matrices |       tile_matmad(Ctile, Atile, Btile, Ctile); | ||||||
|       STEEL_PRAGMA_UNROLL |  | ||||||
|       for (short i = 0; i < TM; i++) { |  | ||||||
|         STEEL_PRAGMA_UNROLL |  | ||||||
|         for (short j = 0; j < TN; j++) { |  | ||||||
|           short j_serp = (i % 2) ? (TN - 1 - j) : j; |  | ||||||
|  |  | ||||||
|           simdgroup_multiply_accumulate( |  | ||||||
|               results[i * TN + j_serp], |  | ||||||
|               Asimd[i], |  | ||||||
|               Bsimd[j_serp], |  | ||||||
|               results[i * TN + j_serp]); |  | ||||||
|         } |  | ||||||
|       } |  | ||||||
|  |  | ||||||
|       // Progress to next simdgroup tile |       // Progress to next simdgroup tile | ||||||
|       As += tile_stride_a; |       As += tile_stride_a; | ||||||
| @@ -144,58 +461,35 @@ struct BlockMMA { | |||||||
|   } |   } | ||||||
|  |  | ||||||
|   /* Store results from simdgroup_matrix results into device memory */ |   /* Store results from simdgroup_matrix results into device memory */ | ||||||
|   METAL_FUNC void store_result(device U* D, const int ldd) const { |   METAL_FUNC void store_result(device U* D, const int ldd) { | ||||||
|     // Adjust for simdgroup and thread location |     // Apply epilogue | ||||||
|     D += (sm + tm) * ldd + tn + sn; |  | ||||||
|  |  | ||||||
|     // Loop over all simdgroup tiles |  | ||||||
|     STEEL_PRAGMA_UNROLL |     STEEL_PRAGMA_UNROLL | ||||||
|     for (short i = 0; i < TM; i++) { |     for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { | ||||||
|       STEEL_PRAGMA_UNROLL |       Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); | ||||||
|       for (short j = 0; j < TN; j++) { |  | ||||||
|         // Get accumulated result and associated offset in C |  | ||||||
|         thread const auto& accum = results[i * TN + j].thread_elements(); |  | ||||||
|         int offset = (i * TM_stride) * ldd + (j * TN_stride); |  | ||||||
|  |  | ||||||
|         // Apply epilogue |  | ||||||
|         U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])}; |  | ||||||
|  |  | ||||||
|         // Write out D |  | ||||||
|         D[offset] = outs[0]; |  | ||||||
|         D[offset + 1] = outs[1]; |  | ||||||
|       } |  | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     // Adjust for simdgroup and thread location | ||||||
|  |     D += sm * ldd + sn; | ||||||
|  |  | ||||||
|  |     Ctile.template store<U, WM, WN>(D, ldd); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   METAL_FUNC void |   METAL_FUNC void | ||||||
|   store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) const { |   store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) { | ||||||
|  |     // Apply epilogue | ||||||
|  |     STEEL_PRAGMA_UNROLL | ||||||
|  |     for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { | ||||||
|  |       Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); | ||||||
|  |     } | ||||||
|  |  | ||||||
|     // Adjust for simdgroup and thread location |     // Adjust for simdgroup and thread location | ||||||
|     D += (sm + tm) * ldd + (tn + sn); |     D += sm * ldd + sn; | ||||||
|     dst_tile_dims -= short2(tn + sn, sm + tm); |     dst_tile_dims -= short2(sn, sm); | ||||||
|  |  | ||||||
|     if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) |     if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) | ||||||
|       return; |       return; | ||||||
|  |  | ||||||
|     STEEL_PRAGMA_UNROLL |     Ctile.template store_safe<U, WM, WN>(D, ldd, dst_tile_dims); | ||||||
|     for (int i = 0; i < TM; i++) { |  | ||||||
|       if (i * TM_stride < dst_tile_dims.y) { |  | ||||||
|         STEEL_PRAGMA_UNROLL |  | ||||||
|         for (int j = 0; j < TN; j++) { |  | ||||||
|           // Get accumulated result and associated offset in C |  | ||||||
|           thread const auto& accum = results[i * TN + j].thread_elements(); |  | ||||||
|           int offset = (i * TM_stride) * ldd + (j * TN_stride); |  | ||||||
|  |  | ||||||
|           // Apply epilogue and output C |  | ||||||
|           if (j * TN_stride < dst_tile_dims.x) { |  | ||||||
|             D[offset] = Epilogue::apply(accum[0]); |  | ||||||
|           } |  | ||||||
|  |  | ||||||
|           if (j * TN_stride + 1 < dst_tile_dims.x) { |  | ||||||
|             D[offset + 1] = Epilogue::apply(accum[1]); |  | ||||||
|           } |  | ||||||
|         } |  | ||||||
|       } |  | ||||||
|     } |  | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   /* Apply epilogue */ |   /* Apply epilogue */ | ||||||
| @@ -203,16 +497,8 @@ struct BlockMMA { | |||||||
|   METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) { |   METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) { | ||||||
|     // Loop over all simdgroup tiles |     // Loop over all simdgroup tiles | ||||||
|     STEEL_PRAGMA_UNROLL |     STEEL_PRAGMA_UNROLL | ||||||
|     for (short i = 0; i < TM; i++) { |     for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { | ||||||
|       STEEL_PRAGMA_UNROLL |       Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]); | ||||||
|       for (short j = 0; j < TN; j++) { |  | ||||||
|         // Get accumulated result and associated offset in C |  | ||||||
|         thread auto& accum = results[i * TN + j].thread_elements(); |  | ||||||
|  |  | ||||||
|         // Apply epilogue |  | ||||||
|         accum[0] = epilogue_op.apply(accum[0]); |  | ||||||
|         accum[1] = epilogue_op.apply(accum[1]); |  | ||||||
|       } |  | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
|  |  | ||||||
| @@ -224,7 +510,7 @@ struct BlockMMA { | |||||||
|       const int fdc, |       const int fdc, | ||||||
|       thread const BinaryEpilogue& epilogue_op) { |       thread const BinaryEpilogue& epilogue_op) { | ||||||
|     // Adjust for simdgroup and thread location |     // Adjust for simdgroup and thread location | ||||||
|     C += (sm + tm) * ldc + (tn + sn) * fdc; |     C += (sm)*ldc + (sn)*fdc; | ||||||
|  |  | ||||||
|     // Loop over all simdgroup tiles |     // Loop over all simdgroup tiles | ||||||
|     STEEL_PRAGMA_UNROLL |     STEEL_PRAGMA_UNROLL | ||||||
| @@ -232,12 +518,14 @@ struct BlockMMA { | |||||||
|       STEEL_PRAGMA_UNROLL |       STEEL_PRAGMA_UNROLL | ||||||
|       for (short j = 0; j < TN; j++) { |       for (short j = 0; j < TN; j++) { | ||||||
|         // Get accumulated result and associated offset in C |         // Get accumulated result and associated offset in C | ||||||
|         thread auto& accum = results[i * TN + j].thread_elements(); |         thread auto& accum = Ctile.frag_at(i, j); | ||||||
|         int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; |         int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; | ||||||
|  |  | ||||||
|         // Apply epilogue |         // Apply epilogue | ||||||
|         accum[0] = epilogue_op.apply(accum[0], C[offset_c]); |         STEEL_PRAGMA_UNROLL | ||||||
|         accum[1] = epilogue_op.apply(accum[1], C[offset_c + fdc]); |         for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) { | ||||||
|  |           accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); | ||||||
|  |         } | ||||||
|       } |       } | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| @@ -251,8 +539,8 @@ struct BlockMMA { | |||||||
|       short2 dst_tile_dims, |       short2 dst_tile_dims, | ||||||
|       thread const BinaryEpilogue& epilogue_op) { |       thread const BinaryEpilogue& epilogue_op) { | ||||||
|     // Adjust for simdgroup and thread location |     // Adjust for simdgroup and thread location | ||||||
|     C += (sm + tm) * ldc + (tn + sn) * fdc; |     C += (sm)*ldc + (sn)*fdc; | ||||||
|     dst_tile_dims -= short2(tn + sn, sm + tm); |     dst_tile_dims -= short2(sn, sm); | ||||||
|  |  | ||||||
|     if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) |     if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) | ||||||
|       return; |       return; | ||||||
| @@ -263,22 +551,26 @@ struct BlockMMA { | |||||||
|       STEEL_PRAGMA_UNROLL |       STEEL_PRAGMA_UNROLL | ||||||
|       for (short j = 0; j < TN; j++) { |       for (short j = 0; j < TN; j++) { | ||||||
|         // Get accumulated result and associated offset in C |         // Get accumulated result and associated offset in C | ||||||
|         thread auto& accum = results[i * TN + j].thread_elements(); |         thread auto& accum = Ctile.frag_at(i, j); | ||||||
|         int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; |         int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; | ||||||
|  |  | ||||||
|         // Read C |         constexpr short kelems = decltype(Ctile)::kElemsPerFrag; | ||||||
|         U c_elems[2] = {0}; |  | ||||||
|  |  | ||||||
|         if ((j * TN_stride + 1) < dst_tile_dims.x) { |         // Read C | ||||||
|           c_elems[0] = C[offset_c]; |         U c_elems[kelems] = {0}; | ||||||
|           c_elems[1] = C[offset_c + fdc]; |  | ||||||
|         } else if ((j * TN_stride) < dst_tile_dims.x) { |         STEEL_PRAGMA_UNROLL | ||||||
|           c_elems[0] = C[offset_c]; |         for (short k = 0; k < kelems; k++) { | ||||||
|  |           if ((j * TN_stride + k) < dst_tile_dims.x) { | ||||||
|  |             c_elems[k] = C[offset_c + k * fdc]; | ||||||
|  |           } | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         // Apply epilogue |         // Apply epilogue | ||||||
|         accum[0] = epilogue_op.apply(accum[0], c_elems[0]); |         STEEL_PRAGMA_UNROLL | ||||||
|         accum[1] = epilogue_op.apply(accum[1], c_elems[1]); |         for (short k = 0; k < kelems; k++) { | ||||||
|  |           accum[k] = epilogue_op.apply(accum[k], c_elems[k]); | ||||||
|  |         } | ||||||
|       } |       } | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| @@ -292,8 +584,10 @@ struct BlockMMA { | |||||||
|       const int fdc, |       const int fdc, | ||||||
|       thread const Epilogue& epilogue_op) const { |       thread const Epilogue& epilogue_op) const { | ||||||
|     // Adjust for simdgroup and thread location |     // Adjust for simdgroup and thread location | ||||||
|     C += (sm + tm) * ldc + (tn + sn) * fdc; |     C += (sm)*ldc + (sn)*fdc; | ||||||
|     D += (sm + tm) * ldd + tn + sn; |     D += (sm)*ldd + sn; | ||||||
|  |  | ||||||
|  |     constexpr short kelems = decltype(Ctile)::kElemsPerFrag; | ||||||
|  |  | ||||||
|     // Loop over all simdgroup tiles |     // Loop over all simdgroup tiles | ||||||
|     STEEL_PRAGMA_UNROLL |     STEEL_PRAGMA_UNROLL | ||||||
| @@ -301,18 +595,15 @@ struct BlockMMA { | |||||||
|       STEEL_PRAGMA_UNROLL |       STEEL_PRAGMA_UNROLL | ||||||
|       for (short j = 0; j < TN; j++) { |       for (short j = 0; j < TN; j++) { | ||||||
|         // Get accumulated result and associated offset in C |         // Get accumulated result and associated offset in C | ||||||
|         thread const auto& accum = results[i * TN + j].thread_elements(); |         thread const auto& accum = Ctile.frag_at(i, j); | ||||||
|         int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; |         int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; | ||||||
|         int offset_d = (i * TM_stride) * ldd + (j * TN_stride); |         int offset_d = (i * TM_stride) * ldd + (j * TN_stride); | ||||||
|  |  | ||||||
|         // Apply epilogue |         // Apply epilogue | ||||||
|         U outs[2] = { |         STEEL_PRAGMA_UNROLL | ||||||
|             epilogue_op.apply(accum[0], C[offset_c]), |         for (short k = 0; k < kelems; k++) { | ||||||
|             epilogue_op.apply(accum[1], C[offset_c + fdc])}; |           D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); | ||||||
|  |         } | ||||||
|         // Write out D |  | ||||||
|         D[offset_d] = outs[0]; |  | ||||||
|         D[offset_d + 1] = outs[1]; |  | ||||||
|       } |       } | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| @@ -326,30 +617,32 @@ struct BlockMMA { | |||||||
|       short2 dst_tile_dims, |       short2 dst_tile_dims, | ||||||
|       thread const Epilogue& epilogue_op) const { |       thread const Epilogue& epilogue_op) const { | ||||||
|     // Adjust for simdgroup and thread location |     // Adjust for simdgroup and thread location | ||||||
|     C += (sm + tm) * ldc + (tn + sn) * fdc; |     C += (sm)*ldc + (sn)*fdc; | ||||||
|     D += (sm + tm) * ldd + tn + sn; |     D += (sm)*ldd + sn; | ||||||
|     dst_tile_dims -= short2(tn + sn, sm + tm); |     dst_tile_dims -= short2(sn, sm); | ||||||
|  |  | ||||||
|     if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) |     if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) | ||||||
|       return; |       return; | ||||||
|  |  | ||||||
|  |     constexpr short kelems = decltype(Ctile)::kElemsPerFrag; | ||||||
|  |  | ||||||
|     STEEL_PRAGMA_UNROLL |     STEEL_PRAGMA_UNROLL | ||||||
|     for (int i = 0; i < TM; i++) { |     for (int i = 0; i < TM; i++) { | ||||||
|       if (i * TM_stride < dst_tile_dims.y) { |       if (i * TM_stride < dst_tile_dims.y) { | ||||||
|         STEEL_PRAGMA_UNROLL |         STEEL_PRAGMA_UNROLL | ||||||
|         for (int j = 0; j < TN; j++) { |         for (int j = 0; j < TN; j++) { | ||||||
|           // Get accumulated result and associated offset in C |           // Get accumulated result and associated offset in C | ||||||
|           thread const auto& accum = results[i * TN + j].thread_elements(); |           thread const auto& accum = Ctile.frag_at(i, j); | ||||||
|           int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; |           int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; | ||||||
|           int offset_d = (i * TM_stride) * ldd + (j * TN_stride); |           int offset_d = (i * TM_stride) * ldd + (j * TN_stride); | ||||||
|  |  | ||||||
|           // Apply epilogue and output C |           // Apply epilogue | ||||||
|           if (j * TN_stride < dst_tile_dims.x) { |           STEEL_PRAGMA_UNROLL | ||||||
|             D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]); |           for (short k = 0; k < kelems; k++) { | ||||||
|           } |             if ((j * TN_stride + k) < dst_tile_dims.x) { | ||||||
|  |               D[offset_d + k] = | ||||||
|           if (j * TN_stride + 1 < dst_tile_dims.x) { |                   epilogue_op.apply(accum[k], C[offset_c + k * fdc]); | ||||||
|             D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]); |             } | ||||||
|           } |           } | ||||||
|         } |         } | ||||||
|       } |       } | ||||||
|   | |||||||
							
								
								
									
										96
									
								
								mlx/backend/metal/kernels/steel/utils/integral_constant.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										96
									
								
								mlx/backend/metal/kernels/steel/utils/integral_constant.h
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,96 @@ | |||||||
|  | // Copyright © 2024 Apple Inc. | ||||||
|  |  | ||||||
|  | #pragma once | ||||||
|  |  | ||||||
|  | #include <metal_stdlib> | ||||||
|  | #include "mlx/backend/metal/kernels/steel/utils/type_traits.h" | ||||||
|  |  | ||||||
|  | #pragma METAL internals : enable | ||||||
|  |  | ||||||
|  | namespace mlx { | ||||||
|  | namespace steel { | ||||||
|  |  | ||||||
|  | /////////////////////////////////////////////////////////////////////////////// | ||||||
|  | // Integral constant with casting | ||||||
|  | /////////////////////////////////////////////////////////////////////////////// | ||||||
|  |  | ||||||
|  | template <typename T, T v> | ||||||
|  | struct integral_constant { | ||||||
|  |   static constexpr constant T value = v; | ||||||
|  |   using value_type = T; | ||||||
|  |   using type = integral_constant; | ||||||
|  |  | ||||||
|  |   METAL_FUNC constexpr operator value_type() const noexcept { | ||||||
|  |     return value; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // METAL_FUNC constexpr value_type operator()() const noexcept { | ||||||
|  |   //   return value; | ||||||
|  |   // } | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | template <bool B> | ||||||
|  | using bool_constant = integral_constant<bool, B>; | ||||||
|  | using true_type = bool_constant<true>; | ||||||
|  | using false_type = bool_constant<false>; | ||||||
|  |  | ||||||
|  | template <class T> | ||||||
|  | struct is_integral : bool_constant<metal::is_integral<T>::value> {}; | ||||||
|  |  | ||||||
|  | template <class T, T v> | ||||||
|  | struct is_integral<integral_constant<T, v>> | ||||||
|  |     : bool_constant<metal::is_integral<T>::value> {}; | ||||||
|  |  | ||||||
|  | template <typename T> | ||||||
|  | constexpr constant bool is_integral_v = is_integral<T>::value; | ||||||
|  |  | ||||||
|  | template <int val> | ||||||
|  | using Int = integral_constant<int, val>; | ||||||
|  |  | ||||||
|  | /////////////////////////////////////////////////////////////////////////////// | ||||||
|  | // Binary Operators on Integral constants | ||||||
|  | /////////////////////////////////////////////////////////////////////////////// | ||||||
|  |  | ||||||
|  | #define integral_const_binop(__op__, __operator__)          \ | ||||||
|  |   template <typename T, T tv, typename U, U uv>             \ | ||||||
|  |   METAL_FUNC constexpr auto __operator__(                   \ | ||||||
|  |       integral_constant<T, tv>, integral_constant<U, uv>) { \ | ||||||
|  |     constexpr auto res = tv __op__ uv;                      \ | ||||||
|  |     return integral_constant<decltype(res), res>{};         \ | ||||||
|  |   } | ||||||
|  |  | ||||||
|  | integral_const_binop(+, operator+); | ||||||
|  | integral_const_binop(-, operator-); | ||||||
|  | integral_const_binop(*, operator*); | ||||||
|  | integral_const_binop(/, operator/); | ||||||
|  |  | ||||||
|  | integral_const_binop(==, operator==); | ||||||
|  | integral_const_binop(!=, operator!=); | ||||||
|  | integral_const_binop(<, operator<); | ||||||
|  | integral_const_binop(>, operator>); | ||||||
|  | integral_const_binop(<=, operator<=); | ||||||
|  | integral_const_binop(>=, operator>=); | ||||||
|  |  | ||||||
|  | integral_const_binop(&&, operator&&); | ||||||
|  | integral_const_binop(||, operator||); | ||||||
|  |  | ||||||
|  | #undef integral_const_binop | ||||||
|  |  | ||||||
|  | /////////////////////////////////////////////////////////////////////////////// | ||||||
|  | // Reduction operators | ||||||
|  | /////////////////////////////////////////////////////////////////////////////// | ||||||
|  |  | ||||||
|  | template <typename T> | ||||||
|  | METAL_FUNC constexpr T sum(T x) { | ||||||
|  |   return x; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <typename T, typename... Us> | ||||||
|  | METAL_FUNC constexpr auto sum(T x, Us... us) { | ||||||
|  |   return x + sum(us...); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | } // namespace steel | ||||||
|  | } // namespace mlx | ||||||
|  |  | ||||||
|  | #pragma METAL internals : disable | ||||||
							
								
								
									
										55
									
								
								mlx/backend/metal/kernels/steel/utils/type_traits.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								mlx/backend/metal/kernels/steel/utils/type_traits.h
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,55 @@ | |||||||
|  | // Copyright © 2024 Apple Inc. | ||||||
|  |  | ||||||
|  | #pragma once | ||||||
|  |  | ||||||
|  | #include <metal_stdlib> | ||||||
|  |  | ||||||
|  | #pragma METAL internals : enable | ||||||
|  |  | ||||||
|  | namespace metal { | ||||||
|  |  | ||||||
|  | template <typename T> | ||||||
|  | struct is_empty : metal::bool_constant<__is_empty(T)> {}; | ||||||
|  |  | ||||||
|  | #ifdef __cpp_variable_templates | ||||||
|  | template <typename T> | ||||||
|  | constexpr constant bool is_empty_v = is_empty<T>::value; | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | template <typename... Ts> | ||||||
|  | struct make_void { | ||||||
|  |   typedef void type; | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | template <typename... Ts> | ||||||
|  | using void_t = typename make_void<Ts...>::type; | ||||||
|  |  | ||||||
|  | template <class T> | ||||||
|  | struct is_static : metal::bool_constant<is_empty<remove_cv_t<T>>::value> {}; | ||||||
|  |  | ||||||
|  | template <typename T> | ||||||
|  | struct pointer_element {}; | ||||||
|  |  | ||||||
|  | template <typename T> | ||||||
|  | struct pointer_element<thread T*> { | ||||||
|  |   using type = remove_cv_t<T>; | ||||||
|  | }; | ||||||
|  | template <typename T> | ||||||
|  | struct pointer_element<device T*> { | ||||||
|  |   using type = remove_cv_t<T>; | ||||||
|  | }; | ||||||
|  | template <typename T> | ||||||
|  | struct pointer_element<constant T*> { | ||||||
|  |   using type = remove_cv_t<T>; | ||||||
|  | }; | ||||||
|  | template <typename T> | ||||||
|  | struct pointer_element<threadgroup T*> { | ||||||
|  |   using type = remove_cv_t<T>; | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | template <typename T> | ||||||
|  | using pointer_element_t = typename pointer_element<remove_cv_t<T>>::type; | ||||||
|  |  | ||||||
|  | } // namespace metal | ||||||
|  |  | ||||||
|  | #pragma METAL internals : disable | ||||||
| @@ -88,6 +88,83 @@ inline auto collapse_batches(const array& a, const array& b, const array& c) { | |||||||
| // Steel matmul fallback | // Steel matmul fallback | ||||||
| /////////////////////////////////////////////////////////////////////////////// | /////////////////////////////////////////////////////////////////////////////// | ||||||
|  |  | ||||||
|  | #define GEMM_TPARAM_MACRO(devc)                                           \ | ||||||
|  |   if (devc == 'g') { /* Small device */                                   \ | ||||||
|  |     if (!transpose_a && transpose_b) { /* nt */                           \ | ||||||
|  |       bm = 64;                                                            \ | ||||||
|  |       bn = 32;                                                            \ | ||||||
|  |       bk = 32;                                                            \ | ||||||
|  |       wm = 2;                                                             \ | ||||||
|  |       wn = 2;                                                             \ | ||||||
|  |     } else if (out.dtype() != float32) { /* half and bfloat */            \ | ||||||
|  |       bm = 64;                                                            \ | ||||||
|  |       bn = 64;                                                            \ | ||||||
|  |       bk = 16;                                                            \ | ||||||
|  |       wm = 1;                                                             \ | ||||||
|  |       wn = 2;                                                             \ | ||||||
|  |     }                                                                     \ | ||||||
|  |   } else if (devc == 'd') { /* Large device */                            \ | ||||||
|  |     if ((size_t)batch_size_out * M * N >= 1ul << 20) { /* large matmul */ \ | ||||||
|  |       if (out.dtype() != float32) { /* half and bfloat */                 \ | ||||||
|  |         if (2 * std::max(M, N) > K) { /* Reasonable K */                  \ | ||||||
|  |           bm = 64;                                                        \ | ||||||
|  |           bn = 64;                                                        \ | ||||||
|  |           bk = 16;                                                        \ | ||||||
|  |           wm = 1;                                                         \ | ||||||
|  |           wn = 2;                                                         \ | ||||||
|  |         } else if (!transpose_a && transpose_b) { /* nt with large k */   \ | ||||||
|  |           bm = 64;                                                        \ | ||||||
|  |           bn = 32;                                                        \ | ||||||
|  |           bk = 32;                                                        \ | ||||||
|  |           wm = 2;                                                         \ | ||||||
|  |           wn = 2;                                                         \ | ||||||
|  |         } else { /* nn with large K */                                    \ | ||||||
|  |           bm = 32;                                                        \ | ||||||
|  |           bn = 64;                                                        \ | ||||||
|  |           bk = 16;                                                        \ | ||||||
|  |           wm = 1;                                                         \ | ||||||
|  |           wn = 2;                                                         \ | ||||||
|  |         }                                                                 \ | ||||||
|  |       } /* float takes default */                                         \ | ||||||
|  |     } else { /* smaller matmul */                                         \ | ||||||
|  |       if (out.dtype() != float32) { /* half and bfloat */                 \ | ||||||
|  |         if (!transpose_a && transpose_b) { /* nt */                       \ | ||||||
|  |           bm = 64;                                                        \ | ||||||
|  |           bn = 32;                                                        \ | ||||||
|  |           bk = 32;                                                        \ | ||||||
|  |           wm = 2;                                                         \ | ||||||
|  |           wn = 2;                                                         \ | ||||||
|  |         } else { /* nn */                                                 \ | ||||||
|  |           bm = 64;                                                        \ | ||||||
|  |           bn = 64;                                                        \ | ||||||
|  |           bk = 16;                                                        \ | ||||||
|  |           wm = 1;                                                         \ | ||||||
|  |           wn = 2;                                                         \ | ||||||
|  |         }                                                                 \ | ||||||
|  |       } else { /* floats */                                               \ | ||||||
|  |         if (!transpose_a && transpose_b) { /* nt */                       \ | ||||||
|  |           bm = 32;                                                        \ | ||||||
|  |           bn = 64;                                                        \ | ||||||
|  |           bk = 16;                                                        \ | ||||||
|  |           wm = 1;                                                         \ | ||||||
|  |           wn = 2;                                                         \ | ||||||
|  |         } else { /* nn */                                                 \ | ||||||
|  |           bm = 64;                                                        \ | ||||||
|  |           bn = 32;                                                        \ | ||||||
|  |           bk = 32;                                                        \ | ||||||
|  |           wm = 2;                                                         \ | ||||||
|  |           wn = 2;                                                         \ | ||||||
|  |         }                                                                 \ | ||||||
|  |       }                                                                   \ | ||||||
|  |     }                                                                     \ | ||||||
|  |   } else { /* Medium device */                                            \ | ||||||
|  |     bm = 64;                                                              \ | ||||||
|  |     bn = 64;                                                              \ | ||||||
|  |     bk = 16;                                                              \ | ||||||
|  |     wm = 2;                                                               \ | ||||||
|  |     wn = 2;                                                               \ | ||||||
|  |   } | ||||||
|  |  | ||||||
| void steel_matmul_regular( | void steel_matmul_regular( | ||||||
|     const Stream& s, |     const Stream& s, | ||||||
|     metal::Device& d, |     metal::Device& d, | ||||||
| @@ -112,19 +189,11 @@ void steel_matmul_regular( | |||||||
|   using namespace mlx::steel; |   using namespace mlx::steel; | ||||||
|  |  | ||||||
|   // Determine dispatch kernel |   // Determine dispatch kernel | ||||||
|   int bm = 32, bn = 32, bk = 16; |   int bm = 64, bn = 64, bk = 16; | ||||||
|   int wm = 2, wn = 2; |   int wm = 2, wn = 2; | ||||||
|  |  | ||||||
|   if ((size_t)batch_size_out * M * N >= 1ul << 20) { |   char devc = d.get_architecture().back(); | ||||||
|     if (!transpose_a && transpose_b) { |   GEMM_TPARAM_MACRO(devc) | ||||||
|       bm = 64; |  | ||||||
|       bn = (out.dtype() == float32) ? 64 : 32; |  | ||||||
|       bk = (out.dtype() == float32) ? 16 : 32; |  | ||||||
|     } else { |  | ||||||
|       bm = 64; |  | ||||||
|       bn = 64; |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   // Prepare kernel name |   // Prepare kernel name | ||||||
|   std::ostringstream kname; |   std::ostringstream kname; | ||||||
| @@ -903,19 +972,11 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) { | |||||||
|   // Regular addmm dispatch |   // Regular addmm dispatch | ||||||
|  |  | ||||||
|   // Determine dispatch kernel |   // Determine dispatch kernel | ||||||
|   int bm = 32, bn = 32, bk = 16; |   int bm = 64, bn = 64, bk = 16; | ||||||
|   int wm = 2, wn = 2; |   int wm = 2, wn = 2; | ||||||
|  |  | ||||||
|   if ((size_t)batch_size_out * M * N >= 1ul << 20) { |   char devc = d.get_architecture().back(); | ||||||
|     if (!transpose_a && transpose_b) { |   GEMM_TPARAM_MACRO(devc) | ||||||
|       bm = 64; |  | ||||||
|       bn = (out.dtype() == float32) ? 64 : 32; |  | ||||||
|       bk = (out.dtype() == float32) ? 16 : 32; |  | ||||||
|     } else { |  | ||||||
|       bm = 64; |  | ||||||
|       bn = 64; |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   // Prepare kernel name |   // Prepare kernel name | ||||||
|   std::ostringstream kname; |   std::ostringstream kname; | ||||||
| @@ -1667,19 +1728,11 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) { | |||||||
|   // Regular kernel dispatch |   // Regular kernel dispatch | ||||||
|  |  | ||||||
|   // Determine dispatch kernel |   // Determine dispatch kernel | ||||||
|   int bm = 32, bn = 32, bk = 16; |   int bm = 64, bn = 64, bk = 16; | ||||||
|   int wm = 2, wn = 2; |   int wm = 2, wn = 2; | ||||||
|  |  | ||||||
|   if ((size_t)batch_size_out * M * N >= 1ul << 20) { |   char devc = d.get_architecture().back(); | ||||||
|     if (!transpose_a && transpose_b) { |   GEMM_TPARAM_MACRO(devc) | ||||||
|       bm = 64; |  | ||||||
|       bn = (out.dtype() == float32) ? 64 : 32; |  | ||||||
|       bk = (out.dtype() == float32) ? 16 : 32; |  | ||||||
|     } else { |  | ||||||
|       bm = 64; |  | ||||||
|       bn = 64; |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   // Prepare kernel name |   // Prepare kernel name | ||||||
|   std::ostringstream kname; |   std::ostringstream kname; | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Jagrit Digani
					Jagrit Digani