mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Option to JIT steel gemm / conv (#1139)
This commit is contained in:
parent
eab2685c67
commit
7e26fd8032
@ -1,4 +1,4 @@
|
||||
function(make_jit_source SRC_NAME)
|
||||
function(make_jit_source SRC_FILE)
|
||||
# This function takes a metal header file,
|
||||
# runs the C preprocessesor on it, and makes
|
||||
# the processed contents available as a string in a C++ function
|
||||
@ -9,6 +9,7 @@ function(make_jit_source SRC_NAME)
|
||||
#
|
||||
# Additional arguments to this function are treated as dependencies
|
||||
# in the Cmake build system.
|
||||
get_filename_component(SRC_NAME ${SRC_FILE} NAME)
|
||||
add_custom_command(
|
||||
OUTPUT jit/${SRC_NAME}.cpp
|
||||
COMMAND /bin/bash
|
||||
@ -16,10 +17,10 @@ function(make_jit_source SRC_NAME)
|
||||
${CMAKE_CURRENT_BINARY_DIR}/jit
|
||||
${CMAKE_C_COMPILER}
|
||||
${PROJECT_SOURCE_DIR}
|
||||
${SRC_NAME}
|
||||
${SRC_FILE}
|
||||
"-D${MLX_METAL_VERSION}"
|
||||
DEPENDS make_compiled_preamble.sh
|
||||
kernels/${SRC_NAME}.h
|
||||
kernels/${SRC_FILE}.h
|
||||
${ARGN}
|
||||
)
|
||||
add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp)
|
||||
@ -73,6 +74,39 @@ if (MLX_METAL_JIT)
|
||||
kernels/reduction/reduce_col.h
|
||||
kernels/reduction/reduce_row.h
|
||||
)
|
||||
make_jit_source(
|
||||
steel/gemm/gemm
|
||||
kernels/steel/utils.h
|
||||
kernels/steel/gemm/loader.h
|
||||
kernels/steel/gemm/mma.h
|
||||
kernels/steel/gemm/params.h
|
||||
kernels/steel/gemm/transforms.h
|
||||
)
|
||||
make_jit_source(steel/gemm/kernels/steel_gemm_fused)
|
||||
make_jit_source(
|
||||
steel/gemm/kernels/steel_gemm_masked
|
||||
kernels/steel/defines.h
|
||||
)
|
||||
make_jit_source(steel/gemm/kernels/steel_gemm_splitk)
|
||||
make_jit_source(
|
||||
steel/conv/conv
|
||||
kernels/steel/utils.h
|
||||
kernels/steel/defines.h
|
||||
kernels/steel/gemm/mma.h
|
||||
kernels/steel/gemm/transforms.h
|
||||
kernels/steel/conv/params.h
|
||||
kernels/steel/conv/loader.h
|
||||
kernels/steel/conv/loaders/loader_channel_l.h
|
||||
kernels/steel/conv/loaders/loader_channel_n.h
|
||||
)
|
||||
make_jit_source(
|
||||
steel/conv/kernels/steel_conv
|
||||
)
|
||||
make_jit_source(
|
||||
steel/conv/kernels/steel_conv_general
|
||||
kernels/steel/defines.h
|
||||
kernels/steel/conv/loaders/loader_general.h
|
||||
)
|
||||
else()
|
||||
target_sources(
|
||||
mlx
|
||||
|
@ -7,6 +7,7 @@
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
#include "mlx/backend/metal/matmul.h"
|
||||
@ -335,7 +336,17 @@ void implicit_gemm_conv_2D_gpu(
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel = get_steel_conv_kernel(
|
||||
d,
|
||||
kname.str(),
|
||||
out,
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
n_channel_specialization,
|
||||
small_filter);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Deduce grid launch dimensions
|
||||
@ -488,7 +499,8 @@ void implicit_gemm_conv_2D_general_gpu(
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel =
|
||||
get_steel_conv_general_kernel(d, kname.str(), out, bm, bn, bk, wm, wn);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Deduce grid launch dimensions
|
||||
|
@ -23,4 +23,12 @@ const char* softmax();
|
||||
const char* sort();
|
||||
const char* reduce();
|
||||
|
||||
const char* gemm();
|
||||
const char* steel_gemm_fused();
|
||||
const char* steel_gemm_masked();
|
||||
const char* steel_gemm_splitk();
|
||||
const char* conv();
|
||||
const char* steel_conv();
|
||||
const char* steel_conv_general();
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
32
mlx/backend/metal/jit/steel_conv.h
Normal file
32
mlx/backend/metal/jit/steel_conv.h
Normal file
@ -0,0 +1,32 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view steel_conv_kernels = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
implicit_gemm_conv_2d<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}, {n_channels}, {small_filter}>(
|
||||
const device {itype}* A [[buffer(0)]],
|
||||
const device {itype}* B [[buffer(1)]],
|
||||
device {itype}* C [[buffer(2)]],
|
||||
const constant MLXConvParams<2>* params [[buffer(3)]],
|
||||
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view steel_conv_general_kernels = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
implicit_gemm_conv_2d_general<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}>(
|
||||
const device {itype}* A [[buffer(0)]],
|
||||
const device {itype}* B [[buffer(1)]],
|
||||
device {itype}* C [[buffer(2)]],
|
||||
const constant MLXConvParams<2>* params [[buffer(3)]],
|
||||
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
|
||||
const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]],
|
||||
const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]],
|
||||
const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
)";
|
106
mlx/backend/metal/jit/steel_gemm.h
Normal file
106
mlx/backend/metal/jit/steel_gemm.h
Normal file
@ -0,0 +1,106 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view steel_gemm_fused_kernels = R"(
|
||||
template [[host_name("{name}")]]
|
||||
[[kernel]] void gemm<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}, {trans_a}, {trans_b}, float>(
|
||||
const device {itype} *A [[buffer(0)]],
|
||||
const device {itype} *B [[buffer(1)]],
|
||||
const device {itype} *C [[buffer(2), function_constant(use_out_source)]],
|
||||
device {itype} *D [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant size_t* batch_strides [[buffer(7)]],
|
||||
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
|
||||
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
|
||||
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
|
||||
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
|
||||
const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]],
|
||||
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view steel_gemm_masked_kernels = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
block_masked_gemm<
|
||||
{itype},
|
||||
{outmasktype},
|
||||
{opmasktype},
|
||||
{bm},
|
||||
{bn},
|
||||
{bk},
|
||||
{wm},
|
||||
{wn},
|
||||
{trans_a},
|
||||
{trans_b},
|
||||
{mn_aligned},
|
||||
{k_aligned}>(
|
||||
const device {itype}* A [[buffer(0)]],
|
||||
const device {itype}* B [[buffer(1)]],
|
||||
device {itype}* D [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant size_t* batch_strides [[buffer(7)]],
|
||||
const device {outmasktype}* out_mask [[buffer(10)]],
|
||||
const device {opmasktype}* lhs_mask [[buffer(11)]],
|
||||
const device {opmasktype}* rhs_mask [[buffer(12)]],
|
||||
const constant int* mask_strides [[buffer(13)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view steel_gemm_splitk_kernels = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
gemm_splitk<
|
||||
{itype},
|
||||
{otype},
|
||||
{bm},
|
||||
{bn},
|
||||
{bk},
|
||||
{wm},
|
||||
{wn},
|
||||
{trans_a},
|
||||
{trans_b},
|
||||
{mn_aligned},
|
||||
{k_aligned}>(
|
||||
const device {itype}* A [[buffer(0)]],
|
||||
const device {itype}* B [[buffer(1)]],
|
||||
device {otype}* C [[buffer(2)]],
|
||||
const constant GEMMSpiltKParams* params [[buffer(3)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view steel_gemm_splitk_accum_kernels = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
gemm_splitk_accum<{atype}, {otype}>(
|
||||
const device {atype}* C_split [[buffer(0)]],
|
||||
device {otype}* D [[buffer(1)]],
|
||||
const constant int& k_partitions [[buffer(2)]],
|
||||
const constant int& partition_stride [[buffer(3)]],
|
||||
const constant int& ldd [[buffer(4)]],
|
||||
uint2 gid [[thread_position_in_grid]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view steel_gemm_splitk_accum_axbpy_kernels = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
gemm_splitk_accum_axpby<{atype}, {otype}>(
|
||||
const device {atype}* C_split [[buffer(0)]],
|
||||
device {otype}* D [[buffer(1)]],
|
||||
const constant int& k_partitions [[buffer(2)]],
|
||||
const constant int& partition_stride [[buffer(3)]],
|
||||
const constant int& ldd [[buffer(4)]],
|
||||
const device {otype}* C [[buffer(5)]],
|
||||
const constant int& ldc [[buffer(6)]],
|
||||
const constant int& fdc [[buffer(7)]],
|
||||
const constant float& alpha [[buffer(8)]],
|
||||
const constant float& beta [[buffer(9)]],
|
||||
uint2 gid [[thread_position_in_grid]]);
|
||||
)";
|
@ -1,5 +1,4 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
@ -12,11 +11,15 @@
|
||||
#include "mlx/backend/metal/jit/scan.h"
|
||||
#include "mlx/backend/metal/jit/softmax.h"
|
||||
#include "mlx/backend/metal/jit/sort.h"
|
||||
#include "mlx/backend/metal/jit/steel_conv.h"
|
||||
#include "mlx/backend/metal/jit/steel_gemm.h"
|
||||
#include "mlx/backend/metal/jit/ternary.h"
|
||||
#include "mlx/backend/metal/jit/unary.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
|
||||
using namespace fmt::literals;
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::string op_name(const array& arr) {
|
||||
@ -276,4 +279,208 @@ MTL::ComputePipelineState* get_reduce_kernel(
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const array& out,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::gemm()
|
||||
<< metal::steel_gemm_fused()
|
||||
<< fmt::format(
|
||||
steel_gemm_fused_kernels,
|
||||
"name"_a = lib_name,
|
||||
"itype"_a = get_type_string(out.dtype()),
|
||||
"bm"_a = bm,
|
||||
"bn"_a = bn,
|
||||
"bk"_a = bk,
|
||||
"wm"_a = wm,
|
||||
"wn"_a = wn,
|
||||
"trans_a"_a = transpose_a,
|
||||
"trans_b"_a = transpose_b);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn,
|
||||
bool mn_aligned,
|
||||
bool k_aligned) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::gemm()
|
||||
<< metal::steel_gemm_splitk()
|
||||
<< fmt::format(
|
||||
steel_gemm_splitk_kernels,
|
||||
"name"_a = lib_name,
|
||||
"itype"_a = get_type_string(in.dtype()),
|
||||
"otype"_a = get_type_string(out.dtype()),
|
||||
"bm"_a = bm,
|
||||
"bn"_a = bn,
|
||||
"bk"_a = bk,
|
||||
"wm"_a = wm,
|
||||
"wn"_a = wn,
|
||||
"trans_a"_a = transpose_a,
|
||||
"trans_b"_a = transpose_b,
|
||||
"mn_aligned"_a = mn_aligned,
|
||||
"k_aligned"_a = k_aligned);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out,
|
||||
bool axbpy) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::gemm()
|
||||
<< metal::steel_gemm_splitk()
|
||||
<< fmt::format(
|
||||
axbpy ? steel_gemm_splitk_accum_axbpy_kernels
|
||||
: steel_gemm_splitk_accum_kernels,
|
||||
"name"_a = lib_name,
|
||||
"atype"_a = get_type_string(in.dtype()),
|
||||
"otype"_a = get_type_string(out.dtype()));
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out,
|
||||
const std::optional<array>& mask_out,
|
||||
const std::optional<array>& mask_op,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn,
|
||||
bool mn_aligned,
|
||||
bool k_aligned) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
auto out_mask_type = mask_out.has_value()
|
||||
? get_type_string((*mask_out).dtype())
|
||||
: "nomask_t";
|
||||
auto op_mask_type =
|
||||
mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t";
|
||||
kernel_source << metal::utils() << metal::gemm()
|
||||
<< metal::steel_gemm_masked()
|
||||
<< fmt::format(
|
||||
steel_gemm_masked_kernels,
|
||||
"name"_a = lib_name,
|
||||
"itype"_a = get_type_string(out.dtype()),
|
||||
"outmasktype"_a = out_mask_type,
|
||||
"opmasktype"_a = op_mask_type,
|
||||
"bm"_a = bm,
|
||||
"bn"_a = bn,
|
||||
"bk"_a = bk,
|
||||
"wm"_a = wm,
|
||||
"wn"_a = wn,
|
||||
"trans_a"_a = transpose_a,
|
||||
"trans_b"_a = transpose_b,
|
||||
"mn_aligned"_a = mn_aligned,
|
||||
"k_aligned"_a = k_aligned);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_conv_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn,
|
||||
int n_channel_specialization,
|
||||
bool small_filter) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::conv() << metal::steel_conv()
|
||||
<< fmt::format(
|
||||
steel_conv_kernels,
|
||||
"name"_a = lib_name,
|
||||
"itype"_a = get_type_string(out.dtype()),
|
||||
"bm"_a = bm,
|
||||
"bn"_a = bn,
|
||||
"bk"_a = bk,
|
||||
"wm"_a = wm,
|
||||
"wn"_a = wn,
|
||||
"n_channels"_a = n_channel_specialization,
|
||||
"small_filter"_a = small_filter);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::conv()
|
||||
<< metal::steel_conv_general()
|
||||
<< fmt::format(
|
||||
steel_conv_general_kernels,
|
||||
"name"_a = lib_name,
|
||||
"itype"_a = get_type_string(out.dtype()),
|
||||
"bm"_a = bm,
|
||||
"bn"_a = bn,
|
||||
"bk"_a = bk,
|
||||
"wm"_a = wm,
|
||||
"wn"_a = wn);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -79,4 +79,78 @@ MTL::ComputePipelineState* get_reduce_kernel(
|
||||
const array& in,
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const array& out,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn);
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn,
|
||||
bool mn_aligned,
|
||||
bool k_aligned);
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out,
|
||||
bool axbpy);
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out,
|
||||
const std::optional<array>& mask_out,
|
||||
const std::optional<array>& mask_op,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn,
|
||||
bool mn_aligned,
|
||||
bool k_aligned);
|
||||
|
||||
MTL::ComputePipelineState* get_steel_conv_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn,
|
||||
int n_channel_specialization,
|
||||
bool small_filter);
|
||||
|
||||
MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -1,14 +1,13 @@
|
||||
set(
|
||||
HEADERS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/atomic.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/bf16.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/bf16_math.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/complex.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/defines.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.h
|
||||
bf16.h
|
||||
bf16_math.h
|
||||
complex.h
|
||||
defines.h
|
||||
utils.h
|
||||
steel/conv/params.h
|
||||
)
|
||||
|
||||
|
||||
set(
|
||||
KERNELS
|
||||
"arg_reduce"
|
||||
@ -41,6 +40,7 @@ set(
|
||||
set(
|
||||
HEADERS
|
||||
${HEADERS}
|
||||
atomic.h
|
||||
arange.h
|
||||
unary_ops.h
|
||||
unary.h
|
||||
@ -89,14 +89,40 @@ foreach(KERNEL ${KERNELS})
|
||||
set(KERNEL_AIR ${KERNEL}.air ${KERNEL_AIR})
|
||||
endforeach()
|
||||
|
||||
file(GLOB_RECURSE STEEL_KERNELS ${CMAKE_CURRENT_SOURCE_DIR}/steel/*.metal)
|
||||
file(GLOB_RECURSE STEEL_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/steel/*.h)
|
||||
|
||||
foreach(KERNEL ${STEEL_KERNELS})
|
||||
cmake_path(GET KERNEL STEM TARGET)
|
||||
build_kernel_base(${TARGET} ${KERNEL} "${STEEL_HEADERS}")
|
||||
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR})
|
||||
endforeach()
|
||||
if (NOT MLX_METAL_JIT)
|
||||
set(
|
||||
STEEL_KERNELS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/steel/conv/kernels/steel_conv.metal
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/steel/conv/kernels/steel_conv_general.metal
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/steel/gemm/kernels/steel_gemm_fused.metal
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/steel/gemm/kernels/steel_gemm_masked.metal
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/steel/gemm/kernels/steel_gemm_splitk.metal
|
||||
)
|
||||
set(
|
||||
STEEL_HEADERS
|
||||
steel/defines.h
|
||||
steel/utils.h
|
||||
steel/conv/conv.h
|
||||
steel/conv/loader.h
|
||||
steel/conv/loaders/loader_channel_l.h
|
||||
steel/conv/loaders/loader_channel_n.h
|
||||
steel/conv/loaders/loader_general.h
|
||||
steel/conv/kernels/steel_conv.h
|
||||
steel/conv/kernels/steel_conv_general.h
|
||||
steel/gemm/gemm.h
|
||||
steel/gemm/mma.h
|
||||
steel/gemm/loader.h
|
||||
steel/gemm/transforms.h
|
||||
steel/gemm/kernels/steel_gemm_fused.h
|
||||
steel/gemm/kernels/steel_gemm_masked.h
|
||||
steel/gemm/kernels/steel_gemm_splitk.h
|
||||
)
|
||||
foreach(KERNEL ${STEEL_KERNELS})
|
||||
cmake_path(GET KERNEL STEM TARGET)
|
||||
build_kernel_base(${TARGET} ${KERNEL} "${STEEL_HEADERS}")
|
||||
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR})
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
|
||||
|
@ -1,6 +1,5 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <metal_atomic>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
@ -194,4 +193,4 @@ instantiate_arg_reduce(int32, int32_t)
|
||||
instantiate_arg_reduce(int64, int64_t)
|
||||
instantiate_arg_reduce(float16, half)
|
||||
instantiate_arg_reduce(float32, float)
|
||||
instantiate_arg_reduce(bfloat16, bfloat16_t) // clang-format on
|
||||
instantiate_arg_reduce(bfloat16, bfloat16_t) // clang-format on
|
||||
|
@ -650,4 +650,4 @@ winograd_conv_2d_output_transform(
|
||||
|
||||
// clang-format off
|
||||
instantiate_winograd_conv_2d(float32, float);
|
||||
instantiate_winograd_conv_2d(float16, half); // clang-format on
|
||||
instantiate_winograd_conv_2d(float16, half); // clang-format on
|
||||
|
@ -2,10 +2,12 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/defines.h"
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/conv/loader.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
||||
|
||||
using namespace metal;
|
||||
using namespace mlx::steel;
|
||||
using namespace mlx::steel;
|
||||
|
176
mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h
Normal file
176
mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h
Normal file
@ -0,0 +1,176 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
int N_CHANNELS = 0,
|
||||
bool SMALL_FILTER = false>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
||||
implicit_gemm_conv_2d(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
device T* C [[buffer(2)]],
|
||||
const constant MLXConvParams<2>* params [[buffer(3)]],
|
||||
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
using namespace mlx::steel;
|
||||
|
||||
(void)lid;
|
||||
|
||||
constexpr bool transpose_a = false;
|
||||
constexpr bool transpose_b = true;
|
||||
constexpr short tgp_padding_a = 16 / sizeof(T);
|
||||
constexpr short tgp_padding_b = 16 / sizeof(T);
|
||||
|
||||
constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a;
|
||||
constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b;
|
||||
constexpr short shape_a_rows = (transpose_a ? BK : BM);
|
||||
constexpr short shape_b_rows = (transpose_b ? BN : BK);
|
||||
constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows;
|
||||
constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;
|
||||
|
||||
constexpr short tgp_size = WM * WN * 32;
|
||||
|
||||
// Input loader
|
||||
|
||||
using loader_a_t = typename metal::conditional_t<
|
||||
// Check for small channel specialization
|
||||
N_CHANNELS != 0 && N_CHANNELS <= 4,
|
||||
|
||||
// Go to small channel specialization
|
||||
Conv2DInputBlockLoaderSmallChannels<
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
tgp_size,
|
||||
N_CHANNELS,
|
||||
tgp_padding_a>,
|
||||
|
||||
// Else go to general loader
|
||||
typename metal::conditional_t<
|
||||
// Check if filter size is small enough
|
||||
SMALL_FILTER,
|
||||
|
||||
// Go to small filter specialization
|
||||
Conv2DInputBlockLoaderSmallFilter<
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
tgp_size,
|
||||
tgp_padding_a>,
|
||||
|
||||
// Else go to large filter generalization
|
||||
Conv2DInputBlockLoaderLargeFilter<
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
tgp_size,
|
||||
tgp_padding_a>>>;
|
||||
|
||||
// Weight loader
|
||||
using loader_b_t = typename metal::conditional_t<
|
||||
// Check for small channel specialization
|
||||
N_CHANNELS != 0 && N_CHANNELS <= 4,
|
||||
|
||||
// Go to small channel specialization
|
||||
Conv2DWeightBlockLoaderSmallChannels<
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
tgp_size,
|
||||
N_CHANNELS,
|
||||
tgp_padding_b>,
|
||||
|
||||
// Else go to general loader
|
||||
Conv2DWeightBlockLoader<T, BM, BN, BK, tgp_size, tgp_padding_b>>;
|
||||
|
||||
using mma_t = BlockMMA<
|
||||
T,
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
shape_a_cols,
|
||||
shape_b_cols>;
|
||||
|
||||
threadgroup T As[tgp_mem_size_a];
|
||||
threadgroup T Bs[tgp_mem_size_b];
|
||||
|
||||
const int tid_y = ((tid.y) << gemm_params->swizzle_log) +
|
||||
((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
|
||||
const int tid_x = (tid.x) >> gemm_params->swizzle_log;
|
||||
|
||||
if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int c_row = tid_y * BM;
|
||||
const int c_col = tid_x * BN;
|
||||
const int K = gemm_params->K;
|
||||
const int N = gemm_params->N;
|
||||
const int C_per_group = params->C / params->groups;
|
||||
|
||||
// Groups
|
||||
A += tid.z * C_per_group;
|
||||
B += tid.z * N * K;
|
||||
C += tid.z * N;
|
||||
|
||||
B += c_col * K;
|
||||
C += c_row * (N * params->groups) + c_col;
|
||||
|
||||
const int2 offsets_a(0, c_row);
|
||||
const int2 offsets_b(0, c_col);
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
loader_a_t loader_a(
|
||||
A, As, offsets_a, params, gemm_params, simd_gid, simd_lid);
|
||||
loader_b_t loader_b(
|
||||
B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
mma_t mma_op(simd_gid, simd_lid);
|
||||
|
||||
int gemm_k_iterations = gemm_params->gemm_k_iterations;
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Store results to device memory
|
||||
short tgp_bm = min(BM, gemm_params->M - c_row);
|
||||
short tgp_bn = min(BN, gemm_params->N - c_col);
|
||||
const int ldc = N * params->groups;
|
||||
mma_op.store_result_safe(C, ldc, short2(tgp_bn, tgp_bm));
|
||||
}
|
@ -2,184 +2,13 @@
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
int N_CHANNELS = 0,
|
||||
bool SMALL_FILTER = false>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
||||
implicit_gemm_conv_2d(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
device T* C [[buffer(2)]],
|
||||
const constant MLXConvParams<2>* params [[buffer(3)]],
|
||||
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
using namespace mlx::steel;
|
||||
|
||||
(void)lid;
|
||||
|
||||
constexpr bool transpose_a = false;
|
||||
constexpr bool transpose_b = true;
|
||||
constexpr short tgp_padding_a = 16 / sizeof(T);
|
||||
constexpr short tgp_padding_b = 16 / sizeof(T);
|
||||
|
||||
constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a;
|
||||
constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b;
|
||||
constexpr short shape_a_rows = (transpose_a ? BK : BM);
|
||||
constexpr short shape_b_rows = (transpose_b ? BN : BK);
|
||||
constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows;
|
||||
constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;
|
||||
|
||||
constexpr short tgp_size = WM * WN * 32;
|
||||
|
||||
// Input loader
|
||||
|
||||
using loader_a_t = typename metal::conditional_t<
|
||||
// Check for small channel specialization
|
||||
N_CHANNELS != 0 && N_CHANNELS <= 4,
|
||||
|
||||
// Go to small channel specialization
|
||||
Conv2DInputBlockLoaderSmallChannels<
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
tgp_size,
|
||||
N_CHANNELS,
|
||||
tgp_padding_a>,
|
||||
|
||||
// Else go to general loader
|
||||
typename metal::conditional_t<
|
||||
// Check if filter size is small enough
|
||||
SMALL_FILTER,
|
||||
|
||||
// Go to small filter specialization
|
||||
Conv2DInputBlockLoaderSmallFilter<
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
tgp_size,
|
||||
tgp_padding_a>,
|
||||
|
||||
// Else go to large filter generalization
|
||||
Conv2DInputBlockLoaderLargeFilter<
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
tgp_size,
|
||||
tgp_padding_a>>>;
|
||||
|
||||
// Weight loader
|
||||
using loader_b_t = typename metal::conditional_t<
|
||||
// Check for small channel specialization
|
||||
N_CHANNELS != 0 && N_CHANNELS <= 4,
|
||||
|
||||
// Go to small channel specialization
|
||||
Conv2DWeightBlockLoaderSmallChannels<
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
tgp_size,
|
||||
N_CHANNELS,
|
||||
tgp_padding_b>,
|
||||
|
||||
// Else go to general loader
|
||||
Conv2DWeightBlockLoader<T, BM, BN, BK, tgp_size, tgp_padding_b>>;
|
||||
|
||||
using mma_t = BlockMMA<
|
||||
T,
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
shape_a_cols,
|
||||
shape_b_cols>;
|
||||
|
||||
threadgroup T As[tgp_mem_size_a];
|
||||
threadgroup T Bs[tgp_mem_size_b];
|
||||
|
||||
const int tid_y = ((tid.y) << gemm_params->swizzle_log) +
|
||||
((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
|
||||
const int tid_x = (tid.x) >> gemm_params->swizzle_log;
|
||||
|
||||
if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int c_row = tid_y * BM;
|
||||
const int c_col = tid_x * BN;
|
||||
const int K = gemm_params->K;
|
||||
const int N = gemm_params->N;
|
||||
const int C_per_group = params->C / params->groups;
|
||||
|
||||
// Groups
|
||||
A += tid.z * C_per_group;
|
||||
B += tid.z * N * K;
|
||||
C += tid.z * N;
|
||||
|
||||
B += c_col * K;
|
||||
C += c_row * (N * params->groups) + c_col;
|
||||
|
||||
const int2 offsets_a(0, c_row);
|
||||
const int2 offsets_b(0, c_col);
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
loader_a_t loader_a(
|
||||
A, As, offsets_a, params, gemm_params, simd_gid, simd_lid);
|
||||
loader_b_t loader_b(
|
||||
B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
mma_t mma_op(simd_gid, simd_lid);
|
||||
|
||||
int gemm_k_iterations = gemm_params->gemm_k_iterations;
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Store results to device memory
|
||||
short tgp_bm = min(BM, gemm_params->M - c_row);
|
||||
short tgp_bn = min(BN, gemm_params->N - c_col);
|
||||
const int ldc = N * params->groups;
|
||||
mma_op.store_result_safe(C, ldc, short2(tgp_bn, tgp_bm));
|
||||
}
|
||||
#include "mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h"
|
||||
|
||||
#define instantiate_implicit_conv_2d( \
|
||||
name, \
|
||||
@ -207,25 +36,22 @@ implicit_gemm_conv_2d(
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \
|
||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, s, true) \
|
||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, l, false) \
|
||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 1, 1, l, false) \
|
||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 2, 2, l, false) \
|
||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 3, 3, l, false) \
|
||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 4, 4, l, false) // clang-format on
|
||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 4, 4, l, false)
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_implicit_2d_blocks(name, itype) \
|
||||
instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \
|
||||
instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \
|
||||
instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \
|
||||
instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \
|
||||
instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \
|
||||
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2) // clang-format on
|
||||
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2)
|
||||
|
||||
// clang-format off
|
||||
instantiate_implicit_2d_blocks(float32, float);
|
||||
instantiate_implicit_2d_blocks(float16, half);
|
||||
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); // clang-format on
|
||||
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); // clang-format on
|
||||
|
@ -0,0 +1,188 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h"
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
typename AccumType = float,
|
||||
typename Epilogue = TransformNone<T, AccumType>>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
||||
implicit_gemm_conv_2d_general(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
device T* C [[buffer(2)]],
|
||||
const constant MLXConvParams<2>* params [[buffer(3)]],
|
||||
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
|
||||
const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]],
|
||||
const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]],
|
||||
const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
(void)lid;
|
||||
|
||||
constexpr bool transpose_a = false;
|
||||
constexpr bool transpose_b = true;
|
||||
constexpr short tgp_padding_a = 16 / sizeof(T);
|
||||
constexpr short tgp_padding_b = 16 / sizeof(T);
|
||||
|
||||
constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a;
|
||||
constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b;
|
||||
constexpr short shape_a_rows = (transpose_a ? BK : BM);
|
||||
constexpr short shape_b_rows = (transpose_b ? BN : BK);
|
||||
constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows;
|
||||
constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;
|
||||
|
||||
constexpr short tgp_size = WM * WN * 32;
|
||||
|
||||
// Input loader
|
||||
using loader_a_t =
|
||||
Conv2DInputBlockLoaderGeneral<T, BM, BN, BK, tgp_size, tgp_padding_a>;
|
||||
|
||||
// Weight loader
|
||||
using loader_b_t =
|
||||
Conv2DWeightBlockLoaderGeneral<T, BM, BN, BK, tgp_size, tgp_padding_b>;
|
||||
|
||||
using mma_t = BlockMMA<
|
||||
T,
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
shape_a_cols,
|
||||
shape_b_cols>;
|
||||
|
||||
threadgroup T As[tgp_mem_size_a];
|
||||
threadgroup T Bs[tgp_mem_size_b];
|
||||
|
||||
const int tid_y = ((tid.y) << gemm_params->swizzle_log) +
|
||||
((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
|
||||
const int tid_x = (tid.x) >> gemm_params->swizzle_log;
|
||||
|
||||
if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int tid_z = tid.z;
|
||||
|
||||
const int base_oh = tid_z / jump_params->f_out_jump_w;
|
||||
const int base_ow = tid_z % jump_params->f_out_jump_w;
|
||||
|
||||
const int base_wh = base_h[base_oh].weight_base;
|
||||
const int base_ww = base_w[base_ow].weight_base;
|
||||
|
||||
const int base_wh_size = base_h[base_oh].weight_size;
|
||||
const int base_ww_size = base_w[base_ow].weight_size;
|
||||
|
||||
const int c_row = tid_y * BM;
|
||||
const int c_col = tid_x * BN;
|
||||
const int K = gemm_params->K;
|
||||
|
||||
B += c_col * K;
|
||||
|
||||
const int4 offsets_a(0, c_row, base_oh, base_ow);
|
||||
const int2 offsets_b(0, c_col);
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
loader_a_t loader_a(
|
||||
A,
|
||||
As,
|
||||
offsets_a,
|
||||
params,
|
||||
jump_params,
|
||||
base_wh,
|
||||
base_ww,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
loader_b_t loader_b(
|
||||
B,
|
||||
Bs,
|
||||
offsets_b,
|
||||
params,
|
||||
jump_params,
|
||||
base_wh,
|
||||
base_ww,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
mma_t mma_op(simd_gid, simd_lid);
|
||||
|
||||
int gemm_k_iterations =
|
||||
base_wh_size * base_ww_size * gemm_params->gemm_k_iterations;
|
||||
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Store results to device memory
|
||||
{
|
||||
// Adjust for simdgroup and thread locatio
|
||||
int offset_m = c_row + mma_op.sm + mma_op.tm;
|
||||
int offset_n = c_col + mma_op.sn + mma_op.tn;
|
||||
C += offset_n;
|
||||
|
||||
if (offset_n >= gemm_params->N)
|
||||
return;
|
||||
|
||||
short diff = gemm_params->N - offset_n;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < mma_t::TM; i++) {
|
||||
int cm = offset_m + i * mma_t::TM_stride;
|
||||
|
||||
int n = cm / jump_params->adj_out_hw;
|
||||
int hw = cm % jump_params->adj_out_hw;
|
||||
int oh =
|
||||
(hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + base_oh;
|
||||
int ow =
|
||||
(hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow;
|
||||
|
||||
if (n < params->N && oh < params->oS[0] && ow < params->oS[1]) {
|
||||
int offset_cm = n * params->out_strides[0] +
|
||||
oh * params->out_strides[1] + ow * params->out_strides[2];
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < mma_t::TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum =
|
||||
mma_op.results[i * mma_t::TN + j].thread_elements();
|
||||
int offset = offset_cm + (j * mma_t::TN_stride);
|
||||
|
||||
// Apply epilogue and output C
|
||||
if (j * mma_t::TN_stride < diff) {
|
||||
C[offset] = Epilogue::apply(accum[0]);
|
||||
}
|
||||
|
||||
if (j * mma_t::TN_stride + 1 < diff) {
|
||||
C[offset + 1] = Epilogue::apply(accum[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -2,201 +2,18 @@
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h"
|
||||
|
||||
using namespace metal;
|
||||
using namespace mlx::steel;
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
typename AccumType = float,
|
||||
typename Epilogue = TransformNone<T, AccumType>>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
||||
implicit_gemm_conv_2d_general(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
device T* C [[buffer(2)]],
|
||||
const constant MLXConvParams<2>* params [[buffer(3)]],
|
||||
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
|
||||
const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]],
|
||||
const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]],
|
||||
const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
(void)lid;
|
||||
|
||||
constexpr bool transpose_a = false;
|
||||
constexpr bool transpose_b = true;
|
||||
constexpr short tgp_padding_a = 16 / sizeof(T);
|
||||
constexpr short tgp_padding_b = 16 / sizeof(T);
|
||||
|
||||
constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a;
|
||||
constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b;
|
||||
constexpr short shape_a_rows = (transpose_a ? BK : BM);
|
||||
constexpr short shape_b_rows = (transpose_b ? BN : BK);
|
||||
constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows;
|
||||
constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;
|
||||
|
||||
constexpr short tgp_size = WM * WN * 32;
|
||||
|
||||
// Input loader
|
||||
using loader_a_t =
|
||||
Conv2DInputBlockLoaderGeneral<T, BM, BN, BK, tgp_size, tgp_padding_a>;
|
||||
|
||||
// Weight loader
|
||||
using loader_b_t =
|
||||
Conv2DWeightBlockLoaderGeneral<T, BM, BN, BK, tgp_size, tgp_padding_b>;
|
||||
|
||||
using mma_t = BlockMMA<
|
||||
T,
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
shape_a_cols,
|
||||
shape_b_cols>;
|
||||
|
||||
threadgroup T As[tgp_mem_size_a];
|
||||
threadgroup T Bs[tgp_mem_size_b];
|
||||
|
||||
const int tid_y = ((tid.y) << gemm_params->swizzle_log) +
|
||||
((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
|
||||
const int tid_x = (tid.x) >> gemm_params->swizzle_log;
|
||||
|
||||
if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int tid_z = tid.z;
|
||||
|
||||
const int base_oh = tid_z / jump_params->f_out_jump_w;
|
||||
const int base_ow = tid_z % jump_params->f_out_jump_w;
|
||||
|
||||
const int base_wh = base_h[base_oh].weight_base;
|
||||
const int base_ww = base_w[base_ow].weight_base;
|
||||
|
||||
const int base_wh_size = base_h[base_oh].weight_size;
|
||||
const int base_ww_size = base_w[base_ow].weight_size;
|
||||
|
||||
const int c_row = tid_y * BM;
|
||||
const int c_col = tid_x * BN;
|
||||
const int K = gemm_params->K;
|
||||
|
||||
B += c_col * K;
|
||||
|
||||
const int4 offsets_a(0, c_row, base_oh, base_ow);
|
||||
const int2 offsets_b(0, c_col);
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
loader_a_t loader_a(
|
||||
A,
|
||||
As,
|
||||
offsets_a,
|
||||
params,
|
||||
jump_params,
|
||||
base_wh,
|
||||
base_ww,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
loader_b_t loader_b(
|
||||
B,
|
||||
Bs,
|
||||
offsets_b,
|
||||
params,
|
||||
jump_params,
|
||||
base_wh,
|
||||
base_ww,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
mma_t mma_op(simd_gid, simd_lid);
|
||||
|
||||
int gemm_k_iterations =
|
||||
base_wh_size * base_ww_size * gemm_params->gemm_k_iterations;
|
||||
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Store results to device memory
|
||||
{
|
||||
// Adjust for simdgroup and thread locatio
|
||||
int offset_m = c_row + mma_op.sm + mma_op.tm;
|
||||
int offset_n = c_col + mma_op.sn + mma_op.tn;
|
||||
C += offset_n;
|
||||
|
||||
if (offset_n >= gemm_params->N)
|
||||
return;
|
||||
|
||||
short diff = gemm_params->N - offset_n;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < mma_t::TM; i++) {
|
||||
int cm = offset_m + i * mma_t::TM_stride;
|
||||
|
||||
int n = cm / jump_params->adj_out_hw;
|
||||
int hw = cm % jump_params->adj_out_hw;
|
||||
int oh =
|
||||
(hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + base_oh;
|
||||
int ow =
|
||||
(hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow;
|
||||
|
||||
if (n < params->N && oh < params->oS[0] && ow < params->oS[1]) {
|
||||
int offset_cm = n * params->out_strides[0] +
|
||||
oh * params->out_strides[1] + ow * params->out_strides[2];
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < mma_t::TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum =
|
||||
mma_op.results[i * mma_t::TN + j].thread_elements();
|
||||
int offset = offset_cm + (j * mma_t::TN_stride);
|
||||
|
||||
// Apply epilogue and output C
|
||||
if (j * mma_t::TN_stride < diff) {
|
||||
C[offset] = Epilogue::apply(accum[0]);
|
||||
}
|
||||
|
||||
if (j * mma_t::TN_stride + 1 < diff) {
|
||||
C[offset + 1] = Epilogue::apply(accum[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) \
|
||||
template \
|
||||
[[host_name("implicit_gemm_conv_2d_general_" #name "_bm" #bm "_bn" #bn \
|
||||
@ -218,16 +35,14 @@ implicit_gemm_conv_2d_general(
|
||||
#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \
|
||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn)
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_implicit_2d_blocks(name, itype) \
|
||||
instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \
|
||||
instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \
|
||||
instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \
|
||||
instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \
|
||||
instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \
|
||||
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2) // clang-format on
|
||||
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2)
|
||||
|
||||
// clang-format off
|
||||
instantiate_implicit_2d_blocks(float32, float);
|
||||
instantiate_implicit_2d_blocks(float16, half);
|
||||
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); // clang-format on
|
||||
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); // clang-format on
|
||||
|
@ -2,9 +2,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
#include "mlx/backend/metal/kernels/steel/defines.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Loading helper
|
||||
@ -285,4 +283,4 @@ struct Conv2DWeightBlockLoaderGeneral {
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
||||
} // namespace mlx
|
||||
|
4
mlx/backend/metal/kernels/steel/defines.h
Normal file
4
mlx/backend/metal/kernels/steel/defines.h
Normal file
@ -0,0 +1,4 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#define STEEL_CONST static constant constexpr const
|
||||
#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
|
415
mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h
Normal file
415
mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h
Normal file
@ -0,0 +1,415 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
using namespace mlx::steel;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
constant bool has_batch [[function_constant(10)]];
|
||||
|
||||
constant bool use_out_source [[function_constant(100)]];
|
||||
constant bool do_axpby [[function_constant(110)]];
|
||||
|
||||
constant bool align_M [[function_constant(200)]];
|
||||
constant bool align_N [[function_constant(201)]];
|
||||
constant bool align_K [[function_constant(202)]];
|
||||
|
||||
constant bool do_gather [[function_constant(300)]];
|
||||
|
||||
constant bool gather_bias = do_gather && use_out_source;
|
||||
|
||||
// clang-format off
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
typename AccumType = float>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
const device T* C [[buffer(2), function_constant(use_out_source)]],
|
||||
device T* D [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant size_t* batch_strides [[buffer(7)]],
|
||||
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
|
||||
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
|
||||
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
|
||||
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
|
||||
const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]],
|
||||
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on
|
||||
// Pacifying compiler
|
||||
(void)lid;
|
||||
|
||||
using gemm_kernel = GEMMKernel<
|
||||
T,
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
true,
|
||||
true,
|
||||
AccumType>;
|
||||
|
||||
using loader_a_t = typename gemm_kernel::loader_a_t;
|
||||
using loader_b_t = typename gemm_kernel::loader_b_t;
|
||||
using mma_t = typename gemm_kernel::mma_t;
|
||||
|
||||
// Find block
|
||||
const int tid_y = ((tid.y) << params->swizzle_log) +
|
||||
((tid.x) & ((1 << params->swizzle_log) - 1));
|
||||
const int tid_x = (tid.x) >> params->swizzle_log;
|
||||
|
||||
// Exit early if out of bounds
|
||||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Adjust for batch
|
||||
|
||||
// Handle gather
|
||||
if (do_gather) {
|
||||
// Read indices
|
||||
uint32_t indx_A, indx_B, indx_C;
|
||||
|
||||
if (has_batch) {
|
||||
const constant size_t* indx_A_bstrides = batch_strides;
|
||||
const constant size_t* indx_B_bstrides =
|
||||
batch_strides + params->batch_ndim;
|
||||
|
||||
ulong2 indx_offsets = elem_to_loc_broadcast(
|
||||
tid.z,
|
||||
batch_shape,
|
||||
indx_A_bstrides,
|
||||
indx_B_bstrides,
|
||||
params->batch_ndim);
|
||||
indx_A = lhs_indices[indx_offsets.x];
|
||||
indx_B = rhs_indices[indx_offsets.y];
|
||||
|
||||
if (use_out_source) {
|
||||
const constant size_t* indx_C_bstrides =
|
||||
indx_B_bstrides + params->batch_ndim;
|
||||
auto indx_offset_C = elem_to_loc(
|
||||
tid.z, batch_shape, indx_C_bstrides, params->batch_ndim);
|
||||
indx_C = C_indices[indx_offset_C];
|
||||
}
|
||||
} else {
|
||||
indx_A = lhs_indices[params->batch_stride_a * tid.z];
|
||||
indx_B = rhs_indices[params->batch_stride_b * tid.z];
|
||||
|
||||
if (use_out_source) {
|
||||
indx_C = C_indices[addmm_params->batch_stride_c * tid.z];
|
||||
}
|
||||
}
|
||||
|
||||
// Translate indices to offsets
|
||||
int batch_ndim_A = operand_batch_ndim.x;
|
||||
const constant int* batch_shape_A = operand_shape;
|
||||
const constant size_t* batch_strides_A = operand_strides;
|
||||
A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A);
|
||||
|
||||
int batch_ndim_B = operand_batch_ndim.y;
|
||||
const constant int* batch_shape_B = batch_shape_A + batch_ndim_A;
|
||||
const constant size_t* batch_strides_B = batch_strides_A + batch_ndim_A;
|
||||
B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B);
|
||||
|
||||
if (use_out_source) {
|
||||
int batch_ndim_C = operand_batch_ndim.z;
|
||||
const constant int* batch_shape_C = batch_shape_B + batch_ndim_B;
|
||||
const constant size_t* batch_strides_C = batch_strides_B + batch_ndim_B;
|
||||
C += elem_to_loc(indx_C, batch_shape_C, batch_strides_C, batch_ndim_C);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Handle regular batch
|
||||
else {
|
||||
if (has_batch) {
|
||||
const constant size_t* A_bstrides = batch_strides;
|
||||
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
|
||||
|
||||
A += batch_offsets.x;
|
||||
B += batch_offsets.y;
|
||||
|
||||
if (use_out_source) {
|
||||
const constant size_t* C_bstrides = B_bstrides + params->batch_ndim;
|
||||
C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
|
||||
}
|
||||
} else {
|
||||
A += params->batch_stride_a * tid.z;
|
||||
B += params->batch_stride_b * tid.z;
|
||||
|
||||
if (use_out_source) {
|
||||
C += addmm_params->batch_stride_c * tid.z;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
D += params->batch_stride_d * tid.z;
|
||||
|
||||
// Prepare threadgroup memory
|
||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Find block in A, B, C
|
||||
const int c_row = tid_y * BM;
|
||||
const int c_col = tid_x * BN;
|
||||
const size_t c_row_long = size_t(c_row);
|
||||
const size_t c_col_long = size_t(c_col);
|
||||
|
||||
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
||||
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
||||
D += c_row_long * params->ldd + c_col_long;
|
||||
|
||||
if (use_out_source) {
|
||||
C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc;
|
||||
}
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||
|
||||
// Prepare threadgroup bounds
|
||||
const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
|
||||
const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
|
||||
|
||||
// Prepare iterations
|
||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||
|
||||
// Do unaligned K iterations first
|
||||
if (!align_K) {
|
||||
const int k_last = params->gemm_k_iterations_aligned * BK;
|
||||
const int k_remain = params->K - k_last;
|
||||
const size_t k_jump_a =
|
||||
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
|
||||
const size_t k_jump_b =
|
||||
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
|
||||
|
||||
// Move loader source ahead to end
|
||||
loader_a.src += k_jump_a;
|
||||
loader_b.src += k_jump_b;
|
||||
|
||||
// Load tile
|
||||
const short2 tile_dims_A =
|
||||
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||
const short2 tile_dims_B =
|
||||
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
||||
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Do matmul
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Reset source back to start
|
||||
loader_a.src -= k_jump_a;
|
||||
loader_b.src -= k_jump_b;
|
||||
}
|
||||
|
||||
const TransformAdd<AccumType, AccumType> epilogue_op_add(
|
||||
addmm_params->alpha, addmm_params->beta);
|
||||
const TransformAxpby<AccumType, AccumType> epilogue_op_axpby(
|
||||
addmm_params->alpha, addmm_params->beta);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MNK aligned loop
|
||||
if (align_M && align_N) {
|
||||
// Do gemm
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Do epilogue
|
||||
if (use_out_source) {
|
||||
if (do_axpby) {
|
||||
mma_op.apply_epilogue(
|
||||
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby);
|
||||
} else {
|
||||
mma_op.apply_epilogue(
|
||||
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add);
|
||||
}
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
return mma_op.store_result(D, params->ldd);
|
||||
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MN unaligned loop
|
||||
else { // Loop over K - unaligned case
|
||||
const int leftover_bk = 0;
|
||||
|
||||
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
|
||||
// Do gemm
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<true, true, true>{});
|
||||
|
||||
// Do epilogue
|
||||
if (use_out_source) {
|
||||
if (do_axpby) {
|
||||
mma_op.apply_epilogue(
|
||||
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby);
|
||||
} else {
|
||||
mma_op.apply_epilogue(
|
||||
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add);
|
||||
}
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
return mma_op.store_result(D, params->ldd);
|
||||
|
||||
} else if (align_N || tgp_bn == BN) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<false, true, true>{});
|
||||
|
||||
// Do epilogue
|
||||
if (use_out_source) {
|
||||
if (do_axpby) {
|
||||
mma_op.apply_epilogue_safe(
|
||||
C,
|
||||
addmm_params->ldc,
|
||||
addmm_params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op_axpby);
|
||||
} else {
|
||||
mma_op.apply_epilogue_safe(
|
||||
C,
|
||||
addmm_params->ldc,
|
||||
addmm_params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op_add);
|
||||
}
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
|
||||
} else if (align_M || tgp_bm == BM) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<true, false, true>{});
|
||||
|
||||
// Do epilogue
|
||||
if (use_out_source) {
|
||||
if (do_axpby) {
|
||||
mma_op.apply_epilogue_safe(
|
||||
C,
|
||||
addmm_params->ldc,
|
||||
addmm_params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op_axpby);
|
||||
} else {
|
||||
mma_op.apply_epilogue_safe(
|
||||
C,
|
||||
addmm_params->ldc,
|
||||
addmm_params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op_add);
|
||||
}
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
|
||||
} else {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<false, false, true>{});
|
||||
|
||||
// Do epilogue
|
||||
if (use_out_source) {
|
||||
if (do_axpby) {
|
||||
mma_op.apply_epilogue_safe(
|
||||
C,
|
||||
addmm_params->ldc,
|
||||
addmm_params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op_axpby);
|
||||
} else {
|
||||
mma_op.apply_epilogue_safe(
|
||||
C,
|
||||
addmm_params->ldc,
|
||||
addmm_params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op_add);
|
||||
}
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
}
|
||||
}
|
||||
}
|
@ -1,430 +1,12 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h"
|
||||
|
||||
using namespace metal;
|
||||
using namespace mlx::steel;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
constant bool has_batch [[function_constant(10)]];
|
||||
|
||||
constant bool use_out_source [[function_constant(100)]];
|
||||
constant bool do_axpby [[function_constant(110)]];
|
||||
|
||||
constant bool align_M [[function_constant(200)]];
|
||||
constant bool align_N [[function_constant(201)]];
|
||||
constant bool align_K [[function_constant(202)]];
|
||||
|
||||
constant bool do_gather [[function_constant(300)]];
|
||||
|
||||
constant bool gather_bias = do_gather && use_out_source;
|
||||
|
||||
// clang-format off
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
typename AccumType = float>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
const device T* C [[buffer(2), function_constant(use_out_source)]],
|
||||
device T* D [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant size_t* batch_strides [[buffer(7)]],
|
||||
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
|
||||
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
|
||||
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
|
||||
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
|
||||
const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]],
|
||||
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on
|
||||
// Pacifying compiler
|
||||
(void)lid;
|
||||
|
||||
using gemm_kernel = GEMMKernel<
|
||||
T,
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
true,
|
||||
true,
|
||||
AccumType>;
|
||||
|
||||
using loader_a_t = typename gemm_kernel::loader_a_t;
|
||||
using loader_b_t = typename gemm_kernel::loader_b_t;
|
||||
using mma_t = typename gemm_kernel::mma_t;
|
||||
|
||||
// Find block
|
||||
const int tid_y = ((tid.y) << params->swizzle_log) +
|
||||
((tid.x) & ((1 << params->swizzle_log) - 1));
|
||||
const int tid_x = (tid.x) >> params->swizzle_log;
|
||||
|
||||
// Exit early if out of bounds
|
||||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Adjust for batch
|
||||
|
||||
// Handle gather
|
||||
if (do_gather) {
|
||||
// Read indices
|
||||
uint32_t indx_A, indx_B, indx_C;
|
||||
|
||||
if (has_batch) {
|
||||
const constant size_t* indx_A_bstrides = batch_strides;
|
||||
const constant size_t* indx_B_bstrides =
|
||||
batch_strides + params->batch_ndim;
|
||||
|
||||
ulong2 indx_offsets = elem_to_loc_broadcast(
|
||||
tid.z,
|
||||
batch_shape,
|
||||
indx_A_bstrides,
|
||||
indx_B_bstrides,
|
||||
params->batch_ndim);
|
||||
indx_A = lhs_indices[indx_offsets.x];
|
||||
indx_B = rhs_indices[indx_offsets.y];
|
||||
|
||||
if (use_out_source) {
|
||||
const constant size_t* indx_C_bstrides =
|
||||
indx_B_bstrides + params->batch_ndim;
|
||||
auto indx_offset_C = elem_to_loc(
|
||||
tid.z, batch_shape, indx_C_bstrides, params->batch_ndim);
|
||||
indx_C = C_indices[indx_offset_C];
|
||||
}
|
||||
} else {
|
||||
indx_A = lhs_indices[params->batch_stride_a * tid.z];
|
||||
indx_B = rhs_indices[params->batch_stride_b * tid.z];
|
||||
|
||||
if (use_out_source) {
|
||||
indx_C = C_indices[addmm_params->batch_stride_c * tid.z];
|
||||
}
|
||||
}
|
||||
|
||||
// Translate indices to offsets
|
||||
int batch_ndim_A = operand_batch_ndim.x;
|
||||
const constant int* batch_shape_A = operand_shape;
|
||||
const constant size_t* batch_strides_A = operand_strides;
|
||||
A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A);
|
||||
|
||||
int batch_ndim_B = operand_batch_ndim.y;
|
||||
const constant int* batch_shape_B = batch_shape_A + batch_ndim_A;
|
||||
const constant size_t* batch_strides_B = batch_strides_A + batch_ndim_A;
|
||||
B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B);
|
||||
|
||||
if (use_out_source) {
|
||||
int batch_ndim_C = operand_batch_ndim.z;
|
||||
const constant int* batch_shape_C = batch_shape_B + batch_ndim_B;
|
||||
const constant size_t* batch_strides_C = batch_strides_B + batch_ndim_B;
|
||||
C += elem_to_loc(indx_C, batch_shape_C, batch_strides_C, batch_ndim_C);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Handle regular batch
|
||||
else {
|
||||
if (has_batch) {
|
||||
const constant size_t* A_bstrides = batch_strides;
|
||||
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
|
||||
|
||||
A += batch_offsets.x;
|
||||
B += batch_offsets.y;
|
||||
|
||||
if (use_out_source) {
|
||||
const constant size_t* C_bstrides = B_bstrides + params->batch_ndim;
|
||||
C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
|
||||
}
|
||||
} else {
|
||||
A += params->batch_stride_a * tid.z;
|
||||
B += params->batch_stride_b * tid.z;
|
||||
|
||||
if (use_out_source) {
|
||||
C += addmm_params->batch_stride_c * tid.z;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
D += params->batch_stride_d * tid.z;
|
||||
|
||||
// Prepare threadgroup memory
|
||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Find block in A, B, C
|
||||
const int c_row = tid_y * BM;
|
||||
const int c_col = tid_x * BN;
|
||||
const size_t c_row_long = size_t(c_row);
|
||||
const size_t c_col_long = size_t(c_col);
|
||||
|
||||
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
||||
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
||||
D += c_row_long * params->ldd + c_col_long;
|
||||
|
||||
if (use_out_source) {
|
||||
C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc;
|
||||
}
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||
|
||||
// Prepare threadgroup bounds
|
||||
const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
|
||||
const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
|
||||
|
||||
// Prepare iterations
|
||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||
|
||||
// Do unaligned K iterations first
|
||||
if (!align_K) {
|
||||
const int k_last = params->gemm_k_iterations_aligned * BK;
|
||||
const int k_remain = params->K - k_last;
|
||||
const size_t k_jump_a =
|
||||
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
|
||||
const size_t k_jump_b =
|
||||
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
|
||||
|
||||
// Move loader source ahead to end
|
||||
loader_a.src += k_jump_a;
|
||||
loader_b.src += k_jump_b;
|
||||
|
||||
// Load tile
|
||||
const short2 tile_dims_A =
|
||||
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||
const short2 tile_dims_B =
|
||||
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
||||
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Do matmul
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Reset source back to start
|
||||
loader_a.src -= k_jump_a;
|
||||
loader_b.src -= k_jump_b;
|
||||
}
|
||||
|
||||
const TransformAdd<AccumType, AccumType> epilogue_op_add(
|
||||
addmm_params->alpha, addmm_params->beta);
|
||||
const TransformAxpby<AccumType, AccumType> epilogue_op_axpby(
|
||||
addmm_params->alpha, addmm_params->beta);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MNK aligned loop
|
||||
if (align_M && align_N) {
|
||||
// Do gemm
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Do epilogue
|
||||
if (use_out_source) {
|
||||
if (do_axpby) {
|
||||
mma_op.apply_epilogue(
|
||||
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby);
|
||||
} else {
|
||||
mma_op.apply_epilogue(
|
||||
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add);
|
||||
}
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
return mma_op.store_result(D, params->ldd);
|
||||
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MN unaligned loop
|
||||
else { // Loop over K - unaligned case
|
||||
const int leftover_bk = 0;
|
||||
|
||||
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
|
||||
// Do gemm
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<true, true, true>{});
|
||||
|
||||
// Do epilogue
|
||||
if (use_out_source) {
|
||||
if (do_axpby) {
|
||||
mma_op.apply_epilogue(
|
||||
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby);
|
||||
} else {
|
||||
mma_op.apply_epilogue(
|
||||
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add);
|
||||
}
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
return mma_op.store_result(D, params->ldd);
|
||||
|
||||
} else if (align_N || tgp_bn == BN) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<false, true, true>{});
|
||||
|
||||
// Do epilogue
|
||||
if (use_out_source) {
|
||||
if (do_axpby) {
|
||||
mma_op.apply_epilogue_safe(
|
||||
C,
|
||||
addmm_params->ldc,
|
||||
addmm_params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op_axpby);
|
||||
} else {
|
||||
mma_op.apply_epilogue_safe(
|
||||
C,
|
||||
addmm_params->ldc,
|
||||
addmm_params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op_add);
|
||||
}
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
|
||||
} else if (align_M || tgp_bm == BM) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<true, false, true>{});
|
||||
|
||||
// Do epilogue
|
||||
if (use_out_source) {
|
||||
if (do_axpby) {
|
||||
mma_op.apply_epilogue_safe(
|
||||
C,
|
||||
addmm_params->ldc,
|
||||
addmm_params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op_axpby);
|
||||
} else {
|
||||
mma_op.apply_epilogue_safe(
|
||||
C,
|
||||
addmm_params->ldc,
|
||||
addmm_params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op_add);
|
||||
}
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
|
||||
} else {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<false, false, true>{});
|
||||
|
||||
// Do epilogue
|
||||
if (use_out_source) {
|
||||
if (do_axpby) {
|
||||
mma_op.apply_epilogue_safe(
|
||||
C,
|
||||
addmm_params->ldc,
|
||||
addmm_params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op_axpby);
|
||||
} else {
|
||||
mma_op.apply_epilogue_safe(
|
||||
C,
|
||||
addmm_params->ldc,
|
||||
addmm_params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op_add);
|
||||
}
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernel initializations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
template [[host_name("steel_gemm_fused_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] \
|
||||
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, float>( \
|
||||
@ -445,24 +27,23 @@ template <
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]); // clang-format on
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
|
||||
instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) // clang-format on
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2)
|
||||
|
||||
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
||||
|
||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
||||
// clang-format on
|
||||
|
719
mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h
Normal file
719
mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h
Normal file
@ -0,0 +1,719 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/defines.h"
|
||||
using namespace metal;
|
||||
using namespace mlx::steel;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct _NoMask {
|
||||
char x;
|
||||
|
||||
constexpr METAL_FUNC operator bool() {
|
||||
return true;
|
||||
}
|
||||
constexpr METAL_FUNC operator bool() const threadgroup {
|
||||
return true;
|
||||
}
|
||||
constexpr METAL_FUNC operator bool() const device {
|
||||
return true;
|
||||
}
|
||||
constexpr METAL_FUNC operator bool() const constant {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OutT, typename InT = OutT>
|
||||
struct ScaleOp {
|
||||
OutT scale;
|
||||
|
||||
METAL_FUNC OutT apply(InT x) const {
|
||||
return static_cast<OutT>(x) * scale;
|
||||
}
|
||||
};
|
||||
|
||||
typedef struct _NoMask nomask_t;
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename out_mask_t,
|
||||
typename op_mask_t,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool MN_aligned,
|
||||
bool K_aligned>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
||||
block_masked_gemm(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
device T* D [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant size_t* batch_strides [[buffer(7)]],
|
||||
const device out_mask_t* out_mask [[buffer(10)]],
|
||||
const device op_mask_t* lhs_mask [[buffer(11)]],
|
||||
const device op_mask_t* rhs_mask [[buffer(12)]],
|
||||
const constant int* mask_strides [[buffer(13)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
// Appease the compiler
|
||||
(void)lid;
|
||||
|
||||
static_assert(
|
||||
BM == BN,
|
||||
"block_masked_gemm must have the same block M and block N size");
|
||||
static_assert(BM % BK == 0, "block_masked_gemm must have BM % BK == 0");
|
||||
|
||||
constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
|
||||
constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
|
||||
|
||||
constexpr bool has_mul_operand_mask =
|
||||
has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
|
||||
constexpr bool has_mul_output_mask =
|
||||
has_output_mask && !metal::is_same_v<out_mask_t, bool>;
|
||||
|
||||
constexpr short k_mask_factor = short(BM / BK);
|
||||
|
||||
using gemm_kernel = GEMMKernel<
|
||||
T,
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
MN_aligned,
|
||||
K_aligned>;
|
||||
|
||||
const int tid_y = ((tid.y) << params->swizzle_log) +
|
||||
((tid.x) & ((1 << params->swizzle_log) - 1));
|
||||
const int tid_x = (tid.x) >> params->swizzle_log;
|
||||
|
||||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||
return;
|
||||
}
|
||||
|
||||
const constant size_t* mask_batch_strides =
|
||||
batch_strides + 2 * params->batch_ndim;
|
||||
|
||||
if (params->batch_ndim > 1) {
|
||||
if (has_output_mask) {
|
||||
out_mask += elem_to_loc(
|
||||
tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
|
||||
|
||||
mask_batch_strides += params->batch_ndim;
|
||||
}
|
||||
|
||||
if (has_operand_mask) {
|
||||
const constant size_t* mask_strides_lhs = mask_batch_strides;
|
||||
const constant size_t* mask_strides_rhs =
|
||||
mask_strides_lhs + params->batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z,
|
||||
batch_shape,
|
||||
mask_strides_lhs,
|
||||
mask_strides_rhs,
|
||||
params->batch_ndim);
|
||||
|
||||
lhs_mask += batch_offsets.x;
|
||||
rhs_mask += batch_offsets.y;
|
||||
}
|
||||
} else {
|
||||
if (has_output_mask) {
|
||||
out_mask += tid.z * mask_batch_strides[0];
|
||||
mask_batch_strides += params->batch_ndim;
|
||||
}
|
||||
|
||||
if (has_operand_mask) {
|
||||
lhs_mask += tid.z * mask_batch_strides[0];
|
||||
rhs_mask += tid.z * mask_batch_strides[params->batch_ndim];
|
||||
}
|
||||
}
|
||||
|
||||
// Adjust for batch
|
||||
if (params->batch_ndim > 1) {
|
||||
const constant size_t* A_bstrides = batch_strides;
|
||||
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
|
||||
|
||||
A += batch_offsets.x;
|
||||
B += batch_offsets.y;
|
||||
|
||||
} else {
|
||||
A += params->batch_stride_a * tid.z;
|
||||
B += params->batch_stride_b * tid.z;
|
||||
}
|
||||
|
||||
D += params->batch_stride_d * tid.z;
|
||||
|
||||
// Find block in A, B, C
|
||||
const int c_row = tid_y * BM;
|
||||
const int c_col = tid_x * BN;
|
||||
const size_t c_row_long = size_t(c_row);
|
||||
const size_t c_col_long = size_t(c_col);
|
||||
|
||||
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
||||
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
||||
D += c_row_long * params->ldd + c_col_long;
|
||||
|
||||
const constant int* out_mask_strides = mask_strides;
|
||||
const constant int* lhs_mask_strides =
|
||||
mask_strides + (has_output_mask ? 2 : 0);
|
||||
const constant int* rhs_mask_strides =
|
||||
lhs_mask_strides + (has_operand_mask ? 2 : 0);
|
||||
|
||||
const int out_mask_offset = !has_output_mask
|
||||
? 0
|
||||
: tid_y * out_mask_strides[1] + tid_x * out_mask_strides[0];
|
||||
int lhs_mask_offset = !has_operand_mask ? 0 : tid_y * lhs_mask_strides[1];
|
||||
int rhs_mask_offset = !has_operand_mask ? 0 : tid_x * rhs_mask_strides[0];
|
||||
const int lhs_mask_step = !has_operand_mask ? 0 : lhs_mask_strides[0];
|
||||
const int rhs_mask_step = !has_operand_mask ? 0 : rhs_mask_strides[1];
|
||||
short k_factor_cnt = k_mask_factor;
|
||||
|
||||
ScaleOp<float> out_mask_op;
|
||||
ScaleOp<T> lhs_mask_op;
|
||||
ScaleOp<T> rhs_mask_op;
|
||||
|
||||
if (has_output_mask) {
|
||||
auto mask_out = out_mask[out_mask_offset];
|
||||
|
||||
if (has_mul_output_mask) {
|
||||
out_mask_op.scale = float(mask_out);
|
||||
}
|
||||
|
||||
// Write zeros and return
|
||||
if (!mask_out) {
|
||||
constexpr short tgp_size = WM * WN * 32;
|
||||
constexpr short vec_size = 4;
|
||||
|
||||
// Tile threads in threadgroup
|
||||
constexpr short TN = BN / vec_size;
|
||||
constexpr short TM = tgp_size / TN;
|
||||
|
||||
const short thread_idx = simd_group_id * 32 + simd_lane_id;
|
||||
const short bi = thread_idx / TN;
|
||||
const short bj = vec_size * (thread_idx % TN);
|
||||
|
||||
D += bi * params->ldd + bj;
|
||||
|
||||
short tgp_bm = min(BM, params->M - c_row);
|
||||
short tgp_bn = min(BN, params->N - c_col);
|
||||
|
||||
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
||||
for (short ti = 0; ti < BM; ti += TM) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
D[ti * params->ldd + j] = T(0.);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
short jmax = tgp_bn - bj;
|
||||
jmax = jmax < vec_size ? jmax : vec_size;
|
||||
for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) {
|
||||
for (short j = 0; j < jmax; j++) {
|
||||
D[ti * params->ldd + j] = T(0.);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id);
|
||||
|
||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread typename gemm_kernel::loader_a_t loader_a(
|
||||
A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
thread typename gemm_kernel::loader_b_t loader_b(
|
||||
B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||
|
||||
// Prepare threadgroup bounds
|
||||
const short tgp_bm =
|
||||
MN_aligned ? short(BM) : short(min(BM, params->M - c_row));
|
||||
const short tgp_bn =
|
||||
MN_aligned ? short(BN) : short(min(BN, params->N - c_col));
|
||||
|
||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Do unaligned K iterations first
|
||||
if (!K_aligned) {
|
||||
const int k_last = params->gemm_k_iterations_aligned * BK;
|
||||
const int mask_idx_last = k_last / BM;
|
||||
|
||||
if (!has_operand_mask ||
|
||||
(bool(lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step]) &&
|
||||
bool(rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step]))) {
|
||||
if (has_mul_operand_mask) {
|
||||
lhs_mask_op.scale =
|
||||
lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step];
|
||||
rhs_mask_op.scale =
|
||||
rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step];
|
||||
}
|
||||
|
||||
// Move loader source ahead to end
|
||||
const int k_remain = params->K - k_last;
|
||||
const size_t k_jump_a =
|
||||
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
|
||||
const size_t k_jump_b =
|
||||
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
|
||||
|
||||
loader_a.src += k_jump_a;
|
||||
loader_b.src += k_jump_b;
|
||||
|
||||
// Load tile
|
||||
const short2 tile_dims_A =
|
||||
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||
const short2 tile_dims_B =
|
||||
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
||||
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
|
||||
if (has_mul_operand_mask) {
|
||||
loader_a.apply_inplace_op(lhs_mask_op);
|
||||
loader_b.apply_inplace_op(rhs_mask_op);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Do matmul
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Reset source back to start
|
||||
loader_a.src -= k_jump_a;
|
||||
loader_b.src -= k_jump_b;
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MNK aligned loop
|
||||
if (MN_aligned) {
|
||||
for (; gemm_k_iterations > 0; gemm_k_iterations--) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (!has_operand_mask ||
|
||||
(bool(lhs_mask[lhs_mask_offset]) &&
|
||||
bool(rhs_mask[rhs_mask_offset]))) {
|
||||
if (has_mul_operand_mask) {
|
||||
lhs_mask_op.scale = lhs_mask[lhs_mask_offset];
|
||||
rhs_mask_op.scale = rhs_mask[rhs_mask_offset];
|
||||
}
|
||||
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
if (has_mul_operand_mask) {
|
||||
loader_a.apply_inplace_op(lhs_mask_op);
|
||||
loader_b.apply_inplace_op(rhs_mask_op);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
|
||||
k_factor_cnt--;
|
||||
lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0;
|
||||
rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0;
|
||||
k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt;
|
||||
}
|
||||
|
||||
if (has_mul_output_mask) {
|
||||
mma_op.apply_epilogue(out_mask_op);
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
mma_op.store_result(D, params->ldd);
|
||||
return;
|
||||
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MN unaligned loop
|
||||
else {
|
||||
const bool M_aligned = (tgp_bm == BM);
|
||||
const bool N_aligned = (tgp_bn == BN);
|
||||
|
||||
const short2 tile_dims_A =
|
||||
transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
|
||||
const short2 tile_dims_B =
|
||||
transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
|
||||
|
||||
for (; gemm_k_iterations > 0; gemm_k_iterations--) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (!has_operand_mask ||
|
||||
(bool(lhs_mask[lhs_mask_offset]) &&
|
||||
bool(rhs_mask[rhs_mask_offset]))) {
|
||||
if (has_mul_operand_mask) {
|
||||
lhs_mask_op.scale = lhs_mask[lhs_mask_offset];
|
||||
rhs_mask_op.scale = rhs_mask[rhs_mask_offset];
|
||||
}
|
||||
|
||||
// Load elements into threadgroup
|
||||
if (M_aligned) {
|
||||
loader_a.load_unsafe();
|
||||
} else {
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
}
|
||||
|
||||
if (N_aligned) {
|
||||
loader_b.load_unsafe();
|
||||
} else {
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
}
|
||||
|
||||
if (has_mul_operand_mask) {
|
||||
loader_a.apply_inplace_op(lhs_mask_op);
|
||||
loader_b.apply_inplace_op(rhs_mask_op);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
|
||||
k_factor_cnt--;
|
||||
lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0;
|
||||
rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0;
|
||||
k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt;
|
||||
}
|
||||
|
||||
if (has_mul_output_mask) {
|
||||
mma_op.apply_epilogue(out_mask_op);
|
||||
}
|
||||
|
||||
if (M_aligned && N_aligned) {
|
||||
mma_op.store_result(D, params->ldd);
|
||||
} else {
|
||||
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool MN_aligned,
|
||||
bool K_aligned,
|
||||
bool has_operand_mask = false>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
||||
block_masked_gemm(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
device T* D [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant size_t* batch_strides [[buffer(7)]],
|
||||
const device bool* out_mask [[buffer(10)]],
|
||||
const device bool* lhs_mask [[buffer(11)]],
|
||||
const device bool* rhs_mask [[buffer(12)]],
|
||||
const constant int* mask_strides [[buffer(13)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
// Appease the compiler
|
||||
(void)lid;
|
||||
|
||||
using gemm_kernel = GEMMKernel<
|
||||
T,
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
MN_aligned,
|
||||
K_aligned>;
|
||||
|
||||
const int tid_y = ((tid.y) << params->swizzle_log) +
|
||||
((tid.x) & ((1 << params->swizzle_log) - 1));
|
||||
const int tid_x = (tid.x) >> params->swizzle_log;
|
||||
|
||||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (params->batch_ndim > 1) {
|
||||
const constant size_t* mask_batch_strides =
|
||||
batch_strides + 2 * params->batch_ndim;
|
||||
out_mask +=
|
||||
elem_to_loc(tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
|
||||
|
||||
if (has_operand_mask) {
|
||||
const constant size_t* mask_strides_lhs =
|
||||
mask_batch_strides + params->batch_ndim;
|
||||
const constant size_t* mask_strides_rhs =
|
||||
mask_strides_lhs + params->batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z,
|
||||
batch_shape,
|
||||
mask_strides_lhs,
|
||||
mask_strides_rhs,
|
||||
params->batch_ndim);
|
||||
|
||||
lhs_mask += batch_offsets.x;
|
||||
rhs_mask += batch_offsets.y;
|
||||
}
|
||||
} else {
|
||||
out_mask += tid.z * batch_strides[2 * params->batch_ndim];
|
||||
if (has_operand_mask) {
|
||||
lhs_mask += tid.z * batch_strides[3 * params->batch_ndim];
|
||||
rhs_mask += tid.z * batch_strides[4 * params->batch_ndim];
|
||||
}
|
||||
}
|
||||
|
||||
// Adjust for batch
|
||||
if (params->batch_ndim > 1) {
|
||||
const constant size_t* A_bstrides = batch_strides;
|
||||
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
|
||||
|
||||
A += batch_offsets.x;
|
||||
B += batch_offsets.y;
|
||||
|
||||
} else {
|
||||
A += params->batch_stride_a * tid.z;
|
||||
B += params->batch_stride_b * tid.z;
|
||||
}
|
||||
|
||||
D += params->batch_stride_d * tid.z;
|
||||
|
||||
// Find block in A, B, C
|
||||
const int c_row = tid_y * BM;
|
||||
const int c_col = tid_x * BN;
|
||||
const size_t c_row_long = size_t(c_row);
|
||||
const size_t c_col_long = size_t(c_col);
|
||||
|
||||
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
||||
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
||||
D += c_row_long * params->ldd + c_col_long;
|
||||
|
||||
bool mask_out = out_mask[tid_y * mask_strides[1] + tid_x * mask_strides[0]];
|
||||
|
||||
// Write zeros and return
|
||||
if (!mask_out) {
|
||||
constexpr short tgp_size = WM * WN * 32;
|
||||
constexpr short vec_size = 4;
|
||||
|
||||
// Tile threads in threadgroup
|
||||
constexpr short TN = BN / vec_size;
|
||||
constexpr short TM = tgp_size / TN;
|
||||
|
||||
const short thread_idx = simd_group_id * 32 + simd_lane_id;
|
||||
const short bi = thread_idx / TN;
|
||||
const short bj = vec_size * (thread_idx % TN);
|
||||
|
||||
D += bi * params->ldd + bj;
|
||||
|
||||
short tgp_bm = min(BM, params->M - c_row);
|
||||
short tgp_bn = min(BN, params->N - c_col);
|
||||
|
||||
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
||||
for (short ti = 0; ti < BM; ti += TM) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
D[ti * params->ldd + j] = T(0.);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
short jmax = tgp_bn - bj;
|
||||
jmax = jmax < vec_size ? jmax : vec_size;
|
||||
for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) {
|
||||
for (short j = 0; j < jmax; j++) {
|
||||
D[ti * params->ldd + j] = T(0.);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id);
|
||||
|
||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||
|
||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread typename gemm_kernel::loader_a_t loader_a(
|
||||
A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
thread typename gemm_kernel::loader_b_t loader_b(
|
||||
B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MNK aligned loop
|
||||
if (MN_aligned) {
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (!has_operand_mask ||
|
||||
(lhs_mask
|
||||
[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
|
||||
rhs_mask
|
||||
[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Loop tail
|
||||
if (!K_aligned) {
|
||||
if (!has_operand_mask ||
|
||||
(lhs_mask
|
||||
[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
|
||||
rhs_mask
|
||||
[(params->K / BM) * mask_strides[5] +
|
||||
tid_x * mask_strides[4]])) {
|
||||
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
||||
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
||||
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
mma_op.store_result(D, params->ldd);
|
||||
return;
|
||||
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MN unaligned loop
|
||||
else { // Loop over K - unaligned case
|
||||
short tgp_bm = min(BM, params->M - c_row);
|
||||
short tgp_bn = min(BN, params->N - c_col);
|
||||
short lbk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||
|
||||
bool M_aligned = (tgp_bm == BM);
|
||||
bool N_aligned = (tgp_bn == BN);
|
||||
|
||||
short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
|
||||
short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
|
||||
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (!has_operand_mask ||
|
||||
(lhs_mask
|
||||
[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
|
||||
rhs_mask
|
||||
[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
|
||||
// Load elements into threadgroup
|
||||
if (M_aligned) {
|
||||
loader_a.load_unsafe();
|
||||
} else {
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
}
|
||||
|
||||
if (N_aligned) {
|
||||
loader_b.load_unsafe();
|
||||
} else {
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
}
|
||||
|
||||
if (!K_aligned) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (!has_operand_mask ||
|
||||
(lhs_mask
|
||||
[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
|
||||
rhs_mask
|
||||
[(params->K / BM) * mask_strides[5] +
|
||||
tid_x * mask_strides[4]])) {
|
||||
short2 tile_dims_A_last =
|
||||
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
|
||||
short2 tile_dims_B_last =
|
||||
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
|
||||
|
||||
loader_a.load_safe(tile_dims_A_last);
|
||||
loader_b.load_safe(tile_dims_B_last);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
}
|
||||
|
||||
if (M_aligned && N_aligned) {
|
||||
mma_op.store_result(D, params->ldd);
|
||||
} else {
|
||||
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
}
|
||||
}
|
||||
}
|
@ -1,434 +1,10 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
using namespace mlx::steel;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct _NoMask {
|
||||
char x;
|
||||
|
||||
constexpr METAL_FUNC operator bool() {
|
||||
return true;
|
||||
}
|
||||
constexpr METAL_FUNC operator bool() const threadgroup {
|
||||
return true;
|
||||
}
|
||||
constexpr METAL_FUNC operator bool() const device {
|
||||
return true;
|
||||
}
|
||||
constexpr METAL_FUNC operator bool() const constant {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OutT, typename InT = OutT>
|
||||
struct ScaleOp {
|
||||
OutT scale;
|
||||
|
||||
METAL_FUNC OutT apply(InT x) const {
|
||||
return static_cast<OutT>(x) * scale;
|
||||
}
|
||||
};
|
||||
|
||||
typedef struct _NoMask nomask_t;
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename out_mask_t,
|
||||
typename op_mask_t,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool MN_aligned,
|
||||
bool K_aligned>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
||||
block_masked_gemm(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
device T* D [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant size_t* batch_strides [[buffer(7)]],
|
||||
const device out_mask_t* out_mask [[buffer(10)]],
|
||||
const device op_mask_t* lhs_mask [[buffer(11)]],
|
||||
const device op_mask_t* rhs_mask [[buffer(12)]],
|
||||
const constant int* mask_strides [[buffer(13)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
// Appease the compiler
|
||||
(void)lid;
|
||||
|
||||
static_assert(
|
||||
BM == BN,
|
||||
"block_masked_gemm must have the same block M and block N size");
|
||||
static_assert(BM % BK == 0, "block_masked_gemm must have BM % BK == 0");
|
||||
|
||||
constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
|
||||
constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
|
||||
|
||||
constexpr bool has_mul_operand_mask =
|
||||
has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
|
||||
constexpr bool has_mul_output_mask =
|
||||
has_output_mask && !metal::is_same_v<out_mask_t, bool>;
|
||||
|
||||
constexpr short k_mask_factor = short(BM / BK);
|
||||
|
||||
using gemm_kernel = GEMMKernel<
|
||||
T,
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
MN_aligned,
|
||||
K_aligned>;
|
||||
|
||||
const int tid_y = ((tid.y) << params->swizzle_log) +
|
||||
((tid.x) & ((1 << params->swizzle_log) - 1));
|
||||
const int tid_x = (tid.x) >> params->swizzle_log;
|
||||
|
||||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||
return;
|
||||
}
|
||||
|
||||
const constant size_t* mask_batch_strides =
|
||||
batch_strides + 2 * params->batch_ndim;
|
||||
|
||||
if (params->batch_ndim > 1) {
|
||||
if (has_output_mask) {
|
||||
out_mask += elem_to_loc(
|
||||
tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
|
||||
|
||||
mask_batch_strides += params->batch_ndim;
|
||||
}
|
||||
|
||||
if (has_operand_mask) {
|
||||
const constant size_t* mask_strides_lhs = mask_batch_strides;
|
||||
const constant size_t* mask_strides_rhs =
|
||||
mask_strides_lhs + params->batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z,
|
||||
batch_shape,
|
||||
mask_strides_lhs,
|
||||
mask_strides_rhs,
|
||||
params->batch_ndim);
|
||||
|
||||
lhs_mask += batch_offsets.x;
|
||||
rhs_mask += batch_offsets.y;
|
||||
}
|
||||
} else {
|
||||
if (has_output_mask) {
|
||||
out_mask += tid.z * mask_batch_strides[0];
|
||||
mask_batch_strides += params->batch_ndim;
|
||||
}
|
||||
|
||||
if (has_operand_mask) {
|
||||
lhs_mask += tid.z * mask_batch_strides[0];
|
||||
rhs_mask += tid.z * mask_batch_strides[params->batch_ndim];
|
||||
}
|
||||
}
|
||||
|
||||
// Adjust for batch
|
||||
if (params->batch_ndim > 1) {
|
||||
const constant size_t* A_bstrides = batch_strides;
|
||||
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
|
||||
|
||||
A += batch_offsets.x;
|
||||
B += batch_offsets.y;
|
||||
|
||||
} else {
|
||||
A += params->batch_stride_a * tid.z;
|
||||
B += params->batch_stride_b * tid.z;
|
||||
}
|
||||
|
||||
D += params->batch_stride_d * tid.z;
|
||||
|
||||
// Find block in A, B, C
|
||||
const int c_row = tid_y * BM;
|
||||
const int c_col = tid_x * BN;
|
||||
const size_t c_row_long = size_t(c_row);
|
||||
const size_t c_col_long = size_t(c_col);
|
||||
|
||||
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
||||
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
||||
D += c_row_long * params->ldd + c_col_long;
|
||||
|
||||
const constant int* out_mask_strides = mask_strides;
|
||||
const constant int* lhs_mask_strides =
|
||||
mask_strides + (has_output_mask ? 2 : 0);
|
||||
const constant int* rhs_mask_strides =
|
||||
lhs_mask_strides + (has_operand_mask ? 2 : 0);
|
||||
|
||||
const int out_mask_offset = !has_output_mask
|
||||
? 0
|
||||
: tid_y * out_mask_strides[1] + tid_x * out_mask_strides[0];
|
||||
int lhs_mask_offset = !has_operand_mask ? 0 : tid_y * lhs_mask_strides[1];
|
||||
int rhs_mask_offset = !has_operand_mask ? 0 : tid_x * rhs_mask_strides[0];
|
||||
const int lhs_mask_step = !has_operand_mask ? 0 : lhs_mask_strides[0];
|
||||
const int rhs_mask_step = !has_operand_mask ? 0 : rhs_mask_strides[1];
|
||||
short k_factor_cnt = k_mask_factor;
|
||||
|
||||
ScaleOp<float> out_mask_op;
|
||||
ScaleOp<T> lhs_mask_op;
|
||||
ScaleOp<T> rhs_mask_op;
|
||||
|
||||
if (has_output_mask) {
|
||||
auto mask_out = out_mask[out_mask_offset];
|
||||
|
||||
if (has_mul_output_mask) {
|
||||
out_mask_op.scale = float(mask_out);
|
||||
}
|
||||
|
||||
// Write zeros and return
|
||||
if (!mask_out) {
|
||||
constexpr short tgp_size = WM * WN * 32;
|
||||
constexpr short vec_size = 4;
|
||||
|
||||
// Tile threads in threadgroup
|
||||
constexpr short TN = BN / vec_size;
|
||||
constexpr short TM = tgp_size / TN;
|
||||
|
||||
const short thread_idx = simd_group_id * 32 + simd_lane_id;
|
||||
const short bi = thread_idx / TN;
|
||||
const short bj = vec_size * (thread_idx % TN);
|
||||
|
||||
D += bi * params->ldd + bj;
|
||||
|
||||
short tgp_bm = min(BM, params->M - c_row);
|
||||
short tgp_bn = min(BN, params->N - c_col);
|
||||
|
||||
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
||||
for (short ti = 0; ti < BM; ti += TM) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
D[ti * params->ldd + j] = T(0.);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
short jmax = tgp_bn - bj;
|
||||
jmax = jmax < vec_size ? jmax : vec_size;
|
||||
for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) {
|
||||
for (short j = 0; j < jmax; j++) {
|
||||
D[ti * params->ldd + j] = T(0.);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id);
|
||||
|
||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread typename gemm_kernel::loader_a_t loader_a(
|
||||
A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
thread typename gemm_kernel::loader_b_t loader_b(
|
||||
B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||
|
||||
// Prepare threadgroup bounds
|
||||
const short tgp_bm =
|
||||
MN_aligned ? short(BM) : short(min(BM, params->M - c_row));
|
||||
const short tgp_bn =
|
||||
MN_aligned ? short(BN) : short(min(BN, params->N - c_col));
|
||||
|
||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Do unaligned K iterations first
|
||||
if (!K_aligned) {
|
||||
const int k_last = params->gemm_k_iterations_aligned * BK;
|
||||
const int mask_idx_last = k_last / BM;
|
||||
|
||||
if (!has_operand_mask ||
|
||||
(bool(lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step]) &&
|
||||
bool(rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step]))) {
|
||||
if (has_mul_operand_mask) {
|
||||
lhs_mask_op.scale =
|
||||
lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step];
|
||||
rhs_mask_op.scale =
|
||||
rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step];
|
||||
}
|
||||
|
||||
// Move loader source ahead to end
|
||||
const int k_remain = params->K - k_last;
|
||||
const size_t k_jump_a =
|
||||
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
|
||||
const size_t k_jump_b =
|
||||
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
|
||||
|
||||
loader_a.src += k_jump_a;
|
||||
loader_b.src += k_jump_b;
|
||||
|
||||
// Load tile
|
||||
const short2 tile_dims_A =
|
||||
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||
const short2 tile_dims_B =
|
||||
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
||||
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
|
||||
if (has_mul_operand_mask) {
|
||||
loader_a.apply_inplace_op(lhs_mask_op);
|
||||
loader_b.apply_inplace_op(rhs_mask_op);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Do matmul
|
||||
mma_op.mma(As, Bs);
|
||||
|
||||
// Reset source back to start
|
||||
loader_a.src -= k_jump_a;
|
||||
loader_b.src -= k_jump_b;
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MNK aligned loop
|
||||
if (MN_aligned) {
|
||||
for (; gemm_k_iterations > 0; gemm_k_iterations--) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (!has_operand_mask ||
|
||||
(bool(lhs_mask[lhs_mask_offset]) &&
|
||||
bool(rhs_mask[rhs_mask_offset]))) {
|
||||
if (has_mul_operand_mask) {
|
||||
lhs_mask_op.scale = lhs_mask[lhs_mask_offset];
|
||||
rhs_mask_op.scale = rhs_mask[rhs_mask_offset];
|
||||
}
|
||||
|
||||
// Load elements into threadgroup
|
||||
loader_a.load_unsafe();
|
||||
loader_b.load_unsafe();
|
||||
|
||||
if (has_mul_operand_mask) {
|
||||
loader_a.apply_inplace_op(lhs_mask_op);
|
||||
loader_b.apply_inplace_op(rhs_mask_op);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
|
||||
k_factor_cnt--;
|
||||
lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0;
|
||||
rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0;
|
||||
k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt;
|
||||
}
|
||||
|
||||
if (has_mul_output_mask) {
|
||||
mma_op.apply_epilogue(out_mask_op);
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
mma_op.store_result(D, params->ldd);
|
||||
return;
|
||||
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MN unaligned loop
|
||||
else {
|
||||
const bool M_aligned = (tgp_bm == BM);
|
||||
const bool N_aligned = (tgp_bn == BN);
|
||||
|
||||
const short2 tile_dims_A =
|
||||
transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
|
||||
const short2 tile_dims_B =
|
||||
transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
|
||||
|
||||
for (; gemm_k_iterations > 0; gemm_k_iterations--) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (!has_operand_mask ||
|
||||
(bool(lhs_mask[lhs_mask_offset]) &&
|
||||
bool(rhs_mask[rhs_mask_offset]))) {
|
||||
if (has_mul_operand_mask) {
|
||||
lhs_mask_op.scale = lhs_mask[lhs_mask_offset];
|
||||
rhs_mask_op.scale = rhs_mask[rhs_mask_offset];
|
||||
}
|
||||
|
||||
// Load elements into threadgroup
|
||||
if (M_aligned) {
|
||||
loader_a.load_unsafe();
|
||||
} else {
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
}
|
||||
|
||||
if (N_aligned) {
|
||||
loader_b.load_unsafe();
|
||||
} else {
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
}
|
||||
|
||||
if (has_mul_operand_mask) {
|
||||
loader_a.apply_inplace_op(lhs_mask_op);
|
||||
loader_b.apply_inplace_op(rhs_mask_op);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Multiply and accumulate threadgroup elements
|
||||
mma_op.mma(As, Bs);
|
||||
}
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_a.next();
|
||||
loader_b.next();
|
||||
|
||||
k_factor_cnt--;
|
||||
lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0;
|
||||
rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0;
|
||||
k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt;
|
||||
}
|
||||
|
||||
if (has_mul_output_mask) {
|
||||
mma_op.apply_epilogue(out_mask_op);
|
||||
}
|
||||
|
||||
if (M_aligned && N_aligned) {
|
||||
mma_op.store_result(D, params->ldd);
|
||||
} else {
|
||||
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernel initializations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h"
|
||||
|
||||
#define instantiate_gemm( \
|
||||
outmaskname, \
|
||||
@ -483,7 +59,6 @@ block_masked_gemm(
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||
instantiate_gemm(bool_, bool, bool_, bool, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||
instantiate_gemm(iname, itype, iname, itype, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||
@ -492,28 +67,24 @@ block_masked_gemm(
|
||||
instantiate_gemm(nomask, nomask_t, bool_, bool, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||
instantiate_gemm(nomask, nomask_t, iname, itype, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||
instantiate_gemm(bool_, bool, nomask, nomask_t, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||
instantiate_gemm(iname, itype, nomask, nomask_t, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) // clang-format on
|
||||
instantiate_gemm(iname, itype, nomask, nomask_t, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned)
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
||||
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
||||
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on
|
||||
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
|
||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) // clang-format on
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2)
|
||||
|
||||
// clang-format off
|
||||
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
||||
instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on
|
||||
|
227
mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h
Normal file
227
mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h
Normal file
@ -0,0 +1,227 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
using namespace mlx::steel;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool MN_aligned,
|
||||
bool K_aligned>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm_splitk(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
device U* C [[buffer(2)]],
|
||||
const constant GEMMSpiltKParams* params [[buffer(3)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
(void)lid;
|
||||
|
||||
using gemm_kernel = GEMMKernel<
|
||||
T,
|
||||
U,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
MN_aligned,
|
||||
K_aligned>;
|
||||
using loader_a_t = typename gemm_kernel::loader_a_t;
|
||||
using loader_b_t = typename gemm_kernel::loader_b_t;
|
||||
using mma_t = typename gemm_kernel::mma_t;
|
||||
|
||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
|
||||
const int tid_x = tid.x;
|
||||
const int tid_y = tid.y;
|
||||
const int tid_z = tid.z;
|
||||
|
||||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Find block in A, B, C
|
||||
const int c_row = tid_y * BM;
|
||||
const int c_col = tid_x * BN;
|
||||
const int k_start = params->split_k_partition_size * tid_z;
|
||||
|
||||
const size_t c_row_long = size_t(c_row);
|
||||
const size_t c_col_long = size_t(c_col);
|
||||
const size_t k_start_long = size_t(k_start);
|
||||
|
||||
A += transpose_a ? (c_row_long + k_start_long * params->lda)
|
||||
: (k_start_long + c_row_long * params->lda);
|
||||
B += transpose_b ? (k_start_long + c_col_long * params->ldb)
|
||||
: (c_col_long + k_start_long * params->ldb);
|
||||
C += (size_t(params->split_k_partition_stride) * tid_z) +
|
||||
(c_row_long * params->ldc + c_col_long);
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||
|
||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||
|
||||
short tgp_bm = min(BM, params->M - c_row);
|
||||
short tgp_bn = min(BN, params->N - c_col);
|
||||
short leftover_bk = params->K % BK;
|
||||
|
||||
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<true, true, true>{});
|
||||
} else if (tgp_bn == BN) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<false, true, true>{});
|
||||
} else if (tgp_bm == BM) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<true, false, true>{});
|
||||
} else {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<false, false, true>{});
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if ((tid_z + 1) == (params->split_k_partitions)) {
|
||||
int gemm_k_iter_remaining =
|
||||
(params->K - (k_start + params->split_k_partition_size)) / BK;
|
||||
if (!K_aligned || gemm_k_iter_remaining > 0)
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iter_remaining,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<false, false, K_aligned>{});
|
||||
}
|
||||
|
||||
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
||||
mma_op.store_result(C, params->ldc);
|
||||
} else {
|
||||
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Split k accumulation kernel
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename AccT,
|
||||
typename OutT,
|
||||
typename Epilogue = TransformNone<OutT, AccT>>
|
||||
[[kernel]] void gemm_splitk_accum(
|
||||
const device AccT* C_split [[buffer(0)]],
|
||||
device OutT* D [[buffer(1)]],
|
||||
const constant int& k_partitions [[buffer(2)]],
|
||||
const constant int& partition_stride [[buffer(3)]],
|
||||
const constant int& ldd [[buffer(4)]],
|
||||
uint2 gid [[thread_position_in_grid]]) {
|
||||
// Ajust D and C
|
||||
D += gid.x + gid.y * size_t(ldd);
|
||||
C_split += gid.x + gid.y * size_t(ldd);
|
||||
|
||||
size_t offset = 0;
|
||||
AccT out = 0;
|
||||
|
||||
for (int i = 0; i < k_partitions; i++) {
|
||||
out += C_split[offset];
|
||||
offset += partition_stride;
|
||||
}
|
||||
|
||||
// Write output
|
||||
D[0] = Epilogue::apply(out);
|
||||
}
|
||||
|
||||
template <
|
||||
typename AccT,
|
||||
typename OutT,
|
||||
typename Epilogue = TransformAxpby<OutT, AccT>>
|
||||
[[kernel]] void gemm_splitk_accum_axpby(
|
||||
const device AccT* C_split [[buffer(0)]],
|
||||
device OutT* D [[buffer(1)]],
|
||||
const constant int& k_partitions [[buffer(2)]],
|
||||
const constant int& partition_stride [[buffer(3)]],
|
||||
const constant int& ldd [[buffer(4)]],
|
||||
const device OutT* C [[buffer(5)]],
|
||||
const constant int& ldc [[buffer(6)]],
|
||||
const constant int& fdc [[buffer(7)]],
|
||||
const constant float& alpha [[buffer(8)]],
|
||||
const constant float& beta [[buffer(9)]],
|
||||
uint2 gid [[thread_position_in_grid]]) {
|
||||
// Ajust D and C
|
||||
C += gid.x * size_t(fdc) + gid.y * size_t(ldc);
|
||||
D += gid.x + gid.y * size_t(ldd);
|
||||
C_split += gid.x + gid.y * size_t(ldd);
|
||||
|
||||
size_t offset = 0;
|
||||
AccT out = 0;
|
||||
|
||||
for (int i = 0; i < k_partitions; i++) {
|
||||
out += C_split[offset];
|
||||
offset += partition_stride;
|
||||
}
|
||||
|
||||
// Write output
|
||||
Epilogue op(alpha, beta);
|
||||
D[0] = op.apply(out, *C);
|
||||
}
|
@ -1,173 +1,9 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
|
||||
using namespace metal;
|
||||
using namespace mlx::steel;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool MN_aligned,
|
||||
bool K_aligned>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm_splitk(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
device U* C [[buffer(2)]],
|
||||
const constant GEMMSpiltKParams* params [[buffer(3)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
(void)lid;
|
||||
|
||||
using gemm_kernel = GEMMKernel<
|
||||
T,
|
||||
U,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
MN_aligned,
|
||||
K_aligned>;
|
||||
using loader_a_t = typename gemm_kernel::loader_a_t;
|
||||
using loader_b_t = typename gemm_kernel::loader_b_t;
|
||||
using mma_t = typename gemm_kernel::mma_t;
|
||||
|
||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
|
||||
const int tid_x = tid.x;
|
||||
const int tid_y = tid.y;
|
||||
const int tid_z = tid.z;
|
||||
|
||||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Find block in A, B, C
|
||||
const int c_row = tid_y * BM;
|
||||
const int c_col = tid_x * BN;
|
||||
const int k_start = params->split_k_partition_size * tid_z;
|
||||
|
||||
const size_t c_row_long = size_t(c_row);
|
||||
const size_t c_col_long = size_t(c_col);
|
||||
const size_t k_start_long = size_t(k_start);
|
||||
|
||||
A += transpose_a ? (c_row_long + k_start_long * params->lda)
|
||||
: (k_start_long + c_row_long * params->lda);
|
||||
B += transpose_b ? (k_start_long + c_col_long * params->ldb)
|
||||
: (c_col_long + k_start_long * params->ldb);
|
||||
C += (size_t(params->split_k_partition_stride) * tid_z) +
|
||||
(c_row_long * params->ldc + c_col_long);
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||
|
||||
// Prepare threadgroup mma operation
|
||||
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||
|
||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||
|
||||
short tgp_bm = min(BM, params->M - c_row);
|
||||
short tgp_bn = min(BN, params->N - c_col);
|
||||
short leftover_bk = params->K % BK;
|
||||
|
||||
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<true, true, true>{});
|
||||
} else if (tgp_bn == BN) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<false, true, true>{});
|
||||
} else if (tgp_bm == BM) {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<true, false, true>{});
|
||||
} else {
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iterations,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<false, false, true>{});
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if ((tid_z + 1) == (params->split_k_partitions)) {
|
||||
int gemm_k_iter_remaining =
|
||||
(params->K - (k_start + params->split_k_partition_size)) / BK;
|
||||
if (!K_aligned || gemm_k_iter_remaining > 0)
|
||||
gemm_kernel::gemm_loop(
|
||||
As,
|
||||
Bs,
|
||||
gemm_k_iter_remaining,
|
||||
loader_a,
|
||||
loader_b,
|
||||
mma_op,
|
||||
tgp_bm,
|
||||
tgp_bn,
|
||||
leftover_bk,
|
||||
LoopAlignment<false, false, K_aligned>{});
|
||||
}
|
||||
|
||||
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
||||
mma_op.store_result(C, params->ldc);
|
||||
} else {
|
||||
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernel initializations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h"
|
||||
|
||||
#define instantiate_gemm( \
|
||||
tname, \
|
||||
@ -210,97 +46,27 @@ template <
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
|
||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 16, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 16, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) // clang-format on
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
|
||||
|
||||
// clang-format off
|
||||
instantiate_gemm_shapes_helper(float16, half, float32, float);
|
||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float);
|
||||
|
||||
instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Split k accumulation kernel
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename AccT,
|
||||
typename OutT,
|
||||
typename Epilogue = TransformNone<OutT, AccT>>
|
||||
[[kernel]] void gemm_splitk_accum(
|
||||
const device AccT* C_split [[buffer(0)]],
|
||||
device OutT* D [[buffer(1)]],
|
||||
const constant int& k_partitions [[buffer(2)]],
|
||||
const constant int& partition_stride [[buffer(3)]],
|
||||
const constant int& ldd [[buffer(4)]],
|
||||
uint2 gid [[thread_position_in_grid]]) {
|
||||
// Ajust D and C
|
||||
D += gid.x + gid.y * size_t(ldd);
|
||||
C_split += gid.x + gid.y * size_t(ldd);
|
||||
|
||||
size_t offset = 0;
|
||||
AccT out = 0;
|
||||
|
||||
for (int i = 0; i < k_partitions; i++) {
|
||||
out += C_split[offset];
|
||||
offset += partition_stride;
|
||||
}
|
||||
|
||||
// Write output
|
||||
D[0] = Epilogue::apply(out);
|
||||
}
|
||||
|
||||
template <
|
||||
typename AccT,
|
||||
typename OutT,
|
||||
typename Epilogue = TransformAxpby<OutT, AccT>>
|
||||
[[kernel]] void gemm_splitk_accum_axpby(
|
||||
const device AccT* C_split [[buffer(0)]],
|
||||
device OutT* D [[buffer(1)]],
|
||||
const constant int& k_partitions [[buffer(2)]],
|
||||
const constant int& partition_stride [[buffer(3)]],
|
||||
const constant int& ldd [[buffer(4)]],
|
||||
const device OutT* C [[buffer(5)]],
|
||||
const constant int& ldc [[buffer(6)]],
|
||||
const constant int& fdc [[buffer(7)]],
|
||||
const constant float& alpha [[buffer(8)]],
|
||||
const constant float& beta [[buffer(9)]],
|
||||
uint2 gid [[thread_position_in_grid]]) {
|
||||
// Ajust D and C
|
||||
C += gid.x * size_t(fdc) + gid.y * size_t(ldc);
|
||||
D += gid.x + gid.y * size_t(ldd);
|
||||
C_split += gid.x + gid.y * size_t(ldd);
|
||||
|
||||
size_t offset = 0;
|
||||
AccT out = 0;
|
||||
|
||||
for (int i = 0; i < k_partitions; i++) {
|
||||
out += C_split[offset];
|
||||
offset += partition_stride;
|
||||
}
|
||||
|
||||
// Write output
|
||||
Epilogue op(alpha, beta);
|
||||
D[0] = op.apply(out, *C);
|
||||
}
|
||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
||||
|
||||
#define instantiate_accum(oname, otype, aname, atype) \
|
||||
template [[host_name("steel_gemm_splitk_accum_" #oname \
|
||||
@ -313,7 +79,7 @@ template <
|
||||
const constant int& ldd [[buffer(4)]], \
|
||||
uint2 gid [[thread_position_in_grid]]); \
|
||||
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname \
|
||||
"_axpby")]] [[kernel]] void \
|
||||
"_axbpy")]] [[kernel]] void \
|
||||
gemm_splitk_accum_axpby<atype, otype>( \
|
||||
const device atype* C_split [[buffer(0)]], \
|
||||
device otype* D [[buffer(1)]], \
|
||||
@ -327,7 +93,6 @@ template <
|
||||
const constant float& beta [[buffer(9)]], \
|
||||
uint2 gid [[thread_position_in_grid]]);
|
||||
|
||||
// clang-format off
|
||||
instantiate_accum(bfloat16, bfloat16_t, float32, float);
|
||||
instantiate_accum(float16, half, float32, float);
|
||||
instantiate_accum(float32, float, float32, float); // clang-format on
|
||||
instantiate_accum(float32, float, float32, float); // clang-format on
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
#include "mlx/backend/metal/kernels/steel/defines.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Loading helper
|
||||
@ -134,4 +134,4 @@ struct BlockLoader {
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
||||
} // namespace mlx
|
||||
|
@ -6,8 +6,8 @@
|
||||
#include <metal_simdgroup_matrix>
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/defines.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
@ -358,4 +358,4 @@ struct BlockMMA {
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
||||
} // namespace mlx
|
||||
|
@ -4,9 +4,6 @@
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
#define STEEL_CONST static constant constexpr const
|
||||
#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
|
||||
|
||||
METAL_FUNC ulong2 elem_to_loc_broadcast(
|
||||
uint elem,
|
||||
constant const int* shape,
|
||||
@ -42,4 +39,4 @@ METAL_FUNC ulong3 elem_to_loc_broadcast(
|
||||
loc_c += pos_in_dim * c_strides[i];
|
||||
}
|
||||
return ulong3(loc_a, loc_b, loc_c);
|
||||
}
|
||||
}
|
||||
|
@ -8,9 +8,10 @@
|
||||
OUTPUT_DIR=$1
|
||||
CC=$2
|
||||
SRC_DIR=$3
|
||||
SRC_NAME=$4
|
||||
SRC_FILE=$4
|
||||
CFLAGS=$5
|
||||
INPUT_FILE=${SRC_DIR}/mlx/backend/metal/kernels/${SRC_NAME}.h
|
||||
SRC_NAME=$(basename -- "${SRC_FILE}")
|
||||
INPUT_FILE=${SRC_DIR}/mlx/backend/metal/kernels/${SRC_FILE}.h
|
||||
OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp
|
||||
|
||||
mkdir -p $OUTPUT_DIR
|
||||
|
@ -7,6 +7,7 @@
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
|
||||
#include "mlx/backend/metal/matmul.h"
|
||||
@ -336,7 +337,19 @@ void steel_matmul_conv_groups(
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
|
||||
auto kernel = get_steel_gemm_fused_kernel(
|
||||
d,
|
||||
base_name,
|
||||
hash_name,
|
||||
func_consts,
|
||||
out,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@ -458,17 +471,31 @@ void steel_matmul(
|
||||
C_split.set_data(allocator::malloc_or_wait(C_split.nbytes()));
|
||||
copies.push_back(C_split);
|
||||
|
||||
bool mn_aligned = M % bm == 0 && N % bn == 0;
|
||||
bool k_aligned = K % bk == 0;
|
||||
std::ostringstream kname;
|
||||
kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n')
|
||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||
<< type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||
<< "_wm" << wm << "_wn" << wn << "_MN_"
|
||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
|
||||
<< ((K % bk == 0) ? "t" : "n") << "aligned";
|
||||
<< "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n")
|
||||
<< "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned";
|
||||
|
||||
// Encode and dispatch gemm kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel = get_steel_gemm_splitk_kernel(
|
||||
d,
|
||||
kname.str(),
|
||||
a,
|
||||
C_split,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
mn_aligned,
|
||||
k_aligned);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int tn = (N + bn - 1) / bn;
|
||||
@ -504,10 +531,11 @@ void steel_matmul(
|
||||
static_cast<const MTL::Resource*>(C_split.buffer().ptr());
|
||||
const class MTL::Resource* const resources[1] = {c_split_buf};
|
||||
compute_encoder->memoryBarrier(resources, 1);
|
||||
auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
|
||||
type_to_name(C_split);
|
||||
|
||||
auto kernel = d.get_kernel(
|
||||
"steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
|
||||
type_to_name(C_split));
|
||||
auto kernel = get_steel_gemm_splitk_accum_kernel(
|
||||
d, kernel_name, C_split, out, false);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Set the arguments for the kernel
|
||||
@ -587,7 +615,19 @@ void steel_matmul(
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
|
||||
auto kernel = get_steel_gemm_fused_kernel(
|
||||
d,
|
||||
base_name,
|
||||
hash_name,
|
||||
func_consts,
|
||||
out,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@ -1053,17 +1093,33 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
C_split.set_data(allocator::malloc_or_wait(C_split.nbytes()));
|
||||
copies.push_back(C_split);
|
||||
|
||||
bool mn_aligned = M % bm == 0 && N % bn == 0;
|
||||
bool k_aligned = K % bk == 0;
|
||||
|
||||
std::ostringstream kname;
|
||||
kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n')
|
||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||
<< type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||
<< "_wm" << wm << "_wn" << wn << "_MN_"
|
||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
|
||||
<< ((K % bk == 0) ? "t" : "n") << "aligned";
|
||||
<< "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n")
|
||||
<< "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned";
|
||||
|
||||
// Encode and dispatch gemm kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel = get_steel_gemm_splitk_kernel(
|
||||
d,
|
||||
kname.str(),
|
||||
a,
|
||||
C_split,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
mn_aligned,
|
||||
k_aligned);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int tn = (N + bn - 1) / bn;
|
||||
@ -1095,9 +1151,11 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Do accum kernel
|
||||
{
|
||||
auto kernel = d.get_kernel(
|
||||
"steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
|
||||
type_to_name(C_split) + "_axpby");
|
||||
auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
|
||||
type_to_name(C_split) + "_axbpy";
|
||||
auto kernel = get_steel_gemm_splitk_accum_kernel(
|
||||
d, kernel_name, C_split, out, true);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Set the arguments for the kernel
|
||||
@ -1182,7 +1240,19 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
|
||||
auto kernel = get_steel_gemm_fused_kernel(
|
||||
d,
|
||||
base_name,
|
||||
hash_name,
|
||||
func_consts,
|
||||
out,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@ -1348,6 +1418,8 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Determine dispatch kernel
|
||||
int bm = block_size_, bn = block_size_, bk = 16;
|
||||
int wm = 2, wn = 2;
|
||||
bool mn_aligned = M % bm == 0 && N % bn == 0;
|
||||
bool k_aligned = K % bk == 0;
|
||||
|
||||
// Prepare kernel name
|
||||
std::string out_mask_nm = has_out_mask ? type_to_name(inputs[2]) : "nomask";
|
||||
@ -1358,13 +1430,26 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
<< op_mask_nm << "_" << (transpose_a ? 't' : 'n')
|
||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||
<< "_wm" << wm << "_wn" << wn << "_MN_"
|
||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
|
||||
<< ((K % bk == 0) ? "t" : "n") << "aligned";
|
||||
<< "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n")
|
||||
<< "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned";
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel = get_steel_gemm_masked_kernel(
|
||||
d,
|
||||
kname.str(),
|
||||
out,
|
||||
has_out_mask ? std::optional<array>{inputs[2]} : std::nullopt,
|
||||
has_op_mask ? std::optional<array>{inputs.back()} : std::nullopt,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
mn_aligned,
|
||||
k_aligned);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Use problem size to determine threadblock swizzle
|
||||
@ -1720,7 +1805,19 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
|
||||
auto kernel = get_steel_gemm_fused_kernel(
|
||||
d,
|
||||
base_name,
|
||||
hash_name,
|
||||
func_consts,
|
||||
out,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
|
@ -103,4 +103,90 @@ MTL::ComputePipelineState* get_reduce_kernel(
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const array&,
|
||||
bool,
|
||||
bool,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int) {
|
||||
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array&,
|
||||
const array&,
|
||||
bool,
|
||||
bool,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
bool,
|
||||
bool) {
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array&,
|
||||
const array&,
|
||||
bool) {
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array&,
|
||||
const std::optional<array>& mask_out,
|
||||
const std::optional<array>& mask_op,
|
||||
bool,
|
||||
bool,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
bool,
|
||||
bool) {
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_conv_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array&,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
bool) {
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array&,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int) {
|
||||
return d.get_kernel(kernel_name);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -3968,9 +3968,9 @@ void init_ops(nb::module_& m) {
|
||||
a (array): Input array or scalar.
|
||||
b (array): Input array or scalar.
|
||||
block_size (int): Size of blocks to be masked. Must be ``32`` or ``64`` (default: ``64``)
|
||||
mask_out (array, optional): Boolean mask for output (default: ``None``)
|
||||
mask_lhs (array, optional): Boolean mask for a (default: ``None``)
|
||||
mask_rhs (array, optional): Boolean mask for b (default: ``None``)
|
||||
mask_out (array, optional): Mask for output (default: ``None``)
|
||||
mask_lhs (array, optional): Mask for a (default: ``None``)
|
||||
mask_rhs (array, optional): Mask for b (default: ``None``)
|
||||
|
||||
)pbdoc");
|
||||
m.def(
|
||||
|
@ -556,8 +556,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||
|
||||
# # Batched matmul with simple broadcast
|
||||
# Batched matmul with simple broadcast
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32)
|
||||
|
||||
@ -573,7 +572,6 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||
|
||||
# Matmul with vector
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 128)).astype(np.float32)
|
||||
|
Loading…
Reference in New Issue
Block a user