Option to JIT steel gemm / conv (#1139)

This commit is contained in:
Awni Hannun 2024-05-23 18:07:34 -07:00 committed by GitHub
parent eab2685c67
commit 7e26fd8032
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
31 changed files with 2504 additions and 1540 deletions

View File

@ -1,4 +1,4 @@
function(make_jit_source SRC_NAME) function(make_jit_source SRC_FILE)
# This function takes a metal header file, # This function takes a metal header file,
# runs the C preprocessesor on it, and makes # runs the C preprocessesor on it, and makes
# the processed contents available as a string in a C++ function # the processed contents available as a string in a C++ function
@ -9,6 +9,7 @@ function(make_jit_source SRC_NAME)
# #
# Additional arguments to this function are treated as dependencies # Additional arguments to this function are treated as dependencies
# in the Cmake build system. # in the Cmake build system.
get_filename_component(SRC_NAME ${SRC_FILE} NAME)
add_custom_command( add_custom_command(
OUTPUT jit/${SRC_NAME}.cpp OUTPUT jit/${SRC_NAME}.cpp
COMMAND /bin/bash COMMAND /bin/bash
@ -16,10 +17,10 @@ function(make_jit_source SRC_NAME)
${CMAKE_CURRENT_BINARY_DIR}/jit ${CMAKE_CURRENT_BINARY_DIR}/jit
${CMAKE_C_COMPILER} ${CMAKE_C_COMPILER}
${PROJECT_SOURCE_DIR} ${PROJECT_SOURCE_DIR}
${SRC_NAME} ${SRC_FILE}
"-D${MLX_METAL_VERSION}" "-D${MLX_METAL_VERSION}"
DEPENDS make_compiled_preamble.sh DEPENDS make_compiled_preamble.sh
kernels/${SRC_NAME}.h kernels/${SRC_FILE}.h
${ARGN} ${ARGN}
) )
add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp) add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp)
@ -73,6 +74,39 @@ if (MLX_METAL_JIT)
kernels/reduction/reduce_col.h kernels/reduction/reduce_col.h
kernels/reduction/reduce_row.h kernels/reduction/reduce_row.h
) )
make_jit_source(
steel/gemm/gemm
kernels/steel/utils.h
kernels/steel/gemm/loader.h
kernels/steel/gemm/mma.h
kernels/steel/gemm/params.h
kernels/steel/gemm/transforms.h
)
make_jit_source(steel/gemm/kernels/steel_gemm_fused)
make_jit_source(
steel/gemm/kernels/steel_gemm_masked
kernels/steel/defines.h
)
make_jit_source(steel/gemm/kernels/steel_gemm_splitk)
make_jit_source(
steel/conv/conv
kernels/steel/utils.h
kernels/steel/defines.h
kernels/steel/gemm/mma.h
kernels/steel/gemm/transforms.h
kernels/steel/conv/params.h
kernels/steel/conv/loader.h
kernels/steel/conv/loaders/loader_channel_l.h
kernels/steel/conv/loaders/loader_channel_n.h
)
make_jit_source(
steel/conv/kernels/steel_conv
)
make_jit_source(
steel/conv/kernels/steel_conv_general
kernels/steel/defines.h
kernels/steel/conv/loaders/loader_general.h
)
else() else()
target_sources( target_sources(
mlx mlx

View File

@ -7,6 +7,7 @@
#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h" #include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/matmul.h" #include "mlx/backend/metal/matmul.h"
@ -335,7 +336,17 @@ void implicit_gemm_conv_2D_gpu(
// Encode and dispatch kernel // Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = get_steel_conv_kernel(
d,
kname.str(),
out,
bm,
bn,
bk,
wm,
wn,
n_channel_specialization,
small_filter);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
// Deduce grid launch dimensions // Deduce grid launch dimensions
@ -488,7 +499,8 @@ void implicit_gemm_conv_2D_general_gpu(
// Encode and dispatch kernel // Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel =
get_steel_conv_general_kernel(d, kname.str(), out, bm, bn, bk, wm, wn);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
// Deduce grid launch dimensions // Deduce grid launch dimensions

View File

@ -23,4 +23,12 @@ const char* softmax();
const char* sort(); const char* sort();
const char* reduce(); const char* reduce();
const char* gemm();
const char* steel_gemm_fused();
const char* steel_gemm_masked();
const char* steel_gemm_splitk();
const char* conv();
const char* steel_conv();
const char* steel_conv_general();
} // namespace mlx::core::metal } // namespace mlx::core::metal

View File

@ -0,0 +1,32 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view steel_conv_kernels = R"(
template [[host_name("{name}")]] [[kernel]] void
implicit_gemm_conv_2d<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}, {n_channels}, {small_filter}>(
const device {itype}* A [[buffer(0)]],
const device {itype}* B [[buffer(1)]],
device {itype}* C [[buffer(2)]],
const constant MLXConvParams<2>* params [[buffer(3)]],
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]);
)";
constexpr std::string_view steel_conv_general_kernels = R"(
template [[host_name("{name}")]] [[kernel]] void
implicit_gemm_conv_2d_general<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}>(
const device {itype}* A [[buffer(0)]],
const device {itype}* B [[buffer(1)]],
device {itype}* C [[buffer(2)]],
const constant MLXConvParams<2>* params [[buffer(3)]],
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]],
const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]],
const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]);
)";

View File

@ -0,0 +1,106 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view steel_gemm_fused_kernels = R"(
template [[host_name("{name}")]]
[[kernel]] void gemm<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}, {trans_a}, {trans_b}, float>(
const device {itype} *A [[buffer(0)]],
const device {itype} *B [[buffer(1)]],
const device {itype} *C [[buffer(2), function_constant(use_out_source)]],
device {itype} *D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]],
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
)";
constexpr std::string_view steel_gemm_masked_kernels = R"(
template [[host_name("{name}")]] [[kernel]] void
block_masked_gemm<
{itype},
{outmasktype},
{opmasktype},
{bm},
{bn},
{bk},
{wm},
{wn},
{trans_a},
{trans_b},
{mn_aligned},
{k_aligned}>(
const device {itype}* A [[buffer(0)]],
const device {itype}* B [[buffer(1)]],
device {itype}* D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
const device {outmasktype}* out_mask [[buffer(10)]],
const device {opmasktype}* lhs_mask [[buffer(11)]],
const device {opmasktype}* rhs_mask [[buffer(12)]],
const constant int* mask_strides [[buffer(13)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
)";
constexpr std::string_view steel_gemm_splitk_kernels = R"(
template [[host_name("{name}")]] [[kernel]] void
gemm_splitk<
{itype},
{otype},
{bm},
{bn},
{bk},
{wm},
{wn},
{trans_a},
{trans_b},
{mn_aligned},
{k_aligned}>(
const device {itype}* A [[buffer(0)]],
const device {itype}* B [[buffer(1)]],
device {otype}* C [[buffer(2)]],
const constant GEMMSpiltKParams* params [[buffer(3)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
)";
constexpr std::string_view steel_gemm_splitk_accum_kernels = R"(
template [[host_name("{name}")]] [[kernel]] void
gemm_splitk_accum<{atype}, {otype}>(
const device {atype}* C_split [[buffer(0)]],
device {otype}* D [[buffer(1)]],
const constant int& k_partitions [[buffer(2)]],
const constant int& partition_stride [[buffer(3)]],
const constant int& ldd [[buffer(4)]],
uint2 gid [[thread_position_in_grid]]);
)";
constexpr std::string_view steel_gemm_splitk_accum_axbpy_kernels = R"(
template [[host_name("{name}")]] [[kernel]] void
gemm_splitk_accum_axpby<{atype}, {otype}>(
const device {atype}* C_split [[buffer(0)]],
device {otype}* D [[buffer(1)]],
const constant int& k_partitions [[buffer(2)]],
const constant int& partition_stride [[buffer(3)]],
const constant int& ldd [[buffer(4)]],
const device {otype}* C [[buffer(5)]],
const constant int& ldc [[buffer(6)]],
const constant int& fdc [[buffer(7)]],
const constant float& alpha [[buffer(8)]],
const constant float& beta [[buffer(9)]],
uint2 gid [[thread_position_in_grid]]);
)";

View File

@ -1,5 +1,4 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include <fmt/format.h> #include <fmt/format.h>
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
@ -12,11 +11,15 @@
#include "mlx/backend/metal/jit/scan.h" #include "mlx/backend/metal/jit/scan.h"
#include "mlx/backend/metal/jit/softmax.h" #include "mlx/backend/metal/jit/softmax.h"
#include "mlx/backend/metal/jit/sort.h" #include "mlx/backend/metal/jit/sort.h"
#include "mlx/backend/metal/jit/steel_conv.h"
#include "mlx/backend/metal/jit/steel_gemm.h"
#include "mlx/backend/metal/jit/ternary.h" #include "mlx/backend/metal/jit/ternary.h"
#include "mlx/backend/metal/jit/unary.h" #include "mlx/backend/metal/jit/unary.h"
#include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
using namespace fmt::literals;
namespace mlx::core { namespace mlx::core {
std::string op_name(const array& arr) { std::string op_name(const array& arr) {
@ -276,4 +279,208 @@ MTL::ComputePipelineState* get_reduce_kernel(
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& out,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm()
<< metal::steel_gemm_fused()
<< fmt::format(
steel_gemm_fused_kernels,
"name"_a = lib_name,
"itype"_a = get_type_string(out.dtype()),
"bm"_a = bm,
"bn"_a = bn,
"bk"_a = bk,
"wm"_a = wm,
"wn"_a = wn,
"trans_a"_a = transpose_a,
"trans_b"_a = transpose_b);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
}
MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn,
bool mn_aligned,
bool k_aligned) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm()
<< metal::steel_gemm_splitk()
<< fmt::format(
steel_gemm_splitk_kernels,
"name"_a = lib_name,
"itype"_a = get_type_string(in.dtype()),
"otype"_a = get_type_string(out.dtype()),
"bm"_a = bm,
"bn"_a = bn,
"bk"_a = bk,
"wm"_a = wm,
"wn"_a = wn,
"trans_a"_a = transpose_a,
"trans_b"_a = transpose_b,
"mn_aligned"_a = mn_aligned,
"k_aligned"_a = k_aligned);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out,
bool axbpy) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm()
<< metal::steel_gemm_splitk()
<< fmt::format(
axbpy ? steel_gemm_splitk_accum_axbpy_kernels
: steel_gemm_splitk_accum_kernels,
"name"_a = lib_name,
"atype"_a = get_type_string(in.dtype()),
"otype"_a = get_type_string(out.dtype()));
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out,
const std::optional<array>& mask_out,
const std::optional<array>& mask_op,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn,
bool mn_aligned,
bool k_aligned) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
auto out_mask_type = mask_out.has_value()
? get_type_string((*mask_out).dtype())
: "nomask_t";
auto op_mask_type =
mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t";
kernel_source << metal::utils() << metal::gemm()
<< metal::steel_gemm_masked()
<< fmt::format(
steel_gemm_masked_kernels,
"name"_a = lib_name,
"itype"_a = get_type_string(out.dtype()),
"outmasktype"_a = out_mask_type,
"opmasktype"_a = op_mask_type,
"bm"_a = bm,
"bn"_a = bn,
"bk"_a = bk,
"wm"_a = wm,
"wn"_a = wn,
"trans_a"_a = transpose_a,
"trans_b"_a = transpose_b,
"mn_aligned"_a = mn_aligned,
"k_aligned"_a = k_aligned);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_steel_conv_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out,
int bm,
int bn,
int bk,
int wm,
int wn,
int n_channel_specialization,
bool small_filter) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::conv() << metal::steel_conv()
<< fmt::format(
steel_conv_kernels,
"name"_a = lib_name,
"itype"_a = get_type_string(out.dtype()),
"bm"_a = bm,
"bn"_a = bn,
"bk"_a = bk,
"wm"_a = wm,
"wn"_a = wn,
"n_channels"_a = n_channel_specialization,
"small_filter"_a = small_filter);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_steel_conv_general_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out,
int bm,
int bn,
int bk,
int wm,
int wn) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::conv()
<< metal::steel_conv_general()
<< fmt::format(
steel_conv_general_kernels,
"name"_a = lib_name,
"itype"_a = get_type_string(out.dtype()),
"bm"_a = bm,
"bn"_a = bn,
"bk"_a = bk,
"wm"_a = wm,
"wn"_a = wn);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -79,4 +79,78 @@ MTL::ComputePipelineState* get_reduce_kernel(
const array& in, const array& in,
const array& out); const array& out);
MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& out,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn);
MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn,
bool mn_aligned,
bool k_aligned);
MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out,
bool axbpy);
MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out,
const std::optional<array>& mask_out,
const std::optional<array>& mask_op,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn,
bool mn_aligned,
bool k_aligned);
MTL::ComputePipelineState* get_steel_conv_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out,
int bm,
int bn,
int bk,
int wm,
int wn,
int n_channel_specialization,
bool small_filter);
MTL::ComputePipelineState* get_steel_conv_general_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out,
int bm,
int bn,
int bk,
int wm,
int wn);
} // namespace mlx::core } // namespace mlx::core

