diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index 6d46b03c7..671b45dca 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -1,4 +1,4 @@ -function(make_jit_source SRC_NAME) +function(make_jit_source SRC_FILE) # This function takes a metal header file, # runs the C preprocessesor on it, and makes # the processed contents available as a string in a C++ function @@ -9,6 +9,7 @@ function(make_jit_source SRC_NAME) # # Additional arguments to this function are treated as dependencies # in the Cmake build system. + get_filename_component(SRC_NAME ${SRC_FILE} NAME) add_custom_command( OUTPUT jit/${SRC_NAME}.cpp COMMAND /bin/bash @@ -16,10 +17,10 @@ function(make_jit_source SRC_NAME) ${CMAKE_CURRENT_BINARY_DIR}/jit ${CMAKE_C_COMPILER} ${PROJECT_SOURCE_DIR} - ${SRC_NAME} + ${SRC_FILE} "-D${MLX_METAL_VERSION}" DEPENDS make_compiled_preamble.sh - kernels/${SRC_NAME}.h + kernels/${SRC_FILE}.h ${ARGN} ) add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp) @@ -73,6 +74,39 @@ if (MLX_METAL_JIT) kernels/reduction/reduce_col.h kernels/reduction/reduce_row.h ) + make_jit_source( + steel/gemm/gemm + kernels/steel/utils.h + kernels/steel/gemm/loader.h + kernels/steel/gemm/mma.h + kernels/steel/gemm/params.h + kernels/steel/gemm/transforms.h + ) + make_jit_source(steel/gemm/kernels/steel_gemm_fused) + make_jit_source( + steel/gemm/kernels/steel_gemm_masked + kernels/steel/defines.h + ) + make_jit_source(steel/gemm/kernels/steel_gemm_splitk) + make_jit_source( + steel/conv/conv + kernels/steel/utils.h + kernels/steel/defines.h + kernels/steel/gemm/mma.h + kernels/steel/gemm/transforms.h + kernels/steel/conv/params.h + kernels/steel/conv/loader.h + kernels/steel/conv/loaders/loader_channel_l.h + kernels/steel/conv/loaders/loader_channel_n.h + ) + make_jit_source( + steel/conv/kernels/steel_conv + ) + make_jit_source( + steel/conv/kernels/steel_conv_general + kernels/steel/defines.h + kernels/steel/conv/loaders/loader_general.h + ) else() target_sources( mlx diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 03fda47e4..395488ead 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -7,6 +7,7 @@ #include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/steel/conv/params.h" #include "mlx/backend/metal/matmul.h" @@ -335,7 +336,17 @@ void implicit_gemm_conv_2D_gpu( // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); + auto kernel = get_steel_conv_kernel( + d, + kname.str(), + out, + bm, + bn, + bk, + wm, + wn, + n_channel_specialization, + small_filter); compute_encoder->setComputePipelineState(kernel); // Deduce grid launch dimensions @@ -488,7 +499,8 @@ void implicit_gemm_conv_2D_general_gpu( // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); + auto kernel = + get_steel_conv_general_kernel(d, kname.str(), out, bm, bn, bk, wm, wn); compute_encoder->setComputePipelineState(kernel); // Deduce grid launch dimensions diff --git a/mlx/backend/metal/jit/includes.h b/mlx/backend/metal/jit/includes.h index 7bfcbbedd..4bb4ac38d 100644 --- a/mlx/backend/metal/jit/includes.h +++ b/mlx/backend/metal/jit/includes.h @@ -23,4 +23,12 @@ const char* softmax(); const char* sort(); const char* reduce(); +const char* gemm(); +const char* steel_gemm_fused(); +const char* steel_gemm_masked(); +const char* steel_gemm_splitk(); +const char* conv(); +const char* steel_conv(); +const char* steel_conv_general(); + } // namespace mlx::core::metal diff --git a/mlx/backend/metal/jit/steel_conv.h b/mlx/backend/metal/jit/steel_conv.h new file mode 100644 index 000000000..44b8f95b9 --- /dev/null +++ b/mlx/backend/metal/jit/steel_conv.h @@ -0,0 +1,32 @@ +// Copyright © 2024 Apple Inc. + +constexpr std::string_view steel_conv_kernels = R"( +template [[host_name("{name}")]] [[kernel]] void +implicit_gemm_conv_2d<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}, {n_channels}, {small_filter}>( + const device {itype}* A [[buffer(0)]], + const device {itype}* B [[buffer(1)]], + device {itype}* C [[buffer(2)]], + const constant MLXConvParams<2>* params [[buffer(3)]], + const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]); +)"; + +constexpr std::string_view steel_conv_general_kernels = R"( +template [[host_name("{name}")]] [[kernel]] void + implicit_gemm_conv_2d_general<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}>( + const device {itype}* A [[buffer(0)]], + const device {itype}* B [[buffer(1)]], + device {itype}* C [[buffer(2)]], + const constant MLXConvParams<2>* params [[buffer(3)]], + const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], + const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]], + const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]], + const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]); +)"; diff --git a/mlx/backend/metal/jit/steel_gemm.h b/mlx/backend/metal/jit/steel_gemm.h new file mode 100644 index 000000000..d1a2378bf --- /dev/null +++ b/mlx/backend/metal/jit/steel_gemm.h @@ -0,0 +1,106 @@ +// Copyright © 2024 Apple Inc. + +constexpr std::string_view steel_gemm_fused_kernels = R"( +template [[host_name("{name}")]] +[[kernel]] void gemm<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}, {trans_a}, {trans_b}, float>( + const device {itype} *A [[buffer(0)]], + const device {itype} *B [[buffer(1)]], + const device {itype} *C [[buffer(2), function_constant(use_out_source)]], + device {itype} *D [[buffer(3)]], + const constant GEMMParams* params [[buffer(4)]], + const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], + const constant int* batch_shape [[buffer(6)]], + const constant size_t* batch_strides [[buffer(7)]], + const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]], + const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]], + const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]], + const constant int* operand_shape [[buffer(13), function_constant(do_gather)]], + const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]], + const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]); +)"; + +constexpr std::string_view steel_gemm_masked_kernels = R"( +template [[host_name("{name}")]] [[kernel]] void +block_masked_gemm< + {itype}, + {outmasktype}, + {opmasktype}, + {bm}, + {bn}, + {bk}, + {wm}, + {wn}, + {trans_a}, + {trans_b}, + {mn_aligned}, + {k_aligned}>( + const device {itype}* A [[buffer(0)]], + const device {itype}* B [[buffer(1)]], + device {itype}* D [[buffer(3)]], + const constant GEMMParams* params [[buffer(4)]], + const constant int* batch_shape [[buffer(6)]], + const constant size_t* batch_strides [[buffer(7)]], + const device {outmasktype}* out_mask [[buffer(10)]], + const device {opmasktype}* lhs_mask [[buffer(11)]], + const device {opmasktype}* rhs_mask [[buffer(12)]], + const constant int* mask_strides [[buffer(13)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]); +)"; + +constexpr std::string_view steel_gemm_splitk_kernels = R"( +template [[host_name("{name}")]] [[kernel]] void +gemm_splitk< + {itype}, + {otype}, + {bm}, + {bn}, + {bk}, + {wm}, + {wn}, + {trans_a}, + {trans_b}, + {mn_aligned}, + {k_aligned}>( + const device {itype}* A [[buffer(0)]], + const device {itype}* B [[buffer(1)]], + device {otype}* C [[buffer(2)]], + const constant GEMMSpiltKParams* params [[buffer(3)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]); +)"; + +constexpr std::string_view steel_gemm_splitk_accum_kernels = R"( +template [[host_name("{name}")]] [[kernel]] void +gemm_splitk_accum<{atype}, {otype}>( + const device {atype}* C_split [[buffer(0)]], + device {otype}* D [[buffer(1)]], + const constant int& k_partitions [[buffer(2)]], + const constant int& partition_stride [[buffer(3)]], + const constant int& ldd [[buffer(4)]], + uint2 gid [[thread_position_in_grid]]); +)"; + +constexpr std::string_view steel_gemm_splitk_accum_axbpy_kernels = R"( +template [[host_name("{name}")]] [[kernel]] void +gemm_splitk_accum_axpby<{atype}, {otype}>( + const device {atype}* C_split [[buffer(0)]], + device {otype}* D [[buffer(1)]], + const constant int& k_partitions [[buffer(2)]], + const constant int& partition_stride [[buffer(3)]], + const constant int& ldd [[buffer(4)]], + const device {otype}* C [[buffer(5)]], + const constant int& ldc [[buffer(6)]], + const constant int& fdc [[buffer(7)]], + const constant float& alpha [[buffer(8)]], + const constant float& beta [[buffer(9)]], + uint2 gid [[thread_position_in_grid]]); +)"; diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index 5d46adbc5..813b5c392 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -1,5 +1,4 @@ // Copyright © 2024 Apple Inc. - #include #include "mlx/backend/common/compiled.h" @@ -12,11 +11,15 @@ #include "mlx/backend/metal/jit/scan.h" #include "mlx/backend/metal/jit/softmax.h" #include "mlx/backend/metal/jit/sort.h" +#include "mlx/backend/metal/jit/steel_conv.h" +#include "mlx/backend/metal/jit/steel_gemm.h" #include "mlx/backend/metal/jit/ternary.h" #include "mlx/backend/metal/jit/unary.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" +using namespace fmt::literals; + namespace mlx::core { std::string op_name(const array& arr) { @@ -276,4 +279,208 @@ MTL::ComputePipelineState* get_reduce_kernel( return d.get_kernel(kernel_name, lib); } +MTL::ComputePipelineState* get_steel_gemm_fused_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array& out, + bool transpose_a, + bool transpose_b, + int bm, + int bn, + int bk, + int wm, + int wn) { + const auto& lib_name = kernel_name; + auto lib = d.get_library(lib_name); + if (lib == nullptr) { + std::ostringstream kernel_source; + kernel_source << metal::utils() << metal::gemm() + << metal::steel_gemm_fused() + << fmt::format( + steel_gemm_fused_kernels, + "name"_a = lib_name, + "itype"_a = get_type_string(out.dtype()), + "bm"_a = bm, + "bn"_a = bn, + "bk"_a = bk, + "wm"_a = wm, + "wn"_a = wn, + "trans_a"_a = transpose_a, + "trans_b"_a = transpose_b); + lib = d.get_library(lib_name, kernel_source.str()); + } + return d.get_kernel(kernel_name, lib, hash_name, func_consts); +} + +MTL::ComputePipelineState* get_steel_gemm_splitk_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& in, + const array& out, + bool transpose_a, + bool transpose_b, + int bm, + int bn, + int bk, + int wm, + int wn, + bool mn_aligned, + bool k_aligned) { + const auto& lib_name = kernel_name; + auto lib = d.get_library(lib_name); + if (lib == nullptr) { + std::ostringstream kernel_source; + kernel_source << metal::utils() << metal::gemm() + << metal::steel_gemm_splitk() + << fmt::format( + steel_gemm_splitk_kernels, + "name"_a = lib_name, + "itype"_a = get_type_string(in.dtype()), + "otype"_a = get_type_string(out.dtype()), + "bm"_a = bm, + "bn"_a = bn, + "bk"_a = bk, + "wm"_a = wm, + "wn"_a = wn, + "trans_a"_a = transpose_a, + "trans_b"_a = transpose_b, + "mn_aligned"_a = mn_aligned, + "k_aligned"_a = k_aligned); + lib = d.get_library(lib_name, kernel_source.str()); + } + return d.get_kernel(kernel_name, lib); +} + +MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& in, + const array& out, + bool axbpy) { + const auto& lib_name = kernel_name; + auto lib = d.get_library(lib_name); + if (lib == nullptr) { + std::ostringstream kernel_source; + kernel_source << metal::utils() << metal::gemm() + << metal::steel_gemm_splitk() + << fmt::format( + axbpy ? steel_gemm_splitk_accum_axbpy_kernels + : steel_gemm_splitk_accum_kernels, + "name"_a = lib_name, + "atype"_a = get_type_string(in.dtype()), + "otype"_a = get_type_string(out.dtype())); + lib = d.get_library(lib_name, kernel_source.str()); + } + return d.get_kernel(kernel_name, lib); +} + +MTL::ComputePipelineState* get_steel_gemm_masked_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& out, + const std::optional& mask_out, + const std::optional& mask_op, + bool transpose_a, + bool transpose_b, + int bm, + int bn, + int bk, + int wm, + int wn, + bool mn_aligned, + bool k_aligned) { + const auto& lib_name = kernel_name; + auto lib = d.get_library(lib_name); + if (lib == nullptr) { + std::ostringstream kernel_source; + auto out_mask_type = mask_out.has_value() + ? get_type_string((*mask_out).dtype()) + : "nomask_t"; + auto op_mask_type = + mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t"; + kernel_source << metal::utils() << metal::gemm() + << metal::steel_gemm_masked() + << fmt::format( + steel_gemm_masked_kernels, + "name"_a = lib_name, + "itype"_a = get_type_string(out.dtype()), + "outmasktype"_a = out_mask_type, + "opmasktype"_a = op_mask_type, + "bm"_a = bm, + "bn"_a = bn, + "bk"_a = bk, + "wm"_a = wm, + "wn"_a = wn, + "trans_a"_a = transpose_a, + "trans_b"_a = transpose_b, + "mn_aligned"_a = mn_aligned, + "k_aligned"_a = k_aligned); + lib = d.get_library(lib_name, kernel_source.str()); + } + return d.get_kernel(kernel_name, lib); +} + +MTL::ComputePipelineState* get_steel_conv_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& out, + int bm, + int bn, + int bk, + int wm, + int wn, + int n_channel_specialization, + bool small_filter) { + const auto& lib_name = kernel_name; + auto lib = d.get_library(lib_name); + if (lib == nullptr) { + std::ostringstream kernel_source; + kernel_source << metal::utils() << metal::conv() << metal::steel_conv() + << fmt::format( + steel_conv_kernels, + "name"_a = lib_name, + "itype"_a = get_type_string(out.dtype()), + "bm"_a = bm, + "bn"_a = bn, + "bk"_a = bk, + "wm"_a = wm, + "wn"_a = wn, + "n_channels"_a = n_channel_specialization, + "small_filter"_a = small_filter); + lib = d.get_library(lib_name, kernel_source.str()); + } + return d.get_kernel(kernel_name, lib); +} + +MTL::ComputePipelineState* get_steel_conv_general_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& out, + int bm, + int bn, + int bk, + int wm, + int wn) { + const auto& lib_name = kernel_name; + auto lib = d.get_library(lib_name); + if (lib == nullptr) { + std::ostringstream kernel_source; + kernel_source << metal::utils() << metal::conv() + << metal::steel_conv_general() + << fmt::format( + steel_conv_general_kernels, + "name"_a = lib_name, + "itype"_a = get_type_string(out.dtype()), + "bm"_a = bm, + "bn"_a = bn, + "bk"_a = bk, + "wm"_a = wm, + "wn"_a = wn); + lib = d.get_library(lib_name, kernel_source.str()); + } + return d.get_kernel(kernel_name, lib); +} + } // namespace mlx::core diff --git a/mlx/backend/metal/kernels.h b/mlx/backend/metal/kernels.h index 82ce8e5ea..e1df0521b 100644 --- a/mlx/backend/metal/kernels.h +++ b/mlx/backend/metal/kernels.h @@ -79,4 +79,78 @@ MTL::ComputePipelineState* get_reduce_kernel( const array& in, const array& out); +MTL::ComputePipelineState* get_steel_gemm_fused_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array& out, + bool transpose_a, + bool transpose_b, + int bm, + int bn, + int bk, + int wm, + int wn); + +MTL::ComputePipelineState* get_steel_gemm_splitk_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& in, + const array& out, + bool transpose_a, + bool transpose_b, + int bm, + int bn, + int bk, + int wm, + int wn, + bool mn_aligned, + bool k_aligned); + +MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& in, + const array& out, + bool axbpy); + +MTL::ComputePipelineState* get_steel_gemm_masked_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& out, + const std::optional& mask_out, + const std::optional& mask_op, + bool transpose_a, + bool transpose_b, + int bm, + int bn, + int bk, + int wm, + int wn, + bool mn_aligned, + bool k_aligned); + +MTL::ComputePipelineState* get_steel_conv_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& out, + int bm, + int bn, + int bk, + int wm, + int wn, + int n_channel_specialization, + bool small_filter); + +MTL::ComputePipelineState* get_steel_conv_general_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& out, + int bm, + int bn, + int bk, + int wm, + int wn); + } // namespace mlx::core diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 841c4800d..9d99d39a5 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -1,14 +1,13 @@ set( HEADERS - ${CMAKE_CURRENT_SOURCE_DIR}/atomic.h - ${CMAKE_CURRENT_SOURCE_DIR}/bf16.h - ${CMAKE_CURRENT_SOURCE_DIR}/bf16_math.h - ${CMAKE_CURRENT_SOURCE_DIR}/complex.h - ${CMAKE_CURRENT_SOURCE_DIR}/defines.h - ${CMAKE_CURRENT_SOURCE_DIR}/utils.h + bf16.h + bf16_math.h + complex.h + defines.h + utils.h + steel/conv/params.h ) - set( KERNELS "arg_reduce" @@ -41,6 +40,7 @@ set( set( HEADERS ${HEADERS} + atomic.h arange.h unary_ops.h unary.h @@ -89,14 +89,40 @@ foreach(KERNEL ${KERNELS}) set(KERNEL_AIR ${KERNEL}.air ${KERNEL_AIR}) endforeach() -file(GLOB_RECURSE STEEL_KERNELS ${CMAKE_CURRENT_SOURCE_DIR}/steel/*.metal) -file(GLOB_RECURSE STEEL_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/steel/*.h) - -foreach(KERNEL ${STEEL_KERNELS}) - cmake_path(GET KERNEL STEM TARGET) - build_kernel_base(${TARGET} ${KERNEL} "${STEEL_HEADERS}") - set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR}) -endforeach() +if (NOT MLX_METAL_JIT) + set( + STEEL_KERNELS + ${CMAKE_CURRENT_SOURCE_DIR}/steel/conv/kernels/steel_conv.metal + ${CMAKE_CURRENT_SOURCE_DIR}/steel/conv/kernels/steel_conv_general.metal + ${CMAKE_CURRENT_SOURCE_DIR}/steel/gemm/kernels/steel_gemm_fused.metal + ${CMAKE_CURRENT_SOURCE_DIR}/steel/gemm/kernels/steel_gemm_masked.metal + ${CMAKE_CURRENT_SOURCE_DIR}/steel/gemm/kernels/steel_gemm_splitk.metal + ) + set( + STEEL_HEADERS + steel/defines.h + steel/utils.h + steel/conv/conv.h + steel/conv/loader.h + steel/conv/loaders/loader_channel_l.h + steel/conv/loaders/loader_channel_n.h + steel/conv/loaders/loader_general.h + steel/conv/kernels/steel_conv.h + steel/conv/kernels/steel_conv_general.h + steel/gemm/gemm.h + steel/gemm/mma.h + steel/gemm/loader.h + steel/gemm/transforms.h + steel/gemm/kernels/steel_gemm_fused.h + steel/gemm/kernels/steel_gemm_masked.h + steel/gemm/kernels/steel_gemm_splitk.h + ) + foreach(KERNEL ${STEEL_KERNELS}) + cmake_path(GET KERNEL STEM TARGET) + build_kernel_base(${TARGET} ${KERNEL} "${STEEL_HEADERS}") + set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR}) + endforeach() +endif() add_custom_command( OUTPUT ${MLX_METAL_PATH}/mlx.metallib diff --git a/mlx/backend/metal/kernels/arg_reduce.metal b/mlx/backend/metal/kernels/arg_reduce.metal index 6adeb12ab..61df1273e 100644 --- a/mlx/backend/metal/kernels/arg_reduce.metal +++ b/mlx/backend/metal/kernels/arg_reduce.metal @@ -1,6 +1,5 @@ // Copyright © 2023 Apple Inc. -#include #include #include "mlx/backend/metal/kernels/utils.h" @@ -194,4 +193,4 @@ instantiate_arg_reduce(int32, int32_t) instantiate_arg_reduce(int64, int64_t) instantiate_arg_reduce(float16, half) instantiate_arg_reduce(float32, float) -instantiate_arg_reduce(bfloat16, bfloat16_t) // clang-format on \ No newline at end of file +instantiate_arg_reduce(bfloat16, bfloat16_t) // clang-format on diff --git a/mlx/backend/metal/kernels/conv.metal b/mlx/backend/metal/kernels/conv.metal index 92e91505d..e67acd93a 100644 --- a/mlx/backend/metal/kernels/conv.metal +++ b/mlx/backend/metal/kernels/conv.metal @@ -650,4 +650,4 @@ winograd_conv_2d_output_transform( // clang-format off instantiate_winograd_conv_2d(float32, float); -instantiate_winograd_conv_2d(float16, half); // clang-format on \ No newline at end of file +instantiate_winograd_conv_2d(float16, half); // clang-format on diff --git a/mlx/backend/metal/kernels/steel/conv/conv.h b/mlx/backend/metal/kernels/steel/conv/conv.h index e5065cea2..d2e718f2e 100644 --- a/mlx/backend/metal/kernels/steel/conv/conv.h +++ b/mlx/backend/metal/kernels/steel/conv/conv.h @@ -2,10 +2,12 @@ #pragma once +#include "mlx/backend/metal/kernels/steel/defines.h" #include "mlx/backend/metal/kernels/steel/utils.h" #include "mlx/backend/metal/kernels/steel/conv/loader.h" #include "mlx/backend/metal/kernels/steel/conv/params.h" +#include "mlx/backend/metal/kernels/steel/gemm/mma.h" using namespace metal; -using namespace mlx::steel; \ No newline at end of file +using namespace mlx::steel; diff --git a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h new file mode 100644 index 000000000..6f822c1dd --- /dev/null +++ b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h @@ -0,0 +1,176 @@ +// Copyright © 2024 Apple Inc. + +#include + +using namespace metal; + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + int N_CHANNELS = 0, + bool SMALL_FILTER = false> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void +implicit_gemm_conv_2d( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device T* C [[buffer(2)]], + const constant MLXConvParams<2>* params [[buffer(3)]], + const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + using namespace mlx::steel; + + (void)lid; + + constexpr bool transpose_a = false; + constexpr bool transpose_b = true; + constexpr short tgp_padding_a = 16 / sizeof(T); + constexpr short tgp_padding_b = 16 / sizeof(T); + + constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a; + constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b; + constexpr short shape_a_rows = (transpose_a ? BK : BM); + constexpr short shape_b_rows = (transpose_b ? BN : BK); + constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows; + constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows; + + constexpr short tgp_size = WM * WN * 32; + + // Input loader + + using loader_a_t = typename metal::conditional_t< + // Check for small channel specialization + N_CHANNELS != 0 && N_CHANNELS <= 4, + + // Go to small channel specialization + Conv2DInputBlockLoaderSmallChannels< + T, + BM, + BN, + BK, + tgp_size, + N_CHANNELS, + tgp_padding_a>, + + // Else go to general loader + typename metal::conditional_t< + // Check if filter size is small enough + SMALL_FILTER, + + // Go to small filter specialization + Conv2DInputBlockLoaderSmallFilter< + T, + BM, + BN, + BK, + tgp_size, + tgp_padding_a>, + + // Else go to large filter generalization + Conv2DInputBlockLoaderLargeFilter< + T, + BM, + BN, + BK, + tgp_size, + tgp_padding_a>>>; + + // Weight loader + using loader_b_t = typename metal::conditional_t< + // Check for small channel specialization + N_CHANNELS != 0 && N_CHANNELS <= 4, + + // Go to small channel specialization + Conv2DWeightBlockLoaderSmallChannels< + T, + BM, + BN, + BK, + tgp_size, + N_CHANNELS, + tgp_padding_b>, + + // Else go to general loader + Conv2DWeightBlockLoader>; + + using mma_t = BlockMMA< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + shape_a_cols, + shape_b_cols>; + + threadgroup T As[tgp_mem_size_a]; + threadgroup T Bs[tgp_mem_size_b]; + + const int tid_y = ((tid.y) << gemm_params->swizzle_log) + + ((tid.x) & ((1 << gemm_params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> gemm_params->swizzle_log; + + if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) { + return; + } + + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const int K = gemm_params->K; + const int N = gemm_params->N; + const int C_per_group = params->C / params->groups; + + // Groups + A += tid.z * C_per_group; + B += tid.z * N * K; + C += tid.z * N; + + B += c_col * K; + C += c_row * (N * params->groups) + c_col; + + const int2 offsets_a(0, c_row); + const int2 offsets_b(0, c_col); + + // Prepare threadgroup loading operations + loader_a_t loader_a( + A, As, offsets_a, params, gemm_params, simd_gid, simd_lid); + loader_b_t loader_b( + B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid); + + // Prepare threadgroup mma operation + mma_t mma_op(simd_gid, simd_lid); + + int gemm_k_iterations = gemm_params->gemm_k_iterations; + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + // Store results to device memory + short tgp_bm = min(BM, gemm_params->M - c_row); + short tgp_bn = min(BN, gemm_params->N - c_col); + const int ldc = N * params->groups; + mma_op.store_result_safe(C, ldc, short2(tgp_bn, tgp_bm)); +} diff --git a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal index ee5bcb285..1bc99ffb0 100644 --- a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +++ b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal @@ -2,184 +2,13 @@ #include +// clang-format off #include "mlx/backend/metal/kernels/steel/gemm/mma.h" #include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/steel/conv/conv.h" #include "mlx/backend/metal/kernels/steel/conv/params.h" - -using namespace metal; - -template < - typename T, - int BM, - int BN, - int BK, - int WM, - int WN, - int N_CHANNELS = 0, - bool SMALL_FILTER = false> -[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void -implicit_gemm_conv_2d( - const device T* A [[buffer(0)]], - const device T* B [[buffer(1)]], - device T* C [[buffer(2)]], - const constant MLXConvParams<2>* params [[buffer(3)]], - const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - using namespace mlx::steel; - - (void)lid; - - constexpr bool transpose_a = false; - constexpr bool transpose_b = true; - constexpr short tgp_padding_a = 16 / sizeof(T); - constexpr short tgp_padding_b = 16 / sizeof(T); - - constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a; - constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b; - constexpr short shape_a_rows = (transpose_a ? BK : BM); - constexpr short shape_b_rows = (transpose_b ? BN : BK); - constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows; - constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows; - - constexpr short tgp_size = WM * WN * 32; - - // Input loader - - using loader_a_t = typename metal::conditional_t< - // Check for small channel specialization - N_CHANNELS != 0 && N_CHANNELS <= 4, - - // Go to small channel specialization - Conv2DInputBlockLoaderSmallChannels< - T, - BM, - BN, - BK, - tgp_size, - N_CHANNELS, - tgp_padding_a>, - - // Else go to general loader - typename metal::conditional_t< - // Check if filter size is small enough - SMALL_FILTER, - - // Go to small filter specialization - Conv2DInputBlockLoaderSmallFilter< - T, - BM, - BN, - BK, - tgp_size, - tgp_padding_a>, - - // Else go to large filter generalization - Conv2DInputBlockLoaderLargeFilter< - T, - BM, - BN, - BK, - tgp_size, - tgp_padding_a>>>; - - // Weight loader - using loader_b_t = typename metal::conditional_t< - // Check for small channel specialization - N_CHANNELS != 0 && N_CHANNELS <= 4, - - // Go to small channel specialization - Conv2DWeightBlockLoaderSmallChannels< - T, - BM, - BN, - BK, - tgp_size, - N_CHANNELS, - tgp_padding_b>, - - // Else go to general loader - Conv2DWeightBlockLoader>; - - using mma_t = BlockMMA< - T, - T, - BM, - BN, - BK, - WM, - WN, - transpose_a, - transpose_b, - shape_a_cols, - shape_b_cols>; - - threadgroup T As[tgp_mem_size_a]; - threadgroup T Bs[tgp_mem_size_b]; - - const int tid_y = ((tid.y) << gemm_params->swizzle_log) + - ((tid.x) & ((1 << gemm_params->swizzle_log) - 1)); - const int tid_x = (tid.x) >> gemm_params->swizzle_log; - - if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) { - return; - } - - const int c_row = tid_y * BM; - const int c_col = tid_x * BN; - const int K = gemm_params->K; - const int N = gemm_params->N; - const int C_per_group = params->C / params->groups; - - // Groups - A += tid.z * C_per_group; - B += tid.z * N * K; - C += tid.z * N; - - B += c_col * K; - C += c_row * (N * params->groups) + c_col; - - const int2 offsets_a(0, c_row); - const int2 offsets_b(0, c_col); - - // Prepare threadgroup loading operations - loader_a_t loader_a( - A, As, offsets_a, params, gemm_params, simd_gid, simd_lid); - loader_b_t loader_b( - B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid); - - // Prepare threadgroup mma operation - mma_t mma_op(simd_gid, simd_lid); - - int gemm_k_iterations = gemm_params->gemm_k_iterations; - for (int k = 0; k < gemm_k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - // Load elements into threadgroup - loader_a.load_unsafe(); - loader_b.load_unsafe(); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } - - threadgroup_barrier(mem_flags::mem_none); - - // Store results to device memory - short tgp_bm = min(BM, gemm_params->M - c_row); - short tgp_bn = min(BN, gemm_params->N - c_col); - const int ldc = N * params->groups; - mma_op.store_result_safe(C, ldc, short2(tgp_bn, tgp_bm)); -} +#include "mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h" #define instantiate_implicit_conv_2d( \ name, \ @@ -207,25 +36,22 @@ implicit_gemm_conv_2d( uint simd_gid [[simdgroup_index_in_threadgroup]], \ uint simd_lid [[thread_index_in_simdgroup]]); -// clang-format off #define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \ instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, s, true) \ instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, l, false) \ instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 1, 1, l, false) \ instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 2, 2, l, false) \ instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 3, 3, l, false) \ - instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 4, 4, l, false) // clang-format on + instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 4, 4, l, false) -// clang-format off #define instantiate_implicit_2d_blocks(name, itype) \ instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \ instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \ instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \ instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \ instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \ - instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2) // clang-format on + instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2) -// clang-format off instantiate_implicit_2d_blocks(float32, float); instantiate_implicit_2d_blocks(float16, half); -instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); // clang-format on \ No newline at end of file +instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); // clang-format on diff --git a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h new file mode 100644 index 000000000..f5430d590 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h @@ -0,0 +1,188 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h" + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + typename AccumType = float, + typename Epilogue = TransformNone> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void +implicit_gemm_conv_2d_general( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device T* C [[buffer(2)]], + const constant MLXConvParams<2>* params [[buffer(3)]], + const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], + const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]], + const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]], + const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr bool transpose_a = false; + constexpr bool transpose_b = true; + constexpr short tgp_padding_a = 16 / sizeof(T); + constexpr short tgp_padding_b = 16 / sizeof(T); + + constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a; + constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b; + constexpr short shape_a_rows = (transpose_a ? BK : BM); + constexpr short shape_b_rows = (transpose_b ? BN : BK); + constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows; + constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows; + + constexpr short tgp_size = WM * WN * 32; + + // Input loader + using loader_a_t = + Conv2DInputBlockLoaderGeneral; + + // Weight loader + using loader_b_t = + Conv2DWeightBlockLoaderGeneral; + + using mma_t = BlockMMA< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + shape_a_cols, + shape_b_cols>; + + threadgroup T As[tgp_mem_size_a]; + threadgroup T Bs[tgp_mem_size_b]; + + const int tid_y = ((tid.y) << gemm_params->swizzle_log) + + ((tid.x) & ((1 << gemm_params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> gemm_params->swizzle_log; + + if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) { + return; + } + + const int tid_z = tid.z; + + const int base_oh = tid_z / jump_params->f_out_jump_w; + const int base_ow = tid_z % jump_params->f_out_jump_w; + + const int base_wh = base_h[base_oh].weight_base; + const int base_ww = base_w[base_ow].weight_base; + + const int base_wh_size = base_h[base_oh].weight_size; + const int base_ww_size = base_w[base_ow].weight_size; + + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const int K = gemm_params->K; + + B += c_col * K; + + const int4 offsets_a(0, c_row, base_oh, base_ow); + const int2 offsets_b(0, c_col); + + // Prepare threadgroup loading operations + loader_a_t loader_a( + A, + As, + offsets_a, + params, + jump_params, + base_wh, + base_ww, + simd_gid, + simd_lid); + loader_b_t loader_b( + B, + Bs, + offsets_b, + params, + jump_params, + base_wh, + base_ww, + simd_gid, + simd_lid); + + // Prepare threadgroup mma operation + mma_t mma_op(simd_gid, simd_lid); + + int gemm_k_iterations = + base_wh_size * base_ww_size * gemm_params->gemm_k_iterations; + + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + // Store results to device memory + { + // Adjust for simdgroup and thread locatio + int offset_m = c_row + mma_op.sm + mma_op.tm; + int offset_n = c_col + mma_op.sn + mma_op.tn; + C += offset_n; + + if (offset_n >= gemm_params->N) + return; + + short diff = gemm_params->N - offset_n; + + STEEL_PRAGMA_UNROLL + for (int i = 0; i < mma_t::TM; i++) { + int cm = offset_m + i * mma_t::TM_stride; + + int n = cm / jump_params->adj_out_hw; + int hw = cm % jump_params->adj_out_hw; + int oh = + (hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + base_oh; + int ow = + (hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow; + + if (n < params->N && oh < params->oS[0] && ow < params->oS[1]) { + int offset_cm = n * params->out_strides[0] + + oh * params->out_strides[1] + ow * params->out_strides[2]; + + STEEL_PRAGMA_UNROLL + for (int j = 0; j < mma_t::TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = + mma_op.results[i * mma_t::TN + j].thread_elements(); + int offset = offset_cm + (j * mma_t::TN_stride); + + // Apply epilogue and output C + if (j * mma_t::TN_stride < diff) { + C[offset] = Epilogue::apply(accum[0]); + } + + if (j * mma_t::TN_stride + 1 < diff) { + C[offset + 1] = Epilogue::apply(accum[1]); + } + } + } + } + } +} diff --git a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal index e902918f9..099822c04 100644 --- a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +++ b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal @@ -2,201 +2,18 @@ #include +// clang-format off #include "mlx/backend/metal/kernels/steel/gemm/mma.h" #include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/steel/conv/conv.h" -#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h" #include "mlx/backend/metal/kernels/steel/conv/params.h" +#include "mlx/backend/metal/kernels/steel/utils.h" +#include "mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h" using namespace metal; using namespace mlx::steel; -template < - typename T, - int BM, - int BN, - int BK, - int WM, - int WN, - typename AccumType = float, - typename Epilogue = TransformNone> -[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void -implicit_gemm_conv_2d_general( - const device T* A [[buffer(0)]], - const device T* B [[buffer(1)]], - device T* C [[buffer(2)]], - const constant MLXConvParams<2>* params [[buffer(3)]], - const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], - const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]], - const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]], - const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - (void)lid; - - constexpr bool transpose_a = false; - constexpr bool transpose_b = true; - constexpr short tgp_padding_a = 16 / sizeof(T); - constexpr short tgp_padding_b = 16 / sizeof(T); - - constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a; - constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b; - constexpr short shape_a_rows = (transpose_a ? BK : BM); - constexpr short shape_b_rows = (transpose_b ? BN : BK); - constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows; - constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows; - - constexpr short tgp_size = WM * WN * 32; - - // Input loader - using loader_a_t = - Conv2DInputBlockLoaderGeneral; - - // Weight loader - using loader_b_t = - Conv2DWeightBlockLoaderGeneral; - - using mma_t = BlockMMA< - T, - T, - BM, - BN, - BK, - WM, - WN, - transpose_a, - transpose_b, - shape_a_cols, - shape_b_cols>; - - threadgroup T As[tgp_mem_size_a]; - threadgroup T Bs[tgp_mem_size_b]; - - const int tid_y = ((tid.y) << gemm_params->swizzle_log) + - ((tid.x) & ((1 << gemm_params->swizzle_log) - 1)); - const int tid_x = (tid.x) >> gemm_params->swizzle_log; - - if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) { - return; - } - - const int tid_z = tid.z; - - const int base_oh = tid_z / jump_params->f_out_jump_w; - const int base_ow = tid_z % jump_params->f_out_jump_w; - - const int base_wh = base_h[base_oh].weight_base; - const int base_ww = base_w[base_ow].weight_base; - - const int base_wh_size = base_h[base_oh].weight_size; - const int base_ww_size = base_w[base_ow].weight_size; - - const int c_row = tid_y * BM; - const int c_col = tid_x * BN; - const int K = gemm_params->K; - - B += c_col * K; - - const int4 offsets_a(0, c_row, base_oh, base_ow); - const int2 offsets_b(0, c_col); - - // Prepare threadgroup loading operations - loader_a_t loader_a( - A, - As, - offsets_a, - params, - jump_params, - base_wh, - base_ww, - simd_gid, - simd_lid); - loader_b_t loader_b( - B, - Bs, - offsets_b, - params, - jump_params, - base_wh, - base_ww, - simd_gid, - simd_lid); - - // Prepare threadgroup mma operation - mma_t mma_op(simd_gid, simd_lid); - - int gemm_k_iterations = - base_wh_size * base_ww_size * gemm_params->gemm_k_iterations; - - for (int k = 0; k < gemm_k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - // Load elements into threadgroup - loader_a.load_unsafe(); - loader_b.load_unsafe(); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } - - threadgroup_barrier(mem_flags::mem_none); - - // Store results to device memory - { - // Adjust for simdgroup and thread locatio - int offset_m = c_row + mma_op.sm + mma_op.tm; - int offset_n = c_col + mma_op.sn + mma_op.tn; - C += offset_n; - - if (offset_n >= gemm_params->N) - return; - - short diff = gemm_params->N - offset_n; - - STEEL_PRAGMA_UNROLL - for (int i = 0; i < mma_t::TM; i++) { - int cm = offset_m + i * mma_t::TM_stride; - - int n = cm / jump_params->adj_out_hw; - int hw = cm % jump_params->adj_out_hw; - int oh = - (hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + base_oh; - int ow = - (hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow; - - if (n < params->N && oh < params->oS[0] && ow < params->oS[1]) { - int offset_cm = n * params->out_strides[0] + - oh * params->out_strides[1] + ow * params->out_strides[2]; - - STEEL_PRAGMA_UNROLL - for (int j = 0; j < mma_t::TN; j++) { - // Get accumulated result and associated offset in C - thread const auto& accum = - mma_op.results[i * mma_t::TN + j].thread_elements(); - int offset = offset_cm + (j * mma_t::TN_stride); - - // Apply epilogue and output C - if (j * mma_t::TN_stride < diff) { - C[offset] = Epilogue::apply(accum[0]); - } - - if (j * mma_t::TN_stride + 1 < diff) { - C[offset + 1] = Epilogue::apply(accum[1]); - } - } - } - } - } -} - #define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) \ template \ [[host_name("implicit_gemm_conv_2d_general_" #name "_bm" #bm "_bn" #bn \ @@ -218,16 +35,14 @@ implicit_gemm_conv_2d_general( #define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \ instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) -// clang-format off #define instantiate_implicit_2d_blocks(name, itype) \ instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \ instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \ instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \ instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \ instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \ - instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2) // clang-format on + instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2) -// clang-format off instantiate_implicit_2d_blocks(float32, float); instantiate_implicit_2d_blocks(float16, half); -instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); // clang-format on \ No newline at end of file +instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); // clang-format on diff --git a/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h b/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h index b93fd927a..0c89dd51c 100644 --- a/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +++ b/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h @@ -2,9 +2,7 @@ #pragma once -#include "mlx/backend/metal/kernels/steel/utils.h" - -#include "mlx/backend/metal/kernels/steel/conv/params.h" +#include "mlx/backend/metal/kernels/steel/defines.h" /////////////////////////////////////////////////////////////////////////////// // Loading helper @@ -285,4 +283,4 @@ struct Conv2DWeightBlockLoaderGeneral { }; } // namespace steel -} // namespace mlx \ No newline at end of file +} // namespace mlx diff --git a/mlx/backend/metal/kernels/steel/defines.h b/mlx/backend/metal/kernels/steel/defines.h new file mode 100644 index 000000000..6c3bfcf4e --- /dev/null +++ b/mlx/backend/metal/kernels/steel/defines.h @@ -0,0 +1,4 @@ +// Copyright © 2024 Apple Inc. + +#define STEEL_CONST static constant constexpr const +#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h new file mode 100644 index 000000000..5e1d2f231 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h @@ -0,0 +1,415 @@ +// Copyright © 2024 Apple Inc. + +using namespace mlx::steel; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernels +/////////////////////////////////////////////////////////////////////////////// + +constant bool has_batch [[function_constant(10)]]; + +constant bool use_out_source [[function_constant(100)]]; +constant bool do_axpby [[function_constant(110)]]; + +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; +constant bool align_K [[function_constant(202)]]; + +constant bool do_gather [[function_constant(300)]]; + +constant bool gather_bias = do_gather && use_out_source; + +// clang-format off +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + const device T* C [[buffer(2), function_constant(use_out_source)]], + device T* D [[buffer(3)]], + const constant GEMMParams* params [[buffer(4)]], + const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], + const constant int* batch_shape [[buffer(6)]], + const constant size_t* batch_strides [[buffer(7)]], + const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]], + const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]], + const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]], + const constant int* operand_shape [[buffer(13), function_constant(do_gather)]], + const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]], + const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on + // Pacifying compiler + (void)lid; + + using gemm_kernel = GEMMKernel< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + true, + true, + AccumType>; + + using loader_a_t = typename gemm_kernel::loader_a_t; + using loader_b_t = typename gemm_kernel::loader_b_t; + using mma_t = typename gemm_kernel::mma_t; + + // Find block + const int tid_y = ((tid.y) << params->swizzle_log) + + ((tid.x) & ((1 << params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> params->swizzle_log; + + // Exit early if out of bounds + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + // Adjust for batch + + // Handle gather + if (do_gather) { + // Read indices + uint32_t indx_A, indx_B, indx_C; + + if (has_batch) { + const constant size_t* indx_A_bstrides = batch_strides; + const constant size_t* indx_B_bstrides = + batch_strides + params->batch_ndim; + + ulong2 indx_offsets = elem_to_loc_broadcast( + tid.z, + batch_shape, + indx_A_bstrides, + indx_B_bstrides, + params->batch_ndim); + indx_A = lhs_indices[indx_offsets.x]; + indx_B = rhs_indices[indx_offsets.y]; + + if (use_out_source) { + const constant size_t* indx_C_bstrides = + indx_B_bstrides + params->batch_ndim; + auto indx_offset_C = elem_to_loc( + tid.z, batch_shape, indx_C_bstrides, params->batch_ndim); + indx_C = C_indices[indx_offset_C]; + } + } else { + indx_A = lhs_indices[params->batch_stride_a * tid.z]; + indx_B = rhs_indices[params->batch_stride_b * tid.z]; + + if (use_out_source) { + indx_C = C_indices[addmm_params->batch_stride_c * tid.z]; + } + } + + // Translate indices to offsets + int batch_ndim_A = operand_batch_ndim.x; + const constant int* batch_shape_A = operand_shape; + const constant size_t* batch_strides_A = operand_strides; + A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A); + + int batch_ndim_B = operand_batch_ndim.y; + const constant int* batch_shape_B = batch_shape_A + batch_ndim_A; + const constant size_t* batch_strides_B = batch_strides_A + batch_ndim_A; + B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B); + + if (use_out_source) { + int batch_ndim_C = operand_batch_ndim.z; + const constant int* batch_shape_C = batch_shape_B + batch_ndim_B; + const constant size_t* batch_strides_C = batch_strides_B + batch_ndim_B; + C += elem_to_loc(indx_C, batch_shape_C, batch_strides_C, batch_ndim_C); + } + + } + + // Handle regular batch + else { + if (has_batch) { + const constant size_t* A_bstrides = batch_strides; + const constant size_t* B_bstrides = batch_strides + params->batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); + + A += batch_offsets.x; + B += batch_offsets.y; + + if (use_out_source) { + const constant size_t* C_bstrides = B_bstrides + params->batch_ndim; + C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim); + } + } else { + A += params->batch_stride_a * tid.z; + B += params->batch_stride_b * tid.z; + + if (use_out_source) { + C += addmm_params->batch_stride_c * tid.z; + } + } + } + + D += params->batch_stride_d * tid.z; + + // Prepare threadgroup memory + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + threadgroup_barrier(mem_flags::mem_none); + + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + D += c_row_long * params->ldd + c_col_long; + + if (use_out_source) { + C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc; + } + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Prepare threadgroup bounds + const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row)); + const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col)); + + // Prepare iterations + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + // Do unaligned K iterations first + if (!align_K) { + const int k_last = params->gemm_k_iterations_aligned * BK; + const int k_remain = params->K - k_last; + const size_t k_jump_a = + transpose_a ? params->lda * size_t(k_last) : size_t(k_last); + const size_t k_jump_b = + transpose_b ? size_t(k_last) : params->ldb * size_t(k_last); + + // Move loader source ahead to end + loader_a.src += k_jump_a; + loader_b.src += k_jump_b; + + // Load tile + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Do matmul + mma_op.mma(As, Bs); + + // Reset source back to start + loader_a.src -= k_jump_a; + loader_b.src -= k_jump_b; + } + + const TransformAdd epilogue_op_add( + addmm_params->alpha, addmm_params->beta); + const TransformAxpby epilogue_op_axpby( + addmm_params->alpha, addmm_params->beta); + + /////////////////////////////////////////////////////////////////////////////// + // MNK aligned loop + if (align_M && align_N) { + // Do gemm + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + // Do epilogue + if (use_out_source) { + if (do_axpby) { + mma_op.apply_epilogue( + C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby); + } else { + mma_op.apply_epilogue( + C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add); + } + } + + // Store results to device memory + return mma_op.store_result(D, params->ldd); + + } + /////////////////////////////////////////////////////////////////////////////// + // MN unaligned loop + else { // Loop over K - unaligned case + const int leftover_bk = 0; + + if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { + // Do gemm + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + + // Do epilogue + if (use_out_source) { + if (do_axpby) { + mma_op.apply_epilogue( + C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby); + } else { + mma_op.apply_epilogue( + C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add); + } + } + + // Store results to device memory + return mma_op.store_result(D, params->ldd); + + } else if (align_N || tgp_bn == BN) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + + // Do epilogue + if (use_out_source) { + if (do_axpby) { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_axpby); + } else { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_add); + } + } + + // Store results to device memory + return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + + } else if (align_M || tgp_bm == BM) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + + // Do epilogue + if (use_out_source) { + if (do_axpby) { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_axpby); + } else { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_add); + } + } + + // Store results to device memory + return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + + } else { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + + // Do epilogue + if (use_out_source) { + if (do_axpby) { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_axpby); + } else { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_add); + } + } + + // Store results to device memory + return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + } + } +} diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal index b4304a551..0665cb6f3 100644 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal @@ -1,430 +1,12 @@ // Copyright © 2024 Apple Inc. +// clang-format off #include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm.h" +#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h" -using namespace metal; -using namespace mlx::steel; - -/////////////////////////////////////////////////////////////////////////////// -// GEMM kernels -/////////////////////////////////////////////////////////////////////////////// - -constant bool has_batch [[function_constant(10)]]; - -constant bool use_out_source [[function_constant(100)]]; -constant bool do_axpby [[function_constant(110)]]; - -constant bool align_M [[function_constant(200)]]; -constant bool align_N [[function_constant(201)]]; -constant bool align_K [[function_constant(202)]]; - -constant bool do_gather [[function_constant(300)]]; - -constant bool gather_bias = do_gather && use_out_source; - -// clang-format off -template < - typename T, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose_a, - bool transpose_b, - typename AccumType = float> -[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm( - const device T* A [[buffer(0)]], - const device T* B [[buffer(1)]], - const device T* C [[buffer(2), function_constant(use_out_source)]], - device T* D [[buffer(3)]], - const constant GEMMParams* params [[buffer(4)]], - const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], - const constant int* batch_shape [[buffer(6)]], - const constant size_t* batch_strides [[buffer(7)]], - const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]], - const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]], - const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]], - const constant int* operand_shape [[buffer(13), function_constant(do_gather)]], - const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]], - const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on - // Pacifying compiler - (void)lid; - - using gemm_kernel = GEMMKernel< - T, - T, - BM, - BN, - BK, - WM, - WN, - transpose_a, - transpose_b, - true, - true, - AccumType>; - - using loader_a_t = typename gemm_kernel::loader_a_t; - using loader_b_t = typename gemm_kernel::loader_b_t; - using mma_t = typename gemm_kernel::mma_t; - - // Find block - const int tid_y = ((tid.y) << params->swizzle_log) + - ((tid.x) & ((1 << params->swizzle_log) - 1)); - const int tid_x = (tid.x) >> params->swizzle_log; - - // Exit early if out of bounds - if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { - return; - } - - // Adjust for batch - - // Handle gather - if (do_gather) { - // Read indices - uint32_t indx_A, indx_B, indx_C; - - if (has_batch) { - const constant size_t* indx_A_bstrides = batch_strides; - const constant size_t* indx_B_bstrides = - batch_strides + params->batch_ndim; - - ulong2 indx_offsets = elem_to_loc_broadcast( - tid.z, - batch_shape, - indx_A_bstrides, - indx_B_bstrides, - params->batch_ndim); - indx_A = lhs_indices[indx_offsets.x]; - indx_B = rhs_indices[indx_offsets.y]; - - if (use_out_source) { - const constant size_t* indx_C_bstrides = - indx_B_bstrides + params->batch_ndim; - auto indx_offset_C = elem_to_loc( - tid.z, batch_shape, indx_C_bstrides, params->batch_ndim); - indx_C = C_indices[indx_offset_C]; - } - } else { - indx_A = lhs_indices[params->batch_stride_a * tid.z]; - indx_B = rhs_indices[params->batch_stride_b * tid.z]; - - if (use_out_source) { - indx_C = C_indices[addmm_params->batch_stride_c * tid.z]; - } - } - - // Translate indices to offsets - int batch_ndim_A = operand_batch_ndim.x; - const constant int* batch_shape_A = operand_shape; - const constant size_t* batch_strides_A = operand_strides; - A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A); - - int batch_ndim_B = operand_batch_ndim.y; - const constant int* batch_shape_B = batch_shape_A + batch_ndim_A; - const constant size_t* batch_strides_B = batch_strides_A + batch_ndim_A; - B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B); - - if (use_out_source) { - int batch_ndim_C = operand_batch_ndim.z; - const constant int* batch_shape_C = batch_shape_B + batch_ndim_B; - const constant size_t* batch_strides_C = batch_strides_B + batch_ndim_B; - C += elem_to_loc(indx_C, batch_shape_C, batch_strides_C, batch_ndim_C); - } - - } - - // Handle regular batch - else { - if (has_batch) { - const constant size_t* A_bstrides = batch_strides; - const constant size_t* B_bstrides = batch_strides + params->batch_ndim; - - ulong2 batch_offsets = elem_to_loc_broadcast( - tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); - - A += batch_offsets.x; - B += batch_offsets.y; - - if (use_out_source) { - const constant size_t* C_bstrides = B_bstrides + params->batch_ndim; - C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim); - } - } else { - A += params->batch_stride_a * tid.z; - B += params->batch_stride_b * tid.z; - - if (use_out_source) { - C += addmm_params->batch_stride_c * tid.z; - } - } - } - - D += params->batch_stride_d * tid.z; - - // Prepare threadgroup memory - threadgroup T As[gemm_kernel::tgp_mem_size_a]; - threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; - - threadgroup_barrier(mem_flags::mem_none); - - // Find block in A, B, C - const int c_row = tid_y * BM; - const int c_col = tid_x * BN; - const size_t c_row_long = size_t(c_row); - const size_t c_col_long = size_t(c_col); - - A += transpose_a ? c_row_long : c_row_long * params->lda; - B += transpose_b ? c_col_long * params->ldb : c_col_long; - D += c_row_long * params->ldd + c_col_long; - - if (use_out_source) { - C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc; - } - - // Prepare threadgroup mma operation - thread mma_t mma_op(simd_group_id, simd_lane_id); - - // Prepare threadgroup loading operations - thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); - thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); - - // Prepare threadgroup bounds - const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row)); - const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col)); - - // Prepare iterations - int gemm_k_iterations = params->gemm_k_iterations_aligned; - - // Do unaligned K iterations first - if (!align_K) { - const int k_last = params->gemm_k_iterations_aligned * BK; - const int k_remain = params->K - k_last; - const size_t k_jump_a = - transpose_a ? params->lda * size_t(k_last) : size_t(k_last); - const size_t k_jump_b = - transpose_b ? size_t(k_last) : params->ldb * size_t(k_last); - - // Move loader source ahead to end - loader_a.src += k_jump_a; - loader_b.src += k_jump_b; - - // Load tile - const short2 tile_dims_A = - transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); - const short2 tile_dims_B = - transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); - - loader_a.load_safe(tile_dims_A); - loader_b.load_safe(tile_dims_B); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Do matmul - mma_op.mma(As, Bs); - - // Reset source back to start - loader_a.src -= k_jump_a; - loader_b.src -= k_jump_b; - } - - const TransformAdd epilogue_op_add( - addmm_params->alpha, addmm_params->beta); - const TransformAxpby epilogue_op_axpby( - addmm_params->alpha, addmm_params->beta); - - /////////////////////////////////////////////////////////////////////////////// - // MNK aligned loop - if (align_M && align_N) { - // Do gemm - for (int k = 0; k < gemm_k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - // Load elements into threadgroup - loader_a.load_unsafe(); - loader_b.load_unsafe(); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } - - threadgroup_barrier(mem_flags::mem_none); - - // Do epilogue - if (use_out_source) { - if (do_axpby) { - mma_op.apply_epilogue( - C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby); - } else { - mma_op.apply_epilogue( - C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add); - } - } - - // Store results to device memory - return mma_op.store_result(D, params->ldd); - - } - /////////////////////////////////////////////////////////////////////////////// - // MN unaligned loop - else { // Loop over K - unaligned case - const int leftover_bk = 0; - - if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { - // Do gemm - gemm_kernel::gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - leftover_bk, - LoopAlignment{}); - - // Do epilogue - if (use_out_source) { - if (do_axpby) { - mma_op.apply_epilogue( - C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby); - } else { - mma_op.apply_epilogue( - C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add); - } - } - - // Store results to device memory - return mma_op.store_result(D, params->ldd); - - } else if (align_N || tgp_bn == BN) { - gemm_kernel::gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - leftover_bk, - LoopAlignment{}); - - // Do epilogue - if (use_out_source) { - if (do_axpby) { - mma_op.apply_epilogue_safe( - C, - addmm_params->ldc, - addmm_params->fdc, - short2(tgp_bn, tgp_bm), - epilogue_op_axpby); - } else { - mma_op.apply_epilogue_safe( - C, - addmm_params->ldc, - addmm_params->fdc, - short2(tgp_bn, tgp_bm), - epilogue_op_add); - } - } - - // Store results to device memory - return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); - - } else if (align_M || tgp_bm == BM) { - gemm_kernel::gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - leftover_bk, - LoopAlignment{}); - - // Do epilogue - if (use_out_source) { - if (do_axpby) { - mma_op.apply_epilogue_safe( - C, - addmm_params->ldc, - addmm_params->fdc, - short2(tgp_bn, tgp_bm), - epilogue_op_axpby); - } else { - mma_op.apply_epilogue_safe( - C, - addmm_params->ldc, - addmm_params->fdc, - short2(tgp_bn, tgp_bm), - epilogue_op_add); - } - } - - // Store results to device memory - return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); - - } else { - gemm_kernel::gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - leftover_bk, - LoopAlignment{}); - - // Do epilogue - if (use_out_source) { - if (do_axpby) { - mma_op.apply_epilogue_safe( - C, - addmm_params->ldc, - addmm_params->fdc, - short2(tgp_bn, tgp_bm), - epilogue_op_axpby); - } else { - mma_op.apply_epilogue_safe( - C, - addmm_params->ldc, - addmm_params->fdc, - short2(tgp_bn, tgp_bm), - epilogue_op_add); - } - } - - // Store results to device memory - return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); - } - } -} - -/////////////////////////////////////////////////////////////////////////////// -// GEMM kernel initializations -/////////////////////////////////////////////////////////////////////////////// - -// clang-format off #define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ template [[host_name("steel_gemm_fused_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] \ [[kernel]] void gemm( \ @@ -445,24 +27,23 @@ template < uint simd_lane_id [[thread_index_in_simdgroup]], \ uint simd_group_id [[simdgroup_index_in_threadgroup]], \ uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]]); // clang-format on + uint3 lid [[thread_position_in_threadgroup]]); -// clang-format off #define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ - instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on + instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) -// clang-format off #define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \ - instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) // clang-format on + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) instantiate_gemm_shapes_helper(float16, half, float16, half); instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t); -instantiate_gemm_shapes_helper(float32, float, float32, float); \ No newline at end of file +instantiate_gemm_shapes_helper(float32, float, float32, float); +// clang-format on diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h new file mode 100644 index 000000000..702e13152 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h @@ -0,0 +1,719 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/backend/metal/kernels/steel/defines.h" +using namespace metal; +using namespace mlx::steel; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernels +/////////////////////////////////////////////////////////////////////////////// + +struct _NoMask { + char x; + + constexpr METAL_FUNC operator bool() { + return true; + } + constexpr METAL_FUNC operator bool() const threadgroup { + return true; + } + constexpr METAL_FUNC operator bool() const device { + return true; + } + constexpr METAL_FUNC operator bool() const constant { + return true; + } +}; + +template +struct ScaleOp { + OutT scale; + + METAL_FUNC OutT apply(InT x) const { + return static_cast(x) * scale; + } +}; + +typedef struct _NoMask nomask_t; + +template < + typename T, + typename out_mask_t, + typename op_mask_t, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + bool MN_aligned, + bool K_aligned> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void +block_masked_gemm( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device T* D [[buffer(3)]], + const constant GEMMParams* params [[buffer(4)]], + const constant int* batch_shape [[buffer(6)]], + const constant size_t* batch_strides [[buffer(7)]], + const device out_mask_t* out_mask [[buffer(10)]], + const device op_mask_t* lhs_mask [[buffer(11)]], + const device op_mask_t* rhs_mask [[buffer(12)]], + const constant int* mask_strides [[buffer(13)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + // Appease the compiler + (void)lid; + + static_assert( + BM == BN, + "block_masked_gemm must have the same block M and block N size"); + static_assert(BM % BK == 0, "block_masked_gemm must have BM % BK == 0"); + + constexpr bool has_operand_mask = !metal::is_same_v; + constexpr bool has_output_mask = !metal::is_same_v; + + constexpr bool has_mul_operand_mask = + has_operand_mask && !metal::is_same_v; + constexpr bool has_mul_output_mask = + has_output_mask && !metal::is_same_v; + + constexpr short k_mask_factor = short(BM / BK); + + using gemm_kernel = GEMMKernel< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + MN_aligned, + K_aligned>; + + const int tid_y = ((tid.y) << params->swizzle_log) + + ((tid.x) & ((1 << params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> params->swizzle_log; + + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + const constant size_t* mask_batch_strides = + batch_strides + 2 * params->batch_ndim; + + if (params->batch_ndim > 1) { + if (has_output_mask) { + out_mask += elem_to_loc( + tid.z, batch_shape, mask_batch_strides, params->batch_ndim); + + mask_batch_strides += params->batch_ndim; + } + + if (has_operand_mask) { + const constant size_t* mask_strides_lhs = mask_batch_strides; + const constant size_t* mask_strides_rhs = + mask_strides_lhs + params->batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, + batch_shape, + mask_strides_lhs, + mask_strides_rhs, + params->batch_ndim); + + lhs_mask += batch_offsets.x; + rhs_mask += batch_offsets.y; + } + } else { + if (has_output_mask) { + out_mask += tid.z * mask_batch_strides[0]; + mask_batch_strides += params->batch_ndim; + } + + if (has_operand_mask) { + lhs_mask += tid.z * mask_batch_strides[0]; + rhs_mask += tid.z * mask_batch_strides[params->batch_ndim]; + } + } + + // Adjust for batch + if (params->batch_ndim > 1) { + const constant size_t* A_bstrides = batch_strides; + const constant size_t* B_bstrides = batch_strides + params->batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); + + A += batch_offsets.x; + B += batch_offsets.y; + + } else { + A += params->batch_stride_a * tid.z; + B += params->batch_stride_b * tid.z; + } + + D += params->batch_stride_d * tid.z; + + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + D += c_row_long * params->ldd + c_col_long; + + const constant int* out_mask_strides = mask_strides; + const constant int* lhs_mask_strides = + mask_strides + (has_output_mask ? 2 : 0); + const constant int* rhs_mask_strides = + lhs_mask_strides + (has_operand_mask ? 2 : 0); + + const int out_mask_offset = !has_output_mask + ? 0 + : tid_y * out_mask_strides[1] + tid_x * out_mask_strides[0]; + int lhs_mask_offset = !has_operand_mask ? 0 : tid_y * lhs_mask_strides[1]; + int rhs_mask_offset = !has_operand_mask ? 0 : tid_x * rhs_mask_strides[0]; + const int lhs_mask_step = !has_operand_mask ? 0 : lhs_mask_strides[0]; + const int rhs_mask_step = !has_operand_mask ? 0 : rhs_mask_strides[1]; + short k_factor_cnt = k_mask_factor; + + ScaleOp out_mask_op; + ScaleOp lhs_mask_op; + ScaleOp rhs_mask_op; + + if (has_output_mask) { + auto mask_out = out_mask[out_mask_offset]; + + if (has_mul_output_mask) { + out_mask_op.scale = float(mask_out); + } + + // Write zeros and return + if (!mask_out) { + constexpr short tgp_size = WM * WN * 32; + constexpr short vec_size = 4; + + // Tile threads in threadgroup + constexpr short TN = BN / vec_size; + constexpr short TM = tgp_size / TN; + + const short thread_idx = simd_group_id * 32 + simd_lane_id; + const short bi = thread_idx / TN; + const short bj = vec_size * (thread_idx % TN); + + D += bi * params->ldd + bj; + + short tgp_bm = min(BM, params->M - c_row); + short tgp_bn = min(BN, params->N - c_col); + + if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) { + for (short ti = 0; ti < BM; ti += TM) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + D[ti * params->ldd + j] = T(0.); + } + } + } else { + short jmax = tgp_bn - bj; + jmax = jmax < vec_size ? jmax : vec_size; + for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) { + for (short j = 0; j < jmax; j++) { + D[ti * params->ldd + j] = T(0.); + } + } + } + + return; + } + } + + threadgroup_barrier(mem_flags::mem_none); + + // Prepare threadgroup mma operation + thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id); + + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + // Prepare threadgroup loading operations + thread typename gemm_kernel::loader_a_t loader_a( + A, params->lda, As, simd_group_id, simd_lane_id); + thread typename gemm_kernel::loader_b_t loader_b( + B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Prepare threadgroup bounds + const short tgp_bm = + MN_aligned ? short(BM) : short(min(BM, params->M - c_row)); + const short tgp_bn = + MN_aligned ? short(BN) : short(min(BN, params->N - c_col)); + + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + /////////////////////////////////////////////////////////////////////////////// + // Do unaligned K iterations first + if (!K_aligned) { + const int k_last = params->gemm_k_iterations_aligned * BK; + const int mask_idx_last = k_last / BM; + + if (!has_operand_mask || + (bool(lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step]) && + bool(rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step]))) { + if (has_mul_operand_mask) { + lhs_mask_op.scale = + lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step]; + rhs_mask_op.scale = + rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step]; + } + + // Move loader source ahead to end + const int k_remain = params->K - k_last; + const size_t k_jump_a = + transpose_a ? params->lda * size_t(k_last) : size_t(k_last); + const size_t k_jump_b = + transpose_b ? size_t(k_last) : params->ldb * size_t(k_last); + + loader_a.src += k_jump_a; + loader_b.src += k_jump_b; + + // Load tile + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + + if (has_mul_operand_mask) { + loader_a.apply_inplace_op(lhs_mask_op); + loader_b.apply_inplace_op(rhs_mask_op); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Do matmul + mma_op.mma(As, Bs); + + // Reset source back to start + loader_a.src -= k_jump_a; + loader_b.src -= k_jump_b; + } + } + + /////////////////////////////////////////////////////////////////////////////// + // MNK aligned loop + if (MN_aligned) { + for (; gemm_k_iterations > 0; gemm_k_iterations--) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (!has_operand_mask || + (bool(lhs_mask[lhs_mask_offset]) && + bool(rhs_mask[rhs_mask_offset]))) { + if (has_mul_operand_mask) { + lhs_mask_op.scale = lhs_mask[lhs_mask_offset]; + rhs_mask_op.scale = rhs_mask[rhs_mask_offset]; + } + + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + if (has_mul_operand_mask) { + loader_a.apply_inplace_op(lhs_mask_op); + loader_b.apply_inplace_op(rhs_mask_op); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + } + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + + k_factor_cnt--; + lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0; + rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0; + k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt; + } + + if (has_mul_output_mask) { + mma_op.apply_epilogue(out_mask_op); + } + + // Store results to device memory + mma_op.store_result(D, params->ldd); + return; + + } + /////////////////////////////////////////////////////////////////////////////// + // MN unaligned loop + else { + const bool M_aligned = (tgp_bm == BM); + const bool N_aligned = (tgp_bn == BN); + + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK); + + for (; gemm_k_iterations > 0; gemm_k_iterations--) { + threadgroup_barrier(mem_flags::mem_threadgroup); + if (!has_operand_mask || + (bool(lhs_mask[lhs_mask_offset]) && + bool(rhs_mask[rhs_mask_offset]))) { + if (has_mul_operand_mask) { + lhs_mask_op.scale = lhs_mask[lhs_mask_offset]; + rhs_mask_op.scale = rhs_mask[rhs_mask_offset]; + } + + // Load elements into threadgroup + if (M_aligned) { + loader_a.load_unsafe(); + } else { + loader_a.load_safe(tile_dims_A); + } + + if (N_aligned) { + loader_b.load_unsafe(); + } else { + loader_b.load_safe(tile_dims_B); + } + + if (has_mul_operand_mask) { + loader_a.apply_inplace_op(lhs_mask_op); + loader_b.apply_inplace_op(rhs_mask_op); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + } + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + + k_factor_cnt--; + lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0; + rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0; + k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt; + } + + if (has_mul_output_mask) { + mma_op.apply_epilogue(out_mask_op); + } + + if (M_aligned && N_aligned) { + mma_op.store_result(D, params->ldd); + } else { + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + } + } +} + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + bool MN_aligned, + bool K_aligned, + bool has_operand_mask = false> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void +block_masked_gemm( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device T* D [[buffer(3)]], + const constant GEMMParams* params [[buffer(4)]], + const constant int* batch_shape [[buffer(6)]], + const constant size_t* batch_strides [[buffer(7)]], + const device bool* out_mask [[buffer(10)]], + const device bool* lhs_mask [[buffer(11)]], + const device bool* rhs_mask [[buffer(12)]], + const constant int* mask_strides [[buffer(13)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + // Appease the compiler + (void)lid; + + using gemm_kernel = GEMMKernel< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + MN_aligned, + K_aligned>; + + const int tid_y = ((tid.y) << params->swizzle_log) + + ((tid.x) & ((1 << params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> params->swizzle_log; + + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + if (params->batch_ndim > 1) { + const constant size_t* mask_batch_strides = + batch_strides + 2 * params->batch_ndim; + out_mask += + elem_to_loc(tid.z, batch_shape, mask_batch_strides, params->batch_ndim); + + if (has_operand_mask) { + const constant size_t* mask_strides_lhs = + mask_batch_strides + params->batch_ndim; + const constant size_t* mask_strides_rhs = + mask_strides_lhs + params->batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, + batch_shape, + mask_strides_lhs, + mask_strides_rhs, + params->batch_ndim); + + lhs_mask += batch_offsets.x; + rhs_mask += batch_offsets.y; + } + } else { + out_mask += tid.z * batch_strides[2 * params->batch_ndim]; + if (has_operand_mask) { + lhs_mask += tid.z * batch_strides[3 * params->batch_ndim]; + rhs_mask += tid.z * batch_strides[4 * params->batch_ndim]; + } + } + + // Adjust for batch + if (params->batch_ndim > 1) { + const constant size_t* A_bstrides = batch_strides; + const constant size_t* B_bstrides = batch_strides + params->batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); + + A += batch_offsets.x; + B += batch_offsets.y; + + } else { + A += params->batch_stride_a * tid.z; + B += params->batch_stride_b * tid.z; + } + + D += params->batch_stride_d * tid.z; + + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + D += c_row_long * params->ldd + c_col_long; + + bool mask_out = out_mask[tid_y * mask_strides[1] + tid_x * mask_strides[0]]; + + // Write zeros and return + if (!mask_out) { + constexpr short tgp_size = WM * WN * 32; + constexpr short vec_size = 4; + + // Tile threads in threadgroup + constexpr short TN = BN / vec_size; + constexpr short TM = tgp_size / TN; + + const short thread_idx = simd_group_id * 32 + simd_lane_id; + const short bi = thread_idx / TN; + const short bj = vec_size * (thread_idx % TN); + + D += bi * params->ldd + bj; + + short tgp_bm = min(BM, params->M - c_row); + short tgp_bn = min(BN, params->N - c_col); + + if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) { + for (short ti = 0; ti < BM; ti += TM) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + D[ti * params->ldd + j] = T(0.); + } + } + } else { + short jmax = tgp_bn - bj; + jmax = jmax < vec_size ? jmax : vec_size; + for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) { + for (short j = 0; j < jmax; j++) { + D[ti * params->ldd + j] = T(0.); + } + } + } + + return; + } + + threadgroup_barrier(mem_flags::mem_none); + + // Prepare threadgroup mma operation + thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id); + + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + // Prepare threadgroup loading operations + thread typename gemm_kernel::loader_a_t loader_a( + A, params->lda, As, simd_group_id, simd_lane_id); + thread typename gemm_kernel::loader_b_t loader_b( + B, params->ldb, Bs, simd_group_id, simd_lane_id); + + /////////////////////////////////////////////////////////////////////////////// + // MNK aligned loop + if (MN_aligned) { + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (!has_operand_mask || + (lhs_mask + [tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] && + rhs_mask + [((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) { + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + } + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + // Loop tail + if (!K_aligned) { + if (!has_operand_mask || + (lhs_mask + [tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] && + rhs_mask + [(params->K / BM) * mask_strides[5] + + tid_x * mask_strides[4]])) { + int lbk = params->K - params->gemm_k_iterations_aligned * BK; + short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM); + short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk); + + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + } + } + + // Store results to device memory + mma_op.store_result(D, params->ldd); + return; + + } + /////////////////////////////////////////////////////////////////////////////// + // MN unaligned loop + else { // Loop over K - unaligned case + short tgp_bm = min(BM, params->M - c_row); + short tgp_bn = min(BN, params->N - c_col); + short lbk = params->K - params->gemm_k_iterations_aligned * BK; + + bool M_aligned = (tgp_bm == BM); + bool N_aligned = (tgp_bn == BN); + + short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm); + short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK); + + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + if (!has_operand_mask || + (lhs_mask + [tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] && + rhs_mask + [((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) { + // Load elements into threadgroup + if (M_aligned) { + loader_a.load_unsafe(); + } else { + loader_a.load_safe(tile_dims_A); + } + + if (N_aligned) { + loader_b.load_unsafe(); + } else { + loader_b.load_safe(tile_dims_B); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + } + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + if (!K_aligned) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (!has_operand_mask || + (lhs_mask + [tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] && + rhs_mask + [(params->K / BM) * mask_strides[5] + + tid_x * mask_strides[4]])) { + short2 tile_dims_A_last = + transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm); + short2 tile_dims_B_last = + transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk); + + loader_a.load_safe(tile_dims_A_last); + loader_b.load_safe(tile_dims_B_last); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + } + } + + if (M_aligned && N_aligned) { + mma_op.store_result(D, params->ldd); + } else { + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + } + } +} diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal index d39b4b005..52bb8bb41 100644 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal @@ -1,434 +1,10 @@ // Copyright © 2024 Apple Inc. +// clang-format off #include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm.h" #include "mlx/backend/metal/kernels/utils.h" - -using namespace metal; -using namespace mlx::steel; - -/////////////////////////////////////////////////////////////////////////////// -// GEMM kernels -/////////////////////////////////////////////////////////////////////////////// - -struct _NoMask { - char x; - - constexpr METAL_FUNC operator bool() { - return true; - } - constexpr METAL_FUNC operator bool() const threadgroup { - return true; - } - constexpr METAL_FUNC operator bool() const device { - return true; - } - constexpr METAL_FUNC operator bool() const constant { - return true; - } -}; - -template -struct ScaleOp { - OutT scale; - - METAL_FUNC OutT apply(InT x) const { - return static_cast(x) * scale; - } -}; - -typedef struct _NoMask nomask_t; - -template < - typename T, - typename out_mask_t, - typename op_mask_t, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose_a, - bool transpose_b, - bool MN_aligned, - bool K_aligned> -[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void -block_masked_gemm( - const device T* A [[buffer(0)]], - const device T* B [[buffer(1)]], - device T* D [[buffer(3)]], - const constant GEMMParams* params [[buffer(4)]], - const constant int* batch_shape [[buffer(6)]], - const constant size_t* batch_strides [[buffer(7)]], - const device out_mask_t* out_mask [[buffer(10)]], - const device op_mask_t* lhs_mask [[buffer(11)]], - const device op_mask_t* rhs_mask [[buffer(12)]], - const constant int* mask_strides [[buffer(13)]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - // Appease the compiler - (void)lid; - - static_assert( - BM == BN, - "block_masked_gemm must have the same block M and block N size"); - static_assert(BM % BK == 0, "block_masked_gemm must have BM % BK == 0"); - - constexpr bool has_operand_mask = !metal::is_same_v; - constexpr bool has_output_mask = !metal::is_same_v; - - constexpr bool has_mul_operand_mask = - has_operand_mask && !metal::is_same_v; - constexpr bool has_mul_output_mask = - has_output_mask && !metal::is_same_v; - - constexpr short k_mask_factor = short(BM / BK); - - using gemm_kernel = GEMMKernel< - T, - T, - BM, - BN, - BK, - WM, - WN, - transpose_a, - transpose_b, - MN_aligned, - K_aligned>; - - const int tid_y = ((tid.y) << params->swizzle_log) + - ((tid.x) & ((1 << params->swizzle_log) - 1)); - const int tid_x = (tid.x) >> params->swizzle_log; - - if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { - return; - } - - const constant size_t* mask_batch_strides = - batch_strides + 2 * params->batch_ndim; - - if (params->batch_ndim > 1) { - if (has_output_mask) { - out_mask += elem_to_loc( - tid.z, batch_shape, mask_batch_strides, params->batch_ndim); - - mask_batch_strides += params->batch_ndim; - } - - if (has_operand_mask) { - const constant size_t* mask_strides_lhs = mask_batch_strides; - const constant size_t* mask_strides_rhs = - mask_strides_lhs + params->batch_ndim; - - ulong2 batch_offsets = elem_to_loc_broadcast( - tid.z, - batch_shape, - mask_strides_lhs, - mask_strides_rhs, - params->batch_ndim); - - lhs_mask += batch_offsets.x; - rhs_mask += batch_offsets.y; - } - } else { - if (has_output_mask) { - out_mask += tid.z * mask_batch_strides[0]; - mask_batch_strides += params->batch_ndim; - } - - if (has_operand_mask) { - lhs_mask += tid.z * mask_batch_strides[0]; - rhs_mask += tid.z * mask_batch_strides[params->batch_ndim]; - } - } - - // Adjust for batch - if (params->batch_ndim > 1) { - const constant size_t* A_bstrides = batch_strides; - const constant size_t* B_bstrides = batch_strides + params->batch_ndim; - - ulong2 batch_offsets = elem_to_loc_broadcast( - tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); - - A += batch_offsets.x; - B += batch_offsets.y; - - } else { - A += params->batch_stride_a * tid.z; - B += params->batch_stride_b * tid.z; - } - - D += params->batch_stride_d * tid.z; - - // Find block in A, B, C - const int c_row = tid_y * BM; - const int c_col = tid_x * BN; - const size_t c_row_long = size_t(c_row); - const size_t c_col_long = size_t(c_col); - - A += transpose_a ? c_row_long : c_row_long * params->lda; - B += transpose_b ? c_col_long * params->ldb : c_col_long; - D += c_row_long * params->ldd + c_col_long; - - const constant int* out_mask_strides = mask_strides; - const constant int* lhs_mask_strides = - mask_strides + (has_output_mask ? 2 : 0); - const constant int* rhs_mask_strides = - lhs_mask_strides + (has_operand_mask ? 2 : 0); - - const int out_mask_offset = !has_output_mask - ? 0 - : tid_y * out_mask_strides[1] + tid_x * out_mask_strides[0]; - int lhs_mask_offset = !has_operand_mask ? 0 : tid_y * lhs_mask_strides[1]; - int rhs_mask_offset = !has_operand_mask ? 0 : tid_x * rhs_mask_strides[0]; - const int lhs_mask_step = !has_operand_mask ? 0 : lhs_mask_strides[0]; - const int rhs_mask_step = !has_operand_mask ? 0 : rhs_mask_strides[1]; - short k_factor_cnt = k_mask_factor; - - ScaleOp out_mask_op; - ScaleOp lhs_mask_op; - ScaleOp rhs_mask_op; - - if (has_output_mask) { - auto mask_out = out_mask[out_mask_offset]; - - if (has_mul_output_mask) { - out_mask_op.scale = float(mask_out); - } - - // Write zeros and return - if (!mask_out) { - constexpr short tgp_size = WM * WN * 32; - constexpr short vec_size = 4; - - // Tile threads in threadgroup - constexpr short TN = BN / vec_size; - constexpr short TM = tgp_size / TN; - - const short thread_idx = simd_group_id * 32 + simd_lane_id; - const short bi = thread_idx / TN; - const short bj = vec_size * (thread_idx % TN); - - D += bi * params->ldd + bj; - - short tgp_bm = min(BM, params->M - c_row); - short tgp_bn = min(BN, params->N - c_col); - - if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) { - for (short ti = 0; ti < BM; ti += TM) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - D[ti * params->ldd + j] = T(0.); - } - } - } else { - short jmax = tgp_bn - bj; - jmax = jmax < vec_size ? jmax : vec_size; - for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) { - for (short j = 0; j < jmax; j++) { - D[ti * params->ldd + j] = T(0.); - } - } - } - - return; - } - } - - threadgroup_barrier(mem_flags::mem_none); - - // Prepare threadgroup mma operation - thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id); - - threadgroup T As[gemm_kernel::tgp_mem_size_a]; - threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; - - // Prepare threadgroup loading operations - thread typename gemm_kernel::loader_a_t loader_a( - A, params->lda, As, simd_group_id, simd_lane_id); - thread typename gemm_kernel::loader_b_t loader_b( - B, params->ldb, Bs, simd_group_id, simd_lane_id); - - // Prepare threadgroup bounds - const short tgp_bm = - MN_aligned ? short(BM) : short(min(BM, params->M - c_row)); - const short tgp_bn = - MN_aligned ? short(BN) : short(min(BN, params->N - c_col)); - - int gemm_k_iterations = params->gemm_k_iterations_aligned; - - /////////////////////////////////////////////////////////////////////////////// - // Do unaligned K iterations first - if (!K_aligned) { - const int k_last = params->gemm_k_iterations_aligned * BK; - const int mask_idx_last = k_last / BM; - - if (!has_operand_mask || - (bool(lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step]) && - bool(rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step]))) { - if (has_mul_operand_mask) { - lhs_mask_op.scale = - lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step]; - rhs_mask_op.scale = - rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step]; - } - - // Move loader source ahead to end - const int k_remain = params->K - k_last; - const size_t k_jump_a = - transpose_a ? params->lda * size_t(k_last) : size_t(k_last); - const size_t k_jump_b = - transpose_b ? size_t(k_last) : params->ldb * size_t(k_last); - - loader_a.src += k_jump_a; - loader_b.src += k_jump_b; - - // Load tile - const short2 tile_dims_A = - transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); - const short2 tile_dims_B = - transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); - - loader_a.load_safe(tile_dims_A); - loader_b.load_safe(tile_dims_B); - - if (has_mul_operand_mask) { - loader_a.apply_inplace_op(lhs_mask_op); - loader_b.apply_inplace_op(rhs_mask_op); - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Do matmul - mma_op.mma(As, Bs); - - // Reset source back to start - loader_a.src -= k_jump_a; - loader_b.src -= k_jump_b; - } - } - - /////////////////////////////////////////////////////////////////////////////// - // MNK aligned loop - if (MN_aligned) { - for (; gemm_k_iterations > 0; gemm_k_iterations--) { - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (!has_operand_mask || - (bool(lhs_mask[lhs_mask_offset]) && - bool(rhs_mask[rhs_mask_offset]))) { - if (has_mul_operand_mask) { - lhs_mask_op.scale = lhs_mask[lhs_mask_offset]; - rhs_mask_op.scale = rhs_mask[rhs_mask_offset]; - } - - // Load elements into threadgroup - loader_a.load_unsafe(); - loader_b.load_unsafe(); - - if (has_mul_operand_mask) { - loader_a.apply_inplace_op(lhs_mask_op); - loader_b.apply_inplace_op(rhs_mask_op); - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - } - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - - k_factor_cnt--; - lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0; - rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0; - k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt; - } - - if (has_mul_output_mask) { - mma_op.apply_epilogue(out_mask_op); - } - - // Store results to device memory - mma_op.store_result(D, params->ldd); - return; - - } - /////////////////////////////////////////////////////////////////////////////// - // MN unaligned loop - else { - const bool M_aligned = (tgp_bm == BM); - const bool N_aligned = (tgp_bn == BN); - - const short2 tile_dims_A = - transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm); - const short2 tile_dims_B = - transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK); - - for (; gemm_k_iterations > 0; gemm_k_iterations--) { - threadgroup_barrier(mem_flags::mem_threadgroup); - if (!has_operand_mask || - (bool(lhs_mask[lhs_mask_offset]) && - bool(rhs_mask[rhs_mask_offset]))) { - if (has_mul_operand_mask) { - lhs_mask_op.scale = lhs_mask[lhs_mask_offset]; - rhs_mask_op.scale = rhs_mask[rhs_mask_offset]; - } - - // Load elements into threadgroup - if (M_aligned) { - loader_a.load_unsafe(); - } else { - loader_a.load_safe(tile_dims_A); - } - - if (N_aligned) { - loader_b.load_unsafe(); - } else { - loader_b.load_safe(tile_dims_B); - } - - if (has_mul_operand_mask) { - loader_a.apply_inplace_op(lhs_mask_op); - loader_b.apply_inplace_op(rhs_mask_op); - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - } - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - - k_factor_cnt--; - lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0; - rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0; - k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt; - } - - if (has_mul_output_mask) { - mma_op.apply_epilogue(out_mask_op); - } - - if (M_aligned && N_aligned) { - mma_op.store_result(D, params->ldd); - } else { - mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); - } - } -} - -/////////////////////////////////////////////////////////////////////////////// -// GEMM kernel initializations -/////////////////////////////////////////////////////////////////////////////// +#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h" #define instantiate_gemm( \ outmaskname, \ @@ -483,7 +59,6 @@ block_masked_gemm( uint3 tid [[threadgroup_position_in_grid]], \ uint3 lid [[thread_position_in_threadgroup]]); -// clang-format off #define instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ instantiate_gemm(bool_, bool, bool_, bool, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ instantiate_gemm(iname, itype, iname, itype, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ @@ -492,28 +67,24 @@ block_masked_gemm( instantiate_gemm(nomask, nomask_t, bool_, bool, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ instantiate_gemm(nomask, nomask_t, iname, itype, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ instantiate_gemm(bool_, bool, nomask, nomask_t, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ - instantiate_gemm(iname, itype, nomask, nomask_t, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) // clang-format on + instantiate_gemm(iname, itype, nomask, nomask_t, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) -// clang-format off #define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \ instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \ instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \ - instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on + instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) -// clang-format off #define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ - instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on + instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) -// clang-format off #define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \ - instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) // clang-format on + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) -// clang-format off instantiate_gemm_shapes_helper(float16, half, float16, half); instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t); instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h new file mode 100644 index 000000000..1ff97ea48 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h @@ -0,0 +1,227 @@ +// Copyright © 2024 Apple Inc. + +using namespace mlx::steel; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernels +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + bool MN_aligned, + bool K_aligned> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm_splitk( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device U* C [[buffer(2)]], + const constant GEMMSpiltKParams* params [[buffer(3)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + (void)lid; + + using gemm_kernel = GEMMKernel< + T, + U, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + MN_aligned, + K_aligned>; + using loader_a_t = typename gemm_kernel::loader_a_t; + using loader_b_t = typename gemm_kernel::loader_b_t; + using mma_t = typename gemm_kernel::mma_t; + + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + const int tid_x = tid.x; + const int tid_y = tid.y; + const int tid_z = tid.z; + + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const int k_start = params->split_k_partition_size * tid_z; + + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + const size_t k_start_long = size_t(k_start); + + A += transpose_a ? (c_row_long + k_start_long * params->lda) + : (k_start_long + c_row_long * params->lda); + B += transpose_b ? (k_start_long + c_col_long * params->ldb) + : (c_col_long + k_start_long * params->ldb); + C += (size_t(params->split_k_partition_stride) * tid_z) + + (c_row_long * params->ldc + c_col_long); + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + short tgp_bm = min(BM, params->M - c_row); + short tgp_bn = min(BN, params->N - c_col); + short leftover_bk = params->K % BK; + + if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + } else if (tgp_bn == BN) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + } else if (tgp_bm == BM) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + } else { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if ((tid_z + 1) == (params->split_k_partitions)) { + int gemm_k_iter_remaining = + (params->K - (k_start + params->split_k_partition_size)) / BK; + if (!K_aligned || gemm_k_iter_remaining > 0) + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iter_remaining, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + } + + if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) { + mma_op.store_result(C, params->ldc); + } else { + mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm)); + } +} + +/////////////////////////////////////////////////////////////////////////////// +// Split k accumulation kernel +/////////////////////////////////////////////////////////////////////////////// + +template < + typename AccT, + typename OutT, + typename Epilogue = TransformNone> +[[kernel]] void gemm_splitk_accum( + const device AccT* C_split [[buffer(0)]], + device OutT* D [[buffer(1)]], + const constant int& k_partitions [[buffer(2)]], + const constant int& partition_stride [[buffer(3)]], + const constant int& ldd [[buffer(4)]], + uint2 gid [[thread_position_in_grid]]) { + // Ajust D and C + D += gid.x + gid.y * size_t(ldd); + C_split += gid.x + gid.y * size_t(ldd); + + size_t offset = 0; + AccT out = 0; + + for (int i = 0; i < k_partitions; i++) { + out += C_split[offset]; + offset += partition_stride; + } + + // Write output + D[0] = Epilogue::apply(out); +} + +template < + typename AccT, + typename OutT, + typename Epilogue = TransformAxpby> +[[kernel]] void gemm_splitk_accum_axpby( + const device AccT* C_split [[buffer(0)]], + device OutT* D [[buffer(1)]], + const constant int& k_partitions [[buffer(2)]], + const constant int& partition_stride [[buffer(3)]], + const constant int& ldd [[buffer(4)]], + const device OutT* C [[buffer(5)]], + const constant int& ldc [[buffer(6)]], + const constant int& fdc [[buffer(7)]], + const constant float& alpha [[buffer(8)]], + const constant float& beta [[buffer(9)]], + uint2 gid [[thread_position_in_grid]]) { + // Ajust D and C + C += gid.x * size_t(fdc) + gid.y * size_t(ldc); + D += gid.x + gid.y * size_t(ldd); + C_split += gid.x + gid.y * size_t(ldd); + + size_t offset = 0; + AccT out = 0; + + for (int i = 0; i < k_partitions; i++) { + out += C_split[offset]; + offset += partition_stride; + } + + // Write output + Epilogue op(alpha, beta); + D[0] = op.apply(out, *C); +} diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal index e5b279f4e..9def75cda 100644 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal @@ -1,173 +1,9 @@ // Copyright © 2024 Apple Inc. +// clang-format off #include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm.h" - -using namespace metal; -using namespace mlx::steel; - -/////////////////////////////////////////////////////////////////////////////// -// GEMM kernels -/////////////////////////////////////////////////////////////////////////////// - -template < - typename T, - typename U, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose_a, - bool transpose_b, - bool MN_aligned, - bool K_aligned> -[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm_splitk( - const device T* A [[buffer(0)]], - const device T* B [[buffer(1)]], - device U* C [[buffer(2)]], - const constant GEMMSpiltKParams* params [[buffer(3)]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - (void)lid; - - using gemm_kernel = GEMMKernel< - T, - U, - BM, - BN, - BK, - WM, - WN, - transpose_a, - transpose_b, - MN_aligned, - K_aligned>; - using loader_a_t = typename gemm_kernel::loader_a_t; - using loader_b_t = typename gemm_kernel::loader_b_t; - using mma_t = typename gemm_kernel::mma_t; - - threadgroup T As[gemm_kernel::tgp_mem_size_a]; - threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; - - const int tid_x = tid.x; - const int tid_y = tid.y; - const int tid_z = tid.z; - - if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { - return; - } - - // Find block in A, B, C - const int c_row = tid_y * BM; - const int c_col = tid_x * BN; - const int k_start = params->split_k_partition_size * tid_z; - - const size_t c_row_long = size_t(c_row); - const size_t c_col_long = size_t(c_col); - const size_t k_start_long = size_t(k_start); - - A += transpose_a ? (c_row_long + k_start_long * params->lda) - : (k_start_long + c_row_long * params->lda); - B += transpose_b ? (k_start_long + c_col_long * params->ldb) - : (c_col_long + k_start_long * params->ldb); - C += (size_t(params->split_k_partition_stride) * tid_z) + - (c_row_long * params->ldc + c_col_long); - - // Prepare threadgroup loading operations - thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); - thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); - - // Prepare threadgroup mma operation - thread mma_t mma_op(simd_group_id, simd_lane_id); - - int gemm_k_iterations = params->gemm_k_iterations_aligned; - - short tgp_bm = min(BM, params->M - c_row); - short tgp_bn = min(BN, params->N - c_col); - short leftover_bk = params->K % BK; - - if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) { - gemm_kernel::gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - leftover_bk, - LoopAlignment{}); - } else if (tgp_bn == BN) { - gemm_kernel::gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - leftover_bk, - LoopAlignment{}); - } else if (tgp_bm == BM) { - gemm_kernel::gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - leftover_bk, - LoopAlignment{}); - } else { - gemm_kernel::gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - leftover_bk, - LoopAlignment{}); - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - if ((tid_z + 1) == (params->split_k_partitions)) { - int gemm_k_iter_remaining = - (params->K - (k_start + params->split_k_partition_size)) / BK; - if (!K_aligned || gemm_k_iter_remaining > 0) - gemm_kernel::gemm_loop( - As, - Bs, - gemm_k_iter_remaining, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - leftover_bk, - LoopAlignment{}); - } - - if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) { - mma_op.store_result(C, params->ldc); - } else { - mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm)); - } -} - -/////////////////////////////////////////////////////////////////////////////// -// GEMM kernel initializations -/////////////////////////////////////////////////////////////////////////////// +#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h" #define instantiate_gemm( \ tname, \ @@ -210,97 +46,27 @@ template < uint3 tid [[threadgroup_position_in_grid]], \ uint3 lid [[thread_position_in_threadgroup]]); -// clang-format off #define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \ instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \ instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \ - instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on + instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) -// clang-format off #define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ - instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on + instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) -// clang-format off #define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 16, 16, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 32, 16, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 16, 16, 2, 2) \ - instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) // clang-format on + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) -// clang-format off instantiate_gemm_shapes_helper(float16, half, float32, float); instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float); - -instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on - -/////////////////////////////////////////////////////////////////////////////// -// Split k accumulation kernel -/////////////////////////////////////////////////////////////////////////////// - -template < - typename AccT, - typename OutT, - typename Epilogue = TransformNone> -[[kernel]] void gemm_splitk_accum( - const device AccT* C_split [[buffer(0)]], - device OutT* D [[buffer(1)]], - const constant int& k_partitions [[buffer(2)]], - const constant int& partition_stride [[buffer(3)]], - const constant int& ldd [[buffer(4)]], - uint2 gid [[thread_position_in_grid]]) { - // Ajust D and C - D += gid.x + gid.y * size_t(ldd); - C_split += gid.x + gid.y * size_t(ldd); - - size_t offset = 0; - AccT out = 0; - - for (int i = 0; i < k_partitions; i++) { - out += C_split[offset]; - offset += partition_stride; - } - - // Write output - D[0] = Epilogue::apply(out); -} - -template < - typename AccT, - typename OutT, - typename Epilogue = TransformAxpby> -[[kernel]] void gemm_splitk_accum_axpby( - const device AccT* C_split [[buffer(0)]], - device OutT* D [[buffer(1)]], - const constant int& k_partitions [[buffer(2)]], - const constant int& partition_stride [[buffer(3)]], - const constant int& ldd [[buffer(4)]], - const device OutT* C [[buffer(5)]], - const constant int& ldc [[buffer(6)]], - const constant int& fdc [[buffer(7)]], - const constant float& alpha [[buffer(8)]], - const constant float& beta [[buffer(9)]], - uint2 gid [[thread_position_in_grid]]) { - // Ajust D and C - C += gid.x * size_t(fdc) + gid.y * size_t(ldc); - D += gid.x + gid.y * size_t(ldd); - C_split += gid.x + gid.y * size_t(ldd); - - size_t offset = 0; - AccT out = 0; - - for (int i = 0; i < k_partitions; i++) { - out += C_split[offset]; - offset += partition_stride; - } - - // Write output - Epilogue op(alpha, beta); - D[0] = op.apply(out, *C); -} +instantiate_gemm_shapes_helper(float32, float, float32, float); #define instantiate_accum(oname, otype, aname, atype) \ template [[host_name("steel_gemm_splitk_accum_" #oname \ @@ -313,7 +79,7 @@ template < const constant int& ldd [[buffer(4)]], \ uint2 gid [[thread_position_in_grid]]); \ template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname \ - "_axpby")]] [[kernel]] void \ + "_axbpy")]] [[kernel]] void \ gemm_splitk_accum_axpby( \ const device atype* C_split [[buffer(0)]], \ device otype* D [[buffer(1)]], \ @@ -327,7 +93,6 @@ template < const constant float& beta [[buffer(9)]], \ uint2 gid [[thread_position_in_grid]]); -// clang-format off instantiate_accum(bfloat16, bfloat16_t, float32, float); instantiate_accum(float16, half, float32, float); -instantiate_accum(float32, float, float32, float); // clang-format on \ No newline at end of file +instantiate_accum(float32, float, float32, float); // clang-format on diff --git a/mlx/backend/metal/kernels/steel/gemm/loader.h b/mlx/backend/metal/kernels/steel/gemm/loader.h index aa6e8107d..3f084d8ec 100644 --- a/mlx/backend/metal/kernels/steel/gemm/loader.h +++ b/mlx/backend/metal/kernels/steel/gemm/loader.h @@ -2,7 +2,7 @@ #pragma once -#include "mlx/backend/metal/kernels/steel/utils.h" +#include "mlx/backend/metal/kernels/steel/defines.h" /////////////////////////////////////////////////////////////////////////////// // Loading helper @@ -134,4 +134,4 @@ struct BlockLoader { }; } // namespace steel -} // namespace mlx \ No newline at end of file +} // namespace mlx diff --git a/mlx/backend/metal/kernels/steel/gemm/mma.h b/mlx/backend/metal/kernels/steel/gemm/mma.h index 8214ad723..dbd425ef0 100644 --- a/mlx/backend/metal/kernels/steel/gemm/mma.h +++ b/mlx/backend/metal/kernels/steel/gemm/mma.h @@ -6,8 +6,8 @@ #include #include +#include "mlx/backend/metal/kernels/steel/defines.h" #include "mlx/backend/metal/kernels/steel/gemm/transforms.h" -#include "mlx/backend/metal/kernels/steel/utils.h" using namespace metal; @@ -358,4 +358,4 @@ struct BlockMMA { }; } // namespace steel -} // namespace mlx \ No newline at end of file +} // namespace mlx diff --git a/mlx/backend/metal/kernels/steel/utils.h b/mlx/backend/metal/kernels/steel/utils.h index cc6c24260..322b22503 100644 --- a/mlx/backend/metal/kernels/steel/utils.h +++ b/mlx/backend/metal/kernels/steel/utils.h @@ -4,9 +4,6 @@ #include -#define STEEL_CONST static constant constexpr const -#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") - METAL_FUNC ulong2 elem_to_loc_broadcast( uint elem, constant const int* shape, @@ -42,4 +39,4 @@ METAL_FUNC ulong3 elem_to_loc_broadcast( loc_c += pos_in_dim * c_strides[i]; } return ulong3(loc_a, loc_b, loc_c); -} \ No newline at end of file +} diff --git a/mlx/backend/metal/make_compiled_preamble.sh b/mlx/backend/metal/make_compiled_preamble.sh index e397042e3..40e3852bc 100644 --- a/mlx/backend/metal/make_compiled_preamble.sh +++ b/mlx/backend/metal/make_compiled_preamble.sh @@ -8,9 +8,10 @@ OUTPUT_DIR=$1 CC=$2 SRC_DIR=$3 -SRC_NAME=$4 +SRC_FILE=$4 CFLAGS=$5 -INPUT_FILE=${SRC_DIR}/mlx/backend/metal/kernels/${SRC_NAME}.h +SRC_NAME=$(basename -- "${SRC_FILE}") +INPUT_FILE=${SRC_DIR}/mlx/backend/metal/kernels/${SRC_FILE}.h OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp mkdir -p $OUTPUT_DIR diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index c9273532e..b7596d415 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -7,6 +7,7 @@ #include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/steel/gemm/params.h" #include "mlx/backend/metal/matmul.h" @@ -336,7 +337,19 @@ void steel_matmul_conv_groups( // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts); + auto kernel = get_steel_gemm_fused_kernel( + d, + base_name, + hash_name, + func_consts, + out, + transpose_a, + transpose_b, + bm, + bn, + bk, + wm, + wn); compute_encoder->setComputePipelineState(kernel); @@ -458,17 +471,31 @@ void steel_matmul( C_split.set_data(allocator::malloc_or_wait(C_split.nbytes())); copies.push_back(C_split); + bool mn_aligned = M % bm == 0 && N % bn == 0; + bool k_aligned = K % bk == 0; std::ostringstream kname; kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n') << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" << type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk - << "_wm" << wm << "_wn" << wn << "_MN_" - << ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_" - << ((K % bk == 0) ? "t" : "n") << "aligned"; + << "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n") + << "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned"; // Encode and dispatch gemm kernel auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); + auto kernel = get_steel_gemm_splitk_kernel( + d, + kname.str(), + a, + C_split, + transpose_a, + transpose_b, + bm, + bn, + bk, + wm, + wn, + mn_aligned, + k_aligned); compute_encoder->setComputePipelineState(kernel); int tn = (N + bn - 1) / bn; @@ -504,10 +531,11 @@ void steel_matmul( static_cast(C_split.buffer().ptr()); const class MTL::Resource* const resources[1] = {c_split_buf}; compute_encoder->memoryBarrier(resources, 1); + auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" + + type_to_name(C_split); - auto kernel = d.get_kernel( - "steel_gemm_splitk_accum_" + type_to_name(out) + "_" + - type_to_name(C_split)); + auto kernel = get_steel_gemm_splitk_accum_kernel( + d, kernel_name, C_split, out, false); compute_encoder->setComputePipelineState(kernel); // Set the arguments for the kernel @@ -587,7 +615,19 @@ void steel_matmul( // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts); + auto kernel = get_steel_gemm_fused_kernel( + d, + base_name, + hash_name, + func_consts, + out, + transpose_a, + transpose_b, + bm, + bn, + bk, + wm, + wn); compute_encoder->setComputePipelineState(kernel); @@ -1053,17 +1093,33 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { C_split.set_data(allocator::malloc_or_wait(C_split.nbytes())); copies.push_back(C_split); + bool mn_aligned = M % bm == 0 && N % bn == 0; + bool k_aligned = K % bk == 0; + std::ostringstream kname; kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n') << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" << type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk - << "_wm" << wm << "_wn" << wn << "_MN_" - << ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_" - << ((K % bk == 0) ? "t" : "n") << "aligned"; + << "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n") + << "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned"; // Encode and dispatch gemm kernel auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); + auto kernel = get_steel_gemm_splitk_kernel( + d, + kname.str(), + a, + C_split, + transpose_a, + transpose_b, + bm, + bn, + bk, + wm, + wn, + mn_aligned, + k_aligned); + compute_encoder->setComputePipelineState(kernel); int tn = (N + bn - 1) / bn; @@ -1095,9 +1151,11 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { // Do accum kernel { - auto kernel = d.get_kernel( - "steel_gemm_splitk_accum_" + type_to_name(out) + "_" + - type_to_name(C_split) + "_axpby"); + auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" + + type_to_name(C_split) + "_axbpy"; + auto kernel = get_steel_gemm_splitk_accum_kernel( + d, kernel_name, C_split, out, true); + compute_encoder->setComputePipelineState(kernel); // Set the arguments for the kernel @@ -1182,7 +1240,19 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts); + auto kernel = get_steel_gemm_fused_kernel( + d, + base_name, + hash_name, + func_consts, + out, + transpose_a, + transpose_b, + bm, + bn, + bk, + wm, + wn); compute_encoder->setComputePipelineState(kernel); @@ -1348,6 +1418,8 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { // Determine dispatch kernel int bm = block_size_, bn = block_size_, bk = 16; int wm = 2, wn = 2; + bool mn_aligned = M % bm == 0 && N % bn == 0; + bool k_aligned = K % bk == 0; // Prepare kernel name std::string out_mask_nm = has_out_mask ? type_to_name(inputs[2]) : "nomask"; @@ -1358,13 +1430,26 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { << op_mask_nm << "_" << (transpose_a ? 't' : 'n') << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" << type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk - << "_wm" << wm << "_wn" << wn << "_MN_" - << ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_" - << ((K % bk == 0) ? "t" : "n") << "aligned"; + << "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n") + << "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned"; // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); + auto kernel = get_steel_gemm_masked_kernel( + d, + kname.str(), + out, + has_out_mask ? std::optional{inputs[2]} : std::nullopt, + has_op_mask ? std::optional{inputs.back()} : std::nullopt, + transpose_a, + transpose_b, + bm, + bn, + bk, + wm, + wn, + mn_aligned, + k_aligned); compute_encoder->setComputePipelineState(kernel); // Use problem size to determine threadblock swizzle @@ -1720,7 +1805,19 @@ void GatherMM::eval_gpu(const std::vector& inputs, array& out) { // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts); + auto kernel = get_steel_gemm_fused_kernel( + d, + base_name, + hash_name, + func_consts, + out, + transpose_a, + transpose_b, + bm, + bn, + bk, + wm, + wn); compute_encoder->setComputePipelineState(kernel); diff --git a/mlx/backend/metal/nojit_kernels.cpp b/mlx/backend/metal/nojit_kernels.cpp index 2dfd7a2ca..e14d099d3 100644 --- a/mlx/backend/metal/nojit_kernels.cpp +++ b/mlx/backend/metal/nojit_kernels.cpp @@ -103,4 +103,90 @@ MTL::ComputePipelineState* get_reduce_kernel( return d.get_kernel(kernel_name); } +MTL::ComputePipelineState* get_steel_gemm_fused_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array&, + bool, + bool, + int, + int, + int, + int, + int) { + return d.get_kernel(kernel_name, "mlx", hash_name, func_consts); +} + +MTL::ComputePipelineState* get_steel_gemm_splitk_kernel( + metal::Device& d, + const std::string& kernel_name, + const array&, + const array&, + bool, + bool, + int, + int, + int, + int, + int, + bool, + bool) { + return d.get_kernel(kernel_name); +} + +MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel( + metal::Device& d, + const std::string& kernel_name, + const array&, + const array&, + bool) { + return d.get_kernel(kernel_name); +} + +MTL::ComputePipelineState* get_steel_gemm_masked_kernel( + metal::Device& d, + const std::string& kernel_name, + const array&, + const std::optional& mask_out, + const std::optional& mask_op, + bool, + bool, + int, + int, + int, + int, + int, + bool, + bool) { + return d.get_kernel(kernel_name); +} + +MTL::ComputePipelineState* get_steel_conv_kernel( + metal::Device& d, + const std::string& kernel_name, + const array&, + int, + int, + int, + int, + int, + int, + bool) { + return d.get_kernel(kernel_name); +} + +MTL::ComputePipelineState* get_steel_conv_general_kernel( + metal::Device& d, + const std::string& kernel_name, + const array&, + int, + int, + int, + int, + int) { + return d.get_kernel(kernel_name); +} + } // namespace mlx::core diff --git a/python/src/ops.cpp b/python/src/ops.cpp index c601ca0c5..b919faed6 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3968,9 +3968,9 @@ void init_ops(nb::module_& m) { a (array): Input array or scalar. b (array): Input array or scalar. block_size (int): Size of blocks to be masked. Must be ``32`` or ``64`` (default: ``64``) - mask_out (array, optional): Boolean mask for output (default: ``None``) - mask_lhs (array, optional): Boolean mask for a (default: ``None``) - mask_rhs (array, optional): Boolean mask for b (default: ``None``) + mask_out (array, optional): Mask for output (default: ``None``) + mask_lhs (array, optional): Mask for a (default: ``None``) + mask_rhs (array, optional): Mask for b (default: ``None``) )pbdoc"); m.def( diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index c5ae5eaf5..de69d00dd 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -556,8 +556,7 @@ class TestBlas(mlx_tests.MLXTestCase): self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) - - # # Batched matmul with simple broadcast + # Batched matmul with simple broadcast a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32) @@ -573,7 +572,6 @@ class TestBlas(mlx_tests.MLXTestCase): self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) - # Matmul with vector a_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32) b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 128)).astype(np.float32)