mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Option to JIT steel gemm / conv (#1139)
This commit is contained in:
parent
eab2685c67
commit
7e26fd8032
@ -1,4 +1,4 @@
|
|||||||
function(make_jit_source SRC_NAME)
|
function(make_jit_source SRC_FILE)
|
||||||
# This function takes a metal header file,
|
# This function takes a metal header file,
|
||||||
# runs the C preprocessesor on it, and makes
|
# runs the C preprocessesor on it, and makes
|
||||||
# the processed contents available as a string in a C++ function
|
# the processed contents available as a string in a C++ function
|
||||||
@ -9,6 +9,7 @@ function(make_jit_source SRC_NAME)
|
|||||||
#
|
#
|
||||||
# Additional arguments to this function are treated as dependencies
|
# Additional arguments to this function are treated as dependencies
|
||||||
# in the Cmake build system.
|
# in the Cmake build system.
|
||||||
|
get_filename_component(SRC_NAME ${SRC_FILE} NAME)
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
OUTPUT jit/${SRC_NAME}.cpp
|
OUTPUT jit/${SRC_NAME}.cpp
|
||||||
COMMAND /bin/bash
|
COMMAND /bin/bash
|
||||||
@ -16,10 +17,10 @@ function(make_jit_source SRC_NAME)
|
|||||||
${CMAKE_CURRENT_BINARY_DIR}/jit
|
${CMAKE_CURRENT_BINARY_DIR}/jit
|
||||||
${CMAKE_C_COMPILER}
|
${CMAKE_C_COMPILER}
|
||||||
${PROJECT_SOURCE_DIR}
|
${PROJECT_SOURCE_DIR}
|
||||||
${SRC_NAME}
|
${SRC_FILE}
|
||||||
"-D${MLX_METAL_VERSION}"
|
"-D${MLX_METAL_VERSION}"
|
||||||
DEPENDS make_compiled_preamble.sh
|
DEPENDS make_compiled_preamble.sh
|
||||||
kernels/${SRC_NAME}.h
|
kernels/${SRC_FILE}.h
|
||||||
${ARGN}
|
${ARGN}
|
||||||
)
|
)
|
||||||
add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp)
|
add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp)
|
||||||
@ -73,6 +74,39 @@ if (MLX_METAL_JIT)
|
|||||||
kernels/reduction/reduce_col.h
|
kernels/reduction/reduce_col.h
|
||||||
kernels/reduction/reduce_row.h
|
kernels/reduction/reduce_row.h
|
||||||
)
|
)
|
||||||
|
make_jit_source(
|
||||||
|
steel/gemm/gemm
|
||||||
|
kernels/steel/utils.h
|
||||||
|
kernels/steel/gemm/loader.h
|
||||||
|
kernels/steel/gemm/mma.h
|
||||||
|
kernels/steel/gemm/params.h
|
||||||
|
kernels/steel/gemm/transforms.h
|
||||||
|
)
|
||||||
|
make_jit_source(steel/gemm/kernels/steel_gemm_fused)
|
||||||
|
make_jit_source(
|
||||||
|
steel/gemm/kernels/steel_gemm_masked
|
||||||
|
kernels/steel/defines.h
|
||||||
|
)
|
||||||
|
make_jit_source(steel/gemm/kernels/steel_gemm_splitk)
|
||||||
|
make_jit_source(
|
||||||
|
steel/conv/conv
|
||||||
|
kernels/steel/utils.h
|
||||||
|
kernels/steel/defines.h
|
||||||
|
kernels/steel/gemm/mma.h
|
||||||
|
kernels/steel/gemm/transforms.h
|
||||||
|
kernels/steel/conv/params.h
|
||||||
|
kernels/steel/conv/loader.h
|
||||||
|
kernels/steel/conv/loaders/loader_channel_l.h
|
||||||
|
kernels/steel/conv/loaders/loader_channel_n.h
|
||||||
|
)
|
||||||
|
make_jit_source(
|
||||||
|
steel/conv/kernels/steel_conv
|
||||||
|
)
|
||||||
|
make_jit_source(
|
||||||
|
steel/conv/kernels/steel_conv_general
|
||||||
|
kernels/steel/defines.h
|
||||||
|
kernels/steel/conv/loaders/loader_general.h
|
||||||
|
)
|
||||||
else()
|
else()
|
||||||
target_sources(
|
target_sources(
|
||||||
mlx
|
mlx
|
||||||
|
@ -7,6 +7,7 @@
|
|||||||
|
|
||||||
#include "mlx/backend/metal/copy.h"
|
#include "mlx/backend/metal/copy.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/backend/metal/kernels.h"
|
||||||
#include "mlx/backend/metal/kernels/defines.h"
|
#include "mlx/backend/metal/kernels/defines.h"
|
||||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||||
#include "mlx/backend/metal/matmul.h"
|
#include "mlx/backend/metal/matmul.h"
|
||||||
@ -335,7 +336,17 @@ void implicit_gemm_conv_2D_gpu(
|
|||||||
|
|
||||||
// Encode and dispatch kernel
|
// Encode and dispatch kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto kernel = d.get_kernel(kname.str());
|
auto kernel = get_steel_conv_kernel(
|
||||||
|
d,
|
||||||
|
kname.str(),
|
||||||
|
out,
|
||||||
|
bm,
|
||||||
|
bn,
|
||||||
|
bk,
|
||||||
|
wm,
|
||||||
|
wn,
|
||||||
|
n_channel_specialization,
|
||||||
|
small_filter);
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
// Deduce grid launch dimensions
|
// Deduce grid launch dimensions
|
||||||
@ -488,7 +499,8 @@ void implicit_gemm_conv_2D_general_gpu(
|
|||||||
|
|
||||||
// Encode and dispatch kernel
|
// Encode and dispatch kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto kernel = d.get_kernel(kname.str());
|
auto kernel =
|
||||||
|
get_steel_conv_general_kernel(d, kname.str(), out, bm, bn, bk, wm, wn);
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
// Deduce grid launch dimensions
|
// Deduce grid launch dimensions
|
||||||
|
@ -23,4 +23,12 @@ const char* softmax();
|
|||||||
const char* sort();
|
const char* sort();
|
||||||
const char* reduce();
|
const char* reduce();
|
||||||
|
|
||||||
|
const char* gemm();
|
||||||
|
const char* steel_gemm_fused();
|
||||||
|
const char* steel_gemm_masked();
|
||||||
|
const char* steel_gemm_splitk();
|
||||||
|
const char* conv();
|
||||||
|
const char* steel_conv();
|
||||||
|
const char* steel_conv_general();
|
||||||
|
|
||||||
} // namespace mlx::core::metal
|
} // namespace mlx::core::metal
|
||||||
|
32
mlx/backend/metal/jit/steel_conv.h
Normal file
32
mlx/backend/metal/jit/steel_conv.h
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
constexpr std::string_view steel_conv_kernels = R"(
|
||||||
|
template [[host_name("{name}")]] [[kernel]] void
|
||||||
|
implicit_gemm_conv_2d<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}, {n_channels}, {small_filter}>(
|
||||||
|
const device {itype}* A [[buffer(0)]],
|
||||||
|
const device {itype}* B [[buffer(1)]],
|
||||||
|
device {itype}* C [[buffer(2)]],
|
||||||
|
const constant MLXConvParams<2>* params [[buffer(3)]],
|
||||||
|
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||||
|
)";
|
||||||
|
|
||||||
|
constexpr std::string_view steel_conv_general_kernels = R"(
|
||||||
|
template [[host_name("{name}")]] [[kernel]] void
|
||||||
|
implicit_gemm_conv_2d_general<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}>(
|
||||||
|
const device {itype}* A [[buffer(0)]],
|
||||||
|
const device {itype}* B [[buffer(1)]],
|
||||||
|
device {itype}* C [[buffer(2)]],
|
||||||
|
const constant MLXConvParams<2>* params [[buffer(3)]],
|
||||||
|
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
|
||||||
|
const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]],
|
||||||
|
const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]],
|
||||||
|
const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||||
|
)";
|
106
mlx/backend/metal/jit/steel_gemm.h
Normal file
106
mlx/backend/metal/jit/steel_gemm.h
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
constexpr std::string_view steel_gemm_fused_kernels = R"(
|
||||||
|
template [[host_name("{name}")]]
|
||||||
|
[[kernel]] void gemm<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}, {trans_a}, {trans_b}, float>(
|
||||||
|
const device {itype} *A [[buffer(0)]],
|
||||||
|
const device {itype} *B [[buffer(1)]],
|
||||||
|
const device {itype} *C [[buffer(2), function_constant(use_out_source)]],
|
||||||
|
device {itype} *D [[buffer(3)]],
|
||||||
|
const constant GEMMParams* params [[buffer(4)]],
|
||||||
|
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
|
||||||
|
const constant int* batch_shape [[buffer(6)]],
|
||||||
|
const constant size_t* batch_strides [[buffer(7)]],
|
||||||
|
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
|
||||||
|
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
|
||||||
|
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
|
||||||
|
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
|
||||||
|
const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]],
|
||||||
|
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]);
|
||||||
|
)";
|
||||||
|
|
||||||
|
constexpr std::string_view steel_gemm_masked_kernels = R"(
|
||||||
|
template [[host_name("{name}")]] [[kernel]] void
|
||||||
|
block_masked_gemm<
|
||||||
|
{itype},
|
||||||
|
{outmasktype},
|
||||||
|
{opmasktype},
|
||||||
|
{bm},
|
||||||
|
{bn},
|
||||||
|
{bk},
|
||||||
|
{wm},
|
||||||
|
{wn},
|
||||||
|
{trans_a},
|
||||||
|
{trans_b},
|
||||||
|
{mn_aligned},
|
||||||
|
{k_aligned}>(
|
||||||
|
const device {itype}* A [[buffer(0)]],
|
||||||
|
const device {itype}* B [[buffer(1)]],
|
||||||
|
device {itype}* D [[buffer(3)]],
|
||||||
|
const constant GEMMParams* params [[buffer(4)]],
|
||||||
|
const constant int* batch_shape [[buffer(6)]],
|
||||||
|
const constant size_t* batch_strides [[buffer(7)]],
|
||||||
|
const device {outmasktype}* out_mask [[buffer(10)]],
|
||||||
|
const device {opmasktype}* lhs_mask [[buffer(11)]],
|
||||||
|
const device {opmasktype}* rhs_mask [[buffer(12)]],
|
||||||
|
const constant int* mask_strides [[buffer(13)]],
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]);
|
||||||
|
)";
|
||||||
|
|
||||||
|
constexpr std::string_view steel_gemm_splitk_kernels = R"(
|
||||||
|
template [[host_name("{name}")]] [[kernel]] void
|
||||||
|
gemm_splitk<
|
||||||
|
{itype},
|
||||||
|
{otype},
|
||||||
|
{bm},
|
||||||
|
{bn},
|
||||||
|
{bk},
|
||||||
|
{wm},
|
||||||
|
{wn},
|
||||||
|
{trans_a},
|
||||||
|
{trans_b},
|
||||||
|
{mn_aligned},
|
||||||
|
{k_aligned}>(
|
||||||
|
const device {itype}* A [[buffer(0)]],
|
||||||
|
const device {itype}* B [[buffer(1)]],
|
||||||
|
device {otype}* C [[buffer(2)]],
|
||||||
|
const constant GEMMSpiltKParams* params [[buffer(3)]],
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]);
|
||||||
|
)";
|
||||||
|
|
||||||
|
constexpr std::string_view steel_gemm_splitk_accum_kernels = R"(
|
||||||
|
template [[host_name("{name}")]] [[kernel]] void
|
||||||
|
gemm_splitk_accum<{atype}, {otype}>(
|
||||||
|
const device {atype}* C_split [[buffer(0)]],
|
||||||
|
device {otype}* D [[buffer(1)]],
|
||||||
|
const constant int& k_partitions [[buffer(2)]],
|
||||||
|
const constant int& partition_stride [[buffer(3)]],
|
||||||
|
const constant int& ldd [[buffer(4)]],
|
||||||
|
uint2 gid [[thread_position_in_grid]]);
|
||||||
|
)";
|
||||||
|
|
||||||
|
constexpr std::string_view steel_gemm_splitk_accum_axbpy_kernels = R"(
|
||||||
|
template [[host_name("{name}")]] [[kernel]] void
|
||||||
|
gemm_splitk_accum_axpby<{atype}, {otype}>(
|
||||||
|
const device {atype}* C_split [[buffer(0)]],
|
||||||
|
device {otype}* D [[buffer(1)]],
|
||||||
|
const constant int& k_partitions [[buffer(2)]],
|
||||||
|
const constant int& partition_stride [[buffer(3)]],
|
||||||
|
const constant int& ldd [[buffer(4)]],
|
||||||
|
const device {otype}* C [[buffer(5)]],
|
||||||
|
const constant int& ldc [[buffer(6)]],
|
||||||
|
const constant int& fdc [[buffer(7)]],
|
||||||
|
const constant float& alpha [[buffer(8)]],
|
||||||
|
const constant float& beta [[buffer(9)]],
|
||||||
|
uint2 gid [[thread_position_in_grid]]);
|
||||||
|
)";
|
@ -1,5 +1,4 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
|
|
||||||
#include "mlx/backend/common/compiled.h"
|
#include "mlx/backend/common/compiled.h"
|
||||||
@ -12,11 +11,15 @@
|
|||||||
#include "mlx/backend/metal/jit/scan.h"
|
#include "mlx/backend/metal/jit/scan.h"
|
||||||
#include "mlx/backend/metal/jit/softmax.h"
|
#include "mlx/backend/metal/jit/softmax.h"
|
||||||
#include "mlx/backend/metal/jit/sort.h"
|
#include "mlx/backend/metal/jit/sort.h"
|
||||||
|
#include "mlx/backend/metal/jit/steel_conv.h"
|
||||||
|
#include "mlx/backend/metal/jit/steel_gemm.h"
|
||||||
#include "mlx/backend/metal/jit/ternary.h"
|
#include "mlx/backend/metal/jit/ternary.h"
|
||||||
#include "mlx/backend/metal/jit/unary.h"
|
#include "mlx/backend/metal/jit/unary.h"
|
||||||
#include "mlx/backend/metal/kernels.h"
|
#include "mlx/backend/metal/kernels.h"
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
|
|
||||||
|
using namespace fmt::literals;
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
std::string op_name(const array& arr) {
|
std::string op_name(const array& arr) {
|
||||||
@ -276,4 +279,208 @@ MTL::ComputePipelineState* get_reduce_kernel(
|
|||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const std::string& hash_name,
|
||||||
|
const metal::MTLFCList& func_consts,
|
||||||
|
const array& out,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
int bm,
|
||||||
|
int bn,
|
||||||
|
int bk,
|
||||||
|
int wm,
|
||||||
|
int wn) {
|
||||||
|
const auto& lib_name = kernel_name;
|
||||||
|
auto lib = d.get_library(lib_name);
|
||||||
|
if (lib == nullptr) {
|
||||||
|
std::ostringstream kernel_source;
|
||||||
|
kernel_source << metal::utils() << metal::gemm()
|
||||||
|
<< metal::steel_gemm_fused()
|
||||||
|
<< fmt::format(
|
||||||
|
steel_gemm_fused_kernels,
|
||||||
|
"name"_a = lib_name,
|
||||||
|
"itype"_a = get_type_string(out.dtype()),
|
||||||
|
"bm"_a = bm,
|
||||||
|
"bn"_a = bn,
|
||||||
|
"bk"_a = bk,
|
||||||
|
"wm"_a = wm,
|
||||||
|
"wn"_a = wn,
|
||||||
|
"trans_a"_a = transpose_a,
|
||||||
|
"trans_b"_a = transpose_b);
|
||||||
|
lib = d.get_library(lib_name, kernel_source.str());
|
||||||
|
}
|
||||||
|
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const array& in,
|
||||||
|
const array& out,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
int bm,
|
||||||
|
int bn,
|
||||||
|
int bk,
|
||||||
|
int wm,
|
||||||
|
int wn,
|
||||||
|
bool mn_aligned,
|
||||||
|
bool k_aligned) {
|
||||||
|
const auto& lib_name = kernel_name;
|
||||||
|
auto lib = d.get_library(lib_name);
|
||||||
|
if (lib == nullptr) {
|
||||||
|
std::ostringstream kernel_source;
|
||||||
|
kernel_source << metal::utils() << metal::gemm()
|
||||||
|
<< metal::steel_gemm_splitk()
|
||||||
|
<< fmt::format(
|
||||||
|
steel_gemm_splitk_kernels,
|
||||||
|
"name"_a = lib_name,
|
||||||
|
"itype"_a = get_type_string(in.dtype()),
|
||||||
|
"otype"_a = get_type_string(out.dtype()),
|
||||||
|
"bm"_a = bm,
|
||||||
|
"bn"_a = bn,
|
||||||
|
"bk"_a = bk,
|
||||||
|
"wm"_a = wm,
|
||||||
|
"wn"_a = wn,
|
||||||
|
"trans_a"_a = transpose_a,
|
||||||
|
"trans_b"_a = transpose_b,
|
||||||
|
"mn_aligned"_a = mn_aligned,
|
||||||
|
"k_aligned"_a = k_aligned);
|
||||||
|
lib = d.get_library(lib_name, kernel_source.str());
|
||||||
|
}
|
||||||
|
return d.get_kernel(kernel_name, lib);
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const array& in,
|
||||||
|
const array& out,
|
||||||
|
bool axbpy) {
|
||||||
|
const auto& lib_name = kernel_name;
|
||||||
|
auto lib = d.get_library(lib_name);
|
||||||
|
if (lib == nullptr) {
|
||||||
|
std::ostringstream kernel_source;
|
||||||
|
kernel_source << metal::utils() << metal::gemm()
|
||||||
|
<< metal::steel_gemm_splitk()
|
||||||
|
<< fmt::format(
|
||||||
|
axbpy ? steel_gemm_splitk_accum_axbpy_kernels
|
||||||
|
: steel_gemm_splitk_accum_kernels,
|
||||||
|
"name"_a = lib_name,
|
||||||
|
"atype"_a = get_type_string(in.dtype()),
|
||||||
|
"otype"_a = get_type_string(out.dtype()));
|
||||||
|
lib = d.get_library(lib_name, kernel_source.str());
|
||||||
|
}
|
||||||
|
return d.get_kernel(kernel_name, lib);
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const array& out,
|
||||||
|
const std::optional<array>& mask_out,
|
||||||
|
const std::optional<array>& mask_op,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
int bm,
|
||||||
|
int bn,
|
||||||
|
int bk,
|
||||||
|
int wm,
|
||||||
|
int wn,
|
||||||
|
bool mn_aligned,
|
||||||
|
bool k_aligned) {
|
||||||
|
const auto& lib_name = kernel_name;
|
||||||
|
auto lib = d.get_library(lib_name);
|
||||||
|
if (lib == nullptr) {
|
||||||
|
std::ostringstream kernel_source;
|
||||||
|
auto out_mask_type = mask_out.has_value()
|
||||||
|
? get_type_string((*mask_out).dtype())
|
||||||
|
: "nomask_t";
|
||||||
|
auto op_mask_type =
|
||||||
|
mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t";
|
||||||
|
kernel_source << metal::utils() << metal::gemm()
|
||||||
|
<< metal::steel_gemm_masked()
|
||||||
|
<< fmt::format(
|
||||||
|
steel_gemm_masked_kernels,
|
||||||
|
"name"_a = lib_name,
|
||||||
|
"itype"_a = get_type_string(out.dtype()),
|
||||||
|
"outmasktype"_a = out_mask_type,
|
||||||
|
"opmasktype"_a = op_mask_type,
|
||||||
|
"bm"_a = bm,
|
||||||
|
"bn"_a = bn,
|
||||||
|
"bk"_a = bk,
|
||||||
|
"wm"_a = wm,
|
||||||
|
"wn"_a = wn,
|
||||||
|
"trans_a"_a = transpose_a,
|
||||||
|
"trans_b"_a = transpose_b,
|
||||||
|
"mn_aligned"_a = mn_aligned,
|
||||||
|
"k_aligned"_a = k_aligned);
|
||||||
|
lib = d.get_library(lib_name, kernel_source.str());
|
||||||
|
}
|
||||||
|
return d.get_kernel(kernel_name, lib);
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_steel_conv_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const array& out,
|
||||||
|
int bm,
|
||||||
|
int bn,
|
||||||
|
int bk,
|
||||||
|
int wm,
|
||||||
|
int wn,
|
||||||
|
int n_channel_specialization,
|
||||||
|
bool small_filter) {
|
||||||
|
const auto& lib_name = kernel_name;
|
||||||
|
auto lib = d.get_library(lib_name);
|
||||||
|
if (lib == nullptr) {
|
||||||
|
std::ostringstream kernel_source;
|
||||||
|
kernel_source << metal::utils() << metal::conv() << metal::steel_conv()
|
||||||
|
<< fmt::format(
|
||||||
|
steel_conv_kernels,
|
||||||
|
"name"_a = lib_name,
|
||||||
|
"itype"_a = get_type_string(out.dtype()),
|
||||||
|
"bm"_a = bm,
|
||||||
|
"bn"_a = bn,
|
||||||
|
"bk"_a = bk,
|
||||||
|
"wm"_a = wm,
|
||||||
|
"wn"_a = wn,
|
||||||
|
"n_channels"_a = n_channel_specialization,
|
||||||
|
"small_filter"_a = small_filter);
|
||||||
|
lib = d.get_library(lib_name, kernel_source.str());
|
||||||
|
}
|
||||||
|
return d.get_kernel(kernel_name, lib);
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const array& out,
|
||||||
|
int bm,
|
||||||
|
int bn,
|
||||||
|
int bk,
|
||||||
|
int wm,
|
||||||
|
int wn) {
|
||||||
|
const auto& lib_name = kernel_name;
|
||||||
|
auto lib = d.get_library(lib_name);
|
||||||
|
if (lib == nullptr) {
|
||||||
|
std::ostringstream kernel_source;
|
||||||
|
kernel_source << metal::utils() << metal::conv()
|
||||||
|
<< metal::steel_conv_general()
|
||||||
|
<< fmt::format(
|
||||||
|
steel_conv_general_kernels,
|
||||||
|
"name"_a = lib_name,
|
||||||
|
"itype"_a = get_type_string(out.dtype()),
|
||||||
|
"bm"_a = bm,
|
||||||
|
"bn"_a = bn,
|
||||||
|
"bk"_a = bk,
|
||||||
|
"wm"_a = wm,
|
||||||
|
"wn"_a = wn);
|
||||||
|
lib = d.get_library(lib_name, kernel_source.str());
|
||||||
|
}
|
||||||
|
return d.get_kernel(kernel_name, lib);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -79,4 +79,78 @@ MTL::ComputePipelineState* get_reduce_kernel(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const array& out);
|
const array& out);
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const std::string& hash_name,
|
||||||
|
const metal::MTLFCList& func_consts,
|
||||||
|
const array& out,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
int bm,
|
||||||
|
int bn,
|
||||||
|
int bk,
|
||||||
|
int wm,
|
||||||
|
int wn);
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const array& in,
|
||||||
|
const array& out,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
int bm,
|
||||||
|
int bn,
|
||||||
|
int bk,
|
||||||
|
int wm,
|
||||||
|
int wn,
|
||||||
|
bool mn_aligned,
|
||||||
|
bool k_aligned);
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const array& in,
|
||||||
|
const array& out,
|
||||||
|
bool axbpy);
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const array& out,
|
||||||
|
const std::optional<array>& mask_out,
|
||||||
|
const std::optional<array>& mask_op,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
int bm,
|
||||||
|
int bn,
|
||||||
|
int bk,
|
||||||
|
int wm,
|
||||||
|
int wn,
|
||||||
|
bool mn_aligned,
|
||||||
|
bool k_aligned);
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_steel_conv_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const array& out,
|
||||||
|
int bm,
|
||||||
|
int bn,
|
||||||
|
int bk,
|
||||||
|
int wm,
|
||||||
|
int wn,
|
||||||
|
int n_channel_specialization,
|
||||||
|
bool small_filter);
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const array& out,
|
||||||
|
int bm,
|
||||||
|
int bn,
|
||||||
|
int bk,
|
||||||
|
int wm,
|
||||||
|
int wn);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -1,14 +1,13 @@
|
|||||||
set(
|
set(
|
||||||
HEADERS
|
HEADERS
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/atomic.h
|
bf16.h
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/bf16.h
|
bf16_math.h
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/bf16_math.h
|
complex.h
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/complex.h
|
defines.h
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/defines.h
|
utils.h
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.h
|
steel/conv/params.h
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
set(
|
set(
|
||||||
KERNELS
|
KERNELS
|
||||||
"arg_reduce"
|
"arg_reduce"
|
||||||
@ -41,6 +40,7 @@ set(
|
|||||||
set(
|
set(
|
||||||
HEADERS
|
HEADERS
|
||||||
${HEADERS}
|
${HEADERS}
|
||||||
|
atomic.h
|
||||||
arange.h
|
arange.h
|
||||||
unary_ops.h
|
unary_ops.h
|
||||||
unary.h
|
unary.h
|
||||||
@ -89,14 +89,40 @@ foreach(KERNEL ${KERNELS})
|
|||||||
set(KERNEL_AIR ${KERNEL}.air ${KERNEL_AIR})
|
set(KERNEL_AIR ${KERNEL}.air ${KERNEL_AIR})
|
||||||
endforeach()
|
endforeach()
|
||||||
|
|
||||||
file(GLOB_RECURSE STEEL_KERNELS ${CMAKE_CURRENT_SOURCE_DIR}/steel/*.metal)
|
if (NOT MLX_METAL_JIT)
|
||||||
file(GLOB_RECURSE STEEL_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/steel/*.h)
|
set(
|
||||||
|
STEEL_KERNELS
|
||||||
foreach(KERNEL ${STEEL_KERNELS})
|
${CMAKE_CURRENT_SOURCE_DIR}/steel/conv/kernels/steel_conv.metal
|
||||||
cmake_path(GET KERNEL STEM TARGET)
|
${CMAKE_CURRENT_SOURCE_DIR}/steel/conv/kernels/steel_conv_general.metal
|
||||||
build_kernel_base(${TARGET} ${KERNEL} "${STEEL_HEADERS}")
|
${CMAKE_CURRENT_SOURCE_DIR}/steel/gemm/kernels/steel_gemm_fused.metal
|
||||||
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR})
|
${CMAKE_CURRENT_SOURCE_DIR}/steel/gemm/kernels/steel_gemm_masked.metal
|
||||||
endforeach()
|
${CMAKE_CURRENT_SOURCE_DIR}/steel/gemm/kernels/steel_gemm_splitk.metal
|
||||||
|
)
|
||||||
|
set(
|
||||||
|
STEEL_HEADERS
|
||||||
|
steel/defines.h
|
||||||
|
steel/utils.h
|
||||||
|
steel/conv/conv.h
|
||||||
|
steel/conv/loader.h
|
||||||
|
steel/conv/loaders/loader_channel_l.h
|
||||||
|
steel/conv/loaders/loader_channel_n.h
|
||||||
|
steel/conv/loaders/loader_general.h
|
||||||
|
steel/conv/kernels/steel_conv.h
|
||||||
|
steel/conv/kernels/steel_conv_general.h
|
||||||
|
steel/gemm/gemm.h
|
||||||
|
steel/gemm/mma.h
|
||||||
|
steel/gemm/loader.h
|
||||||
|
steel/gemm/transforms.h
|
||||||
|
steel/gemm/kernels/steel_gemm_fused.h
|
||||||
|
steel/gemm/kernels/steel_gemm_masked.h
|
||||||
|
steel/gemm/kernels/steel_gemm_splitk.h
|
||||||
|
)
|
||||||
|
foreach(KERNEL ${STEEL_KERNELS})
|
||||||
|
cmake_path(GET KERNEL STEM TARGET)
|
||||||
|
build_kernel_base(${TARGET} ${KERNEL} "${STEEL_HEADERS}")
|
||||||
|
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR})
|
||||||
|
endforeach()
|
||||||
|
endif()
|
||||||
|
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
|
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
#include <metal_atomic>
|
|
||||||
#include <metal_simdgroup>
|
#include <metal_simdgroup>
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
@ -194,4 +193,4 @@ instantiate_arg_reduce(int32, int32_t)
|
|||||||
instantiate_arg_reduce(int64, int64_t)
|
instantiate_arg_reduce(int64, int64_t)
|
||||||
instantiate_arg_reduce(float16, half)
|
instantiate_arg_reduce(float16, half)
|
||||||
instantiate_arg_reduce(float32, float)
|
instantiate_arg_reduce(float32, float)
|
||||||
instantiate_arg_reduce(bfloat16, bfloat16_t) // clang-format on
|
instantiate_arg_reduce(bfloat16, bfloat16_t) // clang-format on
|
||||||
|
@ -650,4 +650,4 @@ winograd_conv_2d_output_transform(
|
|||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
instantiate_winograd_conv_2d(float32, float);
|
instantiate_winograd_conv_2d(float32, float);
|
||||||
instantiate_winograd_conv_2d(float16, half); // clang-format on
|
instantiate_winograd_conv_2d(float16, half); // clang-format on
|
||||||
|
@ -2,10 +2,12 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/steel/defines.h"
|
||||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/steel/conv/loader.h"
|
#include "mlx/backend/metal/kernels/steel/conv/loader.h"
|
||||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
using namespace mlx::steel;
|
using namespace mlx::steel;
|
||||||
|
176
mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h
Normal file
176
mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include <metal_stdlib>
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
int BM,
|
||||||
|
int BN,
|
||||||
|
int BK,
|
||||||
|
int WM,
|
||||||
|
int WN,
|
||||||
|
int N_CHANNELS = 0,
|
||||||
|
bool SMALL_FILTER = false>
|
||||||
|
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
||||||
|
implicit_gemm_conv_2d(
|
||||||
|
const device T* A [[buffer(0)]],
|
||||||
|
const device T* B [[buffer(1)]],
|
||||||
|
device T* C [[buffer(2)]],
|
||||||
|
const constant MLXConvParams<2>* params [[buffer(3)]],
|
||||||
|
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
using namespace mlx::steel;
|
||||||
|
|
||||||
|
(void)lid;
|
||||||
|
|
||||||
|
constexpr bool transpose_a = false;
|
||||||
|
constexpr bool transpose_b = true;
|
||||||
|
constexpr short tgp_padding_a = 16 / sizeof(T);
|
||||||
|
constexpr short tgp_padding_b = 16 / sizeof(T);
|
||||||
|
|
||||||
|
constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a;
|
||||||
|
constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b;
|
||||||
|
constexpr short shape_a_rows = (transpose_a ? BK : BM);
|
||||||
|
constexpr short shape_b_rows = (transpose_b ? BN : BK);
|
||||||
|
constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows;
|
||||||
|
constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;
|
||||||
|
|
||||||
|
constexpr short tgp_size = WM * WN * 32;
|
||||||
|
|
||||||
|
// Input loader
|
||||||
|
|
||||||
|
using loader_a_t = typename metal::conditional_t<
|
||||||
|
// Check for small channel specialization
|
||||||
|
N_CHANNELS != 0 && N_CHANNELS <= 4,
|
||||||
|
|
||||||
|
// Go to small channel specialization
|
||||||
|
Conv2DInputBlockLoaderSmallChannels<
|
||||||
|
T,
|
||||||
|
BM,
|
||||||
|
BN,
|
||||||
|
BK,
|
||||||
|
tgp_size,
|
||||||
|
N_CHANNELS,
|
||||||
|
tgp_padding_a>,
|
||||||
|
|
||||||
|
// Else go to general loader
|
||||||
|
typename metal::conditional_t<
|
||||||
|
// Check if filter size is small enough
|
||||||
|
SMALL_FILTER,
|
||||||
|
|
||||||
|
// Go to small filter specialization
|
||||||
|
Conv2DInputBlockLoaderSmallFilter<
|
||||||
|
T,
|
||||||
|
BM,
|
||||||
|
BN,
|
||||||
|
BK,
|
||||||
|
tgp_size,
|
||||||
|
tgp_padding_a>,
|
||||||
|
|
||||||
|
// Else go to large filter generalization
|
||||||
|
Conv2DInputBlockLoaderLargeFilter<
|
||||||
|
T,
|
||||||
|
BM,
|
||||||
|
BN,
|
||||||
|
BK,
|
||||||
|
tgp_size,
|
||||||
|
tgp_padding_a>>>;
|
||||||
|
|
||||||
|
// Weight loader
|
||||||
|
using loader_b_t = typename metal::conditional_t<
|
||||||
|
// Check for small channel specialization
|
||||||
|
N_CHANNELS != 0 && N_CHANNELS <= 4,
|
||||||
|
|
||||||
|
// Go to small channel specialization
|
||||||
|
Conv2DWeightBlockLoaderSmallChannels<
|
||||||
|
T,
|
||||||
|
BM,
|
||||||
|
BN,
|
||||||
|
BK,
|
||||||
|
tgp_size,
|
||||||
|
N_CHANNELS,
|
||||||
|
tgp_padding_b>,
|
||||||
|
|
||||||
|
// Else go to general loader
|
||||||
|
Conv2DWeightBlockLoader<T, BM, BN, BK, tgp_size, tgp_padding_b>>;
|
||||||
|
|
||||||
|
using mma_t = BlockMMA<
|
||||||
|
T,
|
||||||
|
T,
|
||||||
|
BM,
|
||||||
|
BN,
|
||||||
|
BK,
|
||||||
|
WM,
|
||||||
|
WN,
|
||||||
|
transpose_a,
|
||||||
|
transpose_b,
|
||||||
|
shape_a_cols,
|
||||||
|
shape_b_cols>;
|
||||||
|
|
||||||
|
threadgroup T As[tgp_mem_size_a];
|
||||||
|
threadgroup T Bs[tgp_mem_size_b];
|
||||||
|
|
||||||
|
const int tid_y = ((tid.y) << gemm_params->swizzle_log) +
|
||||||
|
((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
|
||||||
|
const int tid_x = (tid.x) >> gemm_params->swizzle_log;
|
||||||
|
|
||||||
|
if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int c_row = tid_y * BM;
|
||||||
|
const int c_col = tid_x * BN;
|
||||||
|
const int K = gemm_params->K;
|
||||||
|
const int N = gemm_params->N;
|
||||||
|
const int C_per_group = params->C / params->groups;
|
||||||
|
|
||||||
|
// Groups
|
||||||
|
A += tid.z * C_per_group;
|
||||||
|
B += tid.z * N * K;
|
||||||
|
C += tid.z * N;
|
||||||
|
|
||||||
|
B += c_col * K;
|
||||||
|
C += c_row * (N * params->groups) + c_col;
|
||||||
|
|
||||||
|
const int2 offsets_a(0, c_row);
|
||||||
|
const int2 offsets_b(0, c_col);
|
||||||
|
|
||||||
|
// Prepare threadgroup loading operations
|
||||||
|
loader_a_t loader_a(
|
||||||
|
A, As, offsets_a, params, gemm_params, simd_gid, simd_lid);
|
||||||
|
loader_b_t loader_b(
|
||||||
|
B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid);
|
||||||
|
|
||||||
|
// Prepare threadgroup mma operation
|
||||||
|
mma_t mma_op(simd_gid, simd_lid);
|
||||||
|
|
||||||
|
int gemm_k_iterations = gemm_params->gemm_k_iterations;
|
||||||
|
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
// Load elements into threadgroup
|
||||||
|
loader_a.load_unsafe();
|
||||||
|
loader_b.load_unsafe();
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Multiply and accumulate threadgroup elements
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
|
// Prepare for next iteration
|
||||||
|
loader_a.next();
|
||||||
|
loader_b.next();
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
// Store results to device memory
|
||||||
|
short tgp_bm = min(BM, gemm_params->M - c_row);
|
||||||
|
short tgp_bn = min(BN, gemm_params->N - c_col);
|
||||||
|
const int ldc = N * params->groups;
|
||||||
|
mma_op.store_result_safe(C, ldc, short2(tgp_bn, tgp_bm));
|
||||||
|
}
|
@ -2,184 +2,13 @@
|
|||||||
|
|
||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
|
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
|
||||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h"
|
||||||
using namespace metal;
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
int BM,
|
|
||||||
int BN,
|
|
||||||
int BK,
|
|
||||||
int WM,
|
|
||||||
int WN,
|
|
||||||
int N_CHANNELS = 0,
|
|
||||||
bool SMALL_FILTER = false>
|
|
||||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
|
||||||
implicit_gemm_conv_2d(
|
|
||||||
const device T* A [[buffer(0)]],
|
|
||||||
const device T* B [[buffer(1)]],
|
|
||||||
device T* C [[buffer(2)]],
|
|
||||||
const constant MLXConvParams<2>* params [[buffer(3)]],
|
|
||||||
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
|
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
|
||||||
uint3 lid [[thread_position_in_threadgroup]],
|
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
||||||
using namespace mlx::steel;
|
|
||||||
|
|
||||||
(void)lid;
|
|
||||||
|
|
||||||
constexpr bool transpose_a = false;
|
|
||||||
constexpr bool transpose_b = true;
|
|
||||||
constexpr short tgp_padding_a = 16 / sizeof(T);
|
|
||||||
constexpr short tgp_padding_b = 16 / sizeof(T);
|
|
||||||
|
|
||||||
constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a;
|
|
||||||
constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b;
|
|
||||||
constexpr short shape_a_rows = (transpose_a ? BK : BM);
|
|
||||||
constexpr short shape_b_rows = (transpose_b ? BN : BK);
|
|
||||||
constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows;
|
|
||||||
constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;
|
|
||||||
|
|
||||||
constexpr short tgp_size = WM * WN * 32;
|
|
||||||
|
|
||||||
// Input loader
|
|
||||||
|
|
||||||
using loader_a_t = typename metal::conditional_t<
|
|
||||||
// Check for small channel specialization
|
|
||||||
N_CHANNELS != 0 && N_CHANNELS <= 4,
|
|
||||||
|
|
||||||
// Go to small channel specialization
|
|
||||||
Conv2DInputBlockLoaderSmallChannels<
|
|
||||||
T,
|
|
||||||
BM,
|
|
||||||
BN,
|
|
||||||
BK,
|
|
||||||
tgp_size,
|
|
||||||
N_CHANNELS,
|
|
||||||
tgp_padding_a>,
|
|
||||||
|
|
||||||
// Else go to general loader
|
|
||||||
typename metal::conditional_t<
|
|
||||||
// Check if filter size is small enough
|
|
||||||
SMALL_FILTER,
|
|
||||||
|
|
||||||
// Go to small filter specialization
|
|
||||||
Conv2DInputBlockLoaderSmallFilter<
|
|
||||||
T,
|
|
||||||
BM,
|
|
||||||
BN,
|
|
||||||
BK,
|
|
||||||
tgp_size,
|
|
||||||
tgp_padding_a>,
|
|
||||||
|
|
||||||
// Else go to large filter generalization
|
|
||||||
Conv2DInputBlockLoaderLargeFilter<
|
|
||||||
T,
|
|
||||||
BM,
|
|
||||||
BN,
|
|
||||||
BK,
|
|
||||||
tgp_size,
|
|
||||||
tgp_padding_a>>>;
|
|
||||||
|
|
||||||
// Weight loader
|
|
||||||
using loader_b_t = typename metal::conditional_t<
|
|
||||||
// Check for small channel specialization
|
|
||||||
N_CHANNELS != 0 && N_CHANNELS <= 4,
|
|
||||||
|
|
||||||
// Go to small channel specialization
|
|
||||||
Conv2DWeightBlockLoaderSmallChannels<
|
|
||||||
T,
|
|
||||||
BM,
|
|
||||||
BN,
|
|
||||||
BK,
|
|
||||||
tgp_size,
|
|
||||||
N_CHANNELS,
|
|
||||||
tgp_padding_b>,
|
|
||||||
|
|
||||||
// Else go to general loader
|
|
||||||
Conv2DWeightBlockLoader<T, BM, BN, BK, tgp_size, tgp_padding_b>>;
|
|
||||||
|
|
||||||
using mma_t = BlockMMA<
|
|
||||||
T,
|
|
||||||
T,
|
|
||||||
BM,
|
|
||||||
BN,
|
|
||||||
BK,
|
|
||||||
WM,
|
|
||||||
WN,
|
|
||||||
transpose_a,
|
|
||||||
transpose_b,
|
|
||||||
shape_a_cols,
|
|
||||||
shape_b_cols>;
|
|
||||||
|
|
||||||
threadgroup T As[tgp_mem_size_a];
|
|
||||||
threadgroup T Bs[tgp_mem_size_b];
|
|
||||||
|
|
||||||
const int tid_y = ((tid.y) << gemm_params->swizzle_log) +
|
|
||||||
((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
|
|
||||||
const int tid_x = (tid.x) >> gemm_params->swizzle_log;
|
|
||||||
|
|
||||||
if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const int c_row = tid_y * BM;
|
|
||||||
const int c_col = tid_x * BN;
|
|
||||||
const int K = gemm_params->K;
|
|
||||||
const int N = gemm_params->N;
|
|
||||||
const int C_per_group = params->C / params->groups;
|
|
||||||
|
|
||||||
// Groups
|
|
||||||
A += tid.z * C_per_group;
|
|
||||||
B += tid.z * N * K;
|
|
||||||
C += tid.z * N;
|
|
||||||
|
|
||||||
B += c_col * K;
|
|
||||||
C += c_row * (N * params->groups) + c_col;
|
|
||||||
|
|
||||||
const int2 offsets_a(0, c_row);
|
|
||||||
const int2 offsets_b(0, c_col);
|
|
||||||
|
|
||||||
// Prepare threadgroup loading operations
|
|
||||||
loader_a_t loader_a(
|
|
||||||
A, As, offsets_a, params, gemm_params, simd_gid, simd_lid);
|
|
||||||
loader_b_t loader_b(
|
|
||||||
B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid);
|
|
||||||
|
|
||||||
// Prepare threadgroup mma operation
|
|
||||||
mma_t mma_op(simd_gid, simd_lid);
|
|
||||||
|
|
||||||
int gemm_k_iterations = gemm_params->gemm_k_iterations;
|
|
||||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
// Load elements into threadgroup
|
|
||||||
loader_a.load_unsafe();
|
|
||||||
loader_b.load_unsafe();
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// Multiply and accumulate threadgroup elements
|
|
||||||
mma_op.mma(As, Bs);
|
|
||||||
|
|
||||||
// Prepare for next iteration
|
|
||||||
loader_a.next();
|
|
||||||
loader_b.next();
|
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
|
||||||
|
|
||||||
// Store results to device memory
|
|
||||||
short tgp_bm = min(BM, gemm_params->M - c_row);
|
|
||||||
short tgp_bn = min(BN, gemm_params->N - c_col);
|
|
||||||
const int ldc = N * params->groups;
|
|
||||||
mma_op.store_result_safe(C, ldc, short2(tgp_bn, tgp_bm));
|
|
||||||
}
|
|
||||||
|
|
||||||
#define instantiate_implicit_conv_2d( \
|
#define instantiate_implicit_conv_2d( \
|
||||||
name, \
|
name, \
|
||||||
@ -207,25 +36,22 @@ implicit_gemm_conv_2d(
|
|||||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \
|
#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, s, true) \
|
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, s, true) \
|
||||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, l, false) \
|
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, l, false) \
|
||||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 1, 1, l, false) \
|
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 1, 1, l, false) \
|
||||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 2, 2, l, false) \
|
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 2, 2, l, false) \
|
||||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 3, 3, l, false) \
|
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 3, 3, l, false) \
|
||||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 4, 4, l, false) // clang-format on
|
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 4, 4, l, false)
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define instantiate_implicit_2d_blocks(name, itype) \
|
#define instantiate_implicit_2d_blocks(name, itype) \
|
||||||
instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \
|
instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \
|
||||||
instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \
|
instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \
|
||||||
instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \
|
instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \
|
||||||
instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \
|
instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \
|
||||||
instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \
|
instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \
|
||||||
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2) // clang-format on
|
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2)
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
instantiate_implicit_2d_blocks(float32, float);
|
instantiate_implicit_2d_blocks(float32, float);
|
||||||
instantiate_implicit_2d_blocks(float16, half);
|
instantiate_implicit_2d_blocks(float16, half);
|
||||||
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); // clang-format on
|
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); // clang-format on
|
||||||
|
@ -0,0 +1,188 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h"
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
int BM,
|
||||||
|
int BN,
|
||||||
|
int BK,
|
||||||
|
int WM,
|
||||||
|
int WN,
|
||||||
|
typename AccumType = float,
|
||||||
|
typename Epilogue = TransformNone<T, AccumType>>
|
||||||
|
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
||||||
|
implicit_gemm_conv_2d_general(
|
||||||
|
const device T* A [[buffer(0)]],
|
||||||
|
const device T* B [[buffer(1)]],
|
||||||
|
device T* C [[buffer(2)]],
|
||||||
|
const constant MLXConvParams<2>* params [[buffer(3)]],
|
||||||
|
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
|
||||||
|
const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]],
|
||||||
|
const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]],
|
||||||
|
const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
(void)lid;
|
||||||
|
|
||||||
|
constexpr bool transpose_a = false;
|
||||||
|
constexpr bool transpose_b = true;
|
||||||
|
constexpr short tgp_padding_a = 16 / sizeof(T);
|
||||||
|
constexpr short tgp_padding_b = 16 / sizeof(T);
|
||||||
|
|
||||||
|
constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a;
|
||||||
|
constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b;
|
||||||
|
constexpr short shape_a_rows = (transpose_a ? BK : BM);
|
||||||
|
constexpr short shape_b_rows = (transpose_b ? BN : BK);
|
||||||
|
constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows;
|
||||||
|
constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;
|
||||||
|
|
||||||
|
constexpr short tgp_size = WM * WN * 32;
|
||||||
|
|
||||||
|
// Input loader
|
||||||
|
using loader_a_t =
|
||||||
|
Conv2DInputBlockLoaderGeneral<T, BM, BN, BK, tgp_size, tgp_padding_a>;
|
||||||
|
|
||||||
|
// Weight loader
|
||||||
|
using loader_b_t =
|
||||||
|
Conv2DWeightBlockLoaderGeneral<T, BM, BN, BK, tgp_size, tgp_padding_b>;
|
||||||
|
|
||||||
|
using mma_t = BlockMMA<
|
||||||
|
T,
|
||||||
|
T,
|
||||||
|
BM,
|
||||||
|
BN,
|
||||||
|
BK,
|
||||||
|
WM,
|
||||||
|
WN,
|
||||||
|
transpose_a,
|
||||||
|
transpose_b,
|
||||||
|
shape_a_cols,
|
||||||
|
shape_b_cols>;
|
||||||
|
|
||||||
|
threadgroup T As[tgp_mem_size_a];
|
||||||
|
threadgroup T Bs[tgp_mem_size_b];
|
||||||
|
|
||||||
|
const int tid_y = ((tid.y) << gemm_params->swizzle_log) +
|
||||||
|
((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
|
||||||
|
const int tid_x = (tid.x) >> gemm_params->swizzle_log;
|
||||||
|
|
||||||
|
if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int tid_z = tid.z;
|
||||||
|
|
||||||
|
const int base_oh = tid_z / jump_params->f_out_jump_w;
|
||||||
|
const int base_ow = tid_z % jump_params->f_out_jump_w;
|
||||||
|
|
||||||
|
const int base_wh = base_h[base_oh].weight_base;
|
||||||
|
const int base_ww = base_w[base_ow].weight_base;
|
||||||
|
|
||||||
|
const int base_wh_size = base_h[base_oh].weight_size;
|
||||||
|
const int base_ww_size = base_w[base_ow].weight_size;
|
||||||
|
|
||||||
|
const int c_row = tid_y * BM;
|
||||||
|
const int c_col = tid_x * BN;
|
||||||
|
const int K = gemm_params->K;
|
||||||
|
|
||||||
|
B += c_col * K;
|
||||||
|
|
||||||
|
const int4 offsets_a(0, c_row, base_oh, base_ow);
|
||||||
|
const int2 offsets_b(0, c_col);
|
||||||
|
|
||||||
|
// Prepare threadgroup loading operations
|
||||||
|
loader_a_t loader_a(
|
||||||
|
A,
|
||||||
|
As,
|
||||||
|
offsets_a,
|
||||||
|
params,
|
||||||
|
jump_params,
|
||||||
|
base_wh,
|
||||||
|
base_ww,
|
||||||
|
simd_gid,
|
||||||
|
simd_lid);
|
||||||
|
loader_b_t loader_b(
|
||||||
|
B,
|
||||||
|
Bs,
|
||||||
|
offsets_b,
|
||||||
|
params,
|
||||||
|
jump_params,
|
||||||
|
base_wh,
|
||||||
|
base_ww,
|
||||||
|
simd_gid,
|
||||||
|
simd_lid);
|
||||||
|
|
||||||
|
// Prepare threadgroup mma operation
|
||||||
|
mma_t mma_op(simd_gid, simd_lid);
|
||||||
|
|
||||||
|
int gemm_k_iterations =
|
||||||
|
base_wh_size * base_ww_size * gemm_params->gemm_k_iterations;
|
||||||
|
|
||||||
|
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
// Load elements into threadgroup
|
||||||
|
loader_a.load_unsafe();
|
||||||
|
loader_b.load_unsafe();
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Multiply and accumulate threadgroup elements
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
|
// Prepare for next iteration
|
||||||
|
loader_a.next();
|
||||||
|
loader_b.next();
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
// Store results to device memory
|
||||||
|
{
|
||||||
|
// Adjust for simdgroup and thread locatio
|
||||||
|
int offset_m = c_row + mma_op.sm + mma_op.tm;
|
||||||
|
int offset_n = c_col + mma_op.sn + mma_op.tn;
|
||||||
|
C += offset_n;
|
||||||
|
|
||||||
|
if (offset_n >= gemm_params->N)
|
||||||
|
return;
|
||||||
|
|
||||||
|
short diff = gemm_params->N - offset_n;
|
||||||
|
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < mma_t::TM; i++) {
|
||||||
|
int cm = offset_m + i * mma_t::TM_stride;
|
||||||
|
|
||||||
|
int n = cm / jump_params->adj_out_hw;
|
||||||
|
int hw = cm % jump_params->adj_out_hw;
|
||||||
|
int oh =
|
||||||
|
(hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + base_oh;
|
||||||
|
int ow =
|
||||||
|
(hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow;
|
||||||
|
|
||||||
|
if (n < params->N && oh < params->oS[0] && ow < params->oS[1]) {
|
||||||
|
int offset_cm = n * params->out_strides[0] +
|
||||||
|
oh * params->out_strides[1] + ow * params->out_strides[2];
|
||||||
|
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (int j = 0; j < mma_t::TN; j++) {
|
||||||
|
// Get accumulated result and associated offset in C
|
||||||
|
thread const auto& accum =
|
||||||
|
mma_op.results[i * mma_t::TN + j].thread_elements();
|
||||||
|
int offset = offset_cm + (j * mma_t::TN_stride);
|
||||||
|
|
||||||
|
// Apply epilogue and output C
|
||||||
|
if (j * mma_t::TN_stride < diff) {
|
||||||
|
C[offset] = Epilogue::apply(accum[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (j * mma_t::TN_stride + 1 < diff) {
|
||||||
|
C[offset + 1] = Epilogue::apply(accum[1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -2,201 +2,18 @@
|
|||||||
|
|
||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
|
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
|
||||||
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h"
|
|
||||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h"
|
||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
using namespace mlx::steel;
|
using namespace mlx::steel;
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
int BM,
|
|
||||||
int BN,
|
|
||||||
int BK,
|
|
||||||
int WM,
|
|
||||||
int WN,
|
|
||||||
typename AccumType = float,
|
|
||||||
typename Epilogue = TransformNone<T, AccumType>>
|
|
||||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
|
||||||
implicit_gemm_conv_2d_general(
|
|
||||||
const device T* A [[buffer(0)]],
|
|
||||||
const device T* B [[buffer(1)]],
|
|
||||||
device T* C [[buffer(2)]],
|
|
||||||
const constant MLXConvParams<2>* params [[buffer(3)]],
|
|
||||||
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
|
|
||||||
const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]],
|
|
||||||
const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]],
|
|
||||||
const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]],
|
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
|
||||||
uint3 lid [[thread_position_in_threadgroup]],
|
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
||||||
(void)lid;
|
|
||||||
|
|
||||||
constexpr bool transpose_a = false;
|
|
||||||
constexpr bool transpose_b = true;
|
|
||||||
constexpr short tgp_padding_a = 16 / sizeof(T);
|
|
||||||
constexpr short tgp_padding_b = 16 / sizeof(T);
|
|
||||||
|
|
||||||
constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a;
|
|
||||||
constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b;
|
|
||||||
constexpr short shape_a_rows = (transpose_a ? BK : BM);
|
|
||||||
constexpr short shape_b_rows = (transpose_b ? BN : BK);
|
|
||||||
constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows;
|
|
||||||
constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;
|
|
||||||
|
|
||||||
constexpr short tgp_size = WM * WN * 32;
|
|
||||||
|
|
||||||
// Input loader
|
|
||||||
using loader_a_t =
|
|
||||||
Conv2DInputBlockLoaderGeneral<T, BM, BN, BK, tgp_size, tgp_padding_a>;
|
|
||||||
|
|
||||||
// Weight loader
|
|
||||||
using loader_b_t =
|
|
||||||
Conv2DWeightBlockLoaderGeneral<T, BM, BN, BK, tgp_size, tgp_padding_b>;
|
|
||||||
|
|
||||||
using mma_t = BlockMMA<
|
|
||||||
T,
|
|
||||||
T,
|
|
||||||
BM,
|
|
||||||
BN,
|
|
||||||
BK,
|
|
||||||
WM,
|
|
||||||
WN,
|
|
||||||
transpose_a,
|
|
||||||
transpose_b,
|
|
||||||
shape_a_cols,
|
|
||||||
shape_b_cols>;
|
|
||||||
|
|
||||||
threadgroup T As[tgp_mem_size_a];
|
|
||||||
threadgroup T Bs[tgp_mem_size_b];
|
|
||||||
|
|
||||||
const int tid_y = ((tid.y) << gemm_params->swizzle_log) +
|
|
||||||
((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
|
|
||||||
const int tid_x = (tid.x) >> gemm_params->swizzle_log;
|
|
||||||
|
|
||||||
if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const int tid_z = tid.z;
|
|
||||||
|
|
||||||
const int base_oh = tid_z / jump_params->f_out_jump_w;
|
|
||||||
const int base_ow = tid_z % jump_params->f_out_jump_w;
|
|
||||||
|
|
||||||
const int base_wh = base_h[base_oh].weight_base;
|
|
||||||
const int base_ww = base_w[base_ow].weight_base;
|
|
||||||
|
|
||||||
const int base_wh_size = base_h[base_oh].weight_size;
|
|
||||||
const int base_ww_size = base_w[base_ow].weight_size;
|
|
||||||
|
|
||||||
const int c_row = tid_y * BM;
|
|
||||||
const int c_col = tid_x * BN;
|
|
||||||
const int K = gemm_params->K;
|
|
||||||
|
|
||||||
B += c_col * K;
|
|
||||||
|
|
||||||
const int4 offsets_a(0, c_row, base_oh, base_ow);
|
|
||||||
const int2 offsets_b(0, c_col);
|
|
||||||
|
|
||||||
// Prepare threadgroup loading operations
|
|
||||||
loader_a_t loader_a(
|
|
||||||
A,
|
|
||||||
As,
|
|
||||||
offsets_a,
|
|
||||||
params,
|
|
||||||
jump_params,
|
|
||||||
base_wh,
|
|
||||||
base_ww,
|
|
||||||
simd_gid,
|
|
||||||
simd_lid);
|
|
||||||
loader_b_t loader_b(
|
|
||||||
B,
|
|
||||||
Bs,
|
|
||||||
offsets_b,
|
|
||||||
params,
|
|
||||||
jump_params,
|
|
||||||
base_wh,
|
|
||||||
base_ww,
|
|
||||||
simd_gid,
|
|
||||||
simd_lid);
|
|
||||||
|
|
||||||
// Prepare threadgroup mma operation
|
|
||||||
mma_t mma_op(simd_gid, simd_lid);
|
|
||||||
|
|
||||||
int gemm_k_iterations =
|
|
||||||
base_wh_size * base_ww_size * gemm_params->gemm_k_iterations;
|
|
||||||
|
|
||||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
// Load elements into threadgroup
|
|
||||||
loader_a.load_unsafe();
|
|
||||||
loader_b.load_unsafe();
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// Multiply and accumulate threadgroup elements
|
|
||||||
mma_op.mma(As, Bs);
|
|
||||||
|
|
||||||
// Prepare for next iteration
|
|
||||||
loader_a.next();
|
|
||||||
loader_b.next();
|
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
|
||||||
|
|
||||||
// Store results to device memory
|
|
||||||
{
|
|
||||||
// Adjust for simdgroup and thread locatio
|
|
||||||
int offset_m = c_row + mma_op.sm + mma_op.tm;
|
|
||||||
int offset_n = c_col + mma_op.sn + mma_op.tn;
|
|
||||||
C += offset_n;
|
|
||||||
|
|
||||||
if (offset_n >= gemm_params->N)
|
|
||||||
return;
|
|
||||||
|
|
||||||
short diff = gemm_params->N - offset_n;
|
|
||||||
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (int i = 0; i < mma_t::TM; i++) {
|
|
||||||
int cm = offset_m + i * mma_t::TM_stride;
|
|
||||||
|
|
||||||
int n = cm / jump_params->adj_out_hw;
|
|
||||||
int hw = cm % jump_params->adj_out_hw;
|
|
||||||
int oh =
|
|
||||||
(hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + base_oh;
|
|
||||||
int ow =
|
|
||||||
(hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow;
|
|
||||||
|
|
||||||
if (n < params->N && oh < params->oS[0] && ow < params->oS[1]) {
|
|
||||||
int offset_cm = n * params->out_strides[0] +
|
|
||||||
oh * params->out_strides[1] + ow * params->out_strides[2];
|
|
||||||
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (int j = 0; j < mma_t::TN; j++) {
|
|
||||||
// Get accumulated result and associated offset in C
|
|
||||||
thread const auto& accum =
|
|
||||||
mma_op.results[i * mma_t::TN + j].thread_elements();
|
|
||||||
int offset = offset_cm + (j * mma_t::TN_stride);
|
|
||||||
|
|
||||||
// Apply epilogue and output C
|
|
||||||
if (j * mma_t::TN_stride < diff) {
|
|
||||||
C[offset] = Epilogue::apply(accum[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (j * mma_t::TN_stride + 1 < diff) {
|
|
||||||
C[offset + 1] = Epilogue::apply(accum[1]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) \
|
#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) \
|
||||||
template \
|
template \
|
||||||
[[host_name("implicit_gemm_conv_2d_general_" #name "_bm" #bm "_bn" #bn \
|
[[host_name("implicit_gemm_conv_2d_general_" #name "_bm" #bm "_bn" #bn \
|
||||||
@ -218,16 +35,14 @@ implicit_gemm_conv_2d_general(
|
|||||||
#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \
|
#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn)
|
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn)
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define instantiate_implicit_2d_blocks(name, itype) \
|
#define instantiate_implicit_2d_blocks(name, itype) \
|
||||||
instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \
|
instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \
|
||||||
instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \
|
instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \
|
||||||
instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \
|
instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \
|
||||||
instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \
|
instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \
|
||||||
instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \
|
instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \
|
||||||
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2) // clang-format on
|
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2)
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
instantiate_implicit_2d_blocks(float32, float);
|
instantiate_implicit_2d_blocks(float32, float);
|
||||||
instantiate_implicit_2d_blocks(float16, half);
|
instantiate_implicit_2d_blocks(float16, half);
|
||||||
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); // clang-format on
|
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); // clang-format on
|
||||||
|
@ -2,9 +2,7 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
#include "mlx/backend/metal/kernels/steel/defines.h"
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Loading helper
|
// Loading helper
|
||||||
@ -285,4 +283,4 @@ struct Conv2DWeightBlockLoaderGeneral {
|
|||||||
};
|
};
|
||||||
|
|
||||||
} // namespace steel
|
} // namespace steel
|
||||||
} // namespace mlx
|
} // namespace mlx
|
||||||
|
4
mlx/backend/metal/kernels/steel/defines.h
Normal file
4
mlx/backend/metal/kernels/steel/defines.h
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#define STEEL_CONST static constant constexpr const
|
||||||
|
#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
|
415
mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h
Normal file
415
mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h
Normal file
@ -0,0 +1,415 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
using namespace mlx::steel;
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// GEMM kernels
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
constant bool has_batch [[function_constant(10)]];
|
||||||
|
|
||||||
|
constant bool use_out_source [[function_constant(100)]];
|
||||||
|
constant bool do_axpby [[function_constant(110)]];
|
||||||
|
|
||||||
|
constant bool align_M [[function_constant(200)]];
|
||||||
|
constant bool align_N [[function_constant(201)]];
|
||||||
|
constant bool align_K [[function_constant(202)]];
|
||||||
|
|
||||||
|
constant bool do_gather [[function_constant(300)]];
|
||||||
|
|
||||||
|
constant bool gather_bias = do_gather && use_out_source;
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
int BM,
|
||||||
|
int BN,
|
||||||
|
int BK,
|
||||||
|
int WM,
|
||||||
|
int WN,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
typename AccumType = float>
|
||||||
|
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm(
|
||||||
|
const device T* A [[buffer(0)]],
|
||||||
|
const device T* B [[buffer(1)]],
|
||||||
|
const device T* C [[buffer(2), function_constant(use_out_source)]],
|
||||||
|
device T* D [[buffer(3)]],
|
||||||
|
const constant GEMMParams* params [[buffer(4)]],
|
||||||
|
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
|
||||||
|
const constant int* batch_shape [[buffer(6)]],
|
||||||
|
const constant size_t* batch_strides [[buffer(7)]],
|
||||||
|
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
|
||||||
|
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
|
||||||
|
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
|
||||||
|
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
|
||||||
|
const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]],
|
||||||
|
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on
|
||||||
|
// Pacifying compiler
|
||||||
|
(void)lid;
|
||||||
|
|
||||||
|
using gemm_kernel = GEMMKernel<
|
||||||
|
T,
|
||||||
|
T,
|
||||||
|
BM,
|
||||||
|
BN,
|
||||||
|
BK,
|
||||||
|
WM,
|
||||||
|
WN,
|
||||||
|
transpose_a,
|
||||||
|
transpose_b,
|
||||||
|
true,
|
||||||
|
true,
|
||||||
|
AccumType>;
|
||||||
|
|
||||||
|
using loader_a_t = typename gemm_kernel::loader_a_t;
|
||||||
|
using loader_b_t = typename gemm_kernel::loader_b_t;
|
||||||
|
using mma_t = typename gemm_kernel::mma_t;
|
||||||
|
|
||||||
|
// Find block
|
||||||
|
const int tid_y = ((tid.y) << params->swizzle_log) +
|
||||||
|
((tid.x) & ((1 << params->swizzle_log) - 1));
|
||||||
|
const int tid_x = (tid.x) >> params->swizzle_log;
|
||||||
|
|
||||||
|
// Exit early if out of bounds
|
||||||
|
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Adjust for batch
|
||||||
|
|
||||||
|
// Handle gather
|
||||||
|
if (do_gather) {
|
||||||
|
// Read indices
|
||||||
|
uint32_t indx_A, indx_B, indx_C;
|
||||||
|
|
||||||
|
if (has_batch) {
|
||||||
|
const constant size_t* indx_A_bstrides = batch_strides;
|
||||||
|
const constant size_t* indx_B_bstrides =
|
||||||
|
batch_strides + params->batch_ndim;
|
||||||
|
|
||||||
|
ulong2 indx_offsets = elem_to_loc_broadcast(
|
||||||
|
tid.z,
|
||||||
|
batch_shape,
|
||||||
|
indx_A_bstrides,
|
||||||
|
indx_B_bstrides,
|
||||||
|
params->batch_ndim);
|
||||||
|
indx_A = lhs_indices[indx_offsets.x];
|
||||||
|
indx_B = rhs_indices[indx_offsets.y];
|
||||||
|
|
||||||
|
if (use_out_source) {
|
||||||
|
const constant size_t* indx_C_bstrides =
|
||||||
|
indx_B_bstrides + params->batch_ndim;
|
||||||
|
auto indx_offset_C = elem_to_loc(
|
||||||
|
tid.z, batch_shape, indx_C_bstrides, params->batch_ndim);
|
||||||
|
indx_C = C_indices[indx_offset_C];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
indx_A = lhs_indices[params->batch_stride_a * tid.z];
|
||||||
|
indx_B = rhs_indices[params->batch_stride_b * tid.z];
|
||||||
|
|
||||||
|
if (use_out_source) {
|
||||||
|
indx_C = C_indices[addmm_params->batch_stride_c * tid.z];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Translate indices to offsets
|
||||||
|
int batch_ndim_A = operand_batch_ndim.x;
|
||||||
|
const constant int* batch_shape_A = operand_shape;
|
||||||
|
const constant size_t* batch_strides_A = operand_strides;
|
||||||
|
A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A);
|
||||||
|
|
||||||
|
int batch_ndim_B = operand_batch_ndim.y;
|
||||||
|
const constant int* batch_shape_B = batch_shape_A + batch_ndim_A;
|
||||||
|
const constant size_t* batch_strides_B = batch_strides_A + batch_ndim_A;
|
||||||
|
B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B);
|
||||||
|
|
||||||
|
if (use_out_source) {
|
||||||
|
int batch_ndim_C = operand_batch_ndim.z;
|
||||||
|
const constant int* batch_shape_C = batch_shape_B + batch_ndim_B;
|
||||||
|
const constant size_t* batch_strides_C = batch_strides_B + batch_ndim_B;
|
||||||
|
C += elem_to_loc(indx_C, batch_shape_C, batch_strides_C, batch_ndim_C);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle regular batch
|
||||||
|
else {
|
||||||
|
if (has_batch) {
|
||||||
|
const constant size_t* A_bstrides = batch_strides;
|
||||||
|
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
|
||||||
|
|
||||||
|
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||||
|
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
|
||||||
|
|
||||||
|
A += batch_offsets.x;
|
||||||
|
B += batch_offsets.y;
|
||||||
|
|
||||||
|
if (use_out_source) {
|
||||||
|
const constant size_t* C_bstrides = B_bstrides + params->batch_ndim;
|
||||||
|
C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
A += params->batch_stride_a * tid.z;
|
||||||
|
B += params->batch_stride_b * tid.z;
|
||||||
|
|
||||||
|
if (use_out_source) {
|
||||||
|
C += addmm_params->batch_stride_c * tid.z;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
D += params->batch_stride_d * tid.z;
|
||||||
|
|
||||||
|
// Prepare threadgroup memory
|
||||||
|
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||||
|
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
// Find block in A, B, C
|
||||||
|
const int c_row = tid_y * BM;
|
||||||
|
const int c_col = tid_x * BN;
|
||||||
|
const size_t c_row_long = size_t(c_row);
|
||||||
|
const size_t c_col_long = size_t(c_col);
|
||||||
|
|
||||||
|
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
||||||
|
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
||||||
|
D += c_row_long * params->ldd + c_col_long;
|
||||||
|
|
||||||
|
if (use_out_source) {
|
||||||
|
C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare threadgroup mma operation
|
||||||
|
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||||
|
|
||||||
|
// Prepare threadgroup loading operations
|
||||||
|
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||||
|
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||||
|
|
||||||
|
// Prepare threadgroup bounds
|
||||||
|
const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
|
||||||
|
const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
|
||||||
|
|
||||||
|
// Prepare iterations
|
||||||
|
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||||
|
|
||||||
|
// Do unaligned K iterations first
|
||||||
|
if (!align_K) {
|
||||||
|
const int k_last = params->gemm_k_iterations_aligned * BK;
|
||||||
|
const int k_remain = params->K - k_last;
|
||||||
|
const size_t k_jump_a =
|
||||||
|
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
|
||||||
|
const size_t k_jump_b =
|
||||||
|
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
|
||||||
|
|
||||||
|
// Move loader source ahead to end
|
||||||
|
loader_a.src += k_jump_a;
|
||||||
|
loader_b.src += k_jump_b;
|
||||||
|
|
||||||
|
// Load tile
|
||||||
|
const short2 tile_dims_A =
|
||||||
|
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||||
|
const short2 tile_dims_B =
|
||||||
|
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
||||||
|
|
||||||
|
loader_a.load_safe(tile_dims_A);
|
||||||
|
loader_b.load_safe(tile_dims_B);
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Do matmul
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
|
// Reset source back to start
|
||||||
|
loader_a.src -= k_jump_a;
|
||||||
|
loader_b.src -= k_jump_b;
|
||||||
|
}
|
||||||
|
|
||||||
|
const TransformAdd<AccumType, AccumType> epilogue_op_add(
|
||||||
|
addmm_params->alpha, addmm_params->beta);
|
||||||
|
const TransformAxpby<AccumType, AccumType> epilogue_op_axpby(
|
||||||
|
addmm_params->alpha, addmm_params->beta);
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// MNK aligned loop
|
||||||
|
if (align_M && align_N) {
|
||||||
|
// Do gemm
|
||||||
|
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
// Load elements into threadgroup
|
||||||
|
loader_a.load_unsafe();
|
||||||
|
loader_b.load_unsafe();
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Multiply and accumulate threadgroup elements
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
|
// Prepare for next iteration
|
||||||
|
loader_a.next();
|
||||||
|
loader_b.next();
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
// Do epilogue
|
||||||
|
if (use_out_source) {
|
||||||
|
if (do_axpby) {
|
||||||
|
mma_op.apply_epilogue(
|
||||||
|
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby);
|
||||||
|
} else {
|
||||||
|
mma_op.apply_epilogue(
|
||||||
|
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store results to device memory
|
||||||
|
return mma_op.store_result(D, params->ldd);
|
||||||
|
|
||||||
|
}
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// MN unaligned loop
|
||||||
|
else { // Loop over K - unaligned case
|
||||||
|
const int leftover_bk = 0;
|
||||||
|
|
||||||
|
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
|
||||||
|
// Do gemm
|
||||||
|
gemm_kernel::gemm_loop(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iterations,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
leftover_bk,
|
||||||
|
LoopAlignment<true, true, true>{});
|
||||||
|
|
||||||
|
// Do epilogue
|
||||||
|
if (use_out_source) {
|
||||||
|
if (do_axpby) {
|
||||||
|
mma_op.apply_epilogue(
|
||||||
|
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby);
|
||||||
|
} else {
|
||||||
|
mma_op.apply_epilogue(
|
||||||
|
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store results to device memory
|
||||||
|
return mma_op.store_result(D, params->ldd);
|
||||||
|
|
||||||
|
} else if (align_N || tgp_bn == BN) {
|
||||||
|
gemm_kernel::gemm_loop(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iterations,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
leftover_bk,
|
||||||
|
LoopAlignment<false, true, true>{});
|
||||||
|
|
||||||
|
// Do epilogue
|
||||||
|
if (use_out_source) {
|
||||||
|
if (do_axpby) {
|
||||||
|
mma_op.apply_epilogue_safe(
|
||||||
|
C,
|
||||||
|
addmm_params->ldc,
|
||||||
|
addmm_params->fdc,
|
||||||
|
short2(tgp_bn, tgp_bm),
|
||||||
|
epilogue_op_axpby);
|
||||||
|
} else {
|
||||||
|
mma_op.apply_epilogue_safe(
|
||||||
|
C,
|
||||||
|
addmm_params->ldc,
|
||||||
|
addmm_params->fdc,
|
||||||
|
short2(tgp_bn, tgp_bm),
|
||||||
|
epilogue_op_add);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store results to device memory
|
||||||
|
return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
||||||
|
|
||||||
|
} else if (align_M || tgp_bm == BM) {
|
||||||
|
gemm_kernel::gemm_loop(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iterations,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
leftover_bk,
|
||||||
|
LoopAlignment<true, false, true>{});
|
||||||
|
|
||||||
|
// Do epilogue
|
||||||
|
if (use_out_source) {
|
||||||
|
if (do_axpby) {
|
||||||
|
mma_op.apply_epilogue_safe(
|
||||||
|
C,
|
||||||
|
addmm_params->ldc,
|
||||||
|
addmm_params->fdc,
|
||||||
|
short2(tgp_bn, tgp_bm),
|
||||||
|
epilogue_op_axpby);
|
||||||
|
} else {
|
||||||
|
mma_op.apply_epilogue_safe(
|
||||||
|
C,
|
||||||
|
addmm_params->ldc,
|
||||||
|
addmm_params->fdc,
|
||||||
|
short2(tgp_bn, tgp_bm),
|
||||||
|
epilogue_op_add);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store results to device memory
|
||||||
|
return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
||||||
|
|
||||||
|
} else {
|
||||||
|
gemm_kernel::gemm_loop(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iterations,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
leftover_bk,
|
||||||
|
LoopAlignment<false, false, true>{});
|
||||||
|
|
||||||
|
// Do epilogue
|
||||||
|
if (use_out_source) {
|
||||||
|
if (do_axpby) {
|
||||||
|
mma_op.apply_epilogue_safe(
|
||||||
|
C,
|
||||||
|
addmm_params->ldc,
|
||||||
|
addmm_params->fdc,
|
||||||
|
short2(tgp_bn, tgp_bm),
|
||||||
|
epilogue_op_axpby);
|
||||||
|
} else {
|
||||||
|
mma_op.apply_epilogue_safe(
|
||||||
|
C,
|
||||||
|
addmm_params->ldc,
|
||||||
|
addmm_params->fdc,
|
||||||
|
short2(tgp_bn, tgp_bm),
|
||||||
|
epilogue_op_add);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store results to device memory
|
||||||
|
return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -1,430 +1,12 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h"
|
||||||
|
|
||||||
using namespace metal;
|
|
||||||
using namespace mlx::steel;
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// GEMM kernels
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
constant bool has_batch [[function_constant(10)]];
|
|
||||||
|
|
||||||
constant bool use_out_source [[function_constant(100)]];
|
|
||||||
constant bool do_axpby [[function_constant(110)]];
|
|
||||||
|
|
||||||
constant bool align_M [[function_constant(200)]];
|
|
||||||
constant bool align_N [[function_constant(201)]];
|
|
||||||
constant bool align_K [[function_constant(202)]];
|
|
||||||
|
|
||||||
constant bool do_gather [[function_constant(300)]];
|
|
||||||
|
|
||||||
constant bool gather_bias = do_gather && use_out_source;
|
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
int BM,
|
|
||||||
int BN,
|
|
||||||
int BK,
|
|
||||||
int WM,
|
|
||||||
int WN,
|
|
||||||
bool transpose_a,
|
|
||||||
bool transpose_b,
|
|
||||||
typename AccumType = float>
|
|
||||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm(
|
|
||||||
const device T* A [[buffer(0)]],
|
|
||||||
const device T* B [[buffer(1)]],
|
|
||||||
const device T* C [[buffer(2), function_constant(use_out_source)]],
|
|
||||||
device T* D [[buffer(3)]],
|
|
||||||
const constant GEMMParams* params [[buffer(4)]],
|
|
||||||
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
|
|
||||||
const constant int* batch_shape [[buffer(6)]],
|
|
||||||
const constant size_t* batch_strides [[buffer(7)]],
|
|
||||||
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
|
|
||||||
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
|
|
||||||
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
|
|
||||||
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
|
|
||||||
const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]],
|
|
||||||
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
|
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on
|
|
||||||
// Pacifying compiler
|
|
||||||
(void)lid;
|
|
||||||
|
|
||||||
using gemm_kernel = GEMMKernel<
|
|
||||||
T,
|
|
||||||
T,
|
|
||||||
BM,
|
|
||||||
BN,
|
|
||||||
BK,
|
|
||||||
WM,
|
|
||||||
WN,
|
|
||||||
transpose_a,
|
|
||||||
transpose_b,
|
|
||||||
true,
|
|
||||||
true,
|
|
||||||
AccumType>;
|
|
||||||
|
|
||||||
using loader_a_t = typename gemm_kernel::loader_a_t;
|
|
||||||
using loader_b_t = typename gemm_kernel::loader_b_t;
|
|
||||||
using mma_t = typename gemm_kernel::mma_t;
|
|
||||||
|
|
||||||
// Find block
|
|
||||||
const int tid_y = ((tid.y) << params->swizzle_log) +
|
|
||||||
((tid.x) & ((1 << params->swizzle_log) - 1));
|
|
||||||
const int tid_x = (tid.x) >> params->swizzle_log;
|
|
||||||
|
|
||||||
// Exit early if out of bounds
|
|
||||||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Adjust for batch
|
|
||||||
|
|
||||||
// Handle gather
|
|
||||||
if (do_gather) {
|
|
||||||
// Read indices
|
|
||||||
uint32_t indx_A, indx_B, indx_C;
|
|
||||||
|
|
||||||
if (has_batch) {
|
|
||||||
const constant size_t* indx_A_bstrides = batch_strides;
|
|
||||||
const constant size_t* indx_B_bstrides =
|
|
||||||
batch_strides + params->batch_ndim;
|
|
||||||
|
|
||||||
ulong2 indx_offsets = elem_to_loc_broadcast(
|
|
||||||
tid.z,
|
|
||||||
batch_shape,
|
|
||||||
indx_A_bstrides,
|
|
||||||
indx_B_bstrides,
|
|
||||||
params->batch_ndim);
|
|
||||||
indx_A = lhs_indices[indx_offsets.x];
|
|
||||||
indx_B = rhs_indices[indx_offsets.y];
|
|
||||||
|
|
||||||
if (use_out_source) {
|
|
||||||
const constant size_t* indx_C_bstrides =
|
|
||||||
indx_B_bstrides + params->batch_ndim;
|
|
||||||
auto indx_offset_C = elem_to_loc(
|
|
||||||
tid.z, batch_shape, indx_C_bstrides, params->batch_ndim);
|
|
||||||
indx_C = C_indices[indx_offset_C];
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
indx_A = lhs_indices[params->batch_stride_a * tid.z];
|
|
||||||
indx_B = rhs_indices[params->batch_stride_b * tid.z];
|
|
||||||
|
|
||||||
if (use_out_source) {
|
|
||||||
indx_C = C_indices[addmm_params->batch_stride_c * tid.z];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Translate indices to offsets
|
|
||||||
int batch_ndim_A = operand_batch_ndim.x;
|
|
||||||
const constant int* batch_shape_A = operand_shape;
|
|
||||||
const constant size_t* batch_strides_A = operand_strides;
|
|
||||||
A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A);
|
|
||||||
|
|
||||||
int batch_ndim_B = operand_batch_ndim.y;
|
|
||||||
const constant int* batch_shape_B = batch_shape_A + batch_ndim_A;
|
|
||||||
const constant size_t* batch_strides_B = batch_strides_A + batch_ndim_A;
|
|
||||||
B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B);
|
|
||||||
|
|
||||||
if (use_out_source) {
|
|
||||||
int batch_ndim_C = operand_batch_ndim.z;
|
|
||||||
const constant int* batch_shape_C = batch_shape_B + batch_ndim_B;
|
|
||||||
const constant size_t* batch_strides_C = batch_strides_B + batch_ndim_B;
|
|
||||||
C += elem_to_loc(indx_C, batch_shape_C, batch_strides_C, batch_ndim_C);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle regular batch
|
|
||||||
else {
|
|
||||||
if (has_batch) {
|
|
||||||
const constant size_t* A_bstrides = batch_strides;
|
|
||||||
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
|
|
||||||
|
|
||||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
|
||||||
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
|
|
||||||
|
|
||||||
A += batch_offsets.x;
|
|
||||||
B += batch_offsets.y;
|
|
||||||
|
|
||||||
if (use_out_source) {
|
|
||||||
const constant size_t* C_bstrides = B_bstrides + params->batch_ndim;
|
|
||||||
C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
A += params->batch_stride_a * tid.z;
|
|
||||||
B += params->batch_stride_b * tid.z;
|
|
||||||
|
|
||||||
if (use_out_source) {
|
|
||||||
C += addmm_params->batch_stride_c * tid.z;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
D += params->batch_stride_d * tid.z;
|
|
||||||
|
|
||||||
// Prepare threadgroup memory
|
|
||||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
|
||||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
|
||||||
|
|
||||||
// Find block in A, B, C
|
|
||||||
const int c_row = tid_y * BM;
|
|
||||||
const int c_col = tid_x * BN;
|
|
||||||
const size_t c_row_long = size_t(c_row);
|
|
||||||
const size_t c_col_long = size_t(c_col);
|
|
||||||
|
|
||||||
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
|
||||||
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
|
||||||
D += c_row_long * params->ldd + c_col_long;
|
|
||||||
|
|
||||||
if (use_out_source) {
|
|
||||||
C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prepare threadgroup mma operation
|
|
||||||
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
|
||||||
|
|
||||||
// Prepare threadgroup loading operations
|
|
||||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
|
||||||
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
|
||||||
|
|
||||||
// Prepare threadgroup bounds
|
|
||||||
const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
|
|
||||||
const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
|
|
||||||
|
|
||||||
// Prepare iterations
|
|
||||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
|
||||||
|
|
||||||
// Do unaligned K iterations first
|
|
||||||
if (!align_K) {
|
|
||||||
const int k_last = params->gemm_k_iterations_aligned * BK;
|
|
||||||
const int k_remain = params->K - k_last;
|
|
||||||
const size_t k_jump_a =
|
|
||||||
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
|
|
||||||
const size_t k_jump_b =
|
|
||||||
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
|
|
||||||
|
|
||||||
// Move loader source ahead to end
|
|
||||||
loader_a.src += k_jump_a;
|
|
||||||
loader_b.src += k_jump_b;
|
|
||||||
|
|
||||||
// Load tile
|
|
||||||
const short2 tile_dims_A =
|
|
||||||
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
|
||||||
const short2 tile_dims_B =
|
|
||||||
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
|
||||||
|
|
||||||
loader_a.load_safe(tile_dims_A);
|
|
||||||
loader_b.load_safe(tile_dims_B);
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// Do matmul
|
|
||||||
mma_op.mma(As, Bs);
|
|
||||||
|
|
||||||
// Reset source back to start
|
|
||||||
loader_a.src -= k_jump_a;
|
|
||||||
loader_b.src -= k_jump_b;
|
|
||||||
}
|
|
||||||
|
|
||||||
const TransformAdd<AccumType, AccumType> epilogue_op_add(
|
|
||||||
addmm_params->alpha, addmm_params->beta);
|
|
||||||
const TransformAxpby<AccumType, AccumType> epilogue_op_axpby(
|
|
||||||
addmm_params->alpha, addmm_params->beta);
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// MNK aligned loop
|
|
||||||
if (align_M && align_N) {
|
|
||||||
// Do gemm
|
|
||||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
// Load elements into threadgroup
|
|
||||||
loader_a.load_unsafe();
|
|
||||||
loader_b.load_unsafe();
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// Multiply and accumulate threadgroup elements
|
|
||||||
mma_op.mma(As, Bs);
|
|
||||||
|
|
||||||
// Prepare for next iteration
|
|
||||||
loader_a.next();
|
|
||||||
loader_b.next();
|
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
|
||||||
|
|
||||||
// Do epilogue
|
|
||||||
if (use_out_source) {
|
|
||||||
if (do_axpby) {
|
|
||||||
mma_op.apply_epilogue(
|
|
||||||
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby);
|
|
||||||
} else {
|
|
||||||
mma_op.apply_epilogue(
|
|
||||||
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store results to device memory
|
|
||||||
return mma_op.store_result(D, params->ldd);
|
|
||||||
|
|
||||||
}
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// MN unaligned loop
|
|
||||||
else { // Loop over K - unaligned case
|
|
||||||
const int leftover_bk = 0;
|
|
||||||
|
|
||||||
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
|
|
||||||
// Do gemm
|
|
||||||
gemm_kernel::gemm_loop(
|
|
||||||
As,
|
|
||||||
Bs,
|
|
||||||
gemm_k_iterations,
|
|
||||||
loader_a,
|
|
||||||
loader_b,
|
|
||||||
mma_op,
|
|
||||||
tgp_bm,
|
|
||||||
tgp_bn,
|
|
||||||
leftover_bk,
|
|
||||||
LoopAlignment<true, true, true>{});
|
|
||||||
|
|
||||||
// Do epilogue
|
|
||||||
if (use_out_source) {
|
|
||||||
if (do_axpby) {
|
|
||||||
mma_op.apply_epilogue(
|
|
||||||
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby);
|
|
||||||
} else {
|
|
||||||
mma_op.apply_epilogue(
|
|
||||||
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store results to device memory
|
|
||||||
return mma_op.store_result(D, params->ldd);
|
|
||||||
|
|
||||||
} else if (align_N || tgp_bn == BN) {
|
|
||||||
gemm_kernel::gemm_loop(
|
|
||||||
As,
|
|
||||||
Bs,
|
|
||||||
gemm_k_iterations,
|
|
||||||
loader_a,
|
|
||||||
loader_b,
|
|
||||||
mma_op,
|
|
||||||
tgp_bm,
|
|
||||||
tgp_bn,
|
|
||||||
leftover_bk,
|
|
||||||
LoopAlignment<false, true, true>{});
|
|
||||||
|
|
||||||
// Do epilogue
|
|
||||||
if (use_out_source) {
|
|
||||||
if (do_axpby) {
|
|
||||||
mma_op.apply_epilogue_safe(
|
|
||||||
C,
|
|
||||||
addmm_params->ldc,
|
|
||||||
addmm_params->fdc,
|
|
||||||
short2(tgp_bn, tgp_bm),
|
|
||||||
epilogue_op_axpby);
|
|
||||||
} else {
|
|
||||||
mma_op.apply_epilogue_safe(
|
|
||||||
C,
|
|
||||||
addmm_params->ldc,
|
|
||||||
addmm_params->fdc,
|
|
||||||
short2(tgp_bn, tgp_bm),
|
|
||||||
epilogue_op_add);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store results to device memory
|
|
||||||
return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
|
||||||
|
|
||||||
} else if (align_M || tgp_bm == BM) {
|
|
||||||
gemm_kernel::gemm_loop(
|
|
||||||
As,
|
|
||||||
Bs,
|
|
||||||
gemm_k_iterations,
|
|
||||||
loader_a,
|
|
||||||
loader_b,
|
|
||||||
mma_op,
|
|
||||||
tgp_bm,
|
|
||||||
tgp_bn,
|
|
||||||
leftover_bk,
|
|
||||||
LoopAlignment<true, false, true>{});
|
|
||||||
|
|
||||||
// Do epilogue
|
|
||||||
if (use_out_source) {
|
|
||||||
if (do_axpby) {
|
|
||||||
mma_op.apply_epilogue_safe(
|
|
||||||
C,
|
|
||||||
addmm_params->ldc,
|
|
||||||
addmm_params->fdc,
|
|
||||||
short2(tgp_bn, tgp_bm),
|
|
||||||
epilogue_op_axpby);
|
|
||||||
} else {
|
|
||||||
mma_op.apply_epilogue_safe(
|
|
||||||
C,
|
|
||||||
addmm_params->ldc,
|
|
||||||
addmm_params->fdc,
|
|
||||||
short2(tgp_bn, tgp_bm),
|
|
||||||
epilogue_op_add);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store results to device memory
|
|
||||||
return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
|
||||||
|
|
||||||
} else {
|
|
||||||
gemm_kernel::gemm_loop(
|
|
||||||
As,
|
|
||||||
Bs,
|
|
||||||
gemm_k_iterations,
|
|
||||||
loader_a,
|
|
||||||
loader_b,
|
|
||||||
mma_op,
|
|
||||||
tgp_bm,
|
|
||||||
tgp_bn,
|
|
||||||
leftover_bk,
|
|
||||||
LoopAlignment<false, false, true>{});
|
|
||||||
|
|
||||||
// Do epilogue
|
|
||||||
if (use_out_source) {
|
|
||||||
if (do_axpby) {
|
|
||||||
mma_op.apply_epilogue_safe(
|
|
||||||
C,
|
|
||||||
addmm_params->ldc,
|
|
||||||
addmm_params->fdc,
|
|
||||||
short2(tgp_bn, tgp_bm),
|
|
||||||
epilogue_op_axpby);
|
|
||||||
} else {
|
|
||||||
mma_op.apply_epilogue_safe(
|
|
||||||
C,
|
|
||||||
addmm_params->ldc,
|
|
||||||
addmm_params->fdc,
|
|
||||||
short2(tgp_bn, tgp_bm),
|
|
||||||
epilogue_op_add);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store results to device memory
|
|
||||||
return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// GEMM kernel initializations
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
template [[host_name("steel_gemm_fused_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] \
|
template [[host_name("steel_gemm_fused_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] \
|
||||||
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, float>( \
|
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, float>( \
|
||||||
@ -445,24 +27,23 @@ template <
|
|||||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||||
uint3 tid [[threadgroup_position_in_grid]], \
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
uint3 lid [[thread_position_in_threadgroup]]); // clang-format on
|
uint3 lid [[thread_position_in_threadgroup]]);
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
instantiate_gemm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
instantiate_gemm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
instantiate_gemm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
|
instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) // clang-format on
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2)
|
||||||
|
|
||||||
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
||||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
||||||
|
|
||||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
||||||
|
// clang-format on
|
||||||
|
719
mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h
Normal file
719
mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h
Normal file
@ -0,0 +1,719 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/steel/defines.h"
|
||||||
|
using namespace metal;
|
||||||
|
using namespace mlx::steel;
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// GEMM kernels
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
struct _NoMask {
|
||||||
|
char x;
|
||||||
|
|
||||||
|
constexpr METAL_FUNC operator bool() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
constexpr METAL_FUNC operator bool() const threadgroup {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
constexpr METAL_FUNC operator bool() const device {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
constexpr METAL_FUNC operator bool() const constant {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename OutT, typename InT = OutT>
|
||||||
|
struct ScaleOp {
|
||||||
|
OutT scale;
|
||||||
|
|
||||||
|
METAL_FUNC OutT apply(InT x) const {
|
||||||
|
return static_cast<OutT>(x) * scale;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef struct _NoMask nomask_t;
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename out_mask_t,
|
||||||
|
typename op_mask_t,
|
||||||
|
int BM,
|
||||||
|
int BN,
|
||||||
|
int BK,
|
||||||
|
int WM,
|
||||||
|
int WN,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
bool MN_aligned,
|
||||||
|
bool K_aligned>
|
||||||
|
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
||||||
|
block_masked_gemm(
|
||||||
|
const device T* A [[buffer(0)]],
|
||||||
|
const device T* B [[buffer(1)]],
|
||||||
|
device T* D [[buffer(3)]],
|
||||||
|
const constant GEMMParams* params [[buffer(4)]],
|
||||||
|
const constant int* batch_shape [[buffer(6)]],
|
||||||
|
const constant size_t* batch_strides [[buffer(7)]],
|
||||||
|
const device out_mask_t* out_mask [[buffer(10)]],
|
||||||
|
const device op_mask_t* lhs_mask [[buffer(11)]],
|
||||||
|
const device op_mask_t* rhs_mask [[buffer(12)]],
|
||||||
|
const constant int* mask_strides [[buffer(13)]],
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
// Appease the compiler
|
||||||
|
(void)lid;
|
||||||
|
|
||||||
|
static_assert(
|
||||||
|
BM == BN,
|
||||||
|
"block_masked_gemm must have the same block M and block N size");
|
||||||
|
static_assert(BM % BK == 0, "block_masked_gemm must have BM % BK == 0");
|
||||||
|
|
||||||
|
constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
|
||||||
|
constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
|
||||||
|
|
||||||
|
constexpr bool has_mul_operand_mask =
|
||||||
|
has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
|
||||||
|
constexpr bool has_mul_output_mask =
|
||||||
|
has_output_mask && !metal::is_same_v<out_mask_t, bool>;
|
||||||
|
|
||||||
|
constexpr short k_mask_factor = short(BM / BK);
|
||||||
|
|
||||||
|
using gemm_kernel = GEMMKernel<
|
||||||
|
T,
|
||||||
|
T,
|
||||||
|
BM,
|
||||||
|
BN,
|
||||||
|
BK,
|
||||||
|
WM,
|
||||||
|
WN,
|
||||||
|
transpose_a,
|
||||||
|
transpose_b,
|
||||||
|
MN_aligned,
|
||||||
|
K_aligned>;
|
||||||
|
|
||||||
|
const int tid_y = ((tid.y) << params->swizzle_log) +
|
||||||
|
((tid.x) & ((1 << params->swizzle_log) - 1));
|
||||||
|
const int tid_x = (tid.x) >> params->swizzle_log;
|
||||||
|
|
||||||
|
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const constant size_t* mask_batch_strides =
|
||||||
|
batch_strides + 2 * params->batch_ndim;
|
||||||
|
|
||||||
|
if (params->batch_ndim > 1) {
|
||||||
|
if (has_output_mask) {
|
||||||
|
out_mask += elem_to_loc(
|
||||||
|
tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
|
||||||
|
|
||||||
|
mask_batch_strides += params->batch_ndim;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (has_operand_mask) {
|
||||||
|
const constant size_t* mask_strides_lhs = mask_batch_strides;
|
||||||
|
const constant size_t* mask_strides_rhs =
|
||||||
|
mask_strides_lhs + params->batch_ndim;
|
||||||
|
|
||||||
|
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||||
|
tid.z,
|
||||||
|
batch_shape,
|
||||||
|
mask_strides_lhs,
|
||||||
|
mask_strides_rhs,
|
||||||
|
params->batch_ndim);
|
||||||
|
|
||||||
|
lhs_mask += batch_offsets.x;
|
||||||
|
rhs_mask += batch_offsets.y;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (has_output_mask) {
|
||||||
|
out_mask += tid.z * mask_batch_strides[0];
|
||||||
|
mask_batch_strides += params->batch_ndim;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (has_operand_mask) {
|
||||||
|
lhs_mask += tid.z * mask_batch_strides[0];
|
||||||
|
rhs_mask += tid.z * mask_batch_strides[params->batch_ndim];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Adjust for batch
|
||||||
|
if (params->batch_ndim > 1) {
|
||||||
|
const constant size_t* A_bstrides = batch_strides;
|
||||||
|
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
|
||||||
|
|
||||||
|
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||||
|
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
|
||||||
|
|
||||||
|
A += batch_offsets.x;
|
||||||
|
B += batch_offsets.y;
|
||||||
|
|
||||||
|
} else {
|
||||||
|
A += params->batch_stride_a * tid.z;
|
||||||
|
B += params->batch_stride_b * tid.z;
|
||||||
|
}
|
||||||
|
|
||||||
|
D += params->batch_stride_d * tid.z;
|
||||||
|
|
||||||
|
// Find block in A, B, C
|
||||||
|
const int c_row = tid_y * BM;
|
||||||
|
const int c_col = tid_x * BN;
|
||||||
|
const size_t c_row_long = size_t(c_row);
|
||||||
|
const size_t c_col_long = size_t(c_col);
|
||||||
|
|
||||||
|
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
||||||
|
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
||||||
|
D += c_row_long * params->ldd + c_col_long;
|
||||||
|
|
||||||
|
const constant int* out_mask_strides = mask_strides;
|
||||||
|
const constant int* lhs_mask_strides =
|
||||||
|
mask_strides + (has_output_mask ? 2 : 0);
|
||||||
|
const constant int* rhs_mask_strides =
|
||||||
|
lhs_mask_strides + (has_operand_mask ? 2 : 0);
|
||||||
|
|
||||||
|
const int out_mask_offset = !has_output_mask
|
||||||
|
? 0
|
||||||
|
: tid_y * out_mask_strides[1] + tid_x * out_mask_strides[0];
|
||||||
|
int lhs_mask_offset = !has_operand_mask ? 0 : tid_y * lhs_mask_strides[1];
|
||||||
|
int rhs_mask_offset = !has_operand_mask ? 0 : tid_x * rhs_mask_strides[0];
|
||||||
|
const int lhs_mask_step = !has_operand_mask ? 0 : lhs_mask_strides[0];
|
||||||
|
const int rhs_mask_step = !has_operand_mask ? 0 : rhs_mask_strides[1];
|
||||||
|
short k_factor_cnt = k_mask_factor;
|
||||||
|
|
||||||
|
ScaleOp<float> out_mask_op;
|
||||||
|
ScaleOp<T> lhs_mask_op;
|
||||||
|
ScaleOp<T> rhs_mask_op;
|
||||||
|
|
||||||
|
if (has_output_mask) {
|
||||||
|
auto mask_out = out_mask[out_mask_offset];
|
||||||
|
|
||||||
|
if (has_mul_output_mask) {
|
||||||
|
out_mask_op.scale = float(mask_out);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write zeros and return
|
||||||
|
if (!mask_out) {
|
||||||
|
constexpr short tgp_size = WM * WN * 32;
|
||||||
|
constexpr short vec_size = 4;
|
||||||
|
|
||||||
|
// Tile threads in threadgroup
|
||||||
|
constexpr short TN = BN / vec_size;
|
||||||
|
constexpr short TM = tgp_size / TN;
|
||||||
|
|
||||||
|
const short thread_idx = simd_group_id * 32 + simd_lane_id;
|
||||||
|
const short bi = thread_idx / TN;
|
||||||
|
const short bj = vec_size * (thread_idx % TN);
|
||||||
|
|
||||||
|
D += bi * params->ldd + bj;
|
||||||
|
|
||||||
|
short tgp_bm = min(BM, params->M - c_row);
|
||||||
|
short tgp_bn = min(BN, params->N - c_col);
|
||||||
|
|
||||||
|
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
||||||
|
for (short ti = 0; ti < BM; ti += TM) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < vec_size; j++) {
|
||||||
|
D[ti * params->ldd + j] = T(0.);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
short jmax = tgp_bn - bj;
|
||||||
|
jmax = jmax < vec_size ? jmax : vec_size;
|
||||||
|
for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) {
|
||||||
|
for (short j = 0; j < jmax; j++) {
|
||||||
|
D[ti * params->ldd + j] = T(0.);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
// Prepare threadgroup mma operation
|
||||||
|
thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id);
|
||||||
|
|
||||||
|
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||||
|
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||||
|
|
||||||
|
// Prepare threadgroup loading operations
|
||||||
|
thread typename gemm_kernel::loader_a_t loader_a(
|
||||||
|
A, params->lda, As, simd_group_id, simd_lane_id);
|
||||||
|
thread typename gemm_kernel::loader_b_t loader_b(
|
||||||
|
B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||||
|
|
||||||
|
// Prepare threadgroup bounds
|
||||||
|
const short tgp_bm =
|
||||||
|
MN_aligned ? short(BM) : short(min(BM, params->M - c_row));
|
||||||
|
const short tgp_bn =
|
||||||
|
MN_aligned ? short(BN) : short(min(BN, params->N - c_col));
|
||||||
|
|
||||||
|
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Do unaligned K iterations first
|
||||||
|
if (!K_aligned) {
|
||||||
|
const int k_last = params->gemm_k_iterations_aligned * BK;
|
||||||
|
const int mask_idx_last = k_last / BM;
|
||||||
|
|
||||||
|
if (!has_operand_mask ||
|
||||||
|
(bool(lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step]) &&
|
||||||
|
bool(rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step]))) {
|
||||||
|
if (has_mul_operand_mask) {
|
||||||
|
lhs_mask_op.scale =
|
||||||
|
lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step];
|
||||||
|
rhs_mask_op.scale =
|
||||||
|
rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Move loader source ahead to end
|
||||||
|
const int k_remain = params->K - k_last;
|
||||||
|
const size_t k_jump_a =
|
||||||
|
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
|
||||||
|
const size_t k_jump_b =
|
||||||
|
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
|
||||||
|
|
||||||
|
loader_a.src += k_jump_a;
|
||||||
|
loader_b.src += k_jump_b;
|
||||||
|
|
||||||
|
// Load tile
|
||||||
|
const short2 tile_dims_A =
|
||||||
|
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
||||||
|
const short2 tile_dims_B =
|
||||||
|
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
||||||
|
|
||||||
|
loader_a.load_safe(tile_dims_A);
|
||||||
|
loader_b.load_safe(tile_dims_B);
|
||||||
|
|
||||||
|
if (has_mul_operand_mask) {
|
||||||
|
loader_a.apply_inplace_op(lhs_mask_op);
|
||||||
|
loader_b.apply_inplace_op(rhs_mask_op);
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Do matmul
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
|
// Reset source back to start
|
||||||
|
loader_a.src -= k_jump_a;
|
||||||
|
loader_b.src -= k_jump_b;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// MNK aligned loop
|
||||||
|
if (MN_aligned) {
|
||||||
|
for (; gemm_k_iterations > 0; gemm_k_iterations--) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
if (!has_operand_mask ||
|
||||||
|
(bool(lhs_mask[lhs_mask_offset]) &&
|
||||||
|
bool(rhs_mask[rhs_mask_offset]))) {
|
||||||
|
if (has_mul_operand_mask) {
|
||||||
|
lhs_mask_op.scale = lhs_mask[lhs_mask_offset];
|
||||||
|
rhs_mask_op.scale = rhs_mask[rhs_mask_offset];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load elements into threadgroup
|
||||||
|
loader_a.load_unsafe();
|
||||||
|
loader_b.load_unsafe();
|
||||||
|
|
||||||
|
if (has_mul_operand_mask) {
|
||||||
|
loader_a.apply_inplace_op(lhs_mask_op);
|
||||||
|
loader_b.apply_inplace_op(rhs_mask_op);
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Multiply and accumulate threadgroup elements
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare for next iteration
|
||||||
|
loader_a.next();
|
||||||
|
loader_b.next();
|
||||||
|
|
||||||
|
k_factor_cnt--;
|
||||||
|
lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0;
|
||||||
|
rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0;
|
||||||
|
k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (has_mul_output_mask) {
|
||||||
|
mma_op.apply_epilogue(out_mask_op);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store results to device memory
|
||||||
|
mma_op.store_result(D, params->ldd);
|
||||||
|
return;
|
||||||
|
|
||||||
|
}
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// MN unaligned loop
|
||||||
|
else {
|
||||||
|
const bool M_aligned = (tgp_bm == BM);
|
||||||
|
const bool N_aligned = (tgp_bn == BN);
|
||||||
|
|
||||||
|
const short2 tile_dims_A =
|
||||||
|
transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
|
||||||
|
const short2 tile_dims_B =
|
||||||
|
transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
|
||||||
|
|
||||||
|
for (; gemm_k_iterations > 0; gemm_k_iterations--) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
if (!has_operand_mask ||
|
||||||
|
(bool(lhs_mask[lhs_mask_offset]) &&
|
||||||
|
bool(rhs_mask[rhs_mask_offset]))) {
|
||||||
|
if (has_mul_operand_mask) {
|
||||||
|
lhs_mask_op.scale = lhs_mask[lhs_mask_offset];
|
||||||
|
rhs_mask_op.scale = rhs_mask[rhs_mask_offset];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load elements into threadgroup
|
||||||
|
if (M_aligned) {
|
||||||
|
loader_a.load_unsafe();
|
||||||
|
} else {
|
||||||
|
loader_a.load_safe(tile_dims_A);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (N_aligned) {
|
||||||
|
loader_b.load_unsafe();
|
||||||
|
} else {
|
||||||
|
loader_b.load_safe(tile_dims_B);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (has_mul_operand_mask) {
|
||||||
|
loader_a.apply_inplace_op(lhs_mask_op);
|
||||||
|
loader_b.apply_inplace_op(rhs_mask_op);
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Multiply and accumulate threadgroup elements
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare for next iteration
|
||||||
|
loader_a.next();
|
||||||
|
loader_b.next();
|
||||||
|
|
||||||
|
k_factor_cnt--;
|
||||||
|
lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0;
|
||||||
|
rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0;
|
||||||
|
k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (has_mul_output_mask) {
|
||||||
|
mma_op.apply_epilogue(out_mask_op);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (M_aligned && N_aligned) {
|
||||||
|
mma_op.store_result(D, params->ldd);
|
||||||
|
} else {
|
||||||
|
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
int BM,
|
||||||
|
int BN,
|
||||||
|
int BK,
|
||||||
|
int WM,
|
||||||
|
int WN,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
bool MN_aligned,
|
||||||
|
bool K_aligned,
|
||||||
|
bool has_operand_mask = false>
|
||||||
|
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
||||||
|
block_masked_gemm(
|
||||||
|
const device T* A [[buffer(0)]],
|
||||||
|
const device T* B [[buffer(1)]],
|
||||||
|
device T* D [[buffer(3)]],
|
||||||
|
const constant GEMMParams* params [[buffer(4)]],
|
||||||
|
const constant int* batch_shape [[buffer(6)]],
|
||||||
|
const constant size_t* batch_strides [[buffer(7)]],
|
||||||
|
const device bool* out_mask [[buffer(10)]],
|
||||||
|
const device bool* lhs_mask [[buffer(11)]],
|
||||||
|
const device bool* rhs_mask [[buffer(12)]],
|
||||||
|
const constant int* mask_strides [[buffer(13)]],
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
// Appease the compiler
|
||||||
|
(void)lid;
|
||||||
|
|
||||||
|
using gemm_kernel = GEMMKernel<
|
||||||
|
T,
|
||||||
|
T,
|
||||||
|
BM,
|
||||||
|
BN,
|
||||||
|
BK,
|
||||||
|
WM,
|
||||||
|
WN,
|
||||||
|
transpose_a,
|
||||||
|
transpose_b,
|
||||||
|
MN_aligned,
|
||||||
|
K_aligned>;
|
||||||
|
|
||||||
|
const int tid_y = ((tid.y) << params->swizzle_log) +
|
||||||
|
((tid.x) & ((1 << params->swizzle_log) - 1));
|
||||||
|
const int tid_x = (tid.x) >> params->swizzle_log;
|
||||||
|
|
||||||
|
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params->batch_ndim > 1) {
|
||||||
|
const constant size_t* mask_batch_strides =
|
||||||
|
batch_strides + 2 * params->batch_ndim;
|
||||||
|
out_mask +=
|
||||||
|
elem_to_loc(tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
|
||||||
|
|
||||||
|
if (has_operand_mask) {
|
||||||
|
const constant size_t* mask_strides_lhs =
|
||||||
|
mask_batch_strides + params->batch_ndim;
|
||||||
|
const constant size_t* mask_strides_rhs =
|
||||||
|
mask_strides_lhs + params->batch_ndim;
|
||||||
|
|
||||||
|
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||||
|
tid.z,
|
||||||
|
batch_shape,
|
||||||
|
mask_strides_lhs,
|
||||||
|
mask_strides_rhs,
|
||||||
|
params->batch_ndim);
|
||||||
|
|
||||||
|
lhs_mask += batch_offsets.x;
|
||||||
|
rhs_mask += batch_offsets.y;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
out_mask += tid.z * batch_strides[2 * params->batch_ndim];
|
||||||
|
if (has_operand_mask) {
|
||||||
|
lhs_mask += tid.z * batch_strides[3 * params->batch_ndim];
|
||||||
|
rhs_mask += tid.z * batch_strides[4 * params->batch_ndim];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Adjust for batch
|
||||||
|
if (params->batch_ndim > 1) {
|
||||||
|
const constant size_t* A_bstrides = batch_strides;
|
||||||
|
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
|
||||||
|
|
||||||
|
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||||
|
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
|
||||||
|
|
||||||
|
A += batch_offsets.x;
|
||||||
|
B += batch_offsets.y;
|
||||||
|
|
||||||
|
} else {
|
||||||
|
A += params->batch_stride_a * tid.z;
|
||||||
|
B += params->batch_stride_b * tid.z;
|
||||||
|
}
|
||||||
|
|
||||||
|
D += params->batch_stride_d * tid.z;
|
||||||
|
|
||||||
|
// Find block in A, B, C
|
||||||
|
const int c_row = tid_y * BM;
|
||||||
|
const int c_col = tid_x * BN;
|
||||||
|
const size_t c_row_long = size_t(c_row);
|
||||||
|
const size_t c_col_long = size_t(c_col);
|
||||||
|
|
||||||
|
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
||||||
|
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
||||||
|
D += c_row_long * params->ldd + c_col_long;
|
||||||
|
|
||||||
|
bool mask_out = out_mask[tid_y * mask_strides[1] + tid_x * mask_strides[0]];
|
||||||
|
|
||||||
|
// Write zeros and return
|
||||||
|
if (!mask_out) {
|
||||||
|
constexpr short tgp_size = WM * WN * 32;
|
||||||
|
constexpr short vec_size = 4;
|
||||||
|
|
||||||
|
// Tile threads in threadgroup
|
||||||
|
constexpr short TN = BN / vec_size;
|
||||||
|
constexpr short TM = tgp_size / TN;
|
||||||
|
|
||||||
|
const short thread_idx = simd_group_id * 32 + simd_lane_id;
|
||||||
|
const short bi = thread_idx / TN;
|
||||||
|
const short bj = vec_size * (thread_idx % TN);
|
||||||
|
|
||||||
|
D += bi * params->ldd + bj;
|
||||||
|
|
||||||
|
short tgp_bm = min(BM, params->M - c_row);
|
||||||
|
short tgp_bn = min(BN, params->N - c_col);
|
||||||
|
|
||||||
|
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
||||||
|
for (short ti = 0; ti < BM; ti += TM) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < vec_size; j++) {
|
||||||
|
D[ti * params->ldd + j] = T(0.);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
short jmax = tgp_bn - bj;
|
||||||
|
jmax = jmax < vec_size ? jmax : vec_size;
|
||||||
|
for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) {
|
||||||
|
for (short j = 0; j < jmax; j++) {
|
||||||
|
D[ti * params->ldd + j] = T(0.);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
// Prepare threadgroup mma operation
|
||||||
|
thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id);
|
||||||
|
|
||||||
|
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||||
|
|
||||||
|
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||||
|
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||||
|
|
||||||
|
// Prepare threadgroup loading operations
|
||||||
|
thread typename gemm_kernel::loader_a_t loader_a(
|
||||||
|
A, params->lda, As, simd_group_id, simd_lane_id);
|
||||||
|
thread typename gemm_kernel::loader_b_t loader_b(
|
||||||
|
B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// MNK aligned loop
|
||||||
|
if (MN_aligned) {
|
||||||
|
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
if (!has_operand_mask ||
|
||||||
|
(lhs_mask
|
||||||
|
[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
|
||||||
|
rhs_mask
|
||||||
|
[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
|
||||||
|
// Load elements into threadgroup
|
||||||
|
loader_a.load_unsafe();
|
||||||
|
loader_b.load_unsafe();
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Multiply and accumulate threadgroup elements
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare for next iteration
|
||||||
|
loader_a.next();
|
||||||
|
loader_b.next();
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
// Loop tail
|
||||||
|
if (!K_aligned) {
|
||||||
|
if (!has_operand_mask ||
|
||||||
|
(lhs_mask
|
||||||
|
[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
|
||||||
|
rhs_mask
|
||||||
|
[(params->K / BM) * mask_strides[5] +
|
||||||
|
tid_x * mask_strides[4]])) {
|
||||||
|
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||||
|
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
||||||
|
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
||||||
|
|
||||||
|
loader_a.load_safe(tile_dims_A);
|
||||||
|
loader_b.load_safe(tile_dims_B);
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store results to device memory
|
||||||
|
mma_op.store_result(D, params->ldd);
|
||||||
|
return;
|
||||||
|
|
||||||
|
}
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// MN unaligned loop
|
||||||
|
else { // Loop over K - unaligned case
|
||||||
|
short tgp_bm = min(BM, params->M - c_row);
|
||||||
|
short tgp_bn = min(BN, params->N - c_col);
|
||||||
|
short lbk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||||
|
|
||||||
|
bool M_aligned = (tgp_bm == BM);
|
||||||
|
bool N_aligned = (tgp_bn == BN);
|
||||||
|
|
||||||
|
short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
|
||||||
|
short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
|
||||||
|
|
||||||
|
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
if (!has_operand_mask ||
|
||||||
|
(lhs_mask
|
||||||
|
[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
|
||||||
|
rhs_mask
|
||||||
|
[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
|
||||||
|
// Load elements into threadgroup
|
||||||
|
if (M_aligned) {
|
||||||
|
loader_a.load_unsafe();
|
||||||
|
} else {
|
||||||
|
loader_a.load_safe(tile_dims_A);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (N_aligned) {
|
||||||
|
loader_b.load_unsafe();
|
||||||
|
} else {
|
||||||
|
loader_b.load_safe(tile_dims_B);
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Multiply and accumulate threadgroup elements
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare for next iteration
|
||||||
|
loader_a.next();
|
||||||
|
loader_b.next();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!K_aligned) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
if (!has_operand_mask ||
|
||||||
|
(lhs_mask
|
||||||
|
[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
|
||||||
|
rhs_mask
|
||||||
|
[(params->K / BM) * mask_strides[5] +
|
||||||
|
tid_x * mask_strides[4]])) {
|
||||||
|
short2 tile_dims_A_last =
|
||||||
|
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
|
||||||
|
short2 tile_dims_B_last =
|
||||||
|
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
|
||||||
|
|
||||||
|
loader_a.load_safe(tile_dims_A_last);
|
||||||
|
loader_b.load_safe(tile_dims_B_last);
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (M_aligned && N_aligned) {
|
||||||
|
mma_op.store_result(D, params->ldd);
|
||||||
|
} else {
|
||||||
|
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -1,434 +1,10 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h"
|
||||||
using namespace metal;
|
|
||||||
using namespace mlx::steel;
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// GEMM kernels
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
struct _NoMask {
|
|
||||||
char x;
|
|
||||||
|
|
||||||
constexpr METAL_FUNC operator bool() {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
constexpr METAL_FUNC operator bool() const threadgroup {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
constexpr METAL_FUNC operator bool() const device {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
constexpr METAL_FUNC operator bool() const constant {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename OutT, typename InT = OutT>
|
|
||||||
struct ScaleOp {
|
|
||||||
OutT scale;
|
|
||||||
|
|
||||||
METAL_FUNC OutT apply(InT x) const {
|
|
||||||
return static_cast<OutT>(x) * scale;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
typedef struct _NoMask nomask_t;
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename out_mask_t,
|
|
||||||
typename op_mask_t,
|
|
||||||
int BM,
|
|
||||||
int BN,
|
|
||||||
int BK,
|
|
||||||
int WM,
|
|
||||||
int WN,
|
|
||||||
bool transpose_a,
|
|
||||||
bool transpose_b,
|
|
||||||
bool MN_aligned,
|
|
||||||
bool K_aligned>
|
|
||||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
|
||||||
block_masked_gemm(
|
|
||||||
const device T* A [[buffer(0)]],
|
|
||||||
const device T* B [[buffer(1)]],
|
|
||||||
device T* D [[buffer(3)]],
|
|
||||||
const constant GEMMParams* params [[buffer(4)]],
|
|
||||||
const constant int* batch_shape [[buffer(6)]],
|
|
||||||
const constant size_t* batch_strides [[buffer(7)]],
|
|
||||||
const device out_mask_t* out_mask [[buffer(10)]],
|
|
||||||
const device op_mask_t* lhs_mask [[buffer(11)]],
|
|
||||||
const device op_mask_t* rhs_mask [[buffer(12)]],
|
|
||||||
const constant int* mask_strides [[buffer(13)]],
|
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
|
||||||
// Appease the compiler
|
|
||||||
(void)lid;
|
|
||||||
|
|
||||||
static_assert(
|
|
||||||
BM == BN,
|
|
||||||
"block_masked_gemm must have the same block M and block N size");
|
|
||||||
static_assert(BM % BK == 0, "block_masked_gemm must have BM % BK == 0");
|
|
||||||
|
|
||||||
constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
|
|
||||||
constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
|
|
||||||
|
|
||||||
constexpr bool has_mul_operand_mask =
|
|
||||||
has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
|
|
||||||
constexpr bool has_mul_output_mask =
|
|
||||||
has_output_mask && !metal::is_same_v<out_mask_t, bool>;
|
|
||||||
|
|
||||||
constexpr short k_mask_factor = short(BM / BK);
|
|
||||||
|
|
||||||
using gemm_kernel = GEMMKernel<
|
|
||||||
T,
|
|
||||||
T,
|
|
||||||
BM,
|
|
||||||
BN,
|
|
||||||
BK,
|
|
||||||
WM,
|
|
||||||
WN,
|
|
||||||
transpose_a,
|
|
||||||
transpose_b,
|
|
||||||
MN_aligned,
|
|
||||||
K_aligned>;
|
|
||||||
|
|
||||||
const int tid_y = ((tid.y) << params->swizzle_log) +
|
|
||||||
((tid.x) & ((1 << params->swizzle_log) - 1));
|
|
||||||
const int tid_x = (tid.x) >> params->swizzle_log;
|
|
||||||
|
|
||||||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const constant size_t* mask_batch_strides =
|
|
||||||
batch_strides + 2 * params->batch_ndim;
|
|
||||||
|
|
||||||
if (params->batch_ndim > 1) {
|
|
||||||
if (has_output_mask) {
|
|
||||||
out_mask += elem_to_loc(
|
|
||||||
tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
|
|
||||||
|
|
||||||
mask_batch_strides += params->batch_ndim;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (has_operand_mask) {
|
|
||||||
const constant size_t* mask_strides_lhs = mask_batch_strides;
|
|
||||||
const constant size_t* mask_strides_rhs =
|
|
||||||
mask_strides_lhs + params->batch_ndim;
|
|
||||||
|
|
||||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
|
||||||
tid.z,
|
|
||||||
batch_shape,
|
|
||||||
mask_strides_lhs,
|
|
||||||
mask_strides_rhs,
|
|
||||||
params->batch_ndim);
|
|
||||||
|
|
||||||
lhs_mask += batch_offsets.x;
|
|
||||||
rhs_mask += batch_offsets.y;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (has_output_mask) {
|
|
||||||
out_mask += tid.z * mask_batch_strides[0];
|
|
||||||
mask_batch_strides += params->batch_ndim;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (has_operand_mask) {
|
|
||||||
lhs_mask += tid.z * mask_batch_strides[0];
|
|
||||||
rhs_mask += tid.z * mask_batch_strides[params->batch_ndim];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Adjust for batch
|
|
||||||
if (params->batch_ndim > 1) {
|
|
||||||
const constant size_t* A_bstrides = batch_strides;
|
|
||||||
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
|
|
||||||
|
|
||||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
|
||||||
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
|
|
||||||
|
|
||||||
A += batch_offsets.x;
|
|
||||||
B += batch_offsets.y;
|
|
||||||
|
|
||||||
} else {
|
|
||||||
A += params->batch_stride_a * tid.z;
|
|
||||||
B += params->batch_stride_b * tid.z;
|
|
||||||
}
|
|
||||||
|
|
||||||
D += params->batch_stride_d * tid.z;
|
|
||||||
|
|
||||||
// Find block in A, B, C
|
|
||||||
const int c_row = tid_y * BM;
|
|
||||||
const int c_col = tid_x * BN;
|
|
||||||
const size_t c_row_long = size_t(c_row);
|
|
||||||
const size_t c_col_long = size_t(c_col);
|
|
||||||
|
|
||||||
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
|
||||||
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
|
||||||
D += c_row_long * params->ldd + c_col_long;
|
|
||||||
|
|
||||||
const constant int* out_mask_strides = mask_strides;
|
|
||||||
const constant int* lhs_mask_strides =
|
|
||||||
mask_strides + (has_output_mask ? 2 : 0);
|
|
||||||
const constant int* rhs_mask_strides =
|
|
||||||
lhs_mask_strides + (has_operand_mask ? 2 : 0);
|
|
||||||
|
|
||||||
const int out_mask_offset = !has_output_mask
|
|
||||||
? 0
|
|
||||||
: tid_y * out_mask_strides[1] + tid_x * out_mask_strides[0];
|
|
||||||
int lhs_mask_offset = !has_operand_mask ? 0 : tid_y * lhs_mask_strides[1];
|
|
||||||
int rhs_mask_offset = !has_operand_mask ? 0 : tid_x * rhs_mask_strides[0];
|
|
||||||
const int lhs_mask_step = !has_operand_mask ? 0 : lhs_mask_strides[0];
|
|
||||||
const int rhs_mask_step = !has_operand_mask ? 0 : rhs_mask_strides[1];
|
|
||||||
short k_factor_cnt = k_mask_factor;
|
|
||||||
|
|
||||||
ScaleOp<float> out_mask_op;
|
|
||||||
ScaleOp<T> lhs_mask_op;
|
|
||||||
ScaleOp<T> rhs_mask_op;
|
|
||||||
|
|
||||||
if (has_output_mask) {
|
|
||||||
auto mask_out = out_mask[out_mask_offset];
|
|
||||||
|
|
||||||
if (has_mul_output_mask) {
|
|
||||||
out_mask_op.scale = float(mask_out);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write zeros and return
|
|
||||||
if (!mask_out) {
|
|
||||||
constexpr short tgp_size = WM * WN * 32;
|
|
||||||
constexpr short vec_size = 4;
|
|
||||||
|
|
||||||
// Tile threads in threadgroup
|
|
||||||
constexpr short TN = BN / vec_size;
|
|
||||||
constexpr short TM = tgp_size / TN;
|
|
||||||
|
|
||||||
const short thread_idx = simd_group_id * 32 + simd_lane_id;
|
|
||||||
const short bi = thread_idx / TN;
|
|
||||||
const short bj = vec_size * (thread_idx % TN);
|
|
||||||
|
|
||||||
D += bi * params->ldd + bj;
|
|
||||||
|
|
||||||
short tgp_bm = min(BM, params->M - c_row);
|
|
||||||
short tgp_bn = min(BN, params->N - c_col);
|
|
||||||
|
|
||||||
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
|
||||||
for (short ti = 0; ti < BM; ti += TM) {
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short j = 0; j < vec_size; j++) {
|
|
||||||
D[ti * params->ldd + j] = T(0.);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
short jmax = tgp_bn - bj;
|
|
||||||
jmax = jmax < vec_size ? jmax : vec_size;
|
|
||||||
for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) {
|
|
||||||
for (short j = 0; j < jmax; j++) {
|
|
||||||
D[ti * params->ldd + j] = T(0.);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
|
||||||
|
|
||||||
// Prepare threadgroup mma operation
|
|
||||||
thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id);
|
|
||||||
|
|
||||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
|
||||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
|
||||||
|
|
||||||
// Prepare threadgroup loading operations
|
|
||||||
thread typename gemm_kernel::loader_a_t loader_a(
|
|
||||||
A, params->lda, As, simd_group_id, simd_lane_id);
|
|
||||||
thread typename gemm_kernel::loader_b_t loader_b(
|
|
||||||
B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
|
||||||
|
|
||||||
// Prepare threadgroup bounds
|
|
||||||
const short tgp_bm =
|
|
||||||
MN_aligned ? short(BM) : short(min(BM, params->M - c_row));
|
|
||||||
const short tgp_bn =
|
|
||||||
MN_aligned ? short(BN) : short(min(BN, params->N - c_col));
|
|
||||||
|
|
||||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Do unaligned K iterations first
|
|
||||||
if (!K_aligned) {
|
|
||||||
const int k_last = params->gemm_k_iterations_aligned * BK;
|
|
||||||
const int mask_idx_last = k_last / BM;
|
|
||||||
|
|
||||||
if (!has_operand_mask ||
|
|
||||||
(bool(lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step]) &&
|
|
||||||
bool(rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step]))) {
|
|
||||||
if (has_mul_operand_mask) {
|
|
||||||
lhs_mask_op.scale =
|
|
||||||
lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step];
|
|
||||||
rhs_mask_op.scale =
|
|
||||||
rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Move loader source ahead to end
|
|
||||||
const int k_remain = params->K - k_last;
|
|
||||||
const size_t k_jump_a =
|
|
||||||
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
|
|
||||||
const size_t k_jump_b =
|
|
||||||
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
|
|
||||||
|
|
||||||
loader_a.src += k_jump_a;
|
|
||||||
loader_b.src += k_jump_b;
|
|
||||||
|
|
||||||
// Load tile
|
|
||||||
const short2 tile_dims_A =
|
|
||||||
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
|
||||||
const short2 tile_dims_B =
|
|
||||||
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
|
||||||
|
|
||||||
loader_a.load_safe(tile_dims_A);
|
|
||||||
loader_b.load_safe(tile_dims_B);
|
|
||||||
|
|
||||||
if (has_mul_operand_mask) {
|
|
||||||
loader_a.apply_inplace_op(lhs_mask_op);
|
|
||||||
loader_b.apply_inplace_op(rhs_mask_op);
|
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// Do matmul
|
|
||||||
mma_op.mma(As, Bs);
|
|
||||||
|
|
||||||
// Reset source back to start
|
|
||||||
loader_a.src -= k_jump_a;
|
|
||||||
loader_b.src -= k_jump_b;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// MNK aligned loop
|
|
||||||
if (MN_aligned) {
|
|
||||||
for (; gemm_k_iterations > 0; gemm_k_iterations--) {
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
if (!has_operand_mask ||
|
|
||||||
(bool(lhs_mask[lhs_mask_offset]) &&
|
|
||||||
bool(rhs_mask[rhs_mask_offset]))) {
|
|
||||||
if (has_mul_operand_mask) {
|
|
||||||
lhs_mask_op.scale = lhs_mask[lhs_mask_offset];
|
|
||||||
rhs_mask_op.scale = rhs_mask[rhs_mask_offset];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load elements into threadgroup
|
|
||||||
loader_a.load_unsafe();
|
|
||||||
loader_b.load_unsafe();
|
|
||||||
|
|
||||||
if (has_mul_operand_mask) {
|
|
||||||
loader_a.apply_inplace_op(lhs_mask_op);
|
|
||||||
loader_b.apply_inplace_op(rhs_mask_op);
|
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// Multiply and accumulate threadgroup elements
|
|
||||||
mma_op.mma(As, Bs);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prepare for next iteration
|
|
||||||
loader_a.next();
|
|
||||||
loader_b.next();
|
|
||||||
|
|
||||||
k_factor_cnt--;
|
|
||||||
lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0;
|
|
||||||
rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0;
|
|
||||||
k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (has_mul_output_mask) {
|
|
||||||
mma_op.apply_epilogue(out_mask_op);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store results to device memory
|
|
||||||
mma_op.store_result(D, params->ldd);
|
|
||||||
return;
|
|
||||||
|
|
||||||
}
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// MN unaligned loop
|
|
||||||
else {
|
|
||||||
const bool M_aligned = (tgp_bm == BM);
|
|
||||||
const bool N_aligned = (tgp_bn == BN);
|
|
||||||
|
|
||||||
const short2 tile_dims_A =
|
|
||||||
transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
|
|
||||||
const short2 tile_dims_B =
|
|
||||||
transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
|
|
||||||
|
|
||||||
for (; gemm_k_iterations > 0; gemm_k_iterations--) {
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
if (!has_operand_mask ||
|
|
||||||
(bool(lhs_mask[lhs_mask_offset]) &&
|
|
||||||
bool(rhs_mask[rhs_mask_offset]))) {
|
|
||||||
if (has_mul_operand_mask) {
|
|
||||||
lhs_mask_op.scale = lhs_mask[lhs_mask_offset];
|
|
||||||
rhs_mask_op.scale = rhs_mask[rhs_mask_offset];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load elements into threadgroup
|
|
||||||
if (M_aligned) {
|
|
||||||
loader_a.load_unsafe();
|
|
||||||
} else {
|
|
||||||
loader_a.load_safe(tile_dims_A);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (N_aligned) {
|
|
||||||
loader_b.load_unsafe();
|
|
||||||
} else {
|
|
||||||
loader_b.load_safe(tile_dims_B);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (has_mul_operand_mask) {
|
|
||||||
loader_a.apply_inplace_op(lhs_mask_op);
|
|
||||||
loader_b.apply_inplace_op(rhs_mask_op);
|
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// Multiply and accumulate threadgroup elements
|
|
||||||
mma_op.mma(As, Bs);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prepare for next iteration
|
|
||||||
loader_a.next();
|
|
||||||
loader_b.next();
|
|
||||||
|
|
||||||
k_factor_cnt--;
|
|
||||||
lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0;
|
|
||||||
rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0;
|
|
||||||
k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (has_mul_output_mask) {
|
|
||||||
mma_op.apply_epilogue(out_mask_op);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (M_aligned && N_aligned) {
|
|
||||||
mma_op.store_result(D, params->ldd);
|
|
||||||
} else {
|
|
||||||
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// GEMM kernel initializations
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
#define instantiate_gemm( \
|
#define instantiate_gemm( \
|
||||||
outmaskname, \
|
outmaskname, \
|
||||||
@ -483,7 +59,6 @@ block_masked_gemm(
|
|||||||
uint3 tid [[threadgroup_position_in_grid]], \
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
uint3 lid [[thread_position_in_threadgroup]]);
|
uint3 lid [[thread_position_in_threadgroup]]);
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
#define instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||||
instantiate_gemm(bool_, bool, bool_, bool, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
instantiate_gemm(bool_, bool, bool_, bool, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||||
instantiate_gemm(iname, itype, iname, itype, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
instantiate_gemm(iname, itype, iname, itype, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||||
@ -492,28 +67,24 @@ block_masked_gemm(
|
|||||||
instantiate_gemm(nomask, nomask_t, bool_, bool, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
instantiate_gemm(nomask, nomask_t, bool_, bool, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||||
instantiate_gemm(nomask, nomask_t, iname, itype, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
instantiate_gemm(nomask, nomask_t, iname, itype, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||||
instantiate_gemm(bool_, bool, nomask, nomask_t, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
instantiate_gemm(bool_, bool, nomask, nomask_t, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||||
instantiate_gemm(iname, itype, nomask, nomask_t, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) // clang-format on
|
instantiate_gemm(iname, itype, nomask, nomask_t, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned)
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||||
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
||||||
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
||||||
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on
|
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
|
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) // clang-format on
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2)
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
||||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
||||||
instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on
|
instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on
|
||||||
|
227
mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h
Normal file
227
mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h
Normal file
@ -0,0 +1,227 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
using namespace mlx::steel;
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// GEMM kernels
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename U,
|
||||||
|
int BM,
|
||||||
|
int BN,
|
||||||
|
int BK,
|
||||||
|
int WM,
|
||||||
|
int WN,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
bool MN_aligned,
|
||||||
|
bool K_aligned>
|
||||||
|
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm_splitk(
|
||||||
|
const device T* A [[buffer(0)]],
|
||||||
|
const device T* B [[buffer(1)]],
|
||||||
|
device U* C [[buffer(2)]],
|
||||||
|
const constant GEMMSpiltKParams* params [[buffer(3)]],
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
(void)lid;
|
||||||
|
|
||||||
|
using gemm_kernel = GEMMKernel<
|
||||||
|
T,
|
||||||
|
U,
|
||||||
|
BM,
|
||||||
|
BN,
|
||||||
|
BK,
|
||||||
|
WM,
|
||||||
|
WN,
|
||||||
|
transpose_a,
|
||||||
|
transpose_b,
|
||||||
|
MN_aligned,
|
||||||
|
K_aligned>;
|
||||||
|
using loader_a_t = typename gemm_kernel::loader_a_t;
|
||||||
|
using loader_b_t = typename gemm_kernel::loader_b_t;
|
||||||
|
using mma_t = typename gemm_kernel::mma_t;
|
||||||
|
|
||||||
|
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||||
|
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||||
|
|
||||||
|
const int tid_x = tid.x;
|
||||||
|
const int tid_y = tid.y;
|
||||||
|
const int tid_z = tid.z;
|
||||||
|
|
||||||
|
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find block in A, B, C
|
||||||
|
const int c_row = tid_y * BM;
|
||||||
|
const int c_col = tid_x * BN;
|
||||||
|
const int k_start = params->split_k_partition_size * tid_z;
|
||||||
|
|
||||||
|
const size_t c_row_long = size_t(c_row);
|
||||||
|
const size_t c_col_long = size_t(c_col);
|
||||||
|
const size_t k_start_long = size_t(k_start);
|
||||||
|
|
||||||
|
A += transpose_a ? (c_row_long + k_start_long * params->lda)
|
||||||
|
: (k_start_long + c_row_long * params->lda);
|
||||||
|
B += transpose_b ? (k_start_long + c_col_long * params->ldb)
|
||||||
|
: (c_col_long + k_start_long * params->ldb);
|
||||||
|
C += (size_t(params->split_k_partition_stride) * tid_z) +
|
||||||
|
(c_row_long * params->ldc + c_col_long);
|
||||||
|
|
||||||
|
// Prepare threadgroup loading operations
|
||||||
|
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||||
|
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||||
|
|
||||||
|
// Prepare threadgroup mma operation
|
||||||
|
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||||
|
|
||||||
|
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||||
|
|
||||||
|
short tgp_bm = min(BM, params->M - c_row);
|
||||||
|
short tgp_bn = min(BN, params->N - c_col);
|
||||||
|
short leftover_bk = params->K % BK;
|
||||||
|
|
||||||
|
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
||||||
|
gemm_kernel::gemm_loop(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iterations,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
leftover_bk,
|
||||||
|
LoopAlignment<true, true, true>{});
|
||||||
|
} else if (tgp_bn == BN) {
|
||||||
|
gemm_kernel::gemm_loop(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iterations,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
leftover_bk,
|
||||||
|
LoopAlignment<false, true, true>{});
|
||||||
|
} else if (tgp_bm == BM) {
|
||||||
|
gemm_kernel::gemm_loop(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iterations,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
leftover_bk,
|
||||||
|
LoopAlignment<true, false, true>{});
|
||||||
|
} else {
|
||||||
|
gemm_kernel::gemm_loop(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iterations,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
leftover_bk,
|
||||||
|
LoopAlignment<false, false, true>{});
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
if ((tid_z + 1) == (params->split_k_partitions)) {
|
||||||
|
int gemm_k_iter_remaining =
|
||||||
|
(params->K - (k_start + params->split_k_partition_size)) / BK;
|
||||||
|
if (!K_aligned || gemm_k_iter_remaining > 0)
|
||||||
|
gemm_kernel::gemm_loop(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iter_remaining,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
leftover_bk,
|
||||||
|
LoopAlignment<false, false, K_aligned>{});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
||||||
|
mma_op.store_result(C, params->ldc);
|
||||||
|
} else {
|
||||||
|
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Split k accumulation kernel
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename AccT,
|
||||||
|
typename OutT,
|
||||||
|
typename Epilogue = TransformNone<OutT, AccT>>
|
||||||
|
[[kernel]] void gemm_splitk_accum(
|
||||||
|
const device AccT* C_split [[buffer(0)]],
|
||||||
|
device OutT* D [[buffer(1)]],
|
||||||
|
const constant int& k_partitions [[buffer(2)]],
|
||||||
|
const constant int& partition_stride [[buffer(3)]],
|
||||||
|
const constant int& ldd [[buffer(4)]],
|
||||||
|
uint2 gid [[thread_position_in_grid]]) {
|
||||||
|
// Ajust D and C
|
||||||
|
D += gid.x + gid.y * size_t(ldd);
|
||||||
|
C_split += gid.x + gid.y * size_t(ldd);
|
||||||
|
|
||||||
|
size_t offset = 0;
|
||||||
|
AccT out = 0;
|
||||||
|
|
||||||
|
for (int i = 0; i < k_partitions; i++) {
|
||||||
|
out += C_split[offset];
|
||||||
|
offset += partition_stride;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write output
|
||||||
|
D[0] = Epilogue::apply(out);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename AccT,
|
||||||
|
typename OutT,
|
||||||
|
typename Epilogue = TransformAxpby<OutT, AccT>>
|
||||||
|
[[kernel]] void gemm_splitk_accum_axpby(
|
||||||
|
const device AccT* C_split [[buffer(0)]],
|
||||||
|
device OutT* D [[buffer(1)]],
|
||||||
|
const constant int& k_partitions [[buffer(2)]],
|
||||||
|
const constant int& partition_stride [[buffer(3)]],
|
||||||
|
const constant int& ldd [[buffer(4)]],
|
||||||
|
const device OutT* C [[buffer(5)]],
|
||||||
|
const constant int& ldc [[buffer(6)]],
|
||||||
|
const constant int& fdc [[buffer(7)]],
|
||||||
|
const constant float& alpha [[buffer(8)]],
|
||||||
|
const constant float& beta [[buffer(9)]],
|
||||||
|
uint2 gid [[thread_position_in_grid]]) {
|
||||||
|
// Ajust D and C
|
||||||
|
C += gid.x * size_t(fdc) + gid.y * size_t(ldc);
|
||||||
|
D += gid.x + gid.y * size_t(ldd);
|
||||||
|
C_split += gid.x + gid.y * size_t(ldd);
|
||||||
|
|
||||||
|
size_t offset = 0;
|
||||||
|
AccT out = 0;
|
||||||
|
|
||||||
|
for (int i = 0; i < k_partitions; i++) {
|
||||||
|
out += C_split[offset];
|
||||||
|
offset += partition_stride;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write output
|
||||||
|
Epilogue op(alpha, beta);
|
||||||
|
D[0] = op.apply(out, *C);
|
||||||
|
}
|
@ -1,173 +1,9 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h"
|
||||||
using namespace metal;
|
|
||||||
using namespace mlx::steel;
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// GEMM kernels
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename U,
|
|
||||||
int BM,
|
|
||||||
int BN,
|
|
||||||
int BK,
|
|
||||||
int WM,
|
|
||||||
int WN,
|
|
||||||
bool transpose_a,
|
|
||||||
bool transpose_b,
|
|
||||||
bool MN_aligned,
|
|
||||||
bool K_aligned>
|
|
||||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm_splitk(
|
|
||||||
const device T* A [[buffer(0)]],
|
|
||||||
const device T* B [[buffer(1)]],
|
|
||||||
device U* C [[buffer(2)]],
|
|
||||||
const constant GEMMSpiltKParams* params [[buffer(3)]],
|
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
|
||||||
(void)lid;
|
|
||||||
|
|
||||||
using gemm_kernel = GEMMKernel<
|
|
||||||
T,
|
|
||||||
U,
|
|
||||||
BM,
|
|
||||||
BN,
|
|
||||||
BK,
|
|
||||||
WM,
|
|
||||||
WN,
|
|
||||||
transpose_a,
|
|
||||||
transpose_b,
|
|
||||||
MN_aligned,
|
|
||||||
K_aligned>;
|
|
||||||
using loader_a_t = typename gemm_kernel::loader_a_t;
|
|
||||||
using loader_b_t = typename gemm_kernel::loader_b_t;
|
|
||||||
using mma_t = typename gemm_kernel::mma_t;
|
|
||||||
|
|
||||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
|
||||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
|
||||||
|
|
||||||
const int tid_x = tid.x;
|
|
||||||
const int tid_y = tid.y;
|
|
||||||
const int tid_z = tid.z;
|
|
||||||
|
|
||||||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find block in A, B, C
|
|
||||||
const int c_row = tid_y * BM;
|
|
||||||
const int c_col = tid_x * BN;
|
|
||||||
const int k_start = params->split_k_partition_size * tid_z;
|
|
||||||
|
|
||||||
const size_t c_row_long = size_t(c_row);
|
|
||||||
const size_t c_col_long = size_t(c_col);
|
|
||||||
const size_t k_start_long = size_t(k_start);
|
|
||||||
|
|
||||||
A += transpose_a ? (c_row_long + k_start_long * params->lda)
|
|
||||||
: (k_start_long + c_row_long * params->lda);
|
|
||||||
B += transpose_b ? (k_start_long + c_col_long * params->ldb)
|
|
||||||
: (c_col_long + k_start_long * params->ldb);
|
|
||||||
C += (size_t(params->split_k_partition_stride) * tid_z) +
|
|
||||||
(c_row_long * params->ldc + c_col_long);
|
|
||||||
|
|
||||||
// Prepare threadgroup loading operations
|
|
||||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
|
||||||
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
|
||||||
|
|
||||||
// Prepare threadgroup mma operation
|
|
||||||
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
|
||||||
|
|
||||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
|
||||||
|
|
||||||
short tgp_bm = min(BM, params->M - c_row);
|
|
||||||
short tgp_bn = min(BN, params->N - c_col);
|
|
||||||
short leftover_bk = params->K % BK;
|
|
||||||
|
|
||||||
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
|
||||||
gemm_kernel::gemm_loop(
|
|
||||||
As,
|
|
||||||
Bs,
|
|
||||||
gemm_k_iterations,
|
|
||||||
loader_a,
|
|
||||||
loader_b,
|
|
||||||
mma_op,
|
|
||||||
tgp_bm,
|
|
||||||
tgp_bn,
|
|
||||||
leftover_bk,
|
|
||||||
LoopAlignment<true, true, true>{});
|
|
||||||
} else if (tgp_bn == BN) {
|
|
||||||
gemm_kernel::gemm_loop(
|
|
||||||
As,
|
|
||||||
Bs,
|
|
||||||
gemm_k_iterations,
|
|
||||||
loader_a,
|
|
||||||
loader_b,
|
|
||||||
mma_op,
|
|
||||||
tgp_bm,
|
|
||||||
tgp_bn,
|
|
||||||
leftover_bk,
|
|
||||||
LoopAlignment<false, true, true>{});
|
|
||||||
} else if (tgp_bm == BM) {
|
|
||||||
gemm_kernel::gemm_loop(
|
|
||||||
As,
|
|
||||||
Bs,
|
|
||||||
gemm_k_iterations,
|
|
||||||
loader_a,
|
|
||||||
loader_b,
|
|
||||||
mma_op,
|
|
||||||
tgp_bm,
|
|
||||||
tgp_bn,
|
|
||||||
leftover_bk,
|
|
||||||
LoopAlignment<true, false, true>{});
|
|
||||||
} else {
|
|
||||||
gemm_kernel::gemm_loop(
|
|
||||||
As,
|
|
||||||
Bs,
|
|
||||||
gemm_k_iterations,
|
|
||||||
loader_a,
|
|
||||||
loader_b,
|
|
||||||
mma_op,
|
|
||||||
tgp_bm,
|
|
||||||
tgp_bn,
|
|
||||||
leftover_bk,
|
|
||||||
LoopAlignment<false, false, true>{});
|
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
if ((tid_z + 1) == (params->split_k_partitions)) {
|
|
||||||
int gemm_k_iter_remaining =
|
|
||||||
(params->K - (k_start + params->split_k_partition_size)) / BK;
|
|
||||||
if (!K_aligned || gemm_k_iter_remaining > 0)
|
|
||||||
gemm_kernel::gemm_loop(
|
|
||||||
As,
|
|
||||||
Bs,
|
|
||||||
gemm_k_iter_remaining,
|
|
||||||
loader_a,
|
|
||||||
loader_b,
|
|
||||||
mma_op,
|
|
||||||
tgp_bm,
|
|
||||||
tgp_bn,
|
|
||||||
leftover_bk,
|
|
||||||
LoopAlignment<false, false, K_aligned>{});
|
|
||||||
}
|
|
||||||
|
|
||||||
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
|
||||||
mma_op.store_result(C, params->ldc);
|
|
||||||
} else {
|
|
||||||
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// GEMM kernel initializations
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
#define instantiate_gemm( \
|
#define instantiate_gemm( \
|
||||||
tname, \
|
tname, \
|
||||||
@ -210,97 +46,27 @@ template <
|
|||||||
uint3 tid [[threadgroup_position_in_grid]], \
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
uint3 lid [[thread_position_in_threadgroup]]);
|
uint3 lid [[thread_position_in_threadgroup]]);
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
|
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 16, 16, 2, 2) \
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 16, 16, 2, 2) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 32, 16, 2, 2) \
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 32, 16, 2, 2) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 16, 16, 2, 2) \
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 16, 16, 2, 2) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) // clang-format on
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
instantiate_gemm_shapes_helper(float16, half, float32, float);
|
instantiate_gemm_shapes_helper(float16, half, float32, float);
|
||||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float);
|
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float);
|
||||||
|
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
||||||
instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Split k accumulation kernel
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename AccT,
|
|
||||||
typename OutT,
|
|
||||||
typename Epilogue = TransformNone<OutT, AccT>>
|
|
||||||
[[kernel]] void gemm_splitk_accum(
|
|
||||||
const device AccT* C_split [[buffer(0)]],
|
|
||||||
device OutT* D [[buffer(1)]],
|
|
||||||
const constant int& k_partitions [[buffer(2)]],
|
|
||||||
const constant int& partition_stride [[buffer(3)]],
|
|
||||||
const constant int& ldd [[buffer(4)]],
|
|
||||||
uint2 gid [[thread_position_in_grid]]) {
|
|
||||||
// Ajust D and C
|
|
||||||
D += gid.x + gid.y * size_t(ldd);
|
|
||||||
C_split += gid.x + gid.y * size_t(ldd);
|
|
||||||
|
|
||||||
size_t offset = 0;
|
|
||||||
AccT out = 0;
|
|
||||||
|
|
||||||
for (int i = 0; i < k_partitions; i++) {
|
|
||||||
out += C_split[offset];
|
|
||||||
offset += partition_stride;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write output
|
|
||||||
D[0] = Epilogue::apply(out);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename AccT,
|
|
||||||
typename OutT,
|
|
||||||
typename Epilogue = TransformAxpby<OutT, AccT>>
|
|
||||||
[[kernel]] void gemm_splitk_accum_axpby(
|
|
||||||
const device AccT* C_split [[buffer(0)]],
|
|
||||||
device OutT* D [[buffer(1)]],
|
|
||||||
const constant int& k_partitions [[buffer(2)]],
|
|
||||||
const constant int& partition_stride [[buffer(3)]],
|
|
||||||
const constant int& ldd [[buffer(4)]],
|
|
||||||
const device OutT* C [[buffer(5)]],
|
|
||||||
const constant int& ldc [[buffer(6)]],
|
|
||||||
const constant int& fdc [[buffer(7)]],
|
|
||||||
const constant float& alpha [[buffer(8)]],
|
|
||||||
const constant float& beta [[buffer(9)]],
|
|
||||||
uint2 gid [[thread_position_in_grid]]) {
|
|
||||||
// Ajust D and C
|
|
||||||
C += gid.x * size_t(fdc) + gid.y * size_t(ldc);
|
|
||||||
D += gid.x + gid.y * size_t(ldd);
|
|
||||||
C_split += gid.x + gid.y * size_t(ldd);
|
|
||||||
|
|
||||||
size_t offset = 0;
|
|
||||||
AccT out = 0;
|
|
||||||
|
|
||||||
for (int i = 0; i < k_partitions; i++) {
|
|
||||||
out += C_split[offset];
|
|
||||||
offset += partition_stride;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write output
|
|
||||||
Epilogue op(alpha, beta);
|
|
||||||
D[0] = op.apply(out, *C);
|
|
||||||
}
|
|
||||||
|
|
||||||
#define instantiate_accum(oname, otype, aname, atype) \
|
#define instantiate_accum(oname, otype, aname, atype) \
|
||||||
template [[host_name("steel_gemm_splitk_accum_" #oname \
|
template [[host_name("steel_gemm_splitk_accum_" #oname \
|
||||||
@ -313,7 +79,7 @@ template <
|
|||||||
const constant int& ldd [[buffer(4)]], \
|
const constant int& ldd [[buffer(4)]], \
|
||||||
uint2 gid [[thread_position_in_grid]]); \
|
uint2 gid [[thread_position_in_grid]]); \
|
||||||
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname \
|
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname \
|
||||||
"_axpby")]] [[kernel]] void \
|
"_axbpy")]] [[kernel]] void \
|
||||||
gemm_splitk_accum_axpby<atype, otype>( \
|
gemm_splitk_accum_axpby<atype, otype>( \
|
||||||
const device atype* C_split [[buffer(0)]], \
|
const device atype* C_split [[buffer(0)]], \
|
||||||
device otype* D [[buffer(1)]], \
|
device otype* D [[buffer(1)]], \
|
||||||
@ -327,7 +93,6 @@ template <
|
|||||||
const constant float& beta [[buffer(9)]], \
|
const constant float& beta [[buffer(9)]], \
|
||||||
uint2 gid [[thread_position_in_grid]]);
|
uint2 gid [[thread_position_in_grid]]);
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
instantiate_accum(bfloat16, bfloat16_t, float32, float);
|
instantiate_accum(bfloat16, bfloat16_t, float32, float);
|
||||||
instantiate_accum(float16, half, float32, float);
|
instantiate_accum(float16, half, float32, float);
|
||||||
instantiate_accum(float32, float, float32, float); // clang-format on
|
instantiate_accum(float32, float, float32, float); // clang-format on
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
#include "mlx/backend/metal/kernels/steel/defines.h"
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Loading helper
|
// Loading helper
|
||||||
@ -134,4 +134,4 @@ struct BlockLoader {
|
|||||||
};
|
};
|
||||||
|
|
||||||
} // namespace steel
|
} // namespace steel
|
||||||
} // namespace mlx
|
} // namespace mlx
|
||||||
|
@ -6,8 +6,8 @@
|
|||||||
#include <metal_simdgroup_matrix>
|
#include <metal_simdgroup_matrix>
|
||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/steel/defines.h"
|
||||||
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
|
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
|
||||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
|
||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
@ -358,4 +358,4 @@ struct BlockMMA {
|
|||||||
};
|
};
|
||||||
|
|
||||||
} // namespace steel
|
} // namespace steel
|
||||||
} // namespace mlx
|
} // namespace mlx
|
||||||
|
@ -4,9 +4,6 @@
|
|||||||
|
|
||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
|
|
||||||
#define STEEL_CONST static constant constexpr const
|
|
||||||
#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
|
|
||||||
|
|
||||||
METAL_FUNC ulong2 elem_to_loc_broadcast(
|
METAL_FUNC ulong2 elem_to_loc_broadcast(
|
||||||
uint elem,
|
uint elem,
|
||||||
constant const int* shape,
|
constant const int* shape,
|
||||||
@ -42,4 +39,4 @@ METAL_FUNC ulong3 elem_to_loc_broadcast(
|
|||||||
loc_c += pos_in_dim * c_strides[i];
|
loc_c += pos_in_dim * c_strides[i];
|
||||||
}
|
}
|
||||||
return ulong3(loc_a, loc_b, loc_c);
|
return ulong3(loc_a, loc_b, loc_c);
|
||||||
}
|
}
|
||||||
|
@ -8,9 +8,10 @@
|
|||||||
OUTPUT_DIR=$1
|
OUTPUT_DIR=$1
|
||||||
CC=$2
|
CC=$2
|
||||||
SRC_DIR=$3
|
SRC_DIR=$3
|
||||||
SRC_NAME=$4
|
SRC_FILE=$4
|
||||||
CFLAGS=$5
|
CFLAGS=$5
|
||||||
INPUT_FILE=${SRC_DIR}/mlx/backend/metal/kernels/${SRC_NAME}.h
|
SRC_NAME=$(basename -- "${SRC_FILE}")
|
||||||
|
INPUT_FILE=${SRC_DIR}/mlx/backend/metal/kernels/${SRC_FILE}.h
|
||||||
OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp
|
OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp
|
||||||
|
|
||||||
mkdir -p $OUTPUT_DIR
|
mkdir -p $OUTPUT_DIR
|
||||||
|
@ -7,6 +7,7 @@
|
|||||||
|
|
||||||
#include "mlx/backend/metal/copy.h"
|
#include "mlx/backend/metal/copy.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/backend/metal/kernels.h"
|
||||||
#include "mlx/backend/metal/kernels/defines.h"
|
#include "mlx/backend/metal/kernels/defines.h"
|
||||||
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
|
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
|
||||||
#include "mlx/backend/metal/matmul.h"
|
#include "mlx/backend/metal/matmul.h"
|
||||||
@ -336,7 +337,19 @@ void steel_matmul_conv_groups(
|
|||||||
|
|
||||||
// Encode and dispatch kernel
|
// Encode and dispatch kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
|
auto kernel = get_steel_gemm_fused_kernel(
|
||||||
|
d,
|
||||||
|
base_name,
|
||||||
|
hash_name,
|
||||||
|
func_consts,
|
||||||
|
out,
|
||||||
|
transpose_a,
|
||||||
|
transpose_b,
|
||||||
|
bm,
|
||||||
|
bn,
|
||||||
|
bk,
|
||||||
|
wm,
|
||||||
|
wn);
|
||||||
|
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
@ -458,17 +471,31 @@ void steel_matmul(
|
|||||||
C_split.set_data(allocator::malloc_or_wait(C_split.nbytes()));
|
C_split.set_data(allocator::malloc_or_wait(C_split.nbytes()));
|
||||||
copies.push_back(C_split);
|
copies.push_back(C_split);
|
||||||
|
|
||||||
|
bool mn_aligned = M % bm == 0 && N % bn == 0;
|
||||||
|
bool k_aligned = K % bk == 0;
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n')
|
kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n')
|
||||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||||
<< type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
<< type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||||
<< "_wm" << wm << "_wn" << wn << "_MN_"
|
<< "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n")
|
||||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
|
<< "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned";
|
||||||
<< ((K % bk == 0) ? "t" : "n") << "aligned";
|
|
||||||
|
|
||||||
// Encode and dispatch gemm kernel
|
// Encode and dispatch gemm kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto kernel = d.get_kernel(kname.str());
|
auto kernel = get_steel_gemm_splitk_kernel(
|
||||||
|
d,
|
||||||
|
kname.str(),
|
||||||
|
a,
|
||||||
|
C_split,
|
||||||
|
transpose_a,
|
||||||
|
transpose_b,
|
||||||
|
bm,
|
||||||
|
bn,
|
||||||
|
bk,
|
||||||
|
wm,
|
||||||
|
wn,
|
||||||
|
mn_aligned,
|
||||||
|
k_aligned);
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
int tn = (N + bn - 1) / bn;
|
int tn = (N + bn - 1) / bn;
|
||||||
@ -504,10 +531,11 @@ void steel_matmul(
|
|||||||
static_cast<const MTL::Resource*>(C_split.buffer().ptr());
|
static_cast<const MTL::Resource*>(C_split.buffer().ptr());
|
||||||
const class MTL::Resource* const resources[1] = {c_split_buf};
|
const class MTL::Resource* const resources[1] = {c_split_buf};
|
||||||
compute_encoder->memoryBarrier(resources, 1);
|
compute_encoder->memoryBarrier(resources, 1);
|
||||||
|
auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
|
||||||
|
type_to_name(C_split);
|
||||||
|
|
||||||
auto kernel = d.get_kernel(
|
auto kernel = get_steel_gemm_splitk_accum_kernel(
|
||||||
"steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
|
d, kernel_name, C_split, out, false);
|
||||||
type_to_name(C_split));
|
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
// Set the arguments for the kernel
|
// Set the arguments for the kernel
|
||||||
@ -587,7 +615,19 @@ void steel_matmul(
|
|||||||
|
|
||||||
// Encode and dispatch kernel
|
// Encode and dispatch kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
|
auto kernel = get_steel_gemm_fused_kernel(
|
||||||
|
d,
|
||||||
|
base_name,
|
||||||
|
hash_name,
|
||||||
|
func_consts,
|
||||||
|
out,
|
||||||
|
transpose_a,
|
||||||
|
transpose_b,
|
||||||
|
bm,
|
||||||
|
bn,
|
||||||
|
bk,
|
||||||
|
wm,
|
||||||
|
wn);
|
||||||
|
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
@ -1053,17 +1093,33 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
C_split.set_data(allocator::malloc_or_wait(C_split.nbytes()));
|
C_split.set_data(allocator::malloc_or_wait(C_split.nbytes()));
|
||||||
copies.push_back(C_split);
|
copies.push_back(C_split);
|
||||||
|
|
||||||
|
bool mn_aligned = M % bm == 0 && N % bn == 0;
|
||||||
|
bool k_aligned = K % bk == 0;
|
||||||
|
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n')
|
kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n')
|
||||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||||
<< type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
<< type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||||
<< "_wm" << wm << "_wn" << wn << "_MN_"
|
<< "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n")
|
||||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
|
<< "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned";
|
||||||
<< ((K % bk == 0) ? "t" : "n") << "aligned";
|
|
||||||
|
|
||||||
// Encode and dispatch gemm kernel
|
// Encode and dispatch gemm kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto kernel = d.get_kernel(kname.str());
|
auto kernel = get_steel_gemm_splitk_kernel(
|
||||||
|
d,
|
||||||
|
kname.str(),
|
||||||
|
a,
|
||||||
|
C_split,
|
||||||
|
transpose_a,
|
||||||
|
transpose_b,
|
||||||
|
bm,
|
||||||
|
bn,
|
||||||
|
bk,
|
||||||
|
wm,
|
||||||
|
wn,
|
||||||
|
mn_aligned,
|
||||||
|
k_aligned);
|
||||||
|
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
int tn = (N + bn - 1) / bn;
|
int tn = (N + bn - 1) / bn;
|
||||||
@ -1095,9 +1151,11 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
// Do accum kernel
|
// Do accum kernel
|
||||||
{
|
{
|
||||||
auto kernel = d.get_kernel(
|
auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
|
||||||
"steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
|
type_to_name(C_split) + "_axbpy";
|
||||||
type_to_name(C_split) + "_axpby");
|
auto kernel = get_steel_gemm_splitk_accum_kernel(
|
||||||
|
d, kernel_name, C_split, out, true);
|
||||||
|
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
// Set the arguments for the kernel
|
// Set the arguments for the kernel
|
||||||
@ -1182,7 +1240,19 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
// Encode and dispatch kernel
|
// Encode and dispatch kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
|
auto kernel = get_steel_gemm_fused_kernel(
|
||||||
|
d,
|
||||||
|
base_name,
|
||||||
|
hash_name,
|
||||||
|
func_consts,
|
||||||
|
out,
|
||||||
|
transpose_a,
|
||||||
|
transpose_b,
|
||||||
|
bm,
|
||||||
|
bn,
|
||||||
|
bk,
|
||||||
|
wm,
|
||||||
|
wn);
|
||||||
|
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
@ -1348,6 +1418,8 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
// Determine dispatch kernel
|
// Determine dispatch kernel
|
||||||
int bm = block_size_, bn = block_size_, bk = 16;
|
int bm = block_size_, bn = block_size_, bk = 16;
|
||||||
int wm = 2, wn = 2;
|
int wm = 2, wn = 2;
|
||||||
|
bool mn_aligned = M % bm == 0 && N % bn == 0;
|
||||||
|
bool k_aligned = K % bk == 0;
|
||||||
|
|
||||||
// Prepare kernel name
|
// Prepare kernel name
|
||||||
std::string out_mask_nm = has_out_mask ? type_to_name(inputs[2]) : "nomask";
|
std::string out_mask_nm = has_out_mask ? type_to_name(inputs[2]) : "nomask";
|
||||||
@ -1358,13 +1430,26 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
<< op_mask_nm << "_" << (transpose_a ? 't' : 'n')
|
<< op_mask_nm << "_" << (transpose_a ? 't' : 'n')
|
||||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||||
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||||
<< "_wm" << wm << "_wn" << wn << "_MN_"
|
<< "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n")
|
||||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
|
<< "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned";
|
||||||
<< ((K % bk == 0) ? "t" : "n") << "aligned";
|
|
||||||
|
|
||||||
// Encode and dispatch kernel
|
// Encode and dispatch kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto kernel = d.get_kernel(kname.str());
|
auto kernel = get_steel_gemm_masked_kernel(
|
||||||
|
d,
|
||||||
|
kname.str(),
|
||||||
|
out,
|
||||||
|
has_out_mask ? std::optional<array>{inputs[2]} : std::nullopt,
|
||||||
|
has_op_mask ? std::optional<array>{inputs.back()} : std::nullopt,
|
||||||
|
transpose_a,
|
||||||
|
transpose_b,
|
||||||
|
bm,
|
||||||
|
bn,
|
||||||
|
bk,
|
||||||
|
wm,
|
||||||
|
wn,
|
||||||
|
mn_aligned,
|
||||||
|
k_aligned);
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
// Use problem size to determine threadblock swizzle
|
// Use problem size to determine threadblock swizzle
|
||||||
@ -1720,7 +1805,19 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
// Encode and dispatch kernel
|
// Encode and dispatch kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
|
auto kernel = get_steel_gemm_fused_kernel(
|
||||||
|
d,
|
||||||
|
base_name,
|
||||||
|
hash_name,
|
||||||
|
func_consts,
|
||||||
|
out,
|
||||||
|
transpose_a,
|
||||||
|
transpose_b,
|
||||||
|
bm,
|
||||||
|
bn,
|
||||||
|
bk,
|
||||||
|
wm,
|
||||||
|
wn);
|
||||||
|
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
|
@ -103,4 +103,90 @@ MTL::ComputePipelineState* get_reduce_kernel(
|
|||||||
return d.get_kernel(kernel_name);
|
return d.get_kernel(kernel_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const std::string& hash_name,
|
||||||
|
const metal::MTLFCList& func_consts,
|
||||||
|
const array&,
|
||||||
|
bool,
|
||||||
|
bool,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
int) {
|
||||||
|
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const array&,
|
||||||
|
const array&,
|
||||||
|
bool,
|
||||||
|
bool,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
bool,
|
||||||
|
bool) {
|
||||||
|
return d.get_kernel(kernel_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const array&,
|
||||||
|
const array&,
|
||||||
|
bool) {
|
||||||
|
return d.get_kernel(kernel_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const array&,
|
||||||
|
const std::optional<array>& mask_out,
|
||||||
|
const std::optional<array>& mask_op,
|
||||||
|
bool,
|
||||||
|
bool,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
bool,
|
||||||
|
bool) {
|
||||||
|
return d.get_kernel(kernel_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_steel_conv_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const array&,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
bool) {
|
||||||
|
return d.get_kernel(kernel_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const array&,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
int) {
|
||||||
|
return d.get_kernel(kernel_name);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -3968,9 +3968,9 @@ void init_ops(nb::module_& m) {
|
|||||||
a (array): Input array or scalar.
|
a (array): Input array or scalar.
|
||||||
b (array): Input array or scalar.
|
b (array): Input array or scalar.
|
||||||
block_size (int): Size of blocks to be masked. Must be ``32`` or ``64`` (default: ``64``)
|
block_size (int): Size of blocks to be masked. Must be ``32`` or ``64`` (default: ``64``)
|
||||||
mask_out (array, optional): Boolean mask for output (default: ``None``)
|
mask_out (array, optional): Mask for output (default: ``None``)
|
||||||
mask_lhs (array, optional): Boolean mask for a (default: ``None``)
|
mask_lhs (array, optional): Mask for a (default: ``None``)
|
||||||
mask_rhs (array, optional): Boolean mask for b (default: ``None``)
|
mask_rhs (array, optional): Mask for b (default: ``None``)
|
||||||
|
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
|
@ -556,8 +556,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||||
|
# Batched matmul with simple broadcast
|
||||||
# # Batched matmul with simple broadcast
|
|
||||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||||
b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32)
|
b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32)
|
||||||
|
|
||||||
@ -573,7 +572,6 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||||
|
|
||||||
# Matmul with vector
|
# Matmul with vector
|
||||||
a_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
|
a_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
|
||||||
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 128)).astype(np.float32)
|
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 128)).astype(np.float32)
|
||||||
|
Loading…
Reference in New Issue
Block a user