View File

@ -1,14 +1,13 @@
set( set(
HEADERS HEADERS
${CMAKE_CURRENT_SOURCE_DIR}/atomic.h bf16.h
${CMAKE_CURRENT_SOURCE_DIR}/bf16.h bf16_math.h
${CMAKE_CURRENT_SOURCE_DIR}/bf16_math.h complex.h
${CMAKE_CURRENT_SOURCE_DIR}/complex.h defines.h
${CMAKE_CURRENT_SOURCE_DIR}/defines.h utils.h
${CMAKE_CURRENT_SOURCE_DIR}/utils.h steel/conv/params.h
) )
set( set(
KERNELS KERNELS
"arg_reduce" "arg_reduce"
@ -41,6 +40,7 @@ set(
set( set(
HEADERS HEADERS
${HEADERS} ${HEADERS}
atomic.h
arange.h arange.h
unary_ops.h unary_ops.h
unary.h unary.h
@ -89,14 +89,40 @@ foreach(KERNEL ${KERNELS})
set(KERNEL_AIR ${KERNEL}.air ${KERNEL_AIR}) set(KERNEL_AIR ${KERNEL}.air ${KERNEL_AIR})
endforeach() endforeach()
file(GLOB_RECURSE STEEL_KERNELS ${CMAKE_CURRENT_SOURCE_DIR}/steel/*.metal) if (NOT MLX_METAL_JIT)
file(GLOB_RECURSE STEEL_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/steel/*.h) set(
STEEL_KERNELS
foreach(KERNEL ${STEEL_KERNELS}) ${CMAKE_CURRENT_SOURCE_DIR}/steel/conv/kernels/steel_conv.metal
cmake_path(GET KERNEL STEM TARGET) ${CMAKE_CURRENT_SOURCE_DIR}/steel/conv/kernels/steel_conv_general.metal
build_kernel_base(${TARGET} ${KERNEL} "${STEEL_HEADERS}") ${CMAKE_CURRENT_SOURCE_DIR}/steel/gemm/kernels/steel_gemm_fused.metal
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR}) ${CMAKE_CURRENT_SOURCE_DIR}/steel/gemm/kernels/steel_gemm_masked.metal
endforeach() ${CMAKE_CURRENT_SOURCE_DIR}/steel/gemm/kernels/steel_gemm_splitk.metal
)
set(
STEEL_HEADERS
steel/defines.h
steel/utils.h
steel/conv/conv.h
steel/conv/loader.h
steel/conv/loaders/loader_channel_l.h
steel/conv/loaders/loader_channel_n.h
steel/conv/loaders/loader_general.h
steel/conv/kernels/steel_conv.h
steel/conv/kernels/steel_conv_general.h
steel/gemm/gemm.h
steel/gemm/mma.h
steel/gemm/loader.h
steel/gemm/transforms.h
steel/gemm/kernels/steel_gemm_fused.h
steel/gemm/kernels/steel_gemm_masked.h
steel/gemm/kernels/steel_gemm_splitk.h
)
foreach(KERNEL ${STEEL_KERNELS})
cmake_path(GET KERNEL STEM TARGET)
build_kernel_base(${TARGET} ${KERNEL} "${STEEL_HEADERS}")
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR})
endforeach()
endif()
add_custom_command( add_custom_command(
OUTPUT ${MLX_METAL_PATH}/mlx.metallib OUTPUT ${MLX_METAL_PATH}/mlx.metallib

View File

@ -1,6 +1,5 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#include <metal_atomic>
#include <metal_simdgroup> #include <metal_simdgroup>
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
@ -194,4 +193,4 @@ instantiate_arg_reduce(int32, int32_t)
instantiate_arg_reduce(int64, int64_t) instantiate_arg_reduce(int64, int64_t)
instantiate_arg_reduce(float16, half) instantiate_arg_reduce(float16, half)
instantiate_arg_reduce(float32, float) instantiate_arg_reduce(float32, float)
instantiate_arg_reduce(bfloat16, bfloat16_t) // clang-format on instantiate_arg_reduce(bfloat16, bfloat16_t) // clang-format on

View File

@ -650,4 +650,4 @@ winograd_conv_2d_output_transform(
// clang-format off // clang-format off
instantiate_winograd_conv_2d(float32, float); instantiate_winograd_conv_2d(float32, float);
instantiate_winograd_conv_2d(float16, half); // clang-format on instantiate_winograd_conv_2d(float16, half); // clang-format on

View File

@ -2,10 +2,12 @@
#pragma once #pragma once
#include "mlx/backend/metal/kernels/steel/defines.h"
#include "mlx/backend/metal/kernels/steel/utils.h" #include "mlx/backend/metal/kernels/steel/utils.h"
#include "mlx/backend/metal/kernels/steel/conv/loader.h" #include "mlx/backend/metal/kernels/steel/conv/loader.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h" #include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
using namespace metal; using namespace metal;
using namespace mlx::steel; using namespace mlx::steel;

View File

@ -0,0 +1,176 @@
// Copyright © 2024 Apple Inc.
#include <metal_stdlib>
using namespace metal;
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
int N_CHANNELS = 0,
bool SMALL_FILTER = false>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
implicit_gemm_conv_2d(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device T* C [[buffer(2)]],
const constant MLXConvParams<2>* params [[buffer(3)]],
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using namespace mlx::steel;
(void)lid;
constexpr bool transpose_a = false;
constexpr bool transpose_b = true;
constexpr short tgp_padding_a = 16 / sizeof(T);
constexpr short tgp_padding_b = 16 / sizeof(T);
constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a;
constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b;
constexpr short shape_a_rows = (transpose_a ? BK : BM);
constexpr short shape_b_rows = (transpose_b ? BN : BK);
constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows;
constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;
constexpr short tgp_size = WM * WN * 32;
// Input loader
using loader_a_t = typename metal::conditional_t<
// Check for small channel specialization
N_CHANNELS != 0 && N_CHANNELS <= 4,
// Go to small channel specialization
Conv2DInputBlockLoaderSmallChannels<
T,
BM,
BN,
BK,
tgp_size,
N_CHANNELS,
tgp_padding_a>,
// Else go to general loader
typename metal::conditional_t<
// Check if filter size is small enough
SMALL_FILTER,
// Go to small filter specialization
Conv2DInputBlockLoaderSmallFilter<
T,
BM,
BN,
BK,
tgp_size,
tgp_padding_a>,
// Else go to large filter generalization
Conv2DInputBlockLoaderLargeFilter<
T,
BM,
BN,
BK,
tgp_size,
tgp_padding_a>>>;
// Weight loader
using loader_b_t = typename metal::conditional_t<
// Check for small channel specialization
N_CHANNELS != 0 && N_CHANNELS <= 4,
// Go to small channel specialization
Conv2DWeightBlockLoaderSmallChannels<
T,
BM,
BN,
BK,
tgp_size,
N_CHANNELS,
tgp_padding_b>,
// Else go to general loader
Conv2DWeightBlockLoader<T, BM, BN, BK, tgp_size, tgp_padding_b>>;
using mma_t = BlockMMA<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
shape_a_cols,
shape_b_cols>;
threadgroup T As[tgp_mem_size_a];
threadgroup T Bs[tgp_mem_size_b];
const int tid_y = ((tid.y) << gemm_params->swizzle_log) +
((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
const int tid_x = (tid.x) >> gemm_params->swizzle_log;
if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {
return;
}
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const int K = gemm_params->K;
const int N = gemm_params->N;
const int C_per_group = params->C / params->groups;
// Groups
A += tid.z * C_per_group;
B += tid.z * N * K;
C += tid.z * N;
B += c_col * K;
C += c_row * (N * params->groups) + c_col;
const int2 offsets_a(0, c_row);
const int2 offsets_b(0, c_col);
// Prepare threadgroup loading operations
loader_a_t loader_a(
A, As, offsets_a, params, gemm_params, simd_gid, simd_lid);
loader_b_t loader_b(
B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid);
// Prepare threadgroup mma operation
mma_t mma_op(simd_gid, simd_lid);
int gemm_k_iterations = gemm_params->gemm_k_iterations;
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Store results to device memory
short tgp_bm = min(BM, gemm_params->M - c_row);
short tgp_bn = min(BN, gemm_params->N - c_col);
const int ldc = N * params->groups;
mma_op.store_result_safe(C, ldc, short2(tgp_bn, tgp_bm));
}

View File

@ -2,184 +2,13 @@
#include <metal_stdlib> #include <metal_stdlib>
// clang-format off
#include "mlx/backend/metal/kernels/steel/gemm/mma.h" #include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/conv/conv.h" #include "mlx/backend/metal/kernels/steel/conv/conv.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h" #include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h"
using namespace metal;
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
int N_CHANNELS = 0,
bool SMALL_FILTER = false>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
implicit_gemm_conv_2d(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device T* C [[buffer(2)]],
const constant MLXConvParams<2>* params [[buffer(3)]],
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using namespace mlx::steel;
(void)lid;
constexpr bool transpose_a = false;
constexpr bool transpose_b = true;
constexpr short tgp_padding_a = 16 / sizeof(T);
constexpr short tgp_padding_b = 16 / sizeof(T);
constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a;
constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b;
constexpr short shape_a_rows = (transpose_a ? BK : BM);
constexpr short shape_b_rows = (transpose_b ? BN : BK);
constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows;
constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;
constexpr short tgp_size = WM * WN * 32;
// Input loader
using loader_a_t = typename metal::conditional_t<
// Check for small channel specialization
N_CHANNELS != 0 && N_CHANNELS <= 4,
// Go to small channel specialization
Conv2DInputBlockLoaderSmallChannels<
T,
BM,
BN,
BK,
tgp_size,
N_CHANNELS,
tgp_padding_a>,
// Else go to general loader
typename metal::conditional_t<
// Check if filter size is small enough
SMALL_FILTER,
// Go to small filter specialization
Conv2DInputBlockLoaderSmallFilter<
T,
BM,
BN,
BK,
tgp_size,
tgp_padding_a>,
// Else go to large filter generalization
Conv2DInputBlockLoaderLargeFilter<
T,
BM,
BN,
BK,
tgp_size,
tgp_padding_a>>>;
// Weight loader
using loader_b_t = typename metal::conditional_t<
// Check for small channel specialization
N_CHANNELS != 0 && N_CHANNELS <= 4,
// Go to small channel specialization
Conv2DWeightBlockLoaderSmallChannels<
T,
BM,
BN,
BK,
tgp_size,
N_CHANNELS,
tgp_padding_b>,
// Else go to general loader
Conv2DWeightBlockLoader<T, BM, BN, BK, tgp_size, tgp_padding_b>>;
using mma_t = BlockMMA<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
shape_a_cols,
shape_b_cols>;
threadgroup T As[tgp_mem_size_a];
threadgroup T Bs[tgp_mem_size_b];
const int tid_y = ((tid.y) << gemm_params->swizzle_log) +
((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
const int tid_x = (tid.x) >> gemm_params->swizzle_log;
if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {
return;
}
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const int K = gemm_params->K;
const int N = gemm_params->N;
const int C_per_group = params->C / params->groups;
// Groups
A += tid.z * C_per_group;
B += tid.z * N * K;
C += tid.z * N;
B += c_col * K;
C += c_row * (N * params->groups) + c_col;
const int2 offsets_a(0, c_row);
const int2 offsets_b(0, c_col);
// Prepare threadgroup loading operations
loader_a_t loader_a(
A, As, offsets_a, params, gemm_params, simd_gid, simd_lid);
loader_b_t loader_b(
B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid);
// Prepare threadgroup mma operation
mma_t mma_op(simd_gid, simd_lid);
int gemm_k_iterations = gemm_params->gemm_k_iterations;
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Store results to device memory
short tgp_bm = min(BM, gemm_params->M - c_row);
short tgp_bn = min(BN, gemm_params->N - c_col);
const int ldc = N * params->groups;
mma_op.store_result_safe(C, ldc, short2(tgp_bn, tgp_bm));
}
#define instantiate_implicit_conv_2d( \ #define instantiate_implicit_conv_2d( \
name, \ name, \
@ -207,25 +36,22 @@ implicit_gemm_conv_2d(
uint simd_gid [[simdgroup_index_in_threadgroup]], \ uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]); uint simd_lid [[thread_index_in_simdgroup]]);
// clang-format off
#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \ #define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, s, true) \ instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, s, true) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, l, false) \ instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, l, false) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 1, 1, l, false) \ instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 1, 1, l, false) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 2, 2, l, false) \ instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 2, 2, l, false) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 3, 3, l, false) \ instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 3, 3, l, false) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 4, 4, l, false) // clang-format on instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 4, 4, l, false)
// clang-format off
#define instantiate_implicit_2d_blocks(name, itype) \ #define instantiate_implicit_2d_blocks(name, itype) \
instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \ instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \
instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \ instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \
instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \ instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \ instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \ instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2) // clang-format on instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2)
// clang-format off
instantiate_implicit_2d_blocks(float32, float); instantiate_implicit_2d_blocks(float32, float);
instantiate_implicit_2d_blocks(float16, half); instantiate_implicit_2d_blocks(float16, half);
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); // clang-format on instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); // clang-format on

View File

@ -0,0 +1,188 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h"
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
typename AccumType = float,
typename Epilogue = TransformNone<T, AccumType>>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
implicit_gemm_conv_2d_general(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device T* C [[buffer(2)]],
const constant MLXConvParams<2>* params [[buffer(3)]],
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]],
const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]],
const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
(void)lid;
constexpr bool transpose_a = false;
constexpr bool transpose_b = true;
constexpr short tgp_padding_a = 16 / sizeof(T);
constexpr short tgp_padding_b = 16 / sizeof(T);
constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a;
constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b;
constexpr short shape_a_rows = (transpose_a ? BK : BM);
constexpr short shape_b_rows = (transpose_b ? BN : BK);
constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows;
constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;
constexpr short tgp_size = WM * WN * 32;
// Input loader
using loader_a_t =
Conv2DInputBlockLoaderGeneral<T, BM, BN, BK, tgp_size, tgp_padding_a>;
// Weight loader
using loader_b_t =
Conv2DWeightBlockLoaderGeneral<T, BM, BN, BK, tgp_size, tgp_padding_b>;
using mma_t = BlockMMA<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
shape_a_cols,
shape_b_cols>;
threadgroup T As[tgp_mem_size_a];
threadgroup T Bs[tgp_mem_size_b];
const int tid_y = ((tid.y) << gemm_params->swizzle_log) +
((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
const int tid_x = (tid.x) >> gemm_params->swizzle_log;
if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {
return;
}
const int tid_z = tid.z;
const int base_oh = tid_z / jump_params->f_out_jump_w;
const int base_ow = tid_z % jump_params->f_out_jump_w;
const int base_wh = base_h[base_oh].weight_base;
const int base_ww = base_w[base_ow].weight_base;
const int base_wh_size = base_h[base_oh].weight_size;
const int base_ww_size = base_w[base_ow].weight_size;
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const int K = gemm_params->K;
B += c_col * K;
const int4 offsets_a(0, c_row, base_oh, base_ow);
const int2 offsets_b(0, c_col);
// Prepare threadgroup loading operations
loader_a_t loader_a(
A,
As,
offsets_a,
params,
jump_params,
base_wh,
base_ww,
simd_gid,
simd_lid);
loader_b_t loader_b(
B,
Bs,
offsets_b,
params,
jump_params,
base_wh,
base_ww,
simd_gid,
simd_lid);
// Prepare threadgroup mma operation
mma_t mma_op(simd_gid, simd_lid);
int gemm_k_iterations =
base_wh_size * base_ww_size * gemm_params->gemm_k_iterations;
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Store results to device memory
{
// Adjust for simdgroup and thread locatio
int offset_m = c_row + mma_op.sm + mma_op.tm;
int offset_n = c_col + mma_op.sn + mma_op.tn;
C += offset_n;
if (offset_n >= gemm_params->N)
return;
short diff = gemm_params->N - offset_n;
STEEL_PRAGMA_UNROLL
for (int i = 0; i < mma_t::TM; i++) {
int cm = offset_m + i * mma_t::TM_stride;
int n = cm / jump_params->adj_out_hw;
int hw = cm % jump_params->adj_out_hw;
int oh =
(hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + base_oh;
int ow =
(hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow;
if (n < params->N && oh < params->oS[0] && ow < params->oS[1]) {
int offset_cm = n * params->out_strides[0] +
oh * params->out_strides[1] + ow * params->out_strides[2];
STEEL_PRAGMA_UNROLL
for (int j = 0; j < mma_t::TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum =
mma_op.results[i * mma_t::TN + j].thread_elements();
int offset = offset_cm + (j * mma_t::TN_stride);
// Apply epilogue and output C
if (j * mma_t::TN_stride < diff) {
C[offset] = Epilogue::apply(accum[0]);
}
if (j * mma_t::TN_stride + 1 < diff) {
C[offset + 1] = Epilogue::apply(accum[1]);
}
}
}
}
}
}

View File

@ -2,201 +2,18 @@
#include <metal_stdlib> #include <metal_stdlib>
// clang-format off
#include "mlx/backend/metal/kernels/steel/gemm/mma.h" #include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/conv/conv.h" #include "mlx/backend/metal/kernels/steel/conv/conv.h"
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h" #include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/steel/utils.h"
#include "mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h"
using namespace metal; using namespace metal;
using namespace mlx::steel; using namespace mlx::steel;
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
typename AccumType = float,
typename Epilogue = TransformNone<T, AccumType>>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
implicit_gemm_conv_2d_general(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device T* C [[buffer(2)]],
const constant MLXConvParams<2>* params [[buffer(3)]],
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]],
const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]],
const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
(void)lid;
constexpr bool transpose_a = false;
constexpr bool transpose_b = true;
constexpr short tgp_padding_a = 16 / sizeof(T);
constexpr short tgp_padding_b = 16 / sizeof(T);
constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a;
constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b;
constexpr short shape_a_rows = (transpose_a ? BK : BM);
constexpr short shape_b_rows = (transpose_b ? BN : BK);
constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows;
constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;
constexpr short tgp_size = WM * WN * 32;
// Input loader
using loader_a_t =
Conv2DInputBlockLoaderGeneral<T, BM, BN, BK, tgp_size, tgp_padding_a>;
// Weight loader
using loader_b_t =
Conv2DWeightBlockLoaderGeneral<T, BM, BN, BK, tgp_size, tgp_padding_b>;
using mma_t = BlockMMA<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
shape_a_cols,
shape_b_cols>;
threadgroup T As[tgp_mem_size_a];
threadgroup T Bs[tgp_mem_size_b];
const int tid_y = ((tid.y) << gemm_params->swizzle_log) +
((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
const int tid_x = (tid.x) >> gemm_params->swizzle_log;
if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {
return;
}
const int tid_z = tid.z;
const int base_oh = tid_z / jump_params->f_out_jump_w;
const int base_ow = tid_z % jump_params->f_out_jump_w;
const int base_wh = base_h[base_oh].weight_base;
const int base_ww = base_w[base_ow].weight_base;
const int base_wh_size = base_h[base_oh].weight_size;
const int base_ww_size = base_w[base_ow].weight_size;
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const int K = gemm_params->K;
B += c_col * K;
const int4 offsets_a(0, c_row, base_oh, base_ow);
const int2 offsets_b(0, c_col);
// Prepare threadgroup loading operations
loader_a_t loader_a(
A,
As,
offsets_a,
params,
jump_params,
base_wh,
base_ww,
simd_gid,
simd_lid);
loader_b_t loader_b(
B,
Bs,
offsets_b,
params,
jump_params,
base_wh,
base_ww,
simd_gid,
simd_lid);
// Prepare threadgroup mma operation
mma_t mma_op(simd_gid, simd_lid);
int gemm_k_iterations =
base_wh_size * base_ww_size * gemm_params->gemm_k_iterations;
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Store results to device memory
{
// Adjust for simdgroup and thread locatio
int offset_m = c_row + mma_op.sm + mma_op.tm;
int offset_n = c_col + mma_op.sn + mma_op.tn;
C += offset_n;
if (offset_n >= gemm_params->N)
return;
short diff = gemm_params->N - offset_n;
STEEL_PRAGMA_UNROLL
for (int i = 0; i < mma_t::TM; i++) {
int cm = offset_m + i * mma_t::TM_stride;
int n = cm / jump_params->adj_out_hw;
int hw = cm % jump_params->adj_out_hw;
int oh =
(hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + base_oh;
int ow =
(hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow;
if (n < params->N && oh < params->oS[0] && ow < params->oS[1]) {
int offset_cm = n * params->out_strides[0] +
oh * params->out_strides[1] + ow * params->out_strides[2];
STEEL_PRAGMA_UNROLL
for (int j = 0; j < mma_t::TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum =
mma_op.results[i * mma_t::TN + j].thread_elements();
int offset = offset_cm + (j * mma_t::TN_stride);
// Apply epilogue and output C
if (j * mma_t::TN_stride < diff) {
C[offset] = Epilogue::apply(accum[0]);
}
if (j * mma_t::TN_stride + 1 < diff) {
C[offset + 1] = Epilogue::apply(accum[1]);
}
}
}
}
}
}
#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) \ #define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) \
template \ template \
[[host_name("implicit_gemm_conv_2d_general_" #name "_bm" #bm "_bn" #bn \ [[host_name("implicit_gemm_conv_2d_general_" #name "_bm" #bm "_bn" #bn \
@ -218,16 +35,14 @@ implicit_gemm_conv_2d_general(
#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \ #define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn)
// clang-format off
#define instantiate_implicit_2d_blocks(name, itype) \ #define instantiate_implicit_2d_blocks(name, itype) \
instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \ instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \
instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \ instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \
instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \ instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \ instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \ instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2) // clang-format on instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2)
// clang-format off
instantiate_implicit_2d_blocks(float32, float); instantiate_implicit_2d_blocks(float32, float);
instantiate_implicit_2d_blocks(float16, half); instantiate_implicit_2d_blocks(float16, half);
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); // clang-format on instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); // clang-format on

View File

@ -2,9 +2,7 @@
#pragma once #pragma once
#include "mlx/backend/metal/kernels/steel/utils.h" #include "mlx/backend/metal/kernels/steel/defines.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Loading helper // Loading helper
@ -285,4 +283,4 @@ struct Conv2DWeightBlockLoaderGeneral {
}; };
} // namespace steel } // namespace steel
} // namespace mlx } // namespace mlx

View File

@ -0,0 +1,4 @@
// Copyright © 2024 Apple Inc.
#define STEEL_CONST static constant constexpr const
#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")

View File

@ -0,0 +1,415 @@
// Copyright © 2024 Apple Inc.
using namespace mlx::steel;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
constant bool has_batch [[function_constant(10)]];
constant bool use_out_source [[function_constant(100)]];
constant bool do_axpby [[function_constant(110)]];
constant bool align_M [[function_constant(200)]];
constant bool align_N [[function_constant(201)]];
constant bool align_K [[function_constant(202)]];
constant bool do_gather [[function_constant(300)]];
constant bool gather_bias = do_gather && use_out_source;
// clang-format off
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
typename AccumType = float>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
const device T* C [[buffer(2), function_constant(use_out_source)]],
device T* D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]],
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on
// Pacifying compiler
(void)lid;
using gemm_kernel = GEMMKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
true,
true,
AccumType>;
using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t;
using mma_t = typename gemm_kernel::mma_t;
// Find block
const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;
// Exit early if out of bounds
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
// Adjust for batch
// Handle gather
if (do_gather) {
// Read indices
uint32_t indx_A, indx_B, indx_C;
if (has_batch) {
const constant size_t* indx_A_bstrides = batch_strides;
const constant size_t* indx_B_bstrides =
batch_strides + params->batch_ndim;
ulong2 indx_offsets = elem_to_loc_broadcast(
tid.z,
batch_shape,
indx_A_bstrides,
indx_B_bstrides,
params->batch_ndim);
indx_A = lhs_indices[indx_offsets.x];
indx_B = rhs_indices[indx_offsets.y];
if (use_out_source) {
const constant size_t* indx_C_bstrides =
indx_B_bstrides + params->batch_ndim;
auto indx_offset_C = elem_to_loc(
tid.z, batch_shape, indx_C_bstrides, params->batch_ndim);
indx_C = C_indices[indx_offset_C];
}
} else {
indx_A = lhs_indices[params->batch_stride_a * tid.z];
indx_B = rhs_indices[params->batch_stride_b * tid.z];
if (use_out_source) {
indx_C = C_indices[addmm_params->batch_stride_c * tid.z];
}
}
// Translate indices to offsets
int batch_ndim_A = operand_batch_ndim.x;
const constant int* batch_shape_A = operand_shape;
const constant size_t* batch_strides_A = operand_strides;
A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A);
int batch_ndim_B = operand_batch_ndim.y;
const constant int* batch_shape_B = batch_shape_A + batch_ndim_A;
const constant size_t* batch_strides_B = batch_strides_A + batch_ndim_A;
B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B);
if (use_out_source) {
int batch_ndim_C = operand_batch_ndim.z;
const constant int* batch_shape_C = batch_shape_B + batch_ndim_B;
const constant size_t* batch_strides_C = batch_strides_B + batch_ndim_B;
C += elem_to_loc(indx_C, batch_shape_C, batch_strides_C, batch_ndim_C);
}
}
// Handle regular batch
else {
if (has_batch) {
const constant size_t* A_bstrides = batch_strides;
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
A += batch_offsets.x;
B += batch_offsets.y;
if (use_out_source) {
const constant size_t* C_bstrides = B_bstrides + params->batch_ndim;
C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
}
} else {
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
if (use_out_source) {
C += addmm_params->batch_stride_c * tid.z;
}
}
}
D += params->batch_stride_d * tid.z;
// Prepare threadgroup memory
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
threadgroup_barrier(mem_flags::mem_none);
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
D += c_row_long * params->ldd + c_col_long;
if (use_out_source) {
C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc;
}
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup bounds
const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
// Prepare iterations
int gemm_k_iterations = params->gemm_k_iterations_aligned;
// Do unaligned K iterations first
if (!align_K) {
const int k_last = params->gemm_k_iterations_aligned * BK;
const int k_remain = params->K - k_last;
const size_t k_jump_a =
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
const size_t k_jump_b =
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
// Move loader source ahead to end
loader_a.src += k_jump_a;
loader_b.src += k_jump_b;
// Load tile
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B =
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do matmul
mma_op.mma(As, Bs);
// Reset source back to start
loader_a.src -= k_jump_a;
loader_b.src -= k_jump_b;
}
const TransformAdd<AccumType, AccumType> epilogue_op_add(
addmm_params->alpha, addmm_params->beta);
const TransformAxpby<AccumType, AccumType> epilogue_op_axpby(
addmm_params->alpha, addmm_params->beta);
///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop
if (align_M && align_N) {
// Do gemm
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Do epilogue
if (use_out_source) {
if (do_axpby) {
mma_op.apply_epilogue(
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby);
} else {
mma_op.apply_epilogue(
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add);
}
}
// Store results to device memory
return mma_op.store_result(D, params->ldd);
}
///////////////////////////////////////////////////////////////////////////////
// MN unaligned loop
else { // Loop over K - unaligned case
const int leftover_bk = 0;
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
// Do gemm
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, true, true>{});
// Do epilogue
if (use_out_source) {
if (do_axpby) {
mma_op.apply_epilogue(
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby);
} else {
mma_op.apply_epilogue(
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add);
}
}
// Store results to device memory
return mma_op.store_result(D, params->ldd);
} else if (align_N || tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, true, true>{});
// Do epilogue
if (use_out_source) {
if (do_axpby) {
mma_op.apply_epilogue_safe(
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op_axpby);
} else {
mma_op.apply_epilogue_safe(
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op_add);
}
}
// Store results to device memory
return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
} else if (align_M || tgp_bm == BM) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, false, true>{});
// Do epilogue
if (use_out_source) {
if (do_axpby) {
mma_op.apply_epilogue_safe(
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op_axpby);
} else {
mma_op.apply_epilogue_safe(
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op_add);
}
}
// Store results to device memory
return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
} else {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, false, true>{});
// Do epilogue
if (use_out_source) {
if (do_axpby) {
mma_op.apply_epilogue_safe(
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op_axpby);
} else {
mma_op.apply_epilogue_safe(
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op_add);
}
}
// Store results to device memory
return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
}
}
}

View File

@ -1,430 +1,12 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h"
using namespace metal;
using namespace mlx::steel;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
constant bool has_batch [[function_constant(10)]];
constant bool use_out_source [[function_constant(100)]];
constant bool do_axpby [[function_constant(110)]];
constant bool align_M [[function_constant(200)]];
constant bool align_N [[function_constant(201)]];
constant bool align_K [[function_constant(202)]];
constant bool do_gather [[function_constant(300)]];
constant bool gather_bias = do_gather && use_out_source;
// clang-format off
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
typename AccumType = float>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
const device T* C [[buffer(2), function_constant(use_out_source)]],
device T* D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]],
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on
// Pacifying compiler
(void)lid;
using gemm_kernel = GEMMKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
true,
true,
AccumType>;
using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t;
using mma_t = typename gemm_kernel::mma_t;
// Find block
const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;
// Exit early if out of bounds
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
// Adjust for batch
// Handle gather
if (do_gather) {
// Read indices
uint32_t indx_A, indx_B, indx_C;
if (has_batch) {
const constant size_t* indx_A_bstrides = batch_strides;
const constant size_t* indx_B_bstrides =
batch_strides + params->batch_ndim;
ulong2 indx_offsets = elem_to_loc_broadcast(
tid.z,
batch_shape,
indx_A_bstrides,
indx_B_bstrides,
params->batch_ndim);
indx_A = lhs_indices[indx_offsets.x];
indx_B = rhs_indices[indx_offsets.y];
if (use_out_source) {
const constant size_t* indx_C_bstrides =
indx_B_bstrides + params->batch_ndim;
auto indx_offset_C = elem_to_loc(
tid.z, batch_shape, indx_C_bstrides, params->batch_ndim);
indx_C = C_indices[indx_offset_C];
}
} else {
indx_A = lhs_indices[params->batch_stride_a * tid.z];
indx_B = rhs_indices[params->batch_stride_b * tid.z];
if (use_out_source) {
indx_C = C_indices[addmm_params->batch_stride_c * tid.z];
}
}
// Translate indices to offsets
int batch_ndim_A = operand_batch_ndim.x;
const constant int* batch_shape_A = operand_shape;
const constant size_t* batch_strides_A = operand_strides;
A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A);
int batch_ndim_B = operand_batch_ndim.y;
const constant int* batch_shape_B = batch_shape_A + batch_ndim_A;
const constant size_t* batch_strides_B = batch_strides_A + batch_ndim_A;
B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B);
if (use_out_source) {
int batch_ndim_C = operand_batch_ndim.z;
const constant int* batch_shape_C = batch_shape_B + batch_ndim_B;
const constant size_t* batch_strides_C = batch_strides_B + batch_ndim_B;
C += elem_to_loc(indx_C, batch_shape_C, batch_strides_C, batch_ndim_C);
}
}
// Handle regular batch
else {
if (has_batch) {
const constant size_t* A_bstrides = batch_strides;
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
A += batch_offsets.x;
B += batch_offsets.y;
if (use_out_source) {
const constant size_t* C_bstrides = B_bstrides + params->batch_ndim;
C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
}
} else {
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
if (use_out_source) {
C += addmm_params->batch_stride_c * tid.z;
}
}
}
D += params->batch_stride_d * tid.z;
// Prepare threadgroup memory
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
threadgroup_barrier(mem_flags::mem_none);
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
D += c_row_long * params->ldd + c_col_long;
if (use_out_source) {
C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc;
}
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup bounds
const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
// Prepare iterations
int gemm_k_iterations = params->gemm_k_iterations_aligned;
// Do unaligned K iterations first
if (!align_K) {
const int k_last = params->gemm_k_iterations_aligned * BK;
const int k_remain = params->K - k_last;
const size_t k_jump_a =
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
const size_t k_jump_b =
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
// Move loader source ahead to end
loader_a.src += k_jump_a;
loader_b.src += k_jump_b;
// Load tile
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B =
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do matmul
mma_op.mma(As, Bs);
// Reset source back to start
loader_a.src -= k_jump_a;
loader_b.src -= k_jump_b;
}
const TransformAdd<AccumType, AccumType> epilogue_op_add(
addmm_params->alpha, addmm_params->beta);
const TransformAxpby<AccumType, AccumType> epilogue_op_axpby(
addmm_params->alpha, addmm_params->beta);
///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop
if (align_M && align_N) {
// Do gemm
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Do epilogue
if (use_out_source) {
if (do_axpby) {
mma_op.apply_epilogue(
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby);
} else {
mma_op.apply_epilogue(
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add);
}
}
// Store results to device memory
return mma_op.store_result(D, params->ldd);
}
///////////////////////////////////////////////////////////////////////////////
// MN unaligned loop
else { // Loop over K - unaligned case
const int leftover_bk = 0;
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
// Do gemm
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, true, true>{});
// Do epilogue
if (use_out_source) {
if (do_axpby) {
mma_op.apply_epilogue(
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby);
} else {
mma_op.apply_epilogue(
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add);
}
}
// Store results to device memory
return mma_op.store_result(D, params->ldd);
} else if (align_N || tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, true, true>{});
// Do epilogue
if (use_out_source) {
if (do_axpby) {
mma_op.apply_epilogue_safe(
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op_axpby);
} else {
mma_op.apply_epilogue_safe(
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op_add);
}
}
// Store results to device memory
return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
} else if (align_M || tgp_bm == BM) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, false, true>{});
// Do epilogue
if (use_out_source) {
if (do_axpby) {
mma_op.apply_epilogue_safe(
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op_axpby);
} else {
mma_op.apply_epilogue_safe(
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op_add);
}
}
// Store results to device memory
return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
} else {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, false, true>{});
// Do epilogue
if (use_out_source) {
if (do_axpby) {
mma_op.apply_epilogue_safe(
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op_axpby);
} else {
mma_op.apply_epilogue_safe(
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op_add);
}
}
// Store results to device memory
return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
}
}
}
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel initializations
///////////////////////////////////////////////////////////////////////////////
// clang-format off
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ #define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
template [[host_name("steel_gemm_fused_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] \ template [[host_name("steel_gemm_fused_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] \
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, float>( \ [[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, float>( \
@ -445,24 +27,23 @@ template <
uint simd_lane_id [[thread_index_in_simdgroup]], \ uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \ uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \ uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]); // clang-format on uint3 lid [[thread_position_in_threadgroup]]);
// clang-format off
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ #define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
// clang-format off
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \ #define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) // clang-format on instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2)
instantiate_gemm_shapes_helper(float16, half, float16, half); instantiate_gemm_shapes_helper(float16, half, float16, half);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t); instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
instantiate_gemm_shapes_helper(float32, float, float32, float); instantiate_gemm_shapes_helper(float32, float, float32, float);
// clang-format on

View File

@ -0,0 +1,719 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/kernels/steel/defines.h"
using namespace metal;
using namespace mlx::steel;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
struct _NoMask {
char x;
constexpr METAL_FUNC operator bool() {
return true;
}
constexpr METAL_FUNC operator bool() const threadgroup {
return true;
}
constexpr METAL_FUNC operator bool() const device {
return true;
}
constexpr METAL_FUNC operator bool() const constant {
return true;
}
};
template <typename OutT, typename InT = OutT>
struct ScaleOp {
OutT scale;
METAL_FUNC OutT apply(InT x) const {
return static_cast<OutT>(x) * scale;
}
};
typedef struct _NoMask nomask_t;
template <
typename T,
typename out_mask_t,
typename op_mask_t,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
block_masked_gemm(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device T* D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
const device out_mask_t* out_mask [[buffer(10)]],
const device op_mask_t* lhs_mask [[buffer(11)]],
const device op_mask_t* rhs_mask [[buffer(12)]],
const constant int* mask_strides [[buffer(13)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// Appease the compiler
(void)lid;
static_assert(
BM == BN,
"block_masked_gemm must have the same block M and block N size");
static_assert(BM % BK == 0, "block_masked_gemm must have BM % BK == 0");
constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
constexpr bool has_mul_operand_mask =
has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
constexpr bool has_mul_output_mask =
has_output_mask && !metal::is_same_v<out_mask_t, bool>;
constexpr short k_mask_factor = short(BM / BK);
using gemm_kernel = GEMMKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
MN_aligned,
K_aligned>;
const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
const constant size_t* mask_batch_strides =
batch_strides + 2 * params->batch_ndim;
if (params->batch_ndim > 1) {
if (has_output_mask) {
out_mask += elem_to_loc(
tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
mask_batch_strides += params->batch_ndim;
}
if (has_operand_mask) {
const constant size_t* mask_strides_lhs = mask_batch_strides;
const constant size_t* mask_strides_rhs =
mask_strides_lhs + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z,
batch_shape,
mask_strides_lhs,
mask_strides_rhs,
params->batch_ndim);
lhs_mask += batch_offsets.x;
rhs_mask += batch_offsets.y;
}
} else {
if (has_output_mask) {
out_mask += tid.z * mask_batch_strides[0];
mask_batch_strides += params->batch_ndim;
}
if (has_operand_mask) {
lhs_mask += tid.z * mask_batch_strides[0];
rhs_mask += tid.z * mask_batch_strides[params->batch_ndim];
}
}
// Adjust for batch
if (params->batch_ndim > 1) {
const constant size_t* A_bstrides = batch_strides;
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
A += batch_offsets.x;
B += batch_offsets.y;
} else {
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
}
D += params->batch_stride_d * tid.z;
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
D += c_row_long * params->ldd + c_col_long;
const constant int* out_mask_strides = mask_strides;
const constant int* lhs_mask_strides =
mask_strides + (has_output_mask ? 2 : 0);
const constant int* rhs_mask_strides =
lhs_mask_strides + (has_operand_mask ? 2 : 0);
const int out_mask_offset = !has_output_mask
? 0
: tid_y * out_mask_strides[1] + tid_x * out_mask_strides[0];
int lhs_mask_offset = !has_operand_mask ? 0 : tid_y * lhs_mask_strides[1];
int rhs_mask_offset = !has_operand_mask ? 0 : tid_x * rhs_mask_strides[0];
const int lhs_mask_step = !has_operand_mask ? 0 : lhs_mask_strides[0];
const int rhs_mask_step = !has_operand_mask ? 0 : rhs_mask_strides[1];
short k_factor_cnt = k_mask_factor;
ScaleOp<float> out_mask_op;
ScaleOp<T> lhs_mask_op;
ScaleOp<T> rhs_mask_op;
if (has_output_mask) {
auto mask_out = out_mask[out_mask_offset];
if (has_mul_output_mask) {
out_mask_op.scale = float(mask_out);
}
// Write zeros and return
if (!mask_out) {
constexpr short tgp_size = WM * WN * 32;
constexpr short vec_size = 4;
// Tile threads in threadgroup
constexpr short TN = BN / vec_size;
constexpr short TM = tgp_size / TN;
const short thread_idx = simd_group_id * 32 + simd_lane_id;
const short bi = thread_idx / TN;
const short bj = vec_size * (thread_idx % TN);
D += bi * params->ldd + bj;
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
for (short ti = 0; ti < BM; ti += TM) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
D[ti * params->ldd + j] = T(0.);
}
}
} else {
short jmax = tgp_bn - bj;
jmax = jmax < vec_size ? jmax : vec_size;
for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) {
for (short j = 0; j < jmax; j++) {
D[ti * params->ldd + j] = T(0.);
}
}
}
return;
}
}
threadgroup_barrier(mem_flags::mem_none);
// Prepare threadgroup mma operation
thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id);
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Prepare threadgroup loading operations
thread typename gemm_kernel::loader_a_t loader_a(
A, params->lda, As, simd_group_id, simd_lane_id);
thread typename gemm_kernel::loader_b_t loader_b(
B, params->ldb, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup bounds
const short tgp_bm =
MN_aligned ? short(BM) : short(min(BM, params->M - c_row));
const short tgp_bn =
MN_aligned ? short(BN) : short(min(BN, params->N - c_col));
int gemm_k_iterations = params->gemm_k_iterations_aligned;
///////////////////////////////////////////////////////////////////////////////
// Do unaligned K iterations first
if (!K_aligned) {
const int k_last = params->gemm_k_iterations_aligned * BK;
const int mask_idx_last = k_last / BM;
if (!has_operand_mask ||
(bool(lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step]) &&
bool(rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step]))) {
if (has_mul_operand_mask) {
lhs_mask_op.scale =
lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step];
rhs_mask_op.scale =
rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step];
}
// Move loader source ahead to end
const int k_remain = params->K - k_last;
const size_t k_jump_a =
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
const size_t k_jump_b =
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
loader_a.src += k_jump_a;
loader_b.src += k_jump_b;
// Load tile
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B =
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
if (has_mul_operand_mask) {
loader_a.apply_inplace_op(lhs_mask_op);
loader_b.apply_inplace_op(rhs_mask_op);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do matmul
mma_op.mma(As, Bs);
// Reset source back to start
loader_a.src -= k_jump_a;
loader_b.src -= k_jump_b;
}
}
///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop
if (MN_aligned) {
for (; gemm_k_iterations > 0; gemm_k_iterations--) {
threadgroup_barrier(mem_flags::mem_threadgroup);
if (!has_operand_mask ||
(bool(lhs_mask[lhs_mask_offset]) &&
bool(rhs_mask[rhs_mask_offset]))) {
if (has_mul_operand_mask) {
lhs_mask_op.scale = lhs_mask[lhs_mask_offset];
rhs_mask_op.scale = rhs_mask[rhs_mask_offset];
}
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
if (has_mul_operand_mask) {
loader_a.apply_inplace_op(lhs_mask_op);
loader_b.apply_inplace_op(rhs_mask_op);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
}
// Prepare for next iteration
loader_a.next();
loader_b.next();
k_factor_cnt--;
lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0;
rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0;
k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt;
}
if (has_mul_output_mask) {
mma_op.apply_epilogue(out_mask_op);
}
// Store results to device memory
mma_op.store_result(D, params->ldd);
return;
}
///////////////////////////////////////////////////////////////////////////////
// MN unaligned loop
else {
const bool M_aligned = (tgp_bm == BM);
const bool N_aligned = (tgp_bn == BN);
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
const short2 tile_dims_B =
transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
for (; gemm_k_iterations > 0; gemm_k_iterations--) {
threadgroup_barrier(mem_flags::mem_threadgroup);
if (!has_operand_mask ||
(bool(lhs_mask[lhs_mask_offset]) &&
bool(rhs_mask[rhs_mask_offset]))) {
if (has_mul_operand_mask) {
lhs_mask_op.scale = lhs_mask[lhs_mask_offset];
rhs_mask_op.scale = rhs_mask[rhs_mask_offset];
}
// Load elements into threadgroup
if (M_aligned) {
loader_a.load_unsafe();
} else {
loader_a.load_safe(tile_dims_A);
}
if (N_aligned) {
loader_b.load_unsafe();
} else {
loader_b.load_safe(tile_dims_B);
}
if (has_mul_operand_mask) {
loader_a.apply_inplace_op(lhs_mask_op);
loader_b.apply_inplace_op(rhs_mask_op);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
}
// Prepare for next iteration
loader_a.next();
loader_b.next();
k_factor_cnt--;
lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0;
rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0;
k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt;
}
if (has_mul_output_mask) {
mma_op.apply_epilogue(out_mask_op);
}
if (M_aligned && N_aligned) {
mma_op.store_result(D, params->ldd);
} else {
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
}
}
}
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned,
bool has_operand_mask = false>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
block_masked_gemm(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device T* D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
const device bool* out_mask [[buffer(10)]],
const device bool* lhs_mask [[buffer(11)]],
const device bool* rhs_mask [[buffer(12)]],
const constant int* mask_strides [[buffer(13)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// Appease the compiler
(void)lid;
using gemm_kernel = GEMMKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
MN_aligned,
K_aligned>;
const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
if (params->batch_ndim > 1) {
const constant size_t* mask_batch_strides =
batch_strides + 2 * params->batch_ndim;
out_mask +=
elem_to_loc(tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
if (has_operand_mask) {
const constant size_t* mask_strides_lhs =
mask_batch_strides + params->batch_ndim;
const constant size_t* mask_strides_rhs =
mask_strides_lhs + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z,
batch_shape,
mask_strides_lhs,
mask_strides_rhs,
params->batch_ndim);
lhs_mask += batch_offsets.x;
rhs_mask += batch_offsets.y;
}
} else {
out_mask += tid.z * batch_strides[2 * params->batch_ndim];
if (has_operand_mask) {
lhs_mask += tid.z * batch_strides[3 * params->batch_ndim];
rhs_mask += tid.z * batch_strides[4 * params->batch_ndim];
}
}
// Adjust for batch
if (params->batch_ndim > 1) {
const constant size_t* A_bstrides = batch_strides;
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
A += batch_offsets.x;
B += batch_offsets.y;
} else {
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
}
D += params->batch_stride_d * tid.z;
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
D += c_row_long * params->ldd + c_col_long;
bool mask_out = out_mask[tid_y * mask_strides[1] + tid_x * mask_strides[0]];
// Write zeros and return
if (!mask_out) {
constexpr short tgp_size = WM * WN * 32;
constexpr short vec_size = 4;
// Tile threads in threadgroup
constexpr short TN = BN / vec_size;
constexpr short TM = tgp_size / TN;
const short thread_idx = simd_group_id * 32 + simd_lane_id;
const short bi = thread_idx / TN;
const short bj = vec_size * (thread_idx % TN);
D += bi * params->ldd + bj;
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
for (short ti = 0; ti < BM; ti += TM) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
D[ti * params->ldd + j] = T(0.);
}
}
} else {
short jmax = tgp_bn - bj;
jmax = jmax < vec_size ? jmax : vec_size;
for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) {
for (short j = 0; j < jmax; j++) {
D[ti * params->ldd + j] = T(0.);
}
}
}
return;
}
threadgroup_barrier(mem_flags::mem_none);
// Prepare threadgroup mma operation
thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id);
int gemm_k_iterations = params->gemm_k_iterations_aligned;
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Prepare threadgroup loading operations
thread typename gemm_kernel::loader_a_t loader_a(
A, params->lda, As, simd_group_id, simd_lane_id);
thread typename gemm_kernel::loader_b_t loader_b(
B, params->ldb, Bs, simd_group_id, simd_lane_id);
///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop
if (MN_aligned) {
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
if (!has_operand_mask ||
(lhs_mask
[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
rhs_mask
[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
}
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Loop tail
if (!K_aligned) {
if (!has_operand_mask ||
(lhs_mask
[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
rhs_mask
[(params->K / BM) * mask_strides[5] +
tid_x * mask_strides[4]])) {
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
}
// Store results to device memory
mma_op.store_result(D, params->ldd);
return;
}
///////////////////////////////////////////////////////////////////////////////
// MN unaligned loop
else { // Loop over K - unaligned case
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
short lbk = params->K - params->gemm_k_iterations_aligned * BK;
bool M_aligned = (tgp_bm == BM);
bool N_aligned = (tgp_bn == BN);
short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
if (!has_operand_mask ||
(lhs_mask
[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
rhs_mask
[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
// Load elements into threadgroup
if (M_aligned) {
loader_a.load_unsafe();
} else {
loader_a.load_safe(tile_dims_A);
}
if (N_aligned) {
loader_b.load_unsafe();
} else {
loader_b.load_safe(tile_dims_B);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
}
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
if (!K_aligned) {
threadgroup_barrier(mem_flags::mem_threadgroup);
if (!has_operand_mask ||
(lhs_mask
[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
rhs_mask
[(params->K / BM) * mask_strides[5] +
tid_x * mask_strides[4]])) {
short2 tile_dims_A_last =
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
short2 tile_dims_B_last =
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
loader_a.load_safe(tile_dims_A_last);
loader_b.load_safe(tile_dims_B_last);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
}
if (M_aligned && N_aligned) {
mma_op.store_result(D, params->ldd);
} else {
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
}
}
}

View File

@ -1,434 +1,10 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h"
using namespace metal;
using namespace mlx::steel;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
struct _NoMask {
char x;
constexpr METAL_FUNC operator bool() {
return true;
}
constexpr METAL_FUNC operator bool() const threadgroup {
return true;
}
constexpr METAL_FUNC operator bool() const device {
return true;
}
constexpr METAL_FUNC operator bool() const constant {
return true;
}
};
template <typename OutT, typename InT = OutT>
struct ScaleOp {
OutT scale;
METAL_FUNC OutT apply(InT x) const {
return static_cast<OutT>(x) * scale;
}
};
typedef struct _NoMask nomask_t;
template <
typename T,
typename out_mask_t,
typename op_mask_t,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
block_masked_gemm(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device T* D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
const device out_mask_t* out_mask [[buffer(10)]],
const device op_mask_t* lhs_mask [[buffer(11)]],
const device op_mask_t* rhs_mask [[buffer(12)]],
const constant int* mask_strides [[buffer(13)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// Appease the compiler
(void)lid;
static_assert(
BM == BN,
"block_masked_gemm must have the same block M and block N size");
static_assert(BM % BK == 0, "block_masked_gemm must have BM % BK == 0");
constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
constexpr bool has_mul_operand_mask =
has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
constexpr bool has_mul_output_mask =
has_output_mask && !metal::is_same_v<out_mask_t, bool>;
constexpr short k_mask_factor = short(BM / BK);
using gemm_kernel = GEMMKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
MN_aligned,
K_aligned>;
const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
const constant size_t* mask_batch_strides =
batch_strides + 2 * params->batch_ndim;
if (params->batch_ndim > 1) {
if (has_output_mask) {
out_mask += elem_to_loc(
tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
mask_batch_strides += params->batch_ndim;
}
if (has_operand_mask) {
const constant size_t* mask_strides_lhs = mask_batch_strides;
const constant size_t* mask_strides_rhs =
mask_strides_lhs + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z,
batch_shape,
mask_strides_lhs,
mask_strides_rhs,
params->batch_ndim);
lhs_mask += batch_offsets.x;
rhs_mask += batch_offsets.y;
}
} else {
if (has_output_mask) {
out_mask += tid.z * mask_batch_strides[0];
mask_batch_strides += params->batch_ndim;
}
if (has_operand_mask) {
lhs_mask += tid.z * mask_batch_strides[0];
rhs_mask += tid.z * mask_batch_strides[params->batch_ndim];
}
}
// Adjust for batch
if (params->batch_ndim > 1) {
const constant size_t* A_bstrides = batch_strides;
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
A += batch_offsets.x;
B += batch_offsets.y;
} else {
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
}
D += params->batch_stride_d * tid.z;
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
D += c_row_long * params->ldd + c_col_long;
const constant int* out_mask_strides = mask_strides;
const constant int* lhs_mask_strides =
mask_strides + (has_output_mask ? 2 : 0);
const constant int* rhs_mask_strides =
lhs_mask_strides + (has_operand_mask ? 2 : 0);
const int out_mask_offset = !has_output_mask
? 0
: tid_y * out_mask_strides[1] + tid_x * out_mask_strides[0];
int lhs_mask_offset = !has_operand_mask ? 0 : tid_y * lhs_mask_strides[1];
int rhs_mask_offset = !has_operand_mask ? 0 : tid_x * rhs_mask_strides[0];
const int lhs_mask_step = !has_operand_mask ? 0 : lhs_mask_strides[0];
const int rhs_mask_step = !has_operand_mask ? 0 : rhs_mask_strides[1];
short k_factor_cnt = k_mask_factor;
ScaleOp<float> out_mask_op;
ScaleOp<T> lhs_mask_op;
ScaleOp<T> rhs_mask_op;
if (has_output_mask) {
auto mask_out = out_mask[out_mask_offset];
if (has_mul_output_mask) {
out_mask_op.scale = float(mask_out);
}
// Write zeros and return
if (!mask_out) {
constexpr short tgp_size = WM * WN * 32;
constexpr short vec_size = 4;
// Tile threads in threadgroup
constexpr short TN = BN / vec_size;
constexpr short TM = tgp_size / TN;
const short thread_idx = simd_group_id * 32 + simd_lane_id;
const short bi = thread_idx / TN;
const short bj = vec_size * (thread_idx % TN);
D += bi * params->ldd + bj;
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
for (short ti = 0; ti < BM; ti += TM) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
D[ti * params->ldd + j] = T(0.);
}
}
} else {
short jmax = tgp_bn - bj;
jmax = jmax < vec_size ? jmax : vec_size;
for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) {
for (short j = 0; j < jmax; j++) {
D[ti * params->ldd + j] = T(0.);
}
}
}
return;
}
}
threadgroup_barrier(mem_flags::mem_none);
// Prepare threadgroup mma operation
thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id);
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Prepare threadgroup loading operations
thread typename gemm_kernel::loader_a_t loader_a(
A, params->lda, As, simd_group_id, simd_lane_id);
thread typename gemm_kernel::loader_b_t loader_b(
B, params->ldb, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup bounds
const short tgp_bm =
MN_aligned ? short(BM) : short(min(BM, params->M - c_row));
const short tgp_bn =
MN_aligned ? short(BN) : short(min(BN, params->N - c_col));
int gemm_k_iterations = params->gemm_k_iterations_aligned;
///////////////////////////////////////////////////////////////////////////////
// Do unaligned K iterations first
if (!K_aligned) {
const int k_last = params->gemm_k_iterations_aligned * BK;
const int mask_idx_last = k_last / BM;
if (!has_operand_mask ||
(bool(lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step]) &&
bool(rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step]))) {
if (has_mul_operand_mask) {
lhs_mask_op.scale =
lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step];
rhs_mask_op.scale =
rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step];
}
// Move loader source ahead to end
const int k_remain = params->K - k_last;
const size_t k_jump_a =
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
const size_t k_jump_b =
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
loader_a.src += k_jump_a;
loader_b.src += k_jump_b;
// Load tile
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B =
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
if (has_mul_operand_mask) {
loader_a.apply_inplace_op(lhs_mask_op);
loader_b.apply_inplace_op(rhs_mask_op);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do matmul
mma_op.mma(As, Bs);
// Reset source back to start
loader_a.src -= k_jump_a;
loader_b.src -= k_jump_b;
}
}
///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop
if (MN_aligned) {
for (; gemm_k_iterations > 0; gemm_k_iterations--) {
threadgroup_barrier(mem_flags::mem_threadgroup);
if (!has_operand_mask ||
(bool(lhs_mask[lhs_mask_offset]) &&
bool(rhs_mask[rhs_mask_offset]))) {
if (has_mul_operand_mask) {
lhs_mask_op.scale = lhs_mask[lhs_mask_offset];
rhs_mask_op.scale = rhs_mask[rhs_mask_offset];
}
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
if (has_mul_operand_mask) {
loader_a.apply_inplace_op(lhs_mask_op);
loader_b.apply_inplace_op(rhs_mask_op);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
}
// Prepare for next iteration
loader_a.next();
loader_b.next();
k_factor_cnt--;
lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0;
rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0;
k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt;
}
if (has_mul_output_mask) {
mma_op.apply_epilogue(out_mask_op);
}
// Store results to device memory
mma_op.store_result(D, params->ldd);
return;
}
///////////////////////////////////////////////////////////////////////////////
// MN unaligned loop
else {
const bool M_aligned = (tgp_bm == BM);
const bool N_aligned = (tgp_bn == BN);
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
const short2 tile_dims_B =
transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
for (; gemm_k_iterations > 0; gemm_k_iterations--) {
threadgroup_barrier(mem_flags::mem_threadgroup);
if (!has_operand_mask ||
(bool(lhs_mask[lhs_mask_offset]) &&
bool(rhs_mask[rhs_mask_offset]))) {
if (has_mul_operand_mask) {
lhs_mask_op.scale = lhs_mask[lhs_mask_offset];
rhs_mask_op.scale = rhs_mask[rhs_mask_offset];
}
// Load elements into threadgroup
if (M_aligned) {
loader_a.load_unsafe();
} else {
loader_a.load_safe(tile_dims_A);
}
if (N_aligned) {
loader_b.load_unsafe();
} else {
loader_b.load_safe(tile_dims_B);
}
if (has_mul_operand_mask) {
loader_a.apply_inplace_op(lhs_mask_op);
loader_b.apply_inplace_op(rhs_mask_op);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
}
// Prepare for next iteration
loader_a.next();
loader_b.next();
k_factor_cnt--;
lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0;
rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0;
k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt;
}
if (has_mul_output_mask) {
mma_op.apply_epilogue(out_mask_op);
}
if (M_aligned && N_aligned) {
mma_op.store_result(D, params->ldd);
} else {
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
}
}
}
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel initializations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_gemm( \ #define instantiate_gemm( \
outmaskname, \ outmaskname, \
@ -483,7 +59,6 @@ block_masked_gemm(
uint3 tid [[threadgroup_position_in_grid]], \ uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]); uint3 lid [[thread_position_in_threadgroup]]);
// clang-format off
#define instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ #define instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
instantiate_gemm(bool_, bool, bool_, bool, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ instantiate_gemm(bool_, bool, bool_, bool, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
instantiate_gemm(iname, itype, iname, itype, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ instantiate_gemm(iname, itype, iname, itype, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
@ -492,28 +67,24 @@ block_masked_gemm(
instantiate_gemm(nomask, nomask_t, bool_, bool, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ instantiate_gemm(nomask, nomask_t, bool_, bool, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
instantiate_gemm(nomask, nomask_t, iname, itype, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ instantiate_gemm(nomask, nomask_t, iname, itype, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
instantiate_gemm(bool_, bool, nomask, nomask_t, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ instantiate_gemm(bool_, bool, nomask, nomask_t, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
instantiate_gemm(iname, itype, nomask, nomask_t, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) // clang-format on instantiate_gemm(iname, itype, nomask, nomask_t, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned)
// clang-format off
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ #define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \ instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \ instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \ instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
// clang-format off
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ #define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
// clang-format off
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \ #define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) // clang-format on instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2)
// clang-format off
instantiate_gemm_shapes_helper(float16, half, float16, half); instantiate_gemm_shapes_helper(float16, half, float16, half);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t); instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on

View File

@ -0,0 +1,227 @@
// Copyright © 2024 Apple Inc.
using namespace mlx::steel;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm_splitk(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device U* C [[buffer(2)]],
const constant GEMMSpiltKParams* params [[buffer(3)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
(void)lid;
using gemm_kernel = GEMMKernel<
T,
U,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
MN_aligned,
K_aligned>;
using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t;
using mma_t = typename gemm_kernel::mma_t;
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
const int tid_x = tid.x;
const int tid_y = tid.y;
const int tid_z = tid.z;
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const int k_start = params->split_k_partition_size * tid_z;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
const size_t k_start_long = size_t(k_start);
A += transpose_a ? (c_row_long + k_start_long * params->lda)
: (k_start_long + c_row_long * params->lda);
B += transpose_b ? (k_start_long + c_col_long * params->ldb)
: (c_col_long + k_start_long * params->ldb);
C += (size_t(params->split_k_partition_stride) * tid_z) +
(c_row_long * params->ldc + c_col_long);
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
int gemm_k_iterations = params->gemm_k_iterations_aligned;
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
short leftover_bk = params->K % BK;
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, true, true>{});
} else if (tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, true, true>{});
} else if (tgp_bm == BM) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, false, true>{});
} else {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, false, true>{});
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if ((tid_z + 1) == (params->split_k_partitions)) {
int gemm_k_iter_remaining =
(params->K - (k_start + params->split_k_partition_size)) / BK;
if (!K_aligned || gemm_k_iter_remaining > 0)
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iter_remaining,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, false, K_aligned>{});
}
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
mma_op.store_result(C, params->ldc);
} else {
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
}
}
///////////////////////////////////////////////////////////////////////////////
// Split k accumulation kernel
///////////////////////////////////////////////////////////////////////////////
template <
typename AccT,
typename OutT,
typename Epilogue = TransformNone<OutT, AccT>>
[[kernel]] void gemm_splitk_accum(
const device AccT* C_split [[buffer(0)]],
device OutT* D [[buffer(1)]],
const constant int& k_partitions [[buffer(2)]],
const constant int& partition_stride [[buffer(3)]],
const constant int& ldd [[buffer(4)]],
uint2 gid [[thread_position_in_grid]]) {
// Ajust D and C
D += gid.x + gid.y * size_t(ldd);
C_split += gid.x + gid.y * size_t(ldd);
size_t offset = 0;
AccT out = 0;
for (int i = 0; i < k_partitions; i++) {
out += C_split[offset];
offset += partition_stride;
}
// Write output
D[0] = Epilogue::apply(out);
}
template <
typename AccT,
typename OutT,
typename Epilogue = TransformAxpby<OutT, AccT>>
[[kernel]] void gemm_splitk_accum_axpby(
const device AccT* C_split [[buffer(0)]],
device OutT* D [[buffer(1)]],
const constant int& k_partitions [[buffer(2)]],
const constant int& partition_stride [[buffer(3)]],
const constant int& ldd [[buffer(4)]],
const device OutT* C [[buffer(5)]],
const constant int& ldc [[buffer(6)]],
const constant int& fdc [[buffer(7)]],
const constant float& alpha [[buffer(8)]],
const constant float& beta [[buffer(9)]],
uint2 gid [[thread_position_in_grid]]) {
// Ajust D and C
C += gid.x * size_t(fdc) + gid.y * size_t(ldc);
D += gid.x + gid.y * size_t(ldd);
C_split += gid.x + gid.y * size_t(ldd);
size_t offset = 0;
AccT out = 0;
for (int i = 0; i < k_partitions; i++) {
out += C_split[offset];
offset += partition_stride;
}
// Write output
Epilogue op(alpha, beta);
D[0] = op.apply(out, *C);
}

View File

@ -1,173 +1,9 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h"
using namespace metal;
using namespace mlx::steel;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm_splitk(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device U* C [[buffer(2)]],
const constant GEMMSpiltKParams* params [[buffer(3)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
(void)lid;
using gemm_kernel = GEMMKernel<
T,
U,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
MN_aligned,
K_aligned>;
using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t;
using mma_t = typename gemm_kernel::mma_t;
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
const int tid_x = tid.x;
const int tid_y = tid.y;
const int tid_z = tid.z;
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const int k_start = params->split_k_partition_size * tid_z;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
const size_t k_start_long = size_t(k_start);
A += transpose_a ? (c_row_long + k_start_long * params->lda)
: (k_start_long + c_row_long * params->lda);
B += transpose_b ? (k_start_long + c_col_long * params->ldb)
: (c_col_long + k_start_long * params->ldb);
C += (size_t(params->split_k_partition_stride) * tid_z) +
(c_row_long * params->ldc + c_col_long);
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
int gemm_k_iterations = params->gemm_k_iterations_aligned;
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
short leftover_bk = params->K % BK;
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, true, true>{});
} else if (tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, true, true>{});
} else if (tgp_bm == BM) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, false, true>{});
} else {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, false, true>{});
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if ((tid_z + 1) == (params->split_k_partitions)) {
int gemm_k_iter_remaining =
(params->K - (k_start + params->split_k_partition_size)) / BK;
if (!K_aligned || gemm_k_iter_remaining > 0)
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iter_remaining,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, false, K_aligned>{});
}
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
mma_op.store_result(C, params->ldc);
} else {
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
}
}
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel initializations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_gemm( \ #define instantiate_gemm( \
tname, \ tname, \
@ -210,97 +46,27 @@ template <
uint3 tid [[threadgroup_position_in_grid]], \ uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]); uint3 lid [[thread_position_in_threadgroup]]);
// clang-format off
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ #define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \ instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \ instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \ instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
// clang-format off
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ #define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
// clang-format off
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \ #define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 16, 16, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 16, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 32, 16, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 16, 16, 2, 2) \ instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 16, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) // clang-format on instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
// clang-format off
instantiate_gemm_shapes_helper(float16, half, float32, float); instantiate_gemm_shapes_helper(float16, half, float32, float);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float); instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float);
instantiate_gemm_shapes_helper(float32, float, float32, float);
instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on
///////////////////////////////////////////////////////////////////////////////
// Split k accumulation kernel
///////////////////////////////////////////////////////////////////////////////
template <
typename AccT,
typename OutT,
typename Epilogue = TransformNone<OutT, AccT>>
[[kernel]] void gemm_splitk_accum(
const device AccT* C_split [[buffer(0)]],
device OutT* D [[buffer(1)]],
const constant int& k_partitions [[buffer(2)]],
const constant int& partition_stride [[buffer(3)]],
const constant int& ldd [[buffer(4)]],
uint2 gid [[thread_position_in_grid]]) {
// Ajust D and C
D += gid.x + gid.y * size_t(ldd);
C_split += gid.x + gid.y * size_t(ldd);
size_t offset = 0;
AccT out = 0;
for (int i = 0; i < k_partitions; i++) {
out += C_split[offset];
offset += partition_stride;
}
// Write output
D[0] = Epilogue::apply(out);
}
template <
typename AccT,
typename OutT,
typename Epilogue = TransformAxpby<OutT, AccT>>
[[kernel]] void gemm_splitk_accum_axpby(
const device AccT* C_split [[buffer(0)]],
device OutT* D [[buffer(1)]],
const constant int& k_partitions [[buffer(2)]],
const constant int& partition_stride [[buffer(3)]],
const constant int& ldd [[buffer(4)]],
const device OutT* C [[buffer(5)]],
const constant int& ldc [[buffer(6)]],
const constant int& fdc [[buffer(7)]],
const constant float& alpha [[buffer(8)]],
const constant float& beta [[buffer(9)]],
uint2 gid [[thread_position_in_grid]]) {
// Ajust D and C
C += gid.x * size_t(fdc) + gid.y * size_t(ldc);
D += gid.x + gid.y * size_t(ldd);
C_split += gid.x + gid.y * size_t(ldd);
size_t offset = 0;
AccT out = 0;
for (int i = 0; i < k_partitions; i++) {
out += C_split[offset];
offset += partition_stride;
}
// Write output
Epilogue op(alpha, beta);
D[0] = op.apply(out, *C);
}
#define instantiate_accum(oname, otype, aname, atype) \ #define instantiate_accum(oname, otype, aname, atype) \
template [[host_name("steel_gemm_splitk_accum_" #oname \ template [[host_name("steel_gemm_splitk_accum_" #oname \
@ -313,7 +79,7 @@ template <
const constant int& ldd [[buffer(4)]], \ const constant int& ldd [[buffer(4)]], \
uint2 gid [[thread_position_in_grid]]); \ uint2 gid [[thread_position_in_grid]]); \
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname \ template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname \
"_axpby")]] [[kernel]] void \ "_axbpy")]] [[kernel]] void \
gemm_splitk_accum_axpby<atype, otype>( \ gemm_splitk_accum_axpby<atype, otype>( \
const device atype* C_split [[buffer(0)]], \ const device atype* C_split [[buffer(0)]], \
device otype* D [[buffer(1)]], \ device otype* D [[buffer(1)]], \
@ -327,7 +93,6 @@ template <
const constant float& beta [[buffer(9)]], \ const constant float& beta [[buffer(9)]], \
uint2 gid [[thread_position_in_grid]]); uint2 gid [[thread_position_in_grid]]);
// clang-format off
instantiate_accum(bfloat16, bfloat16_t, float32, float); instantiate_accum(bfloat16, bfloat16_t, float32, float);
instantiate_accum(float16, half, float32, float); instantiate_accum(float16, half, float32, float);
instantiate_accum(float32, float, float32, float); // clang-format on instantiate_accum(float32, float, float32, float); // clang-format on

View File

@ -2,7 +2,7 @@
#pragma once #pragma once
#include "mlx/backend/metal/kernels/steel/utils.h" #include "mlx/backend/metal/kernels/steel/defines.h"
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Loading helper // Loading helper
@ -134,4 +134,4 @@ struct BlockLoader {
}; };
} // namespace steel } // namespace steel
} // namespace mlx } // namespace mlx

View File

@ -6,8 +6,8 @@
#include <metal_simdgroup_matrix> #include <metal_simdgroup_matrix>
#include <metal_stdlib> #include <metal_stdlib>
#include "mlx/backend/metal/kernels/steel/defines.h"
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h" #include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
#include "mlx/backend/metal/kernels/steel/utils.h"
using namespace metal; using namespace metal;
@ -358,4 +358,4 @@ struct BlockMMA {
}; };
} // namespace steel } // namespace steel
} // namespace mlx } // namespace mlx

View File

@ -4,9 +4,6 @@
#include <metal_stdlib> #include <metal_stdlib>
#define STEEL_CONST static constant constexpr const
#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
METAL_FUNC ulong2 elem_to_loc_broadcast( METAL_FUNC ulong2 elem_to_loc_broadcast(
uint elem, uint elem,
constant const int* shape, constant const int* shape,
@ -42,4 +39,4 @@ METAL_FUNC ulong3 elem_to_loc_broadcast(
loc_c += pos_in_dim * c_strides[i]; loc_c += pos_in_dim * c_strides[i];
} }
return ulong3(loc_a, loc_b, loc_c); return ulong3(loc_a, loc_b, loc_c);
} }

View File

@ -8,9 +8,10 @@
OUTPUT_DIR=$1 OUTPUT_DIR=$1
CC=$2 CC=$2
SRC_DIR=$3 SRC_DIR=$3
SRC_NAME=$4 SRC_FILE=$4
CFLAGS=$5 CFLAGS=$5
INPUT_FILE=${SRC_DIR}/mlx/backend/metal/kernels/${SRC_NAME}.h SRC_NAME=$(basename -- "${SRC_FILE}")
INPUT_FILE=${SRC_DIR}/mlx/backend/metal/kernels/${SRC_FILE}.h
OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp
mkdir -p $OUTPUT_DIR mkdir -p $OUTPUT_DIR

View File

@ -7,6 +7,7 @@
#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/steel/gemm/params.h" #include "mlx/backend/metal/kernels/steel/gemm/params.h"
#include "mlx/backend/metal/matmul.h" #include "mlx/backend/metal/matmul.h"
@ -336,7 +337,19 @@ void steel_matmul_conv_groups(
// Encode and dispatch kernel // Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts); auto kernel = get_steel_gemm_fused_kernel(
d,
base_name,
hash_name,
func_consts,
out,
transpose_a,
transpose_b,
bm,
bn,
bk,
wm,
wn);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
@ -458,17 +471,31 @@ void steel_matmul(
C_split.set_data(allocator::malloc_or_wait(C_split.nbytes())); C_split.set_data(allocator::malloc_or_wait(C_split.nbytes()));
copies.push_back(C_split); copies.push_back(C_split);
bool mn_aligned = M % bm == 0 && N % bn == 0;
bool k_aligned = K % bk == 0;
std::ostringstream kname; std::ostringstream kname;
kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n') kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk << type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn << "_MN_" << "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n")
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_" << "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned";
<< ((K % bk == 0) ? "t" : "n") << "aligned";
// Encode and dispatch gemm kernel // Encode and dispatch gemm kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = get_steel_gemm_splitk_kernel(
d,
kname.str(),
a,
C_split,
transpose_a,
transpose_b,
bm,
bn,
bk,
wm,
wn,
mn_aligned,
k_aligned);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
int tn = (N + bn - 1) / bn; int tn = (N + bn - 1) / bn;
@ -504,10 +531,11 @@ void steel_matmul(
static_cast<const MTL::Resource*>(C_split.buffer().ptr()); static_cast<const MTL::Resource*>(C_split.buffer().ptr());
const class MTL::Resource* const resources[1] = {c_split_buf}; const class MTL::Resource* const resources[1] = {c_split_buf};
compute_encoder->memoryBarrier(resources, 1); compute_encoder->memoryBarrier(resources, 1);
auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
type_to_name(C_split);
auto kernel = d.get_kernel( auto kernel = get_steel_gemm_splitk_accum_kernel(
"steel_gemm_splitk_accum_" + type_to_name(out) + "_" + d, kernel_name, C_split, out, false);
type_to_name(C_split));
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
// Set the arguments for the kernel // Set the arguments for the kernel
@ -587,7 +615,19 @@ void steel_matmul(
// Encode and dispatch kernel // Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts); auto kernel = get_steel_gemm_fused_kernel(
d,
base_name,
hash_name,
func_consts,
out,
transpose_a,
transpose_b,
bm,
bn,
bk,
wm,
wn);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
@ -1053,17 +1093,33 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
C_split.set_data(allocator::malloc_or_wait(C_split.nbytes())); C_split.set_data(allocator::malloc_or_wait(C_split.nbytes()));
copies.push_back(C_split); copies.push_back(C_split);
bool mn_aligned = M % bm == 0 && N % bn == 0;
bool k_aligned = K % bk == 0;
std::ostringstream kname; std::ostringstream kname;
kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n') kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk << type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn << "_MN_" << "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n")
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_" << "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned";
<< ((K % bk == 0) ? "t" : "n") << "aligned";
// Encode and dispatch gemm kernel // Encode and dispatch gemm kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = get_steel_gemm_splitk_kernel(
d,
kname.str(),
a,
C_split,
transpose_a,
transpose_b,
bm,
bn,
bk,
wm,
wn,
mn_aligned,
k_aligned);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
int tn = (N + bn - 1) / bn; int tn = (N + bn - 1) / bn;
@ -1095,9 +1151,11 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Do accum kernel // Do accum kernel
{ {
auto kernel = d.get_kernel( auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
"steel_gemm_splitk_accum_" + type_to_name(out) + "_" + type_to_name(C_split) + "_axbpy";
type_to_name(C_split) + "_axpby"); auto kernel = get_steel_gemm_splitk_accum_kernel(
d, kernel_name, C_split, out, true);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
// Set the arguments for the kernel // Set the arguments for the kernel
@ -1182,7 +1240,19 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Encode and dispatch kernel // Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts); auto kernel = get_steel_gemm_fused_kernel(
d,
base_name,
hash_name,
func_consts,
out,
transpose_a,
transpose_b,
bm,
bn,
bk,
wm,
wn);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
@ -1348,6 +1418,8 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Determine dispatch kernel // Determine dispatch kernel
int bm = block_size_, bn = block_size_, bk = 16; int bm = block_size_, bn = block_size_, bk = 16;
int wm = 2, wn = 2; int wm = 2, wn = 2;
bool mn_aligned = M % bm == 0 && N % bn == 0;
bool k_aligned = K % bk == 0;
// Prepare kernel name // Prepare kernel name
std::string out_mask_nm = has_out_mask ? type_to_name(inputs[2]) : "nomask"; std::string out_mask_nm = has_out_mask ? type_to_name(inputs[2]) : "nomask";
@ -1358,13 +1430,26 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
<< op_mask_nm << "_" << (transpose_a ? 't' : 'n') << op_mask_nm << "_" << (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk << type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn << "_MN_" << "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n")
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_" << "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned";
<< ((K % bk == 0) ? "t" : "n") << "aligned";
// Encode and dispatch kernel // Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = get_steel_gemm_masked_kernel(
d,
kname.str(),
out,
has_out_mask ? std::optional<array>{inputs[2]} : std::nullopt,
has_op_mask ? std::optional<array>{inputs.back()} : std::nullopt,
transpose_a,
transpose_b,
bm,
bn,
bk,
wm,
wn,
mn_aligned,
k_aligned);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
// Use problem size to determine threadblock swizzle // Use problem size to determine threadblock swizzle
@ -1720,7 +1805,19 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Encode and dispatch kernel // Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts); auto kernel = get_steel_gemm_fused_kernel(
d,
base_name,
hash_name,
func_consts,
out,
transpose_a,
transpose_b,
bm,
bn,
bk,
wm,
wn);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);

View File

@ -103,4 +103,90 @@ MTL::ComputePipelineState* get_reduce_kernel(
return d.get_kernel(kernel_name); return d.get_kernel(kernel_name);
} }
MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array&,
bool,
bool,
int,
int,
int,
int,
int) {
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
}
MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
metal::Device& d,
const std::string& kernel_name,
const array&,
const array&,
bool,
bool,
int,
int,
int,
int,
int,
bool,
bool) {
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
metal::Device& d,
const std::string& kernel_name,
const array&,
const array&,
bool) {
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
metal::Device& d,
const std::string& kernel_name,
const array&,
const std::optional<array>& mask_out,
const std::optional<array>& mask_op,
bool,
bool,
int,
int,
int,
int,
int,
bool,
bool) {
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_steel_conv_kernel(
metal::Device& d,
const std::string& kernel_name,
const array&,
int,
int,
int,
int,
int,
int,
bool) {
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_steel_conv_general_kernel(
metal::Device& d,
const std::string& kernel_name,
const array&,
int,
int,
int,
int,
int) {
return d.get_kernel(kernel_name);
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -3968,9 +3968,9 @@ void init_ops(nb::module_& m) {
a (array): Input array or scalar. a (array): Input array or scalar.
b (array): Input array or scalar. b (array): Input array or scalar.
block_size (int): Size of blocks to be masked. Must be ``32`` or ``64`` (default: ``64``) block_size (int): Size of blocks to be masked. Must be ``32`` or ``64`` (default: ``64``)
mask_out (array, optional): Boolean mask for output (default: ``None``) mask_out (array, optional): Mask for output (default: ``None``)
mask_lhs (array, optional): Boolean mask for a (default: ``None``) mask_lhs (array, optional): Mask for a (default: ``None``)
mask_rhs (array, optional): Boolean mask for b (default: ``None``) mask_rhs (array, optional): Mask for b (default: ``None``)
)pbdoc"); )pbdoc");
m.def( m.def(

View File

@ -556,8 +556,7 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Batched matmul with simple broadcast
# # Batched matmul with simple broadcast
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32) b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32)
@ -573,7 +572,6 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Matmul with vector # Matmul with vector
a_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32) a_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 128)).astype(np.float32) b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 128)).astype(np.float32)