mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Update GEMM (#424)
* Organize and collect metal subroutine templates and elements in `metal/kernels/steel/` * Update gemm elements for better performance * Add split-K specialization for gemm * Add `addmm` primitive, op and bindings for fused matmul and bias addition * Update tests and benchmarks as needed
This commit is contained in:
parent
556cdf0e06
commit
78102a47ad
@ -166,13 +166,13 @@ if __name__ == "__main__":
|
|||||||
dtypes = ("float32", "float16")
|
dtypes = ("float32", "float16")
|
||||||
transposes = ("nn", "nt", "tn")
|
transposes = ("nn", "nt", "tn")
|
||||||
shapes = (
|
shapes = (
|
||||||
|
(16, 234, 768, 3072),
|
||||||
|
(1, 64, 64, 25344),
|
||||||
(16, 1024, 1024, 1024),
|
(16, 1024, 1024, 1024),
|
||||||
(1, 1024, 1024, 2048),
|
(1, 1024, 1024, 2048),
|
||||||
(4, 1024, 1024, 4096),
|
(4, 1024, 1024, 4096),
|
||||||
(4, 1024, 4096, 1024),
|
(4, 1024, 4096, 1024),
|
||||||
(1, 4096, 4096, 4096),
|
(1, 4096, 4096, 4096),
|
||||||
(15, 1023, 1023, 1023),
|
|
||||||
(17, 1025, 1025, 1025),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for dtype in dtypes:
|
for dtype in dtypes:
|
||||||
|
@ -257,6 +257,13 @@ def linear(w, b, x):
|
|||||||
mx.eval(ys)
|
mx.eval(ys)
|
||||||
|
|
||||||
|
|
||||||
|
def linear_fused(w, b, x):
|
||||||
|
ys = []
|
||||||
|
for i in range(10):
|
||||||
|
ys.append(mx.addmm(b, x, mx.transpose(w, (1, 0))))
|
||||||
|
mx.eval(ys)
|
||||||
|
|
||||||
|
|
||||||
def rope(x):
|
def rope(x):
|
||||||
*_, N, D = x.shape
|
*_, N, D = x.shape
|
||||||
ys = []
|
ys = []
|
||||||
@ -397,7 +404,10 @@ if __name__ == "__main__":
|
|||||||
print(bench(quant_matmul[args.benchmark], *xs))
|
print(bench(quant_matmul[args.benchmark], *xs))
|
||||||
|
|
||||||
elif args.benchmark == "linear":
|
elif args.benchmark == "linear":
|
||||||
print(bench(linear, *xs))
|
if args.fused:
|
||||||
|
print(bench(linear_fused, *xs))
|
||||||
|
else:
|
||||||
|
print(bench(linear, *xs))
|
||||||
|
|
||||||
elif args.benchmark == "sum_axis":
|
elif args.benchmark == "sum_axis":
|
||||||
print(bench(reduction, "sum", axis, x))
|
print(bench(reduction, "sum", axis, x))
|
||||||
|
@ -29,12 +29,16 @@ std::tuple<bool, size_t, array> check_transpose(const array& arr) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void matmul_cblas(const array& a_pre, const array& b_pre, array& out) {
|
inline void matmul_cblas_general(
|
||||||
|
const array& a_pre,
|
||||||
|
const array& b_pre,
|
||||||
|
array& out,
|
||||||
|
float alpha = 1.0f,
|
||||||
|
float beta = 0.0f) {
|
||||||
if (out.dtype() != float32) {
|
if (out.dtype() != float32) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[matmul_cblas] on CPU currently only supports float32");
|
"[matmul_cblas] on CPU currently only supports float32");
|
||||||
}
|
}
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
|
||||||
|
|
||||||
auto [a_transposed, lda, a] = check_transpose(a_pre);
|
auto [a_transposed, lda, a] = check_transpose(a_pre);
|
||||||
auto [b_transposed, ldb, b] = check_transpose(b_pre);
|
auto [b_transposed, ldb, b] = check_transpose(b_pre);
|
||||||
@ -50,21 +54,34 @@ inline void matmul_cblas(const array& a_pre, const array& b_pre, array& out) {
|
|||||||
M,
|
M,
|
||||||
N,
|
N,
|
||||||
K,
|
K,
|
||||||
1.0f, // alpha
|
alpha, // alpha
|
||||||
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
|
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
|
||||||
lda,
|
lda,
|
||||||
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
|
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
|
||||||
ldb,
|
ldb,
|
||||||
0.0f, // beta
|
beta, // beta
|
||||||
out.data<float>() + M * N * i,
|
out.data<float>() + M * N * i,
|
||||||
out.shape(-1) // ldc
|
out.shape(-1) // ldc
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) {
|
inline void matmul_cblas(const array& a_pre, const array& b_pre, array& out) {
|
||||||
// TODO: Update to utilize BNNS broadcasting
|
if (out.dtype() != float32) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[matmul_cblas] on CPU currently only supports float32");
|
||||||
|
}
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
return matmul_cblas_general(a_pre, b_pre, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void matmul_bnns_general(
|
||||||
|
const array& a_pre,
|
||||||
|
const array& b_pre,
|
||||||
|
array& out,
|
||||||
|
float alpha = 1.0f,
|
||||||
|
float beta = 0.0f) {
|
||||||
|
// TODO: Update to utilize BNNS broadcasting
|
||||||
|
|
||||||
auto [a_transposed, lda, a] = check_transpose(a_pre);
|
auto [a_transposed, lda, a] = check_transpose(a_pre);
|
||||||
auto [b_transposed, ldb, b] = check_transpose(b_pre);
|
auto [b_transposed, ldb, b] = check_transpose(b_pre);
|
||||||
@ -75,8 +92,8 @@ inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) {
|
|||||||
BNNSDataType bnns_dtype = to_bnns_dtype(out.dtype());
|
BNNSDataType bnns_dtype = to_bnns_dtype(out.dtype());
|
||||||
|
|
||||||
const BNNSLayerParametersBroadcastMatMul gemm_params{
|
const BNNSLayerParametersBroadcastMatMul gemm_params{
|
||||||
/* float alpha = */ 1.0,
|
/* float alpha = */ alpha,
|
||||||
/* float beta = */ 0.0,
|
/* float beta = */ beta,
|
||||||
/* bool transA = */ a_transposed,
|
/* bool transA = */ a_transposed,
|
||||||
/* bool transB = */ b_transposed,
|
/* bool transB = */ b_transposed,
|
||||||
/* bool quadratic = */ false,
|
/* bool quadratic = */ false,
|
||||||
@ -157,6 +174,12 @@ inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) {
|
|||||||
BNNSFilterDestroy(bnns_filter);
|
BNNSFilterDestroy(bnns_filter);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) {
|
||||||
|
// TODO: Update to utilize BNNS broadcasting
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
return matmul_bnns_general(a_pre, b_pre, out);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
@ -166,4 +189,16 @@ void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
return matmul_bnns(inputs[0], inputs[1], out);
|
return matmul_bnns(inputs[0], inputs[1], out);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
// Fill output with C
|
||||||
|
auto& c = inputs[2];
|
||||||
|
CopyType ctype = c.data_size() == 1 ? CopyType::Scalar : CopyType::General;
|
||||||
|
copy(c, out, ctype);
|
||||||
|
|
||||||
|
if (out.dtype() == float32) {
|
||||||
|
return matmul_cblas_general(inputs[0], inputs[1], out, alpha_, beta_);
|
||||||
|
}
|
||||||
|
return matmul_bnns_general(inputs[0], inputs[1], out, alpha_, beta_);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
@ -98,16 +98,14 @@ DEFAULT(Tanh)
|
|||||||
DEFAULT(Transpose)
|
DEFAULT(Transpose)
|
||||||
DEFAULT_MULTI(DivMod)
|
DEFAULT_MULTI(DivMod)
|
||||||
|
|
||||||
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
namespace {
|
||||||
if (out.dtype() != float32) {
|
|
||||||
throw std::runtime_error(
|
|
||||||
"[Matmul::eval_cpu] Currently only supports float32.");
|
|
||||||
}
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
|
||||||
|
|
||||||
auto& a_pre = inputs[0];
|
|
||||||
auto& b_pre = inputs[1];
|
|
||||||
|
|
||||||
|
inline void matmul_common_general(
|
||||||
|
const array& a_pre,
|
||||||
|
const array& b_pre,
|
||||||
|
array& out,
|
||||||
|
float alpha = 1.0f,
|
||||||
|
float beta = 0.0f) {
|
||||||
auto check_transpose = [](const array& arr) {
|
auto check_transpose = [](const array& arr) {
|
||||||
auto stx = arr.strides()[arr.ndim() - 2];
|
auto stx = arr.strides()[arr.ndim() - 2];
|
||||||
auto sty = arr.strides()[arr.ndim() - 1];
|
auto sty = arr.strides()[arr.ndim() - 1];
|
||||||
@ -125,9 +123,10 @@ void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
auto [a_transposed, lda, a] = check_transpose(a_pre);
|
auto [a_transposed, lda, a] = check_transpose(a_pre);
|
||||||
auto [b_transposed, ldb, b] = check_transpose(b_pre);
|
auto [b_transposed, ldb, b] = check_transpose(b_pre);
|
||||||
int M = a.shape(-2);
|
size_t M = a.shape(-2);
|
||||||
int N = b.shape(-1);
|
size_t N = b.shape(-1);
|
||||||
int K = a.shape(-1);
|
size_t K = a.shape(-1);
|
||||||
|
|
||||||
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
||||||
cblas_sgemm(
|
cblas_sgemm(
|
||||||
CblasRowMajor,
|
CblasRowMajor,
|
||||||
@ -136,16 +135,41 @@ void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
M,
|
M,
|
||||||
N,
|
N,
|
||||||
K,
|
K,
|
||||||
1.0f, // alpha
|
alpha, // alpha
|
||||||
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
|
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
|
||||||
lda,
|
lda,
|
||||||
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
|
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
|
||||||
ldb,
|
ldb,
|
||||||
0.0f, // beta
|
beta, // beta
|
||||||
out.data<float>() + M * N * i,
|
out.data<float>() + M * N * i,
|
||||||
out.shape(-1) // ldc
|
out.shape(-1) // ldc
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
if (out.dtype() != float32) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[Matmul::eval_cpu] Currently only supports float32.");
|
||||||
|
}
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
return matmul_common_general(inputs[0], inputs[1], out);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
if (out.dtype() != float32) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[AddMM::eval_cpu] Currently only supports float32.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fill output with C
|
||||||
|
auto& c = inputs[2];
|
||||||
|
CopyType ctype = c.data_size() == 1 ? CopyType::Scalar : CopyType::General;
|
||||||
|
copy(c, out, ctype);
|
||||||
|
|
||||||
|
return matmul_common_general(inputs[0], inputs[1], out, alpha_, beta_);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -70,7 +70,7 @@ void explicit_gemm_conv_1D_gpu(
|
|||||||
|
|
||||||
// Perform gemm
|
// Perform gemm
|
||||||
std::vector<array> copies = {in_padded, in_strided};
|
std::vector<array> copies = {in_padded, in_strided};
|
||||||
mlx_matmul(
|
return steel_matmul(
|
||||||
s,
|
s,
|
||||||
d,
|
d,
|
||||||
/*a = */ in_strided,
|
/*a = */ in_strided,
|
||||||
@ -262,7 +262,7 @@ void explicit_gemm_conv_2D_gpu(
|
|||||||
|
|
||||||
// Perform gemm
|
// Perform gemm
|
||||||
std::vector<array> copies = {in_padded, in_strided};
|
std::vector<array> copies = {in_padded, in_strided};
|
||||||
mlx_matmul(
|
return steel_matmul(
|
||||||
s,
|
s,
|
||||||
d,
|
d,
|
||||||
/*a = */ in_strided,
|
/*a = */ in_strided,
|
||||||
@ -411,7 +411,7 @@ void winograd_conv_2D_gpu(
|
|||||||
copies_w.push_back(out_wg);
|
copies_w.push_back(out_wg);
|
||||||
{
|
{
|
||||||
std::vector<array> empty_copies;
|
std::vector<array> empty_copies;
|
||||||
mlx_matmul(
|
steel_matmul(
|
||||||
s,
|
s,
|
||||||
d,
|
d,
|
||||||
/*a = */ inp_wg,
|
/*a = */ inp_wg,
|
||||||
|
@ -18,7 +18,6 @@ set(
|
|||||||
"binary_two"
|
"binary_two"
|
||||||
"conv"
|
"conv"
|
||||||
"copy"
|
"copy"
|
||||||
"gemm"
|
|
||||||
"gemv"
|
"gemv"
|
||||||
"quantized"
|
"quantized"
|
||||||
"random"
|
"random"
|
||||||
@ -30,26 +29,27 @@ set(
|
|||||||
"indexing"
|
"indexing"
|
||||||
)
|
)
|
||||||
|
|
||||||
function(build_kernel KERNEL)
|
function(build_kernel_base TARGET SRCFILE DEPS)
|
||||||
set(SRCFILE ${CMAKE_CURRENT_SOURCE_DIR}/${KERNEL}.metal)
|
|
||||||
set(HEADERS_PADDED ${HEADERS})
|
|
||||||
if(${KERNEL} STREQUAL "gemm")
|
|
||||||
set(HEADERS_PADDED ${HEADERS_PADDED} ${CMAKE_CURRENT_SOURCE_DIR}/gemm/gemm.h)
|
|
||||||
endif()
|
|
||||||
if(${KERNEL} STREQUAL "conv")
|
|
||||||
set(HEADERS_PADDED ${HEADERS_PADDED} ${CMAKE_CURRENT_SOURCE_DIR}/gemm/conv.h)
|
|
||||||
endif()
|
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
COMMAND xcrun -sdk macosx metal -Wall -Wextra
|
COMMAND xcrun -sdk macosx metal -Wall -Wextra
|
||||||
-fno-fast-math
|
-fno-fast-math
|
||||||
-c ${SRCFILE}
|
-c ${SRCFILE}
|
||||||
-I${PROJECT_SOURCE_DIR}
|
-I${PROJECT_SOURCE_DIR}
|
||||||
-o ${KERNEL}.air
|
-o ${TARGET}.air
|
||||||
DEPENDS ${SRCFILE} ${HEADERS_PADDED}
|
DEPENDS ${SRCFILE} ${DEPS}
|
||||||
OUTPUT ${KERNEL}.air
|
OUTPUT ${TARGET}.air
|
||||||
COMMENT "Building ${KERNEL}.air"
|
COMMENT "Building ${TARGET}.air"
|
||||||
VERBATIM
|
VERBATIM
|
||||||
)
|
)
|
||||||
|
endfunction(build_kernel_base)
|
||||||
|
|
||||||
|
function(build_kernel KERNEL)
|
||||||
|
set(SRCFILE ${CMAKE_CURRENT_SOURCE_DIR}/${KERNEL}.metal)
|
||||||
|
set(HEADERS_PADDED ${HEADERS})
|
||||||
|
if(${KERNEL} STREQUAL "conv")
|
||||||
|
set(HEADERS_PADDED ${HEADERS_PADDED} ${CMAKE_CURRENT_SOURCE_DIR}/conv.h)
|
||||||
|
endif()
|
||||||
|
build_kernel_base(${KERNEL} ${SRCFILE} "${HEADERS_PADDED}")
|
||||||
endfunction(build_kernel)
|
endfunction(build_kernel)
|
||||||
|
|
||||||
foreach(KERNEL ${KERNELS})
|
foreach(KERNEL ${KERNELS})
|
||||||
@ -57,6 +57,15 @@ foreach(KERNEL ${KERNELS})
|
|||||||
set(KERNEL_AIR ${KERNEL}.air ${KERNEL_AIR})
|
set(KERNEL_AIR ${KERNEL}.air ${KERNEL_AIR})
|
||||||
endforeach()
|
endforeach()
|
||||||
|
|
||||||
|
file(GLOB_RECURSE STEEL_KERNELS ${CMAKE_CURRENT_SOURCE_DIR}/steel/*.metal)
|
||||||
|
file(GLOB_RECURSE STEEL_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/steel/*.h)
|
||||||
|
|
||||||
|
foreach(KERNEL ${STEEL_KERNELS})
|
||||||
|
cmake_path(GET KERNEL STEM TARGET)
|
||||||
|
build_kernel_base(${TARGET} ${KERNEL} "${STEEL_HEADERS}")
|
||||||
|
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR})
|
||||||
|
endforeach()
|
||||||
|
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
|
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
|
||||||
COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o ${MLX_METAL_PATH}/mlx.metallib
|
COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o ${MLX_METAL_PATH}/mlx.metallib
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
#include "mlx/backend/metal/kernels/conv_params.h"
|
#include "mlx/backend/metal/kernels/conv_params.h"
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/gemm/conv.h"
|
#include "mlx/backend/metal/kernels/conv.h"
|
||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
|
@ -1,538 +0,0 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <metal_simdgroup>
|
|
||||||
#include <metal_simdgroup_matrix>
|
|
||||||
#include <metal_stdlib>
|
|
||||||
|
|
||||||
#define MLX_MTL_CONST static constant constexpr const
|
|
||||||
|
|
||||||
using namespace metal;
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Loading helper
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
int BROWS,
|
|
||||||
int BCOLS,
|
|
||||||
int BK,
|
|
||||||
int vec_size,
|
|
||||||
int tgp_size,
|
|
||||||
bool transpose,
|
|
||||||
bool ldK,
|
|
||||||
int tgp_padding = 0>
|
|
||||||
struct BlockLoader {
|
|
||||||
// Destination dimensions
|
|
||||||
MLX_MTL_CONST int dst_fd = transpose ? BCOLS : BROWS;
|
|
||||||
MLX_MTL_CONST int dst_ld = (transpose ? BROWS : BCOLS) + tgp_padding;
|
|
||||||
MLX_MTL_CONST int n_vecs = (transpose ? BROWS : BCOLS) / vec_size;
|
|
||||||
|
|
||||||
// Stride along block row within the block
|
|
||||||
MLX_MTL_CONST int bstride = tgp_size / n_vecs;
|
|
||||||
|
|
||||||
// Leading dimension for src
|
|
||||||
const int src_ld;
|
|
||||||
// Stride along reduction axis between blocks
|
|
||||||
const int tstride;
|
|
||||||
|
|
||||||
// Thread location indices
|
|
||||||
const short thread_idx;
|
|
||||||
const short bi;
|
|
||||||
const short bj;
|
|
||||||
|
|
||||||
// threadgroup and device memory
|
|
||||||
threadgroup T* dst;
|
|
||||||
const device T* src;
|
|
||||||
|
|
||||||
/* Constructor */
|
|
||||||
METAL_FUNC BlockLoader(
|
|
||||||
const device T* src_,
|
|
||||||
const int src_ld_,
|
|
||||||
threadgroup T* dst_,
|
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
|
||||||
: src_ld(src_ld_),
|
|
||||||
tstride(
|
|
||||||
BK * ((int)(transpose ^ !ldK) * src_ld + (int)(transpose ^ ldK))),
|
|
||||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
|
||||||
bi(thread_idx / n_vecs),
|
|
||||||
bj(vec_size * (thread_idx % n_vecs)),
|
|
||||||
dst(dst_ + bi * dst_ld + bj),
|
|
||||||
src(src_ + bi * src_ld + bj) {}
|
|
||||||
|
|
||||||
/* Load from device memory into threadgroup memory - without bound checking */
|
|
||||||
METAL_FUNC void load_unsafe() const {
|
|
||||||
#pragma clang loop unroll(full)
|
|
||||||
for (short i = 0; i < dst_fd; i += bstride) {
|
|
||||||
#pragma clang loop unroll(full)
|
|
||||||
for (short j = 0; j < vec_size; j++) {
|
|
||||||
dst[i * dst_ld + j] = src[i * src_ld + j];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Load from device memory into threadgroup memory - with bound checking */
|
|
||||||
METAL_FUNC void load_safe(short2 src_tile_dim) const {
|
|
||||||
src_tile_dim = transpose ? src_tile_dim.yx : src_tile_dim.xy;
|
|
||||||
|
|
||||||
// Iterate over rows of block
|
|
||||||
#pragma clang loop unroll(full)
|
|
||||||
for (short i = 0; i < dst_fd; i += bstride) {
|
|
||||||
// Row is in bounds, we check against column
|
|
||||||
if ((bi + i) < src_tile_dim.y) {
|
|
||||||
// Use fast thread memory for bound checks
|
|
||||||
short tmp_idx[vec_size];
|
|
||||||
T tmp_val[vec_size];
|
|
||||||
|
|
||||||
// Make sure tmp_idx only contains valid indices
|
|
||||||
#pragma clang loop unroll(full)
|
|
||||||
for (short j = 0; j < vec_size; j++) {
|
|
||||||
tmp_idx[j] = bj + j < src_tile_dim.x ? j : 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read all valid indices into tmp_val
|
|
||||||
#pragma clang loop unroll(full)
|
|
||||||
for (short j = 0; j < vec_size; j++) {
|
|
||||||
tmp_val[j] = src[i * src_ld + tmp_idx[j]];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Zero out unneeded values
|
|
||||||
#pragma clang loop unroll(full)
|
|
||||||
for (short j = 0; j < vec_size; j++) {
|
|
||||||
tmp_val[j] = bj + j < src_tile_dim.x ? tmp_val[j] : T(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Copy values to threadgroup memory
|
|
||||||
#pragma clang loop unroll(full)
|
|
||||||
for (short j = 0; j < vec_size; j++) {
|
|
||||||
dst[i * dst_ld + j] = tmp_val[j];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Row is out of bounds, we just fill tgp memory with zeros
|
|
||||||
else {
|
|
||||||
#pragma clang loop unroll(full)
|
|
||||||
for (short j = 0; j < vec_size; j++) {
|
|
||||||
dst[i * dst_ld + j] = T(0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Iteration helper */
|
|
||||||
METAL_FUNC void next() {
|
|
||||||
src += tstride;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Transforms
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
template <typename OutT, typename InT>
|
|
||||||
struct TransformNone {
|
|
||||||
static METAL_FUNC OutT apply(InT x) {
|
|
||||||
return static_cast<OutT>(x);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct AccumHelper {
|
|
||||||
typedef float accum_type;
|
|
||||||
};
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// MMA helper
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
int BM,
|
|
||||||
int BN,
|
|
||||||
int BK,
|
|
||||||
int WM,
|
|
||||||
int WN,
|
|
||||||
bool transpose_a,
|
|
||||||
bool transpose_b,
|
|
||||||
int tgp_padding_a = 0,
|
|
||||||
int tgp_padding_b = 0,
|
|
||||||
typename AccumType = typename AccumHelper<T>::accum_type,
|
|
||||||
typename Epilogue = TransformNone<T, AccumType>>
|
|
||||||
struct BlockMMA {
|
|
||||||
// Warp tile size along M
|
|
||||||
MLX_MTL_CONST int TM = BM / (WM * 8);
|
|
||||||
// Warp tile size along N
|
|
||||||
MLX_MTL_CONST int TN = BN / (WN * 8);
|
|
||||||
|
|
||||||
// Warp tile simdgroup matrix strides along M
|
|
||||||
MLX_MTL_CONST int TM_stride = 8 * WM;
|
|
||||||
// Warp tile simdgroup matrix strides along M
|
|
||||||
MLX_MTL_CONST int TN_stride = 8 * WN;
|
|
||||||
|
|
||||||
// Leading dimensions of threadgroup A, B blocks
|
|
||||||
MLX_MTL_CONST int lda_tgp = (transpose_a ? BM : BK) + tgp_padding_a;
|
|
||||||
MLX_MTL_CONST int ldb_tgp = (transpose_b ? BK : BN) + tgp_padding_b;
|
|
||||||
|
|
||||||
// Strides of A, B along reduction axis
|
|
||||||
MLX_MTL_CONST short simd_stride_a =
|
|
||||||
transpose_a ? TM_stride : TM_stride * lda_tgp;
|
|
||||||
MLX_MTL_CONST short simd_stride_b =
|
|
||||||
transpose_b ? TN_stride * ldb_tgp : TN_stride;
|
|
||||||
|
|
||||||
// Jump between elements
|
|
||||||
MLX_MTL_CONST short jump_a = transpose_a ? lda_tgp : 1;
|
|
||||||
MLX_MTL_CONST short jump_b = transpose_b ? ldb_tgp : 1;
|
|
||||||
|
|
||||||
// Offsets within threadgroup
|
|
||||||
const int tm;
|
|
||||||
const int tn;
|
|
||||||
|
|
||||||
// Simdgroup matrices
|
|
||||||
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
|
|
||||||
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
|
|
||||||
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
|
|
||||||
simdgroup_matrix<AccumType, 8, 8>(0)};
|
|
||||||
|
|
||||||
short sm;
|
|
||||||
short sn;
|
|
||||||
|
|
||||||
/* Constructor */
|
|
||||||
METAL_FUNC BlockMMA(
|
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]])
|
|
||||||
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
|
|
||||||
short qid = simd_lane_id / 4;
|
|
||||||
sm = (qid & 4) + (simd_lane_id / 2) % 4;
|
|
||||||
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* (BM, BK) X (BK, BN) multiply accumulate function */
|
|
||||||
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
|
|
||||||
// Iterate over BK in blocks of 8
|
|
||||||
#pragma clang loop unroll(full)
|
|
||||||
for (short kk = 0; kk < BK; kk += 8) {
|
|
||||||
short2 offset_a =
|
|
||||||
transpose_a ? short2(tm + sm, kk + sn) : short2(kk + sn, tm + sm);
|
|
||||||
short2 offset_b =
|
|
||||||
transpose_b ? short2(kk + sm, tn + sn) : short2(tn + sn, kk + sm);
|
|
||||||
|
|
||||||
const threadgroup T* As__ = As + offset_a.y * lda_tgp + offset_a.x;
|
|
||||||
const threadgroup T* Bs__ = Bs + offset_b.y * ldb_tgp + offset_b.x;
|
|
||||||
|
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
|
||||||
// Load elements from threadgroup A as simdgroup matrices
|
|
||||||
#pragma clang loop unroll(full)
|
|
||||||
for (short i = 0; i < TM; i++) {
|
|
||||||
Asimd[i].thread_elements()[0] = static_cast<AccumType>(As__[0]);
|
|
||||||
Asimd[i].thread_elements()[1] = static_cast<AccumType>(As__[jump_a]);
|
|
||||||
As__ += simd_stride_a;
|
|
||||||
}
|
|
||||||
|
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
|
||||||
// Load elements from threadgroup B as simdgroup matrices
|
|
||||||
#pragma clang loop unroll(full)
|
|
||||||
for (short j = 0; j < TN; j++) {
|
|
||||||
Bsimd[j].thread_elements()[0] = static_cast<AccumType>(Bs__[0]);
|
|
||||||
Bsimd[j].thread_elements()[1] = static_cast<AccumType>(Bs__[jump_b]);
|
|
||||||
Bs__ += simd_stride_b;
|
|
||||||
}
|
|
||||||
|
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
|
||||||
// Multiply and accumulate into result simdgroup matrices
|
|
||||||
#pragma clang loop unroll(full)
|
|
||||||
for (short i = 0; i < TM; i++) {
|
|
||||||
#pragma clang loop unroll(full)
|
|
||||||
for (short j = 0; j < TN; j++) {
|
|
||||||
simdgroup_multiply_accumulate(
|
|
||||||
results[i * TN + j], Asimd[i], Bsimd[j], results[i * TN + j]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Store results from simdgroup_matrix results into device memory */
|
|
||||||
METAL_FUNC void store_result(device T* C, const int ldc) const {
|
|
||||||
#pragma clang loop unroll(full)
|
|
||||||
for (int i = 0; i < TM; i++) {
|
|
||||||
#pragma clang loop unroll(full)
|
|
||||||
for (int j = 0; j < TN; j++) {
|
|
||||||
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn] =
|
|
||||||
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
|
|
||||||
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn + 1] =
|
|
||||||
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
METAL_FUNC void
|
|
||||||
store_result_safe(device T* C, const int ldc, short2 dst_tile_dims) const {
|
|
||||||
#pragma clang loop unroll(full)
|
|
||||||
for (int i = 0; i < TM; i++) {
|
|
||||||
if (tm + i * TM_stride + sm < dst_tile_dims.y) {
|
|
||||||
#pragma clang loop unroll(full)
|
|
||||||
for (int j = 0; j < TN; j++) {
|
|
||||||
if (tn + j * TN_stride + sn < dst_tile_dims.x) {
|
|
||||||
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn] =
|
|
||||||
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (tn + j * TN_stride + sn + 1 < dst_tile_dims.x) {
|
|
||||||
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn + 1] =
|
|
||||||
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// GEMM kernels
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
int BM,
|
|
||||||
int BN,
|
|
||||||
int BK,
|
|
||||||
int WM,
|
|
||||||
int WN,
|
|
||||||
bool transpose_a,
|
|
||||||
bool transpose_b,
|
|
||||||
bool MN_aligned,
|
|
||||||
bool K_aligned,
|
|
||||||
typename AccumType = typename AccumHelper<T>::accum_type,
|
|
||||||
typename Epilogue = TransformNone<T, AccumType>>
|
|
||||||
struct GEMMKernel {
|
|
||||||
MLX_MTL_CONST short tgp_padding_a = 16 / sizeof(T);
|
|
||||||
MLX_MTL_CONST short tgp_padding_b = 16 / sizeof(T);
|
|
||||||
MLX_MTL_CONST short tgp_mem_size_a =
|
|
||||||
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
|
|
||||||
MLX_MTL_CONST short tgp_mem_size_b =
|
|
||||||
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
|
|
||||||
MLX_MTL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
|
|
||||||
|
|
||||||
MLX_MTL_CONST short tgp_size = WM * WN * 32;
|
|
||||||
MLX_MTL_CONST short vec_size = (BM == 64 && BN == 64) ? 8 : 4;
|
|
||||||
|
|
||||||
using loader_a_t = BlockLoader<
|
|
||||||
T,
|
|
||||||
BM,
|
|
||||||
BK,
|
|
||||||
BK,
|
|
||||||
vec_size,
|
|
||||||
tgp_size,
|
|
||||||
transpose_a,
|
|
||||||
true,
|
|
||||||
tgp_padding_a>;
|
|
||||||
using loader_b_t = BlockLoader<
|
|
||||||
T,
|
|
||||||
BK,
|
|
||||||
BN,
|
|
||||||
BK,
|
|
||||||
vec_size,
|
|
||||||
tgp_size,
|
|
||||||
transpose_b,
|
|
||||||
false,
|
|
||||||
tgp_padding_b>;
|
|
||||||
using mma_t = BlockMMA<
|
|
||||||
T,
|
|
||||||
BM,
|
|
||||||
BN,
|
|
||||||
BK,
|
|
||||||
WM,
|
|
||||||
WN,
|
|
||||||
transpose_a,
|
|
||||||
transpose_b,
|
|
||||||
tgp_padding_a,
|
|
||||||
tgp_padding_b,
|
|
||||||
AccumType,
|
|
||||||
Epilogue>;
|
|
||||||
|
|
||||||
/* Main kernel function */
|
|
||||||
static METAL_FUNC void run(
|
|
||||||
const device T* A [[buffer(0)]],
|
|
||||||
const device T* B [[buffer(1)]],
|
|
||||||
device T* C [[buffer(2)]],
|
|
||||||
const constant int& M [[buffer(3)]],
|
|
||||||
const constant int& N [[buffer(4)]],
|
|
||||||
const constant int& K [[buffer(5)]],
|
|
||||||
const constant int& batch_stride_a [[buffer(6)]],
|
|
||||||
const constant int& batch_stride_b [[buffer(7)]],
|
|
||||||
const constant int& batch_stride_c [[buffer(8)]],
|
|
||||||
threadgroup T* tgp_memory [[threadgroup(0)]],
|
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
|
||||||
// Pacifying compiler
|
|
||||||
(void)lid;
|
|
||||||
|
|
||||||
// Adjust for batch
|
|
||||||
A += batch_stride_a * tid.z;
|
|
||||||
B += batch_stride_b * tid.z;
|
|
||||||
C += batch_stride_c * tid.z;
|
|
||||||
|
|
||||||
// Adjust for transpose
|
|
||||||
const int lda_dev = transpose_a ? M : K;
|
|
||||||
const int ldb_dev = transpose_b ? K : N;
|
|
||||||
|
|
||||||
// Find block in A, B, C
|
|
||||||
const int c_row = tid.y * BM;
|
|
||||||
const int c_col = tid.x * BN;
|
|
||||||
|
|
||||||
A += transpose_a ? c_row : c_row * K;
|
|
||||||
B += transpose_b ? c_col * K : c_col;
|
|
||||||
C += c_row * N + c_col;
|
|
||||||
|
|
||||||
// Prepare threadgroup memory for loading
|
|
||||||
threadgroup T* As = tgp_memory;
|
|
||||||
threadgroup T* Bs = tgp_memory + tgp_mem_size_a;
|
|
||||||
|
|
||||||
// Prepare threadgroup loading operations
|
|
||||||
loader_a_t loader_a(A, lda_dev, As, simd_group_id, simd_lane_id);
|
|
||||||
loader_b_t loader_b(B, ldb_dev, Bs, simd_group_id, simd_lane_id);
|
|
||||||
|
|
||||||
// Prepare threadgroup mma operation
|
|
||||||
mma_t mma_op(simd_group_id, simd_lane_id);
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// MNK aligned loop
|
|
||||||
if (MN_aligned && K_aligned) {
|
|
||||||
for (int k = 0; k < K; k += BK) {
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
// Load elements into threadgroup
|
|
||||||
loader_a.load_unsafe();
|
|
||||||
loader_b.load_unsafe();
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// Multiply and accumulate threadgroup elements
|
|
||||||
mma_op.mma(As, Bs);
|
|
||||||
|
|
||||||
// Prepare for next iteration
|
|
||||||
loader_a.next();
|
|
||||||
loader_b.next();
|
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
|
||||||
|
|
||||||
// Store results to device memory
|
|
||||||
mma_op.store_result(C, N);
|
|
||||||
return;
|
|
||||||
|
|
||||||
}
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// MN aligned, K unaligned loop
|
|
||||||
else if (MN_aligned && !K_aligned) {
|
|
||||||
// Main loop
|
|
||||||
int k = 0;
|
|
||||||
for (; k + BK <= K; k += BK) {
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
// Load elements into threadgroup
|
|
||||||
loader_a.load_unsafe();
|
|
||||||
loader_b.load_unsafe();
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// Multiply and accumulate threadgroup elements
|
|
||||||
mma_op.mma(As, Bs);
|
|
||||||
|
|
||||||
// Prepare for next iteration
|
|
||||||
loader_a.next();
|
|
||||||
loader_b.next();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Loop tail
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
loader_a.load_safe(short2(K - k, BM));
|
|
||||||
loader_b.load_safe(short2(BN, K - k));
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
mma_op.mma(As, Bs);
|
|
||||||
|
|
||||||
// Store results to device memory
|
|
||||||
mma_op.store_result(C, N);
|
|
||||||
return;
|
|
||||||
|
|
||||||
}
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// MNK unaligned loop
|
|
||||||
else { // Loop over K - unaligned case
|
|
||||||
|
|
||||||
short2 src_tile_dims(min(BN, N - c_col), min(BM, M - c_row));
|
|
||||||
|
|
||||||
if (src_tile_dims.y == BM && src_tile_dims.x == BN) {
|
|
||||||
int k = 0;
|
|
||||||
for (; k + BK <= K; k += BK) {
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
// Load elements into threadgroup
|
|
||||||
loader_a.load_unsafe();
|
|
||||||
loader_b.load_unsafe();
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// Multiply and accumulate threadgroup elements
|
|
||||||
mma_op.mma(As, Bs);
|
|
||||||
|
|
||||||
// Prepare for next iteration
|
|
||||||
loader_a.next();
|
|
||||||
loader_b.next();
|
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
|
||||||
|
|
||||||
if (k < K) {
|
|
||||||
loader_a.load_safe(short2(K - k, BM));
|
|
||||||
loader_b.load_safe(short2(BN, K - k));
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
mma_op.mma(As, Bs);
|
|
||||||
}
|
|
||||||
|
|
||||||
mma_op.store_result(C, N);
|
|
||||||
return;
|
|
||||||
|
|
||||||
} else {
|
|
||||||
int k = 0;
|
|
||||||
for (; k + BK <= K; k += BK) {
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
// Load elements into threadgroup
|
|
||||||
loader_a.load_safe(short2(BK, src_tile_dims.y));
|
|
||||||
loader_b.load_safe(short2(src_tile_dims.x, BK));
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// Multiply and accumulate threadgroup elements
|
|
||||||
mma_op.mma(As, Bs);
|
|
||||||
|
|
||||||
// Prepare for next iteration
|
|
||||||
loader_a.next();
|
|
||||||
loader_b.next();
|
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
|
||||||
|
|
||||||
if (k < K) {
|
|
||||||
loader_a.load_safe(short2(K - k, src_tile_dims.y));
|
|
||||||
loader_b.load_safe(short2(src_tile_dims.x, K - k));
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
mma_op.mma(As, Bs);
|
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
|
||||||
mma_op.store_result_safe(C, N, src_tile_dims);
|
|
||||||
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
@ -5,9 +5,10 @@
|
|||||||
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
#include "mlx/backend/metal/kernels/defines.h"
|
#include "mlx/backend/metal/kernels/defines.h"
|
||||||
#include "mlx/backend/metal/kernels/gemm/gemm.h"
|
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
#define MLX_MTL_CONST static constant constexpr const
|
#define MLX_MTL_CONST static constant constexpr const
|
||||||
@ -239,8 +240,9 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
|||||||
constexpr int w_els_per_thread = (BN * BK / el_per_int) / (SIMD_SIZE * WM * WN);
|
constexpr int w_els_per_thread = (BN * BK / el_per_int) / (SIMD_SIZE * WM * WN);
|
||||||
|
|
||||||
// Instantiate the appropriate BlockMMA and Loader
|
// Instantiate the appropriate BlockMMA and Loader
|
||||||
using mma_t = BlockMMA<T, BM, BN, BK, WM, WN, false, true>;
|
using mma_t = mlx::steel::BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK, BK>;
|
||||||
using loader_x_t = BlockLoader<T, BM, BK, BK, 4, WM * WN * SIMD_SIZE, false, true, 0>;
|
using loader_x_t = mlx::steel::BlockLoader<T, BM, BK, BK, 1, WM * WN * SIMD_SIZE, 1, 4>;
|
||||||
|
|
||||||
|
|
||||||
threadgroup T scales_block[BN * groups_per_block];
|
threadgroup T scales_block[BN * groups_per_block];
|
||||||
threadgroup T biases_block[BN * groups_per_block];
|
threadgroup T biases_block[BN * groups_per_block];
|
||||||
@ -392,8 +394,8 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
|||||||
constexpr int w_els_per_thread = (BK * BN / el_per_int) / (SIMD_SIZE * WM * WN);
|
constexpr int w_els_per_thread = (BK * BN / el_per_int) / (SIMD_SIZE * WM * WN);
|
||||||
|
|
||||||
// Instantiate the appropriate BlockMMA and Loader
|
// Instantiate the appropriate BlockMMA and Loader
|
||||||
using mma_t = BlockMMA<T, BM, BN, BK, WM, WN, false, false>;
|
using mma_t = mlx::steel::BlockMMA<T, T, BM, BN, BK, WM, WN, false, false, BK, BN>;
|
||||||
using loader_x_t = BlockLoader<T, BM, BK, BK, 4, WM * WN * SIMD_SIZE, false, true, 0>;
|
using loader_x_t = mlx::steel::BlockLoader<T, BM, BK, BK, 1, WM * WN * SIMD_SIZE, 1, 4>;
|
||||||
|
|
||||||
threadgroup T scales_block[BK * groups_per_block];
|
threadgroup T scales_block[BK * groups_per_block];
|
||||||
threadgroup T biases_block[BK * groups_per_block];
|
threadgroup T biases_block[BK * groups_per_block];
|
||||||
|
312
mlx/backend/metal/kernels/steel/gemm/gemm.h
Normal file
312
mlx/backend/metal/kernels/steel/gemm/gemm.h
Normal file
@ -0,0 +1,312 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/steel/gemm/loader.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// GEMM kernel class
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace mlx {
|
||||||
|
namespace steel {
|
||||||
|
|
||||||
|
template <bool M_aligned, bool N_aligned, bool K_aligned>
|
||||||
|
struct LoopAlignment {};
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename U,
|
||||||
|
int BM,
|
||||||
|
int BN,
|
||||||
|
int BK,
|
||||||
|
int WM,
|
||||||
|
int WN,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
bool MN_aligned,
|
||||||
|
bool K_aligned,
|
||||||
|
typename AccumType = typename AccumHelper<T>::accum_type,
|
||||||
|
typename Epilogue = TransformNone<U, AccumType>>
|
||||||
|
struct GEMMKernel {
|
||||||
|
STEEL_CONST short tgp_padding_a = 16 / sizeof(T);
|
||||||
|
STEEL_CONST short tgp_padding_b = 16 / sizeof(T);
|
||||||
|
STEEL_CONST short tgp_mem_size_a =
|
||||||
|
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
|
||||||
|
STEEL_CONST short tgp_mem_size_b =
|
||||||
|
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
|
||||||
|
STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
|
||||||
|
|
||||||
|
STEEL_CONST short tgp_size = WM * WN * 32;
|
||||||
|
|
||||||
|
using loader_a_t = BlockLoader<
|
||||||
|
T,
|
||||||
|
transpose_a ? BK : BM,
|
||||||
|
transpose_a ? BM : BK,
|
||||||
|
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
|
||||||
|
!transpose_a,
|
||||||
|
tgp_size>;
|
||||||
|
using loader_b_t = BlockLoader<
|
||||||
|
T,
|
||||||
|
transpose_b ? BN : BK,
|
||||||
|
transpose_b ? BK : BN,
|
||||||
|
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
|
||||||
|
transpose_b,
|
||||||
|
tgp_size>;
|
||||||
|
using mma_t = BlockMMA<
|
||||||
|
T,
|
||||||
|
U,
|
||||||
|
BM,
|
||||||
|
BN,
|
||||||
|
BK,
|
||||||
|
WM,
|
||||||
|
WN,
|
||||||
|
transpose_a,
|
||||||
|
transpose_b,
|
||||||
|
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
|
||||||
|
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
|
||||||
|
AccumType,
|
||||||
|
Epilogue>;
|
||||||
|
|
||||||
|
/* Main kernel function */
|
||||||
|
template <bool M_aligned, bool N_aligned, bool K_aligned_>
|
||||||
|
static METAL_FUNC void gemm_loop(
|
||||||
|
threadgroup T* As [[threadgroup(0)]],
|
||||||
|
threadgroup T* Bs [[threadgroup(1)]],
|
||||||
|
const int gemm_k_iterations,
|
||||||
|
thread loader_a_t& loader_a,
|
||||||
|
thread loader_b_t& loader_b,
|
||||||
|
thread mma_t& mma_op,
|
||||||
|
thread const short& tgp_bm,
|
||||||
|
thread const short& tgp_bn,
|
||||||
|
thread const short& lbk,
|
||||||
|
LoopAlignment<M_aligned, N_aligned, K_aligned_> l = {}) {
|
||||||
|
// Appease the compiler
|
||||||
|
(void)l;
|
||||||
|
|
||||||
|
thread bool mask_A[loader_a_t::n_rows][loader_a_t::vec_size];
|
||||||
|
thread bool mask_B[loader_b_t::n_rows][loader_b_t::vec_size];
|
||||||
|
|
||||||
|
if (!M_aligned) {
|
||||||
|
short2 tile_dims_A =
|
||||||
|
transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
|
||||||
|
loader_a.set_mask(tile_dims_A, mask_A);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!N_aligned) {
|
||||||
|
short2 tile_dims_B =
|
||||||
|
transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
|
||||||
|
loader_b.set_mask(tile_dims_B, mask_B);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
// Load elements into threadgroup
|
||||||
|
if (M_aligned) {
|
||||||
|
loader_a.load_unsafe();
|
||||||
|
} else {
|
||||||
|
loader_a.load_safe(mask_A);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (N_aligned) {
|
||||||
|
loader_b.load_unsafe();
|
||||||
|
} else {
|
||||||
|
loader_b.load_safe(mask_B);
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Multiply and accumulate threadgroup elements
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
|
// Prepare for next iteration
|
||||||
|
loader_a.next();
|
||||||
|
loader_b.next();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!K_aligned_) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
short2 tile_dims_A_last =
|
||||||
|
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
|
||||||
|
short2 tile_dims_B_last =
|
||||||
|
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
|
||||||
|
|
||||||
|
loader_a.set_mask(tile_dims_A_last, mask_A);
|
||||||
|
loader_b.set_mask(tile_dims_B_last, mask_B);
|
||||||
|
|
||||||
|
loader_a.load_safe(mask_A);
|
||||||
|
loader_b.load_safe(mask_B);
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Main kernel function */
|
||||||
|
static METAL_FUNC void run(
|
||||||
|
const device T* A [[buffer(0)]],
|
||||||
|
const device T* B [[buffer(1)]],
|
||||||
|
device U* C [[buffer(2)]],
|
||||||
|
const constant GEMMParams* params [[buffer(3)]],
|
||||||
|
threadgroup T* As [[threadgroup(0)]],
|
||||||
|
threadgroup T* Bs [[threadgroup(1)]],
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
// Pacifying compiler
|
||||||
|
(void)lid;
|
||||||
|
|
||||||
|
const int tid_y = ((tid.y) << params->swizzle_log) +
|
||||||
|
((tid.x) & ((1 << params->swizzle_log) - 1));
|
||||||
|
const int tid_x = (tid.x) >> params->swizzle_log;
|
||||||
|
|
||||||
|
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
// Find block in A, B, C
|
||||||
|
const int c_row = tid_y * BM;
|
||||||
|
const int c_col = tid_x * BN;
|
||||||
|
|
||||||
|
A += transpose_a ? c_row : c_row * params->lda;
|
||||||
|
B += transpose_b ? c_col * params->ldb : c_col;
|
||||||
|
C += c_row * params->ldc + c_col;
|
||||||
|
|
||||||
|
// Prepare threadgroup loading operations
|
||||||
|
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||||
|
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||||
|
|
||||||
|
// Prepare threadgroup mma operation
|
||||||
|
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||||
|
|
||||||
|
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// MNK aligned loop
|
||||||
|
if (MN_aligned) {
|
||||||
|
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
// Load elements into threadgroup
|
||||||
|
loader_a.load_unsafe();
|
||||||
|
loader_b.load_unsafe();
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Multiply and accumulate threadgroup elements
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
|
// Prepare for next iteration
|
||||||
|
loader_a.next();
|
||||||
|
loader_b.next();
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
// Loop tail
|
||||||
|
if (!K_aligned) {
|
||||||
|
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||||
|
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
||||||
|
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
||||||
|
|
||||||
|
thread bool mask_A[loader_a_t::n_rows][loader_a_t::vec_size];
|
||||||
|
thread bool mask_B[loader_b_t::n_rows][loader_b_t::vec_size];
|
||||||
|
|
||||||
|
loader_a.set_mask(tile_dims_A, mask_A);
|
||||||
|
loader_b.set_mask(tile_dims_B, mask_B);
|
||||||
|
|
||||||
|
loader_a.load_safe(mask_A);
|
||||||
|
loader_b.load_safe(mask_B);
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store results to device memory
|
||||||
|
mma_op.store_result(C, params->ldc);
|
||||||
|
return;
|
||||||
|
|
||||||
|
}
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// MN unaligned loop
|
||||||
|
else { // Loop over K - unaligned case
|
||||||
|
short tgp_bm = min(BM, params->M - c_row);
|
||||||
|
short tgp_bn = min(BN, params->N - c_col);
|
||||||
|
short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||||
|
|
||||||
|
if (tgp_bm == BM && tgp_bn == BN) {
|
||||||
|
gemm_loop<true, true, K_aligned>(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iterations,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
leftover_bk);
|
||||||
|
|
||||||
|
mma_op.store_result(C, params->ldc);
|
||||||
|
return;
|
||||||
|
|
||||||
|
} else if (tgp_bn == BN) {
|
||||||
|
gemm_loop<false, true, K_aligned>(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iterations,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
leftover_bk);
|
||||||
|
|
||||||
|
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
||||||
|
return;
|
||||||
|
|
||||||
|
} else if (tgp_bm == BM) {
|
||||||
|
gemm_loop<true, false, K_aligned>(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iterations,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
leftover_bk);
|
||||||
|
|
||||||
|
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
||||||
|
return;
|
||||||
|
|
||||||
|
} else {
|
||||||
|
gemm_loop<false, false, K_aligned>(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iterations,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
leftover_bk);
|
||||||
|
|
||||||
|
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace steel
|
||||||
|
} // namespace mlx
|
@ -1,9 +1,10 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
#include "mlx/backend/metal/kernels/gemm/gemm.h"
|
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
using namespace mlx::steel;
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// GEMM kernels
|
// GEMM kernels
|
||||||
@ -23,26 +24,26 @@ template <typename T,
|
|||||||
const device T *A [[buffer(0)]],
|
const device T *A [[buffer(0)]],
|
||||||
const device T *B [[buffer(1)]],
|
const device T *B [[buffer(1)]],
|
||||||
device T *C [[buffer(2)]],
|
device T *C [[buffer(2)]],
|
||||||
const constant int &M [[buffer(3)]],
|
const constant GEMMParams* params [[buffer(3)]],
|
||||||
const constant int &N [[buffer(4)]],
|
|
||||||
const constant int &K [[buffer(5)]],
|
|
||||||
const constant int &batch_stride_a [[buffer(6)]],
|
|
||||||
const constant int &batch_stride_b [[buffer(7)]],
|
|
||||||
const constant int &batch_stride_c [[buffer(8)]],
|
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
|
||||||
using gemm_kernel = GEMMKernel<T, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
|
using gemm_kernel = GEMMKernel<T, T, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
|
||||||
|
|
||||||
threadgroup T tgp_memory[gemm_kernel::tgp_mem_size];
|
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||||
|
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||||
|
|
||||||
|
// Adjust for batch
|
||||||
|
A += params->batch_stride_a * tid.z;
|
||||||
|
B += params->batch_stride_b * tid.z;
|
||||||
|
C += params->batch_stride_c * tid.z;
|
||||||
|
|
||||||
gemm_kernel::run(
|
gemm_kernel::run(
|
||||||
A, B, C,
|
A, B, C,
|
||||||
M, N, K,
|
params,
|
||||||
batch_stride_a, batch_stride_b, batch_stride_c,
|
As, Bs,
|
||||||
tgp_memory,
|
|
||||||
simd_lane_id, simd_group_id, tid, lid
|
simd_lane_id, simd_group_id, tid, lid
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@ -52,17 +53,12 @@ template <typename T,
|
|||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||||
template [[host_name("gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \
|
template [[host_name("steel_gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \
|
||||||
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
|
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
|
||||||
const device itype *A [[buffer(0)]], \
|
const device itype *A [[buffer(0)]], \
|
||||||
const device itype *B [[buffer(1)]], \
|
const device itype *B [[buffer(1)]], \
|
||||||
device itype *C [[buffer(2)]], \
|
device itype *C [[buffer(2)]], \
|
||||||
const constant int &M [[buffer(3)]], \
|
const constant GEMMParams* params [[buffer(3)]], \
|
||||||
const constant int &N [[buffer(4)]], \
|
|
||||||
const constant int &K [[buffer(5)]], \
|
|
||||||
const constant int &batch_stride_a [[buffer(6)]], \
|
|
||||||
const constant int &batch_stride_b [[buffer(7)]], \
|
|
||||||
const constant int &batch_stride_c [[buffer(8)]], \
|
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||||
uint3 tid [[threadgroup_position_in_grid]], \
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
@ -84,10 +80,10 @@ template <typename T,
|
|||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2)
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
|
||||||
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2)
|
||||||
|
|
||||||
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
||||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
|
||||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
||||||
|
|
||||||
// TODO: Accumulation in different type
|
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
@ -0,0 +1,260 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
using namespace mlx::steel;
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// GEMM kernels
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename T,
|
||||||
|
int BM,
|
||||||
|
int BN,
|
||||||
|
int BK,
|
||||||
|
int WM,
|
||||||
|
int WN,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
bool MN_aligned,
|
||||||
|
bool K_aligned,
|
||||||
|
typename AccumType = float,
|
||||||
|
typename Epilogue = TransformAdd<T, AccumType>>
|
||||||
|
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void addmm(
|
||||||
|
const device T *A [[buffer(0)]],
|
||||||
|
const device T *B [[buffer(1)]],
|
||||||
|
const device T *C [[buffer(2)]],
|
||||||
|
device T *D [[buffer(3)]],
|
||||||
|
const constant GEMMAddMMParams* params [[buffer(4)]],
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
|
||||||
|
// Pacifying compiler
|
||||||
|
(void)lid;
|
||||||
|
|
||||||
|
using gemm_kernel =
|
||||||
|
GEMMKernel<T, T, BM, BN, BK, WM, WN,
|
||||||
|
transpose_a, transpose_b,
|
||||||
|
MN_aligned, K_aligned,
|
||||||
|
AccumType, Epilogue>;
|
||||||
|
|
||||||
|
using loader_a_t = typename gemm_kernel::loader_a_t;
|
||||||
|
using loader_b_t = typename gemm_kernel::loader_b_t;
|
||||||
|
using mma_t = typename gemm_kernel::mma_t;
|
||||||
|
|
||||||
|
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||||
|
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||||
|
|
||||||
|
// Adjust for batch
|
||||||
|
A += params->batch_stride_a * tid.z;
|
||||||
|
B += params->batch_stride_b * tid.z;
|
||||||
|
C += params->batch_stride_c * tid.z;
|
||||||
|
D += params->batch_stride_d * tid.z;
|
||||||
|
|
||||||
|
const int tid_y = ((tid.y) << params->swizzle_log) +
|
||||||
|
((tid.x) & ((1 << params->swizzle_log) - 1));
|
||||||
|
const int tid_x = (tid.x) >> params->swizzle_log;
|
||||||
|
|
||||||
|
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
// Find block in A, B, C
|
||||||
|
const int c_row = tid_y * BM;
|
||||||
|
const int c_col = tid_x * BN;
|
||||||
|
|
||||||
|
A += transpose_a ? c_row : c_row * params->lda;
|
||||||
|
B += transpose_b ? c_col * params->ldb : c_col;
|
||||||
|
C += c_row * params->ldc + c_col * params->fdc;
|
||||||
|
D += c_row * params->ldd + c_col;
|
||||||
|
|
||||||
|
// Prepare threadgroup loading operations
|
||||||
|
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||||
|
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||||
|
|
||||||
|
// Prepare threadgroup mma operation
|
||||||
|
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||||
|
|
||||||
|
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||||
|
|
||||||
|
const Epilogue epilogue_op(params->alpha, params->beta);
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// MNK aligned loop
|
||||||
|
if (MN_aligned) {
|
||||||
|
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
// Load elements into threadgroup
|
||||||
|
loader_a.load_unsafe();
|
||||||
|
loader_b.load_unsafe();
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Multiply and accumulate threadgroup elements
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
|
||||||
|
// Prepare for next iteration
|
||||||
|
loader_a.next();
|
||||||
|
loader_b.next();
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
// Loop tail
|
||||||
|
if (!K_aligned) {
|
||||||
|
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||||
|
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
||||||
|
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
||||||
|
|
||||||
|
thread bool mask_A[loader_a_t::n_rows][loader_a_t::vec_size];
|
||||||
|
thread bool mask_B[loader_b_t::n_rows][loader_b_t::vec_size];
|
||||||
|
|
||||||
|
loader_a.set_mask(tile_dims_A, mask_A);
|
||||||
|
loader_b.set_mask(tile_dims_B, mask_B);
|
||||||
|
|
||||||
|
loader_a.load_safe(mask_A);
|
||||||
|
loader_b.load_safe(mask_B);
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
mma_op.mma(As, Bs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store results to device memory
|
||||||
|
mma_op.store_result(D, params->ldd, C, params->ldc, params->fdc, epilogue_op);
|
||||||
|
return;
|
||||||
|
|
||||||
|
}
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// MN unaligned loop
|
||||||
|
else { // Loop over K - unaligned case
|
||||||
|
short tgp_bm = min(BM, params->M - c_row);
|
||||||
|
short tgp_bn = min(BN, params->N - c_col);
|
||||||
|
short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
|
||||||
|
|
||||||
|
if (tgp_bm == BM && tgp_bn == BN) {
|
||||||
|
gemm_kernel::gemm_loop(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iterations,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
leftover_bk,
|
||||||
|
LoopAlignment<true, true, K_aligned>{});
|
||||||
|
|
||||||
|
mma_op.store_result(D, params->ldd, C, params->ldc, params->fdc, epilogue_op);
|
||||||
|
return;
|
||||||
|
|
||||||
|
} else if (tgp_bn == BN) {
|
||||||
|
gemm_kernel::gemm_loop(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iterations,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
leftover_bk,
|
||||||
|
LoopAlignment<false, true, K_aligned>{});
|
||||||
|
|
||||||
|
return mma_op.store_result_safe(
|
||||||
|
D, params->ldd,
|
||||||
|
C, params->ldc, params->fdc,
|
||||||
|
short2(tgp_bn, tgp_bm),
|
||||||
|
epilogue_op);
|
||||||
|
|
||||||
|
} else if (tgp_bm == BM) {
|
||||||
|
gemm_kernel::gemm_loop(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iterations,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
leftover_bk,
|
||||||
|
LoopAlignment<true, false, K_aligned>{});
|
||||||
|
|
||||||
|
return mma_op.store_result_safe(
|
||||||
|
D, params->ldd,
|
||||||
|
C, params->ldc, params->fdc,
|
||||||
|
short2(tgp_bn, tgp_bm),
|
||||||
|
epilogue_op);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
gemm_kernel::gemm_loop(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iterations,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
leftover_bk,
|
||||||
|
LoopAlignment<false, false, K_aligned>{});
|
||||||
|
|
||||||
|
return mma_op.store_result_safe(
|
||||||
|
D, params->ldd,
|
||||||
|
C, params->ldc, params->fdc,
|
||||||
|
short2(tgp_bn, tgp_bm),
|
||||||
|
epilogue_op);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// GEMM kernel initializations
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, ep_name, epilogue) \
|
||||||
|
template [[host_name("steel_addmm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname "_" #ep_name)]] \
|
||||||
|
[[kernel]] void addmm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned, float, epilogue<itype, float>>( \
|
||||||
|
const device itype *A [[buffer(0)]], \
|
||||||
|
const device itype *B [[buffer(1)]], \
|
||||||
|
const device itype *C [[buffer(2)]], \
|
||||||
|
device itype *D [[buffer(3)]], \
|
||||||
|
const constant GEMMAddMMParams* params [[buffer(4)]], \
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]);
|
||||||
|
|
||||||
|
#define instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||||
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, add, TransformAdd) \
|
||||||
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, axpby, TransformAxpby)
|
||||||
|
|
||||||
|
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||||
|
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
||||||
|
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
||||||
|
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
|
||||||
|
|
||||||
|
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||||
|
|
||||||
|
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||||
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
||||||
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
||||||
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
||||||
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
|
||||||
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2)
|
||||||
|
|
||||||
|
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
||||||
|
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
||||||
|
|
||||||
|
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
@ -0,0 +1,280 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
using namespace mlx::steel;
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// GEMM kernels
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename T,
|
||||||
|
typename U,
|
||||||
|
int BM,
|
||||||
|
int BN,
|
||||||
|
int BK,
|
||||||
|
int WM,
|
||||||
|
int WN,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
bool MN_aligned,
|
||||||
|
bool K_aligned>
|
||||||
|
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm_splitk(
|
||||||
|
const device T *A [[buffer(0)]],
|
||||||
|
const device T *B [[buffer(1)]],
|
||||||
|
device U *C [[buffer(2)]],
|
||||||
|
const constant GEMMSpiltKParams* params [[buffer(3)]],
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||||
|
|
||||||
|
(void)lid;
|
||||||
|
|
||||||
|
using gemm_kernel = GEMMKernel<T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
|
||||||
|
using loader_a_t = typename gemm_kernel::loader_a_t;
|
||||||
|
using loader_b_t = typename gemm_kernel::loader_b_t;
|
||||||
|
using mma_t = typename gemm_kernel::mma_t;
|
||||||
|
|
||||||
|
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||||
|
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||||
|
|
||||||
|
const int tid_x = tid.x;
|
||||||
|
const int tid_y = tid.y;
|
||||||
|
const int tid_z = tid.z;
|
||||||
|
|
||||||
|
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find block in A, B, C
|
||||||
|
const int c_row = tid_y * BM;
|
||||||
|
const int c_col = tid_x * BN;
|
||||||
|
const int k_start = params->split_k_partition_size * tid_z;
|
||||||
|
|
||||||
|
A += transpose_a ? (c_row + k_start * params->lda) : (k_start + c_row * params->lda);
|
||||||
|
B += transpose_b ? (k_start + c_col * params->ldb) : (c_col + k_start * params->ldb);
|
||||||
|
C += (params->split_k_partition_stride * tid_z) + (c_row * params->ldc + c_col);
|
||||||
|
|
||||||
|
// Prepare threadgroup loading operations
|
||||||
|
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||||
|
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||||
|
|
||||||
|
// Prepare threadgroup mma operation
|
||||||
|
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
||||||
|
|
||||||
|
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||||
|
|
||||||
|
short tgp_bm = min(BM, params->M - c_row);
|
||||||
|
short tgp_bn = min(BN, params->N - c_col);
|
||||||
|
short leftover_bk = params->K % BK;
|
||||||
|
|
||||||
|
if(MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
||||||
|
gemm_kernel::gemm_loop(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iterations,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
leftover_bk,
|
||||||
|
LoopAlignment<true, true, true>{});
|
||||||
|
} else if (tgp_bn == BN) {
|
||||||
|
gemm_kernel::gemm_loop(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iterations,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
leftover_bk,
|
||||||
|
LoopAlignment<false, true, true>{});
|
||||||
|
} else if (tgp_bm == BM) {
|
||||||
|
gemm_kernel::gemm_loop(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iterations,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
leftover_bk,
|
||||||
|
LoopAlignment<true, false, true>{});
|
||||||
|
} else {
|
||||||
|
gemm_kernel::gemm_loop(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iterations,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
leftover_bk,
|
||||||
|
LoopAlignment<false, false, true>{});
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
if ((tid_z + 1) == (params->split_k_partitions)) {
|
||||||
|
int gemm_k_iter_remaining = (params->K - (k_start + params->split_k_partition_size)) / BK;
|
||||||
|
if(!K_aligned || gemm_k_iter_remaining > 0)
|
||||||
|
gemm_kernel::gemm_loop(
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
gemm_k_iter_remaining,
|
||||||
|
loader_a,
|
||||||
|
loader_b,
|
||||||
|
mma_op,
|
||||||
|
tgp_bm,
|
||||||
|
tgp_bn,
|
||||||
|
leftover_bk,
|
||||||
|
LoopAlignment<false, false, K_aligned>{});
|
||||||
|
}
|
||||||
|
|
||||||
|
if(MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
||||||
|
mma_op.store_result(C, params->ldc);
|
||||||
|
} else {
|
||||||
|
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// GEMM kernel initializations
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
||||||
|
template [[host_name("steel_gemm_splitk_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \
|
||||||
|
[[kernel]] void gemm_splitk<itype, otype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
|
||||||
|
const device itype *A [[buffer(0)]], \
|
||||||
|
const device itype *B [[buffer(1)]], \
|
||||||
|
device otype *C [[buffer(2)]], \
|
||||||
|
const constant GEMMSpiltKParams* params [[buffer(3)]], \
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]);
|
||||||
|
|
||||||
|
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||||
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
||||||
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
||||||
|
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
|
||||||
|
|
||||||
|
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||||
|
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||||
|
|
||||||
|
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||||
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 16, 16, 2, 2) \
|
||||||
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 32, 16, 2, 2) \
|
||||||
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 16, 16, 2, 2) \
|
||||||
|
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
|
||||||
|
|
||||||
|
instantiate_gemm_shapes_helper(float16, half, float32, float);
|
||||||
|
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float);
|
||||||
|
|
||||||
|
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Split k accumulation kernel
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename AccT,
|
||||||
|
typename OutT,
|
||||||
|
typename Epilogue = TransformNone<OutT, AccT>>
|
||||||
|
[[kernel]] void gemm_splitk_accum(
|
||||||
|
const device AccT *C_split [[buffer(0)]],
|
||||||
|
device OutT *D [[buffer(1)]],
|
||||||
|
const constant int& k_partitions [[buffer(2)]],
|
||||||
|
const constant int& partition_stride [[buffer(3)]],
|
||||||
|
const constant int& ldd [[buffer(4)]],
|
||||||
|
uint2 gid [[thread_position_in_grid]]) {
|
||||||
|
|
||||||
|
// Ajust D and C
|
||||||
|
D += gid.x + gid.y * ldd;
|
||||||
|
C_split += gid.x + gid.y * ldd;
|
||||||
|
|
||||||
|
int offset = 0;
|
||||||
|
AccT out = 0;
|
||||||
|
|
||||||
|
for(int i = 0; i < k_partitions; i++) {
|
||||||
|
out += C_split[offset];
|
||||||
|
offset += partition_stride;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write output
|
||||||
|
D[0] = Epilogue::apply(out);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename AccT,
|
||||||
|
typename OutT,
|
||||||
|
typename Epilogue = TransformAxpby<OutT, AccT>>
|
||||||
|
[[kernel]] void gemm_splitk_accum_axpby(
|
||||||
|
const device AccT *C_split [[buffer(0)]],
|
||||||
|
device OutT *D [[buffer(1)]],
|
||||||
|
const constant int& k_partitions [[buffer(2)]],
|
||||||
|
const constant int& partition_stride [[buffer(3)]],
|
||||||
|
const constant int& ldd [[buffer(4)]],
|
||||||
|
const device OutT *C [[buffer(5)]],
|
||||||
|
const constant int& ldc [[buffer(6)]],
|
||||||
|
const constant int& fdc [[buffer(7)]],
|
||||||
|
const constant float& alpha [[buffer(8)]],
|
||||||
|
const constant float& beta [[buffer(9)]],
|
||||||
|
uint2 gid [[thread_position_in_grid]]) {
|
||||||
|
|
||||||
|
// Ajust D and C
|
||||||
|
C += gid.x * fdc + gid.y * ldc;
|
||||||
|
D += gid.x + gid.y * ldd;
|
||||||
|
C_split += gid.x + gid.y * ldd;
|
||||||
|
|
||||||
|
int offset = 0;
|
||||||
|
AccT out = 0;
|
||||||
|
|
||||||
|
for(int i = 0; i < k_partitions; i++) {
|
||||||
|
out += C_split[offset];
|
||||||
|
offset += partition_stride;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write output
|
||||||
|
Epilogue op(alpha, beta);
|
||||||
|
D[0] = op.apply(out, *C);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#define instantiate_accum(oname, otype, aname, atype) \
|
||||||
|
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname)]] \
|
||||||
|
[[kernel]] void gemm_splitk_accum<atype, otype>( \
|
||||||
|
const device atype *C_split [[buffer(0)]], \
|
||||||
|
device otype *D [[buffer(1)]], \
|
||||||
|
const constant int& k_partitions [[buffer(2)]], \
|
||||||
|
const constant int& partition_stride [[buffer(3)]], \
|
||||||
|
const constant int& ldd [[buffer(4)]], \
|
||||||
|
uint2 gid [[thread_position_in_grid]]); \
|
||||||
|
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname "_axpby")]] \
|
||||||
|
[[kernel]] void gemm_splitk_accum_axpby<atype, otype>( \
|
||||||
|
const device atype *C_split [[buffer(0)]], \
|
||||||
|
device otype *D [[buffer(1)]], \
|
||||||
|
const constant int& k_partitions [[buffer(2)]], \
|
||||||
|
const constant int& partition_stride [[buffer(3)]], \
|
||||||
|
const constant int& ldd [[buffer(4)]], \
|
||||||
|
const device otype *C [[buffer(5)]], \
|
||||||
|
const constant int& ldc [[buffer(6)]], \
|
||||||
|
const constant int& fdc [[buffer(7)]], \
|
||||||
|
const constant float& alpha [[buffer(8)]], \
|
||||||
|
const constant float& beta [[buffer(9)]], \
|
||||||
|
uint2 gid [[thread_position_in_grid]]);
|
||||||
|
|
||||||
|
instantiate_accum(bfloat16, bfloat16_t, float32, float);
|
||||||
|
instantiate_accum(float16, half, float32, float);
|
||||||
|
instantiate_accum(float32, float, float32, float);
|
160
mlx/backend/metal/kernels/steel/gemm/loader.h
Normal file
160
mlx/backend/metal/kernels/steel/gemm/loader.h
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Loading helper
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace mlx {
|
||||||
|
namespace steel {
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
short BROWS,
|
||||||
|
short BCOLS,
|
||||||
|
short dst_ld,
|
||||||
|
short reduction_dim,
|
||||||
|
short tgp_size,
|
||||||
|
short alignment = 1,
|
||||||
|
short n_reads = (BCOLS * BROWS) / (tgp_size),
|
||||||
|
short TCOLS = BCOLS / n_reads,
|
||||||
|
short TROWS = tgp_size / TCOLS>
|
||||||
|
struct BlockLoader {
|
||||||
|
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
|
||||||
|
STEEL_CONST short vec_size = n_reads;
|
||||||
|
|
||||||
|
// Leading dimension for src
|
||||||
|
const int src_ld;
|
||||||
|
const int tile_stride;
|
||||||
|
|
||||||
|
// Thread location indices
|
||||||
|
const short thread_idx;
|
||||||
|
const short bi;
|
||||||
|
const short bj;
|
||||||
|
|
||||||
|
// threadgroup and device memory
|
||||||
|
threadgroup T* dst;
|
||||||
|
const device T* src;
|
||||||
|
|
||||||
|
struct alignas(alignment * sizeof(T)) ReadVector {
|
||||||
|
uint8_t v[sizeof(T) * vec_size];
|
||||||
|
};
|
||||||
|
|
||||||
|
/* Constructor */
|
||||||
|
METAL_FUNC BlockLoader(
|
||||||
|
const device T* src_,
|
||||||
|
const int src_ld_,
|
||||||
|
threadgroup T* dst_,
|
||||||
|
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
|
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
||||||
|
: src_ld(src_ld_),
|
||||||
|
tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
|
||||||
|
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||||
|
bi(thread_idx / TCOLS),
|
||||||
|
bj(vec_size * (thread_idx % TCOLS)),
|
||||||
|
dst(dst_ + bi * dst_ld + bj),
|
||||||
|
src(src_ + bi * src_ld + bj) {}
|
||||||
|
|
||||||
|
/* Load from device memory into threadgroup memory - without bound checking */
|
||||||
|
METAL_FUNC void load_unsafe() const {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0; i < BROWS; i += TROWS) {
|
||||||
|
*((threadgroup ReadVector*)(&dst[i * dst_ld])) =
|
||||||
|
*((const device ReadVector*)(&src[i * src_ld]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Load from device memory into threadgroup memory - without bound checking */
|
||||||
|
METAL_FUNC void set_mask(
|
||||||
|
thread const short2& src_tile_dims,
|
||||||
|
thread bool mask[n_rows][vec_size]) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0; i < n_rows; i++) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < vec_size; j++) {
|
||||||
|
mask[i][j] =
|
||||||
|
((bi + i) < src_tile_dims.y) && ((bj + j) < src_tile_dims.x);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Load from device memory into threadgroup memory - with bound checking */
|
||||||
|
METAL_FUNC void load_safe(short2 src_tile_dim) const {
|
||||||
|
src_tile_dim = src_tile_dim - short2(bj, bi);
|
||||||
|
|
||||||
|
// Use fast thread memory for bound checks
|
||||||
|
bool tmp_idx[vec_size];
|
||||||
|
T tmp_val[vec_size];
|
||||||
|
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0; i < BROWS; i += TROWS) {
|
||||||
|
// Make sure tmp_idx only contains valid indices
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < vec_size; j++) {
|
||||||
|
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read valid indices into tmp_val
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < vec_size; j++) {
|
||||||
|
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Zero out uneeded values
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < vec_size; j++) {
|
||||||
|
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy values to threadgroup memory
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < vec_size; j++) {
|
||||||
|
dst[i * dst_ld + j] = tmp_val[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Load from device memory into threadgroup memory - with bound checking */
|
||||||
|
METAL_FUNC void load_safe(const thread bool mask[n_rows][vec_size]) const {
|
||||||
|
T tmp_val[vec_size];
|
||||||
|
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0, ii = 0; i < BROWS; i += TROWS, ii++) {
|
||||||
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
|
// Use fast thread memory for bound checks
|
||||||
|
|
||||||
|
// Read valid indices into tmp_val
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < vec_size; j++) {
|
||||||
|
tmp_val[j] = src[(mask[ii][j] ? i * src_ld + j : 0)];
|
||||||
|
}
|
||||||
|
|
||||||
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
// Zero out uneeded values
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < vec_size; j++) {
|
||||||
|
tmp_val[j] = mask[ii][j] ? tmp_val[j] : T(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
// Copy values to threadgroup memory
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < vec_size; j++) {
|
||||||
|
dst[i * dst_ld + j] = tmp_val[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Iteration helper */
|
||||||
|
METAL_FUNC void next() {
|
||||||
|
src += tile_stride;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace steel
|
||||||
|
} // namespace mlx
|
264
mlx/backend/metal/kernels/steel/gemm/mma.h
Normal file
264
mlx/backend/metal/kernels/steel/gemm/mma.h
Normal file
@ -0,0 +1,264 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// MMA helper
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace mlx {
|
||||||
|
namespace steel {
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename U,
|
||||||
|
int BM,
|
||||||
|
int BN,
|
||||||
|
int BK,
|
||||||
|
int WM,
|
||||||
|
int WN,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
short lda_tgp,
|
||||||
|
short ldb_tgp,
|
||||||
|
typename AccumType = float,
|
||||||
|
typename Epilogue = TransformNone<U, AccumType>>
|
||||||
|
struct BlockMMA {
|
||||||
|
// Warp tile simdgroup matrix strides along M
|
||||||
|
STEEL_CONST short TM_stride = 8 * WM;
|
||||||
|
// Warp tile simdgroup matrix strides along M
|
||||||
|
STEEL_CONST short TN_stride = 8 * WN;
|
||||||
|
|
||||||
|
// Warp tile size along M
|
||||||
|
STEEL_CONST short TM = BM / TM_stride;
|
||||||
|
// Warp tile size along N
|
||||||
|
STEEL_CONST short TN = BN / TN_stride;
|
||||||
|
|
||||||
|
// Strides of A, B along reduction axis
|
||||||
|
STEEL_CONST short simd_stride_a = {
|
||||||
|
transpose_a ? TM_stride : TM_stride * lda_tgp};
|
||||||
|
STEEL_CONST short simd_stride_b = {
|
||||||
|
transpose_b ? TN_stride * ldb_tgp : TN_stride};
|
||||||
|
|
||||||
|
// Jump between elements
|
||||||
|
STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1};
|
||||||
|
STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1};
|
||||||
|
|
||||||
|
STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8};
|
||||||
|
STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp};
|
||||||
|
|
||||||
|
// Simdgroup matrices
|
||||||
|
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
|
||||||
|
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
|
||||||
|
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
|
||||||
|
simdgroup_matrix<AccumType, 8, 8>(0)};
|
||||||
|
|
||||||
|
// Offsets within threadgroup
|
||||||
|
const short tm;
|
||||||
|
const short tn;
|
||||||
|
|
||||||
|
short sm;
|
||||||
|
short sn;
|
||||||
|
|
||||||
|
short As_offset;
|
||||||
|
short Bs_offset;
|
||||||
|
|
||||||
|
/* Constructor */
|
||||||
|
METAL_FUNC BlockMMA(
|
||||||
|
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
|
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
||||||
|
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
|
||||||
|
// Determine thread position in simdgroup matrix
|
||||||
|
short qid = simd_lane_id / 4;
|
||||||
|
sm = (qid & 4) + (simd_lane_id / 2) % 4;
|
||||||
|
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
||||||
|
|
||||||
|
// Determine thread and simdgroup offset
|
||||||
|
As_offset =
|
||||||
|
transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp);
|
||||||
|
Bs_offset =
|
||||||
|
transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* (BM, BK) X (BK, BN) multiply accumulate function */
|
||||||
|
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
|
||||||
|
// Adjust for simdgroup and thread location
|
||||||
|
As += As_offset;
|
||||||
|
Bs += Bs_offset;
|
||||||
|
|
||||||
|
// Iterate over BK in blocks of 8
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short kk = 0; kk < BK; kk += 8) {
|
||||||
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
// Load elements from threadgroup A as simdgroup matrices
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0; i < TM; i++) {
|
||||||
|
Asimd[i].thread_elements()[0] =
|
||||||
|
static_cast<AccumType>(As[i * simd_stride_a + 0]);
|
||||||
|
Asimd[i].thread_elements()[1] =
|
||||||
|
static_cast<AccumType>(As[i * simd_stride_a + jump_a]);
|
||||||
|
}
|
||||||
|
|
||||||
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
// Load elements from threadgroup B as simdgroup matrices
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < TN; j++) {
|
||||||
|
Bsimd[j].thread_elements()[0] =
|
||||||
|
static_cast<AccumType>(Bs[j * simd_stride_b + 0]);
|
||||||
|
Bsimd[j].thread_elements()[1] =
|
||||||
|
static_cast<AccumType>(Bs[j * simd_stride_b + jump_b]);
|
||||||
|
}
|
||||||
|
|
||||||
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
// Multiply and accumulate into result simdgroup matrices
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0; i < TM; i++) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < TN; j++) {
|
||||||
|
short j_serp = (i % 2) ? (TN - 1 - j) : j;
|
||||||
|
|
||||||
|
simdgroup_multiply_accumulate(
|
||||||
|
results[i * TN + j_serp],
|
||||||
|
Asimd[i],
|
||||||
|
Bsimd[j_serp],
|
||||||
|
results[i * TN + j_serp]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Progress to next simdgroup tile
|
||||||
|
As += tile_stride_a;
|
||||||
|
Bs += tile_stride_b;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Store results from simdgroup_matrix results into device memory */
|
||||||
|
METAL_FUNC void store_result(device U* C, const int ldc) const {
|
||||||
|
// Adjust for simdgroup and thread location
|
||||||
|
C += (sm + tm) * ldc + tn + sn;
|
||||||
|
|
||||||
|
// Loop over all simdgroup tiles
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0; i < TM; i++) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < TN; j++) {
|
||||||
|
// Get accumulated result and associated offset in C
|
||||||
|
thread const auto& accum = results[i * TN + j].thread_elements();
|
||||||
|
int offset = (i * TM_stride) * ldc + (j * TN_stride);
|
||||||
|
|
||||||
|
// Apply epilogue
|
||||||
|
U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])};
|
||||||
|
|
||||||
|
// Write out C
|
||||||
|
C[offset] = outs[0];
|
||||||
|
C[offset + 1] = outs[1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
METAL_FUNC void
|
||||||
|
store_result_safe(device U* C, const int ldc, short2 dst_tile_dims) const {
|
||||||
|
// Adjust for simdgroup and thread location
|
||||||
|
C += (sm + tm) * ldc + (tn + sn);
|
||||||
|
dst_tile_dims -= short2(tn + sn, sm + tm);
|
||||||
|
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < TM; i++) {
|
||||||
|
if (i * TM_stride < dst_tile_dims.y) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (int j = 0; j < TN; j++) {
|
||||||
|
// Get accumulated result and associated offset in C
|
||||||
|
thread const auto& accum = results[i * TN + j].thread_elements();
|
||||||
|
int offset = (i * TM_stride) * ldc + (j * TN_stride);
|
||||||
|
|
||||||
|
// Apply epilogue and output C
|
||||||
|
if (j * TN_stride < dst_tile_dims.x) {
|
||||||
|
C[offset] = Epilogue::apply(accum[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (j * TN_stride + 1 < dst_tile_dims.x) {
|
||||||
|
C[offset + 1] = Epilogue::apply(accum[1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Store results from simdgroup_matrix results into device memory */
|
||||||
|
METAL_FUNC void store_result(
|
||||||
|
device U* D,
|
||||||
|
const int ldd,
|
||||||
|
const device U* C,
|
||||||
|
const int ldc,
|
||||||
|
const int fdc,
|
||||||
|
thread const Epilogue& epilogue_op) const {
|
||||||
|
// Adjust for simdgroup and thread location
|
||||||
|
C += (sm + tm) * ldc + (tn + sn) * fdc;
|
||||||
|
D += (sm + tm) * ldd + tn + sn;
|
||||||
|
|
||||||
|
// Loop over all simdgroup tiles
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0; i < TM; i++) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short j = 0; j < TN; j++) {
|
||||||
|
// Get accumulated result and associated offset in C
|
||||||
|
thread const auto& accum = results[i * TN + j].thread_elements();
|
||||||
|
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||||
|
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
|
||||||
|
|
||||||
|
// Apply epilogue
|
||||||
|
U outs[2] = {
|
||||||
|
epilogue_op.apply(accum[0], C[offset_c]),
|
||||||
|
epilogue_op.apply(accum[1], C[offset_c + fdc])};
|
||||||
|
|
||||||
|
// Write out D
|
||||||
|
D[offset_d] = outs[0];
|
||||||
|
D[offset_d + 1] = outs[1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
METAL_FUNC void store_result_safe(
|
||||||
|
device U* D,
|
||||||
|
const int ldd,
|
||||||
|
const device U* C,
|
||||||
|
const int ldc,
|
||||||
|
const int fdc,
|
||||||
|
short2 dst_tile_dims,
|
||||||
|
thread const Epilogue& epilogue_op) const {
|
||||||
|
// Adjust for simdgroup and thread location
|
||||||
|
C += (sm + tm) * ldc + (tn + sn) * fdc;
|
||||||
|
D += (sm + tm) * ldd + tn + sn;
|
||||||
|
dst_tile_dims -= short2(tn + sn, sm + tm);
|
||||||
|
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < TM; i++) {
|
||||||
|
if (i * TM_stride < dst_tile_dims.y) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (int j = 0; j < TN; j++) {
|
||||||
|
// Get accumulated result and associated offset in C
|
||||||
|
thread const auto& accum = results[i * TN + j].thread_elements();
|
||||||
|
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||||
|
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
|
||||||
|
|
||||||
|
// Apply epilogue and output C
|
||||||
|
if (j * TN_stride < dst_tile_dims.x) {
|
||||||
|
D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (j * TN_stride + 1 < dst_tile_dims.x) {
|
||||||
|
D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace steel
|
||||||
|
} // namespace mlx
|
79
mlx/backend/metal/kernels/steel/gemm/params.h
Normal file
79
mlx/backend/metal/kernels/steel/gemm/params.h
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// GEMM param classes
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace mlx {
|
||||||
|
namespace steel {
|
||||||
|
|
||||||
|
struct GEMMParams {
|
||||||
|
const int M;
|
||||||
|
const int N;
|
||||||
|
const int K;
|
||||||
|
|
||||||
|
const int lda;
|
||||||
|
const int ldb;
|
||||||
|
const int ldc;
|
||||||
|
|
||||||
|
const int tiles_n;
|
||||||
|
const int tiles_m;
|
||||||
|
|
||||||
|
const int batch_stride_a;
|
||||||
|
const int batch_stride_b;
|
||||||
|
const int batch_stride_c;
|
||||||
|
|
||||||
|
const int swizzle_log;
|
||||||
|
const int gemm_k_iterations_aligned;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct GEMMSpiltKParams {
|
||||||
|
const int M;
|
||||||
|
const int N;
|
||||||
|
const int K;
|
||||||
|
|
||||||
|
const int lda;
|
||||||
|
const int ldb;
|
||||||
|
const int ldc;
|
||||||
|
|
||||||
|
const int tiles_n;
|
||||||
|
const int tiles_m;
|
||||||
|
|
||||||
|
const int split_k_partitions;
|
||||||
|
const int split_k_partition_stride;
|
||||||
|
const int split_k_partition_size;
|
||||||
|
|
||||||
|
const int gemm_k_iterations_aligned;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct GEMMAddMMParams {
|
||||||
|
const int M;
|
||||||
|
const int N;
|
||||||
|
const int K;
|
||||||
|
|
||||||
|
const int lda;
|
||||||
|
const int ldb;
|
||||||
|
const int ldc;
|
||||||
|
const int ldd;
|
||||||
|
|
||||||
|
const int tiles_n;
|
||||||
|
const int tiles_m;
|
||||||
|
|
||||||
|
const int batch_stride_a;
|
||||||
|
const int batch_stride_b;
|
||||||
|
const int batch_stride_c;
|
||||||
|
const int batch_stride_d;
|
||||||
|
|
||||||
|
const int swizzle_log;
|
||||||
|
const int gemm_k_iterations_aligned;
|
||||||
|
|
||||||
|
const float alpha;
|
||||||
|
const float beta;
|
||||||
|
|
||||||
|
const int fdc;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace steel
|
||||||
|
} // namespace mlx
|
63
mlx/backend/metal/kernels/steel/gemm/transforms.h
Normal file
63
mlx/backend/metal/kernels/steel/gemm/transforms.h
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Transforms and Epilogues
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace mlx {
|
||||||
|
namespace steel {
|
||||||
|
|
||||||
|
template <typename OutT, typename InT>
|
||||||
|
struct TransformNone {
|
||||||
|
static METAL_FUNC OutT apply(InT x) {
|
||||||
|
return static_cast<OutT>(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
static METAL_FUNC OutT apply(InT x, OutT) {
|
||||||
|
return static_cast<OutT>(x);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename OutT, typename InT>
|
||||||
|
struct TransformAdd {
|
||||||
|
TransformAdd(const float, const float) {}
|
||||||
|
|
||||||
|
static METAL_FUNC OutT apply(InT x, OutT c) {
|
||||||
|
return static_cast<OutT>(x) + c;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename OutT, typename InT>
|
||||||
|
struct TransformAxpby {
|
||||||
|
const float alpha;
|
||||||
|
const float beta;
|
||||||
|
|
||||||
|
TransformAxpby(const float alpha_, const float beta_)
|
||||||
|
: alpha(alpha_), beta(beta_) {}
|
||||||
|
|
||||||
|
METAL_FUNC OutT apply(InT x, OutT c) const {
|
||||||
|
return static_cast<OutT>(x * alpha + (beta * c));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct AccumHelper {
|
||||||
|
typedef float accum_type;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct BlockSwizzle {
|
||||||
|
static METAL_FUNC int2
|
||||||
|
swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) {
|
||||||
|
const int tid_x = (tid.x) >> swizzle_log;
|
||||||
|
const int tid_y =
|
||||||
|
((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1));
|
||||||
|
return int2(tid_x, tid_y);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace steel
|
||||||
|
} // namespace mlx
|
5
mlx/backend/metal/kernels/steel/host.h
Normal file
5
mlx/backend/metal/kernels/steel/host.h
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
|
9
mlx/backend/metal/kernels/steel/utils.h
Normal file
9
mlx/backend/metal/kernels/steel/utils.h
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <metal_stdlib>
|
||||||
|
#include "mlx/backend/metal/kernels/steel/host.h"
|
||||||
|
|
||||||
|
#define STEEL_CONST static constant constexpr const
|
||||||
|
#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
|
@ -8,6 +8,7 @@
|
|||||||
#include "mlx/backend/metal/copy.h"
|
#include "mlx/backend/metal/copy.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/kernels/defines.h"
|
#include "mlx/backend/metal/kernels/defines.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/host.h"
|
||||||
#include "mlx/backend/metal/matmul.h"
|
#include "mlx/backend/metal/matmul.h"
|
||||||
#include "mlx/backend/metal/mps/gemm.h"
|
#include "mlx/backend/metal/mps/gemm.h"
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
@ -16,6 +17,10 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// MPS Matmul fallback
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
bool use_mps() {
|
bool use_mps() {
|
||||||
@ -46,7 +51,9 @@ inline void mps_matmul(
|
|||||||
int ldb,
|
int ldb,
|
||||||
bool transpose_a,
|
bool transpose_a,
|
||||||
bool transpose_b,
|
bool transpose_b,
|
||||||
std::vector<array>& copies) {
|
std::vector<array>& copies,
|
||||||
|
float alpha = 1.0f,
|
||||||
|
float beta = 0.0f) {
|
||||||
MPS::DataType mps_dtype = MPS::DataTypeFloat32;
|
MPS::DataType mps_dtype = MPS::DataTypeFloat32;
|
||||||
|
|
||||||
if (out.dtype() == float16) {
|
if (out.dtype() == float16) {
|
||||||
@ -121,7 +128,7 @@ inline void mps_matmul(
|
|||||||
auto out_mat = MPS::Matrix::alloc()->init(out_buf, out_desc);
|
auto out_mat = MPS::Matrix::alloc()->init(out_buf, out_desc);
|
||||||
|
|
||||||
auto kernel = MPS::MatrixMultiplication::alloc()->init(
|
auto kernel = MPS::MatrixMultiplication::alloc()->init(
|
||||||
d.mtl_device(), transpose_a, transpose_b, M, N, K, 1.0, 0.0);
|
d.mtl_device(), transpose_a, transpose_b, M, N, K, alpha, beta);
|
||||||
|
|
||||||
auto command_buffer = d.get_command_buffer(s.index);
|
auto command_buffer = d.get_command_buffer(s.index);
|
||||||
kernel->setBatchSize(batch_size_out);
|
kernel->setBatchSize(batch_size_out);
|
||||||
@ -162,7 +169,7 @@ inline void mps_matmul(
|
|||||||
auto out_mat = MPS::Matrix::alloc()->init(out_buf, out_desc);
|
auto out_mat = MPS::Matrix::alloc()->init(out_buf, out_desc);
|
||||||
|
|
||||||
auto kernel = MPS::MatrixMultiplication::alloc()->init(
|
auto kernel = MPS::MatrixMultiplication::alloc()->init(
|
||||||
d.mtl_device(), transpose_a, transpose_b, M, N, K, 1.0, 0.0);
|
d.mtl_device(), transpose_a, transpose_b, M, N, K, alpha, beta);
|
||||||
|
|
||||||
auto command_buffer = d.get_command_buffer(s.index);
|
auto command_buffer = d.get_command_buffer(s.index);
|
||||||
for (int i = 0; i < batch_size_out; ++i) {
|
for (int i = 0; i < batch_size_out; ++i) {
|
||||||
@ -186,7 +193,11 @@ inline void mps_matmul(
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void mlx_matmul(
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Steel matmul fallback
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
void steel_matmul(
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const array& a,
|
const array& a,
|
||||||
@ -201,6 +212,15 @@ void mlx_matmul(
|
|||||||
bool transpose_a,
|
bool transpose_a,
|
||||||
bool transpose_b,
|
bool transpose_b,
|
||||||
std::vector<array>& copies) {
|
std::vector<array>& copies) {
|
||||||
|
using namespace mlx::steel;
|
||||||
|
|
||||||
|
// Coalesce (B, M, K) X (K, N) to (B*M, K) X (K, N)
|
||||||
|
if (batch_size_out > 1 && !transpose_a &&
|
||||||
|
a.data_size() == batch_size_out * M * K && b.size() == K * N) {
|
||||||
|
M = M * batch_size_out;
|
||||||
|
batch_size_out = 1;
|
||||||
|
}
|
||||||
|
|
||||||
// Account for batch sizes and basic broadcasting
|
// Account for batch sizes and basic broadcasting
|
||||||
int batch_size_a = a.data_size() / (M * K);
|
int batch_size_a = a.data_size() / (M * K);
|
||||||
int batch_size_b = b.data_size() / (K * N);
|
int batch_size_b = b.data_size() / (K * N);
|
||||||
@ -209,11 +229,108 @@ void mlx_matmul(
|
|||||||
int matrix_stride_b = (batch_size_b == 1) ? 0 : K * N;
|
int matrix_stride_b = (batch_size_b == 1) ? 0 : K * N;
|
||||||
int matrix_stride_out = M * N;
|
int matrix_stride_out = M * N;
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Split K specialization
|
||||||
|
|
||||||
|
int _tm = M / 16;
|
||||||
|
int _tn = N / 16;
|
||||||
|
int _tk = K / 16;
|
||||||
|
|
||||||
|
if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) {
|
||||||
|
int bm = M < 40 ? 16 : 32;
|
||||||
|
int bn = N < 40 ? 16 : 32;
|
||||||
|
int bk = 16;
|
||||||
|
int wm = 2, wn = 2;
|
||||||
|
|
||||||
|
int split_k_partitions =
|
||||||
|
_tk < 16 ? 2 : (_tk < 32 ? 4 : (_tk < 64 ? 8 : 16));
|
||||||
|
int split_k_partition_stride = M * N;
|
||||||
|
int gemm_k_iterations = (K / bk) / split_k_partitions;
|
||||||
|
int split_k_partition_size = gemm_k_iterations * bk;
|
||||||
|
|
||||||
|
array C_split({split_k_partitions, M, N}, float32, nullptr, {});
|
||||||
|
C_split.set_data(allocator::malloc_or_wait(C_split.nbytes()));
|
||||||
|
copies.push_back(C_split);
|
||||||
|
|
||||||
|
std::ostringstream kname;
|
||||||
|
kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n')
|
||||||
|
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||||
|
<< type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||||
|
<< "_wm" << wm << "_wn" << wn << "_MN_"
|
||||||
|
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned"
|
||||||
|
<< "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned";
|
||||||
|
|
||||||
|
// Encode and dispatch gemm kernel
|
||||||
|
auto compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
auto kernel = d.get_kernel(kname.str());
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
|
int tn = (N + bn - 1) / bn;
|
||||||
|
int tm = (M + bm - 1) / bm;
|
||||||
|
|
||||||
|
GEMMSpiltKParams params{
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
lda,
|
||||||
|
ldb,
|
||||||
|
N,
|
||||||
|
tn,
|
||||||
|
tm,
|
||||||
|
split_k_partitions,
|
||||||
|
split_k_partition_stride,
|
||||||
|
split_k_partition_size,
|
||||||
|
gemm_k_iterations};
|
||||||
|
|
||||||
|
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||||
|
MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions);
|
||||||
|
|
||||||
|
set_array_buffer(compute_encoder, a, 0);
|
||||||
|
set_array_buffer(compute_encoder, b, 1);
|
||||||
|
set_array_buffer(compute_encoder, C_split, 2);
|
||||||
|
|
||||||
|
compute_encoder->setBytes(¶ms, sizeof(GEMMSpiltKParams), 3);
|
||||||
|
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||||
|
|
||||||
|
// Do accum kernel
|
||||||
|
{
|
||||||
|
auto c_split_buf =
|
||||||
|
static_cast<const MTL::Resource*>(C_split.buffer().ptr());
|
||||||
|
const class MTL::Resource* const resources[1] = {c_split_buf};
|
||||||
|
compute_encoder->memoryBarrier(resources, 1);
|
||||||
|
|
||||||
|
auto kernel = d.get_kernel(
|
||||||
|
"steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
|
||||||
|
type_to_name(C_split));
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
|
// Set the arguments for the kernel
|
||||||
|
set_array_buffer(compute_encoder, C_split, 0);
|
||||||
|
set_array_buffer(compute_encoder, out, 1);
|
||||||
|
compute_encoder->setBytes(&split_k_partitions, sizeof(int), 2);
|
||||||
|
compute_encoder->setBytes(&split_k_partition_stride, sizeof(int), 3);
|
||||||
|
compute_encoder->setBytes(&N, sizeof(int), 4);
|
||||||
|
|
||||||
|
// Launch enough thread groups for each output
|
||||||
|
MTL::Size grid_dims = MTL::Size(N, M, 1);
|
||||||
|
MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1);
|
||||||
|
|
||||||
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||||
|
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Regular kernel dispatch
|
||||||
|
|
||||||
// Determine dispatch kernel
|
// Determine dispatch kernel
|
||||||
int bm = 32, bn = 32, bk = 16;
|
int bm = 32, bn = 32, bk = 16;
|
||||||
int wm = 2, wn = 2;
|
int wm = 2, wn = 2;
|
||||||
|
|
||||||
if ((size_t)batch_size_out * M * N >= 2ul << 20) {
|
if ((size_t)batch_size_out * M * N >= 1ul << 20) {
|
||||||
if (!transpose_a && transpose_b) {
|
if (!transpose_a && transpose_b) {
|
||||||
bm = 64;
|
bm = 64;
|
||||||
bn = (out.dtype() == float32) ? 64 : 32;
|
bn = (out.dtype() == float32) ? 64 : 32;
|
||||||
@ -224,10 +341,12 @@ void mlx_matmul(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Prepare kernel name
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
kname << "gemm_" << (transpose_a ? 't' : 'n') << (transpose_b ? 't' : 'n')
|
kname << "steel_gemm_" << (transpose_a ? 't' : 'n')
|
||||||
<< "_" << type_to_name(a) << "_" << type_to_name(out) << "_bm" << bm
|
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||||
<< "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn << "_MN_"
|
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||||
|
<< "_wm" << wm << "_wn" << wn << "_MN_"
|
||||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned"
|
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned"
|
||||||
<< "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned";
|
<< "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned";
|
||||||
|
|
||||||
@ -236,34 +355,55 @@ void mlx_matmul(
|
|||||||
auto kernel = d.get_kernel(kname.str());
|
auto kernel = d.get_kernel(kname.str());
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
|
// Use problem size to determine threadblock swizzle
|
||||||
|
int tn = (N + bn - 1) / bn;
|
||||||
|
int tm = (M + bm - 1) / bm;
|
||||||
|
|
||||||
|
// TODO: Explore device-based tuning for swizzle
|
||||||
|
int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);
|
||||||
|
|
||||||
|
// Prepare steel matmul params
|
||||||
|
GEMMParams params{
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
lda,
|
||||||
|
ldb,
|
||||||
|
N,
|
||||||
|
tn,
|
||||||
|
tm,
|
||||||
|
matrix_stride_a,
|
||||||
|
matrix_stride_b,
|
||||||
|
matrix_stride_out,
|
||||||
|
swizzle_log,
|
||||||
|
(K / bk)};
|
||||||
|
|
||||||
|
// Prepare launch grid params
|
||||||
|
int tile = 1 << swizzle_log;
|
||||||
|
tm = (tm + tile - 1) / tile;
|
||||||
|
tn = tn * tile;
|
||||||
|
|
||||||
|
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||||
|
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
|
||||||
|
|
||||||
// Launch only 1 kernel in the case of simple batching / broadcasting
|
// Launch only 1 kernel in the case of simple batching / broadcasting
|
||||||
if (batch_size_out == std::max(batch_size_a, batch_size_b) &&
|
if (batch_size_out == std::max(batch_size_a, batch_size_b) &&
|
||||||
(batch_size_a == batch_size_b ||
|
(batch_size_a == batch_size_b ||
|
||||||
std::min(batch_size_a, batch_size_b) == 1)) {
|
std::min(batch_size_a, batch_size_b) == 1)) {
|
||||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
|
||||||
MTL::Size grid_dims =
|
|
||||||
MTL::Size((N + bn - 1) / bn, (M + bm - 1) / bm, batch_size_out);
|
|
||||||
|
|
||||||
set_array_buffer(compute_encoder, a, 0);
|
set_array_buffer(compute_encoder, a, 0);
|
||||||
set_array_buffer(compute_encoder, b, 1);
|
set_array_buffer(compute_encoder, b, 1);
|
||||||
set_array_buffer(compute_encoder, out, 2);
|
set_array_buffer(compute_encoder, out, 2);
|
||||||
|
|
||||||
compute_encoder->setBytes(&M, sizeof(int), 3);
|
compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 3);
|
||||||
compute_encoder->setBytes(&N, sizeof(int), 4);
|
|
||||||
compute_encoder->setBytes(&K, sizeof(int), 5);
|
|
||||||
compute_encoder->setBytes(&matrix_stride_a, sizeof(int), 6);
|
|
||||||
compute_encoder->setBytes(&matrix_stride_b, sizeof(int), 7);
|
|
||||||
compute_encoder->setBytes(&matrix_stride_out, sizeof(int), 8);
|
|
||||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||||
} else { // Other launch kernels with set offsets
|
} else { // Otherwise launch kernels with set offsets
|
||||||
|
|
||||||
|
MTL::Size grid_dims_single = MTL::Size(tn, tm, 1);
|
||||||
|
|
||||||
for (int i = 0; i < batch_size_out; ++i) {
|
for (int i = 0; i < batch_size_out; ++i) {
|
||||||
auto a_off = elem_to_loc(M * K * i, a.shape(), a.strides());
|
auto a_off = elem_to_loc(M * K * i, a.shape(), a.strides());
|
||||||
auto b_off = elem_to_loc(K * N * i, b.shape(), b.strides());
|
auto b_off = elem_to_loc(K * N * i, b.shape(), b.strides());
|
||||||
|
|
||||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
|
||||||
MTL::Size grid_dims = MTL::Size((N + bn - 1) / bn, (M + bm - 1) / bm, 1);
|
|
||||||
|
|
||||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||||
auto b_buf = static_cast<const MTL::Buffer*>(b.buffer().ptr());
|
auto b_buf = static_cast<const MTL::Buffer*>(b.buffer().ptr());
|
||||||
auto out_buf = static_cast<const MTL::Buffer*>(out.buffer().ptr());
|
auto out_buf = static_cast<const MTL::Buffer*>(out.buffer().ptr());
|
||||||
@ -272,13 +412,8 @@ void mlx_matmul(
|
|||||||
compute_encoder->setBuffer(b_buf, b_off * b.itemsize(), 1);
|
compute_encoder->setBuffer(b_buf, b_off * b.itemsize(), 1);
|
||||||
compute_encoder->setBuffer(out_buf, i * M * N * out.itemsize(), 2);
|
compute_encoder->setBuffer(out_buf, i * M * N * out.itemsize(), 2);
|
||||||
|
|
||||||
compute_encoder->setBytes(&M, sizeof(int), 3);
|
compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 3);
|
||||||
compute_encoder->setBytes(&N, sizeof(int), 4);
|
compute_encoder->dispatchThreadgroups(grid_dims_single, group_dims);
|
||||||
compute_encoder->setBytes(&K, sizeof(int), 5);
|
|
||||||
compute_encoder->setBytes(&matrix_stride_a, sizeof(int), 6);
|
|
||||||
compute_encoder->setBytes(&matrix_stride_b, sizeof(int), 7);
|
|
||||||
compute_encoder->setBytes(&matrix_stride_out, sizeof(int), 8);
|
|
||||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -300,6 +435,9 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto& a_pre = inputs[0];
|
auto& a_pre = inputs[0];
|
||||||
auto& b_pre = inputs[1];
|
auto& b_pre = inputs[1];
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Init checks and prep
|
||||||
|
|
||||||
// Keep a vector with copies to be cleared in the completed buffer to release
|
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||||
// the arrays
|
// the arrays
|
||||||
std::vector<array> copies;
|
std::vector<array> copies;
|
||||||
@ -328,6 +466,9 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
auto batch_size_out = out.size() / (M * N);
|
auto batch_size_out = out.size() / (M * N);
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Gemv specialization
|
||||||
|
|
||||||
// Route to gemv if needed
|
// Route to gemv if needed
|
||||||
if (std::min(M, N) == 1) {
|
if (std::min(M, N) == 1) {
|
||||||
// Collect problem info
|
// Collect problem info
|
||||||
@ -433,10 +574,13 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
d.end_encoding(s.index);
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Gemm specialization
|
||||||
|
|
||||||
if (use_mps()) {
|
if (use_mps()) {
|
||||||
mps_matmul(
|
d.end_encoding(s.index);
|
||||||
|
|
||||||
|
return mps_matmul(
|
||||||
s,
|
s,
|
||||||
d,
|
d,
|
||||||
a,
|
a,
|
||||||
@ -451,10 +595,9 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
a_transposed,
|
a_transposed,
|
||||||
b_transposed,
|
b_transposed,
|
||||||
copies);
|
copies);
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
mlx_matmul(
|
return steel_matmul(
|
||||||
s,
|
s,
|
||||||
d,
|
d,
|
||||||
a,
|
a,
|
||||||
@ -471,4 +614,266 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
copies);
|
copies);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
if (!is_floating_point(out.dtype())) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[matmul] Does not yet support non-floating point types.");
|
||||||
|
}
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
auto& s = stream();
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
|
auto& a_pre = inputs[0];
|
||||||
|
auto& b_pre = inputs[1];
|
||||||
|
auto& c_pre = inputs[2];
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Init checks and prep
|
||||||
|
|
||||||
|
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||||
|
// the arrays
|
||||||
|
std::vector<array> copies;
|
||||||
|
auto check_transpose = [&copies, &s](const array& arr) {
|
||||||
|
auto stx = arr.strides()[arr.ndim() - 2];
|
||||||
|
auto sty = arr.strides()[arr.ndim() - 1];
|
||||||
|
if (stx == arr.shape(-1) && sty == 1) {
|
||||||
|
return std::make_tuple(false, stx, arr);
|
||||||
|
} else if (stx == 1 && sty == arr.shape(-2)) {
|
||||||
|
return std::make_tuple(true, sty, arr);
|
||||||
|
} else {
|
||||||
|
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||||
|
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||||
|
copies.push_back(arr_copy);
|
||||||
|
size_t stx = arr.shape(-1);
|
||||||
|
return std::make_tuple(false, stx, arr_copy);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto [transpose_a, a_cols, a] = check_transpose(a_pre);
|
||||||
|
auto [transpose_b, b_cols, b] = check_transpose(b_pre);
|
||||||
|
|
||||||
|
int M = a.shape(-2);
|
||||||
|
int N = b.shape(-1);
|
||||||
|
int K = a.shape(-1);
|
||||||
|
|
||||||
|
auto batch_size_out = out.size() / (M * N);
|
||||||
|
|
||||||
|
array c = c_pre;
|
||||||
|
int ldc = c.strides()[c.ndim() - 2];
|
||||||
|
int fdc = c.strides()[c.ndim() - 1];
|
||||||
|
int matrix_stride_c = c.ndim() <= 2 ? 0 : c.strides()[c.ndim() - 3];
|
||||||
|
|
||||||
|
int lda = a_cols;
|
||||||
|
int ldb = b_cols;
|
||||||
|
|
||||||
|
using namespace mlx::steel;
|
||||||
|
|
||||||
|
// Account for batch sizes and basic broadcasting
|
||||||
|
int batch_size_a = a.data_size() / (M * K);
|
||||||
|
int batch_size_b = b.data_size() / (K * N);
|
||||||
|
|
||||||
|
int matrix_stride_a = (batch_size_a == 1) ? 0 : M * K;
|
||||||
|
int matrix_stride_b = (batch_size_b == 1) ? 0 : K * N;
|
||||||
|
int matrix_stride_out = M * N;
|
||||||
|
|
||||||
|
int _tm = M / 16;
|
||||||
|
int _tn = N / 16;
|
||||||
|
int _tk = K / 16;
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Split K specialization
|
||||||
|
|
||||||
|
if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) {
|
||||||
|
int bm = M < 40 ? 16 : 32;
|
||||||
|
int bn = N < 40 ? 16 : 32;
|
||||||
|
int bk = 16;
|
||||||
|
int wm = 2, wn = 2;
|
||||||
|
|
||||||
|
int split_k_partitions =
|
||||||
|
_tk < 16 ? 2 : (_tk < 32 ? 4 : (_tk < 64 ? 8 : 16));
|
||||||
|
int split_k_partition_stride = M * N;
|
||||||
|
int gemm_k_iterations = (K / bk) / split_k_partitions;
|
||||||
|
int split_k_partition_size = gemm_k_iterations * bk;
|
||||||
|
|
||||||
|
array C_split({split_k_partitions, M, N}, float32, nullptr, {});
|
||||||
|
C_split.set_data(allocator::malloc_or_wait(C_split.nbytes()));
|
||||||
|
copies.push_back(C_split);
|
||||||
|
|
||||||
|
std::ostringstream kname;
|
||||||
|
kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n')
|
||||||
|
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||||
|
<< type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||||
|
<< "_wm" << wm << "_wn" << wn << "_MN_"
|
||||||
|
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned"
|
||||||
|
<< "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned";
|
||||||
|
|
||||||
|
// Encode and dispatch gemm kernel
|
||||||
|
auto compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
auto kernel = d.get_kernel(kname.str());
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
|
int tn = (N + bn - 1) / bn;
|
||||||
|
int tm = (M + bm - 1) / bm;
|
||||||
|
|
||||||
|
GEMMSpiltKParams params{
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
lda,
|
||||||
|
ldb,
|
||||||
|
N,
|
||||||
|
tn,
|
||||||
|
tm,
|
||||||
|
split_k_partitions,
|
||||||
|
split_k_partition_stride,
|
||||||
|
split_k_partition_size,
|
||||||
|
gemm_k_iterations};
|
||||||
|
|
||||||
|
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||||
|
MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions);
|
||||||
|
|
||||||
|
set_array_buffer(compute_encoder, a, 0);
|
||||||
|
set_array_buffer(compute_encoder, b, 1);
|
||||||
|
set_array_buffer(compute_encoder, C_split, 2);
|
||||||
|
|
||||||
|
compute_encoder->setBytes(¶ms, sizeof(GEMMSpiltKParams), 3);
|
||||||
|
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||||
|
|
||||||
|
// Do accum kernel
|
||||||
|
{
|
||||||
|
auto kernel = d.get_kernel(
|
||||||
|
"steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
|
||||||
|
type_to_name(C_split) + "_axpby");
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
|
// Set the arguments for the kernel
|
||||||
|
set_array_buffer(compute_encoder, C_split, 0);
|
||||||
|
set_array_buffer(compute_encoder, out, 1);
|
||||||
|
compute_encoder->setBytes(&split_k_partitions, sizeof(int), 2);
|
||||||
|
compute_encoder->setBytes(&split_k_partition_stride, sizeof(int), 3);
|
||||||
|
compute_encoder->setBytes(&N, sizeof(int), 4);
|
||||||
|
set_array_buffer(compute_encoder, c, 5);
|
||||||
|
compute_encoder->setBytes(&ldc, sizeof(int), 6);
|
||||||
|
compute_encoder->setBytes(&fdc, sizeof(int), 7);
|
||||||
|
compute_encoder->setBytes(&alpha_, sizeof(float), 8);
|
||||||
|
compute_encoder->setBytes(&beta_, sizeof(float), 9);
|
||||||
|
|
||||||
|
// Launch enough thread groups for each output
|
||||||
|
MTL::Size grid_dims = MTL::Size(N, M, 1);
|
||||||
|
MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1);
|
||||||
|
|
||||||
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||||
|
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Regular addmm dispatch
|
||||||
|
|
||||||
|
// Determine dispatch kernel
|
||||||
|
int bm = 32, bn = 32, bk = 16;
|
||||||
|
int wm = 2, wn = 2;
|
||||||
|
|
||||||
|
if ((size_t)batch_size_out * M * N >= 1ul << 20) {
|
||||||
|
if (!transpose_a && transpose_b) {
|
||||||
|
bm = 64;
|
||||||
|
bn = (out.dtype() == float32) ? 64 : 32;
|
||||||
|
bk = (out.dtype() == float32) ? 16 : 32;
|
||||||
|
} else {
|
||||||
|
bm = 64;
|
||||||
|
bn = 64;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::ostringstream kname;
|
||||||
|
kname << "steel_addmm_" << (transpose_a ? 't' : 'n')
|
||||||
|
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||||
|
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||||
|
<< "_wm" << wm << "_wn" << wn << "_MN_"
|
||||||
|
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned"
|
||||||
|
<< "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned"
|
||||||
|
<< ((alpha_ == 1. && beta_ == 1.) ? "_add" : "_axpby");
|
||||||
|
|
||||||
|
// Encode and dispatch kernel
|
||||||
|
auto compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
auto kernel = d.get_kernel(kname.str());
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
|
int tn = (N + bn - 1) / bn;
|
||||||
|
int tm = (M + bm - 1) / bm;
|
||||||
|
|
||||||
|
// TODO: Explore device-based tuning for swizzle
|
||||||
|
int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);
|
||||||
|
|
||||||
|
GEMMAddMMParams params{
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
lda,
|
||||||
|
ldb,
|
||||||
|
ldc,
|
||||||
|
N,
|
||||||
|
tn,
|
||||||
|
tm,
|
||||||
|
matrix_stride_a,
|
||||||
|
matrix_stride_b,
|
||||||
|
matrix_stride_c,
|
||||||
|
matrix_stride_out,
|
||||||
|
swizzle_log,
|
||||||
|
(K / bk),
|
||||||
|
alpha_,
|
||||||
|
beta_,
|
||||||
|
fdc};
|
||||||
|
|
||||||
|
int tile = 1 << swizzle_log;
|
||||||
|
tm = (tm + tile - 1) / tile;
|
||||||
|
tn = tn * tile;
|
||||||
|
|
||||||
|
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||||
|
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
|
||||||
|
|
||||||
|
// Launch only 1 kernel in the case of simple batching / broadcasting
|
||||||
|
if (batch_size_out == std::max(batch_size_a, batch_size_b) &&
|
||||||
|
(batch_size_a == batch_size_b ||
|
||||||
|
std::min(batch_size_a, batch_size_b) == 1)) {
|
||||||
|
set_array_buffer(compute_encoder, a, 0);
|
||||||
|
set_array_buffer(compute_encoder, b, 1);
|
||||||
|
set_array_buffer(compute_encoder, c, 2);
|
||||||
|
set_array_buffer(compute_encoder, out, 3);
|
||||||
|
|
||||||
|
compute_encoder->setBytes(¶ms, sizeof(GEMMAddMMParams), 4);
|
||||||
|
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||||
|
} else { // Otherwise launch kernels with set offsets
|
||||||
|
|
||||||
|
MTL::Size grid_dims_single = MTL::Size(tn, tm, 1);
|
||||||
|
|
||||||
|
for (int i = 0; i < batch_size_out; ++i) {
|
||||||
|
auto a_off = elem_to_loc(M * K * i, a.shape(), a.strides());
|
||||||
|
auto b_off = elem_to_loc(K * N * i, b.shape(), b.strides());
|
||||||
|
auto c_off = elem_to_loc(M * N * i, c.shape(), c.strides());
|
||||||
|
|
||||||
|
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||||
|
auto b_buf = static_cast<const MTL::Buffer*>(b.buffer().ptr());
|
||||||
|
auto c_buf = static_cast<const MTL::Buffer*>(c.buffer().ptr());
|
||||||
|
auto out_buf = static_cast<const MTL::Buffer*>(out.buffer().ptr());
|
||||||
|
|
||||||
|
compute_encoder->setBuffer(a_buf, a_off * a.itemsize(), 0);
|
||||||
|
compute_encoder->setBuffer(b_buf, b_off * b.itemsize(), 1);
|
||||||
|
compute_encoder->setBuffer(c_buf, c_off * c.itemsize(), 2);
|
||||||
|
compute_encoder->setBuffer(out_buf, i * M * N * out.itemsize(), 3);
|
||||||
|
|
||||||
|
compute_encoder->setBytes(¶ms, sizeof(GEMMAddMMParams), 4);
|
||||||
|
compute_encoder->dispatchThreadgroups(grid_dims_single, group_dims);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||||
|
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
void mlx_matmul(
|
void steel_matmul(
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const array& a,
|
const array& a,
|
||||||
|
@ -17,6 +17,7 @@ namespace mlx::core {
|
|||||||
|
|
||||||
NO_GPU(Abs)
|
NO_GPU(Abs)
|
||||||
NO_GPU(Add)
|
NO_GPU(Add)
|
||||||
|
NO_GPU(AddMM)
|
||||||
NO_GPU(Arange)
|
NO_GPU(Arange)
|
||||||
NO_GPU(ArcCos)
|
NO_GPU(ArcCos)
|
||||||
NO_GPU(ArcCosh)
|
NO_GPU(ArcCosh)
|
||||||
|
94
mlx/ops.cpp
94
mlx/ops.cpp
@ -3057,4 +3057,98 @@ array inner(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
|||||||
return tensordot(a, b, {{-1}, {-1}}, s);
|
return tensordot(a, b, {{-1}, {-1}}, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Compute D = beta * C + alpha * (A @ B) */
|
||||||
|
array addmm(
|
||||||
|
array c,
|
||||||
|
array a,
|
||||||
|
array b,
|
||||||
|
const float& alpha /* = 1.f */,
|
||||||
|
const float& beta /* = 1.f */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
// Divert in the case of vector-matrix multiplication
|
||||||
|
// TODO: Add the needed specializtion
|
||||||
|
if (a.ndim() == 1 || b.ndim() == 1) {
|
||||||
|
array X = matmul(a, b, s);
|
||||||
|
array alpha_arr = array(alpha, X.dtype());
|
||||||
|
array aX = multiply(alpha_arr, X, s);
|
||||||
|
|
||||||
|
array beta_arr = array(beta, c.dtype());
|
||||||
|
array bY = multiply(beta_arr, c, s);
|
||||||
|
return add(aX, bY, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (a.ndim() == 0 || b.ndim() == 0) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[addmm] Got 0 dimension input. Inputs must "
|
||||||
|
"have at least one dimension.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (a.shape(-1) != b.shape(-2)) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[addmm] Last dimension of first input with shape " << a.shape()
|
||||||
|
<< " must match second to last dimension of"
|
||||||
|
<< " second input with shape " << b.shape() << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type promotion
|
||||||
|
auto out_type = result_type({a, b, c});
|
||||||
|
if (!is_floating_point(out_type) || is_complex(out_type)) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[addmm] Only real floating point types are supported but "
|
||||||
|
<< c.dtype() << ", " << a.dtype() << " and " << b.dtype()
|
||||||
|
<< " were provided which results in " << out_type
|
||||||
|
<< ", which is not a real floating point type.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
a = astype(a, out_type, s);
|
||||||
|
b = astype(b, out_type, s);
|
||||||
|
c = astype(c, out_type, s);
|
||||||
|
|
||||||
|
// We can batch the multiplication by reshaping a
|
||||||
|
if (a.ndim() > 2 && b.ndim() == 2 && c.ndim() <= 1) {
|
||||||
|
std::vector<int> out_shape = a.shape();
|
||||||
|
a = reshape(a, {-1, out_shape.back()}, s);
|
||||||
|
out_shape.back() = b.shape(-1);
|
||||||
|
c = broadcast_to(c, {a.shape(0), b.shape(1)}, s);
|
||||||
|
auto out = array(
|
||||||
|
{a.shape(0), b.shape(1)},
|
||||||
|
out_type,
|
||||||
|
std::make_unique<AddMM>(to_stream(s), alpha, beta),
|
||||||
|
{a, b, c});
|
||||||
|
return reshape(out, out_shape, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (a.ndim() > 2 || b.ndim() > 2) {
|
||||||
|
std::vector<int> bsx_a(a.shape().begin(), a.shape().end() - 2);
|
||||||
|
std::vector<int> bsx_b(b.shape().begin(), b.shape().end() - 2);
|
||||||
|
auto inner_shape = broadcast_shapes(bsx_a, bsx_b);
|
||||||
|
|
||||||
|
// Broadcast a
|
||||||
|
inner_shape.push_back(a.shape(-2));
|
||||||
|
inner_shape.push_back(a.shape(-1));
|
||||||
|
a = broadcast_to(a, inner_shape, s);
|
||||||
|
|
||||||
|
// Broadcast b
|
||||||
|
*(inner_shape.end() - 2) = b.shape(-2);
|
||||||
|
*(inner_shape.end() - 1) = b.shape(-1);
|
||||||
|
b = broadcast_to(b, inner_shape, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto out_shape = a.shape();
|
||||||
|
out_shape.back() = b.shape(-1);
|
||||||
|
|
||||||
|
auto c_broadcast_shape = broadcast_shapes(c.shape(), out_shape);
|
||||||
|
c = broadcast_to(c, c_broadcast_shape, s);
|
||||||
|
|
||||||
|
auto out = array(
|
||||||
|
out_shape,
|
||||||
|
out_type,
|
||||||
|
std::make_unique<AddMM>(to_stream(s), alpha, beta),
|
||||||
|
{a, b, c});
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -1122,4 +1122,12 @@ std::unordered_map<std::string, array> load_gguf(
|
|||||||
|
|
||||||
void save_gguf(std::string file, std::unordered_map<std::string, array> a);
|
void save_gguf(std::string file, std::unordered_map<std::string, array> a);
|
||||||
|
|
||||||
|
/** Compute D = beta * C + alpha * (A @ B) */
|
||||||
|
array addmm(
|
||||||
|
array c,
|
||||||
|
array a,
|
||||||
|
array b,
|
||||||
|
const float& alpha = 1.f,
|
||||||
|
const float& beta = 1.f,
|
||||||
|
StreamOrDevice s = {});
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -124,6 +124,52 @@ std::pair<std::vector<array>, std::vector<int>> Add::vmap(
|
|||||||
return {{add(a, b, stream())}, {to_ax}};
|
return {{add(a, b, stream())}, {to_ax}};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<array> AddMM::vjp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& cotangents,
|
||||||
|
const std::vector<int>& argnums,
|
||||||
|
const std::vector<array>&) {
|
||||||
|
std::vector<array> vjps;
|
||||||
|
auto& cotan = cotangents[0];
|
||||||
|
std::vector<int> reorder(cotan.ndim());
|
||||||
|
std::iota(reorder.begin(), reorder.end(), 0);
|
||||||
|
std::iter_swap(reorder.end() - 1, reorder.end() - 2);
|
||||||
|
for (auto arg : argnums) {
|
||||||
|
if (arg == 0) {
|
||||||
|
// M X N * (K X N).T -> M X K
|
||||||
|
auto cotan_scaled = cotan;
|
||||||
|
if (alpha_ != 1.) {
|
||||||
|
auto alpha_arr = array(alpha_, cotan.dtype());
|
||||||
|
cotan_scaled = (multiply(alpha_arr, cotan_scaled, stream()));
|
||||||
|
}
|
||||||
|
vjps.push_back(matmul(
|
||||||
|
cotan_scaled, transpose(primals[1], reorder, stream()), stream()));
|
||||||
|
} else if (arg == 1) {
|
||||||
|
// (M X K).T * M X N -> K X N
|
||||||
|
auto cotan_scaled = cotan;
|
||||||
|
if (alpha_ != 1.) {
|
||||||
|
auto alpha_arr = array(alpha_, cotan.dtype());
|
||||||
|
cotan_scaled = (multiply(alpha_arr, cotan_scaled, stream()));
|
||||||
|
}
|
||||||
|
vjps.push_back(matmul(
|
||||||
|
transpose(primals[0], reorder, stream()), cotan_scaled, stream()));
|
||||||
|
} else {
|
||||||
|
auto cotan_scaled = cotan;
|
||||||
|
if (beta_ != 1.) {
|
||||||
|
auto beta_arr = array(beta_, cotan.dtype());
|
||||||
|
cotan_scaled = (multiply(beta_arr, cotan_scaled, stream()));
|
||||||
|
}
|
||||||
|
vjps.push_back(cotan_scaled);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return vjps;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AddMM::is_equivalent(const Primitive& other) const {
|
||||||
|
const AddMM& a_other = static_cast<const AddMM&>(other);
|
||||||
|
return (alpha_ == a_other.alpha_ && beta_ == a_other.beta_);
|
||||||
|
}
|
||||||
|
|
||||||
bool Arange::is_equivalent(const Primitive& other) const {
|
bool Arange::is_equivalent(const Primitive& other) const {
|
||||||
const Arange& a_other = static_cast<const Arange&>(other);
|
const Arange& a_other = static_cast<const Arange&>(other);
|
||||||
return (
|
return (
|
||||||
|
@ -171,6 +171,29 @@ class Add : public UnaryPrimitive {
|
|||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class AddMM : public UnaryPrimitive {
|
||||||
|
public:
|
||||||
|
explicit AddMM(Stream stream, float alpha, float beta)
|
||||||
|
: UnaryPrimitive(stream), alpha_(alpha), beta_(beta){};
|
||||||
|
|
||||||
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
|
||||||
|
std::vector<array> vjp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& cotangents,
|
||||||
|
const std::vector<int>& argnums,
|
||||||
|
const std::vector<array>& outputs) override;
|
||||||
|
|
||||||
|
DEFINE_PRINT(AddMM)
|
||||||
|
|
||||||
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const float alpha_;
|
||||||
|
const float beta_;
|
||||||
|
};
|
||||||
|
|
||||||
class Arange : public UnaryPrimitive {
|
class Arange : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Arange(Stream stream, double start, double stop, double step)
|
explicit Arange(Stream stream, double start, double stop, double step)
|
||||||
|
@ -63,9 +63,10 @@ class Linear(Module):
|
|||||||
return f"input_dims={self.weight.shape[1]}, output_dims={self.weight.shape[0]}, bias={'bias' in self}"
|
return f"input_dims={self.weight.shape[1]}, output_dims={self.weight.shape[0]}, bias={'bias' in self}"
|
||||||
|
|
||||||
def __call__(self, x: mx.array) -> mx.array:
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
x = x @ self.weight.T
|
|
||||||
if "bias" in self:
|
if "bias" in self:
|
||||||
x = x + self.bias
|
x = mx.addmm(self.bias, x, self.weight.T)
|
||||||
|
else:
|
||||||
|
x = x @ self.weight.T
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@ -3476,4 +3476,34 @@ void init_ops(py::module_& m) {
|
|||||||
Returns:
|
Returns:
|
||||||
result (array): The tiled array.
|
result (array): The tiled array.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"addmm",
|
||||||
|
&addmm,
|
||||||
|
"c"_a,
|
||||||
|
"a"_a,
|
||||||
|
"b"_a,
|
||||||
|
py::pos_only(),
|
||||||
|
"alpha"_a = 1.0f,
|
||||||
|
"beta"_a = 1.0f,
|
||||||
|
py::kw_only(),
|
||||||
|
"stream"_a = none,
|
||||||
|
R"pbdoc(
|
||||||
|
addmm(c: array, a: array, b: array, /, alpha: float = 1.0, beta: float = 1.0, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
|
Matrix multiplication with addition and optional scaling.
|
||||||
|
|
||||||
|
Perform the (possibly batched) matrix multiplication of two arrays and add to the result
|
||||||
|
with optional scaling factors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
c (array): Input array or scalar.
|
||||||
|
a (array): Input array or scalar.
|
||||||
|
b (array): Input array or scalar.
|
||||||
|
alpha (float, optional): Scaling factor for the
|
||||||
|
matrix product of ``a`` and ``b`` (default: ``1``)
|
||||||
|
beta (float, optional): Scaling factor for ``c`` (default: ``1``)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: ``alpha * (a @ b) + beta * c``
|
||||||
|
)pbdoc");
|
||||||
}
|
}
|
||||||
|
@ -74,6 +74,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
if mx.default_device() == mx.gpu:
|
if mx.default_device() == mx.gpu:
|
||||||
shapes += [
|
shapes += [
|
||||||
(16, 768, 768, 128),
|
(16, 768, 768, 128),
|
||||||
|
(1, 64, 64, 4096),
|
||||||
]
|
]
|
||||||
|
|
||||||
for dtype in self.dtypes:
|
for dtype in self.dtypes:
|
||||||
@ -444,3 +445,139 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
list(c_npy.shape), list(c_mlx.shape)
|
list(c_npy.shape), list(c_mlx.shape)
|
||||||
)
|
)
|
||||||
self.assertTrue(np.array_equal(c_mlx, c_npy))
|
self.assertTrue(np.array_equal(c_mlx, c_npy))
|
||||||
|
|
||||||
|
def test_addmm(self):
|
||||||
|
np.random.seed(0)
|
||||||
|
# Batched matmul
|
||||||
|
alpha = 0.5
|
||||||
|
beta = 2.0
|
||||||
|
|
||||||
|
# Regular batched case
|
||||||
|
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||||
|
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 16)).astype(np.float32)
|
||||||
|
|
||||||
|
a_mlx = mx.array(a_npy)
|
||||||
|
b_mlx = mx.array(b_npy)
|
||||||
|
|
||||||
|
for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)):
|
||||||
|
c_npy = np.ones(c_shape).astype(np.float32)
|
||||||
|
c_mlx = mx.array(c_npy)
|
||||||
|
|
||||||
|
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
||||||
|
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
||||||
|
|
||||||
|
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||||
|
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||||
|
|
||||||
|
# Batched and transposed matmul
|
||||||
|
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||||
|
b_mlx = mx.array(b_npy)
|
||||||
|
|
||||||
|
for c_shape in ((1,), (32, 1, 128), (1, 128)):
|
||||||
|
c_npy = np.ones(c_shape).astype(np.float32)
|
||||||
|
c_mlx = mx.array(c_npy)
|
||||||
|
|
||||||
|
b_np_t = np.transpose(b_npy, (0, 2, 1))
|
||||||
|
b_mx_t = mx.transpose(b_mlx, (0, 2, 1))
|
||||||
|
|
||||||
|
d_npy = alpha * (a_npy @ b_np_t) + beta * c_npy
|
||||||
|
d_mlx = mx.addmm(c_mlx, a_mlx, b_mx_t, alpha, beta)
|
||||||
|
|
||||||
|
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||||
|
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||||
|
|
||||||
|
# # Batched matmul with simple broadcast
|
||||||
|
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||||
|
b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32)
|
||||||
|
|
||||||
|
a_mlx = mx.array(a_npy)
|
||||||
|
b_mlx = mx.array(b_npy)
|
||||||
|
|
||||||
|
for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)):
|
||||||
|
c_npy = np.ones(c_shape).astype(np.float32)
|
||||||
|
c_mlx = mx.array(c_npy)
|
||||||
|
|
||||||
|
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
||||||
|
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
||||||
|
|
||||||
|
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||||
|
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||||
|
|
||||||
|
# Matmul with vector
|
||||||
|
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||||
|
b_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
|
||||||
|
a_mlx = mx.array(a_npy)
|
||||||
|
b_mlx = mx.array(b_npy)
|
||||||
|
|
||||||
|
for c_shape in (
|
||||||
|
(1,),
|
||||||
|
(32, 128),
|
||||||
|
):
|
||||||
|
c_npy = np.ones(c_shape).astype(np.float32)
|
||||||
|
c_mlx = mx.array(c_npy)
|
||||||
|
|
||||||
|
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
||||||
|
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
||||||
|
|
||||||
|
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||||
|
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||||
|
|
||||||
|
# Split K specializtion
|
||||||
|
a_npy = np.random.normal(0.0, 1.0 / 128, (64, 4096)).astype(np.float32)
|
||||||
|
b_npy = np.random.normal(0.0, 1.0 / 128, (4096, 32)).astype(np.float32)
|
||||||
|
|
||||||
|
a_mlx = mx.array(a_npy)
|
||||||
|
b_mlx = mx.array(b_npy)
|
||||||
|
|
||||||
|
for c_shape in ((1,), (1, 32), (64, 1), (64, 32)):
|
||||||
|
c_npy = np.ones(c_shape).astype(np.float32)
|
||||||
|
c_mlx = mx.array(c_npy)
|
||||||
|
|
||||||
|
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
||||||
|
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
||||||
|
|
||||||
|
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||||
|
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||||
|
|
||||||
|
def test_addmm_grad(self):
|
||||||
|
def make_ref_addmm(alpha, beta):
|
||||||
|
return lambda c, a, b: alpha * (a @ b) + beta * c
|
||||||
|
|
||||||
|
def make_addmm(alpha, beta):
|
||||||
|
return lambda c, a, b: mx.addmm(c, a, b, alpha, beta)
|
||||||
|
|
||||||
|
# B, M, N, K
|
||||||
|
shapes = ((1, 64, 32, 128), (4, 28, 24, 47), (1, 1, 24, 47))
|
||||||
|
|
||||||
|
alpha = 2.0
|
||||||
|
beta = 0.5
|
||||||
|
|
||||||
|
f_test = make_addmm(alpha, beta)
|
||||||
|
f_ref = make_ref_addmm(alpha, beta)
|
||||||
|
|
||||||
|
for B, M, N, K in shapes:
|
||||||
|
cotan = mx.ones((B, M, N))
|
||||||
|
c = mx.random.normal((B, M, N))
|
||||||
|
a = mx.random.normal((B, M, K))
|
||||||
|
b = mx.random.normal((B, K, N))
|
||||||
|
|
||||||
|
out_ref, dout_ref = mx.vjp(
|
||||||
|
f_ref,
|
||||||
|
[c, a, b],
|
||||||
|
[
|
||||||
|
cotan,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
out_test, dout_test = mx.vjp(
|
||||||
|
f_test,
|
||||||
|
[c, a, b],
|
||||||
|
[
|
||||||
|
cotan,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-5).item())
|
||||||
|
|
||||||
|
for r, t in zip(dout_ref, dout_test):
|
||||||
|
self.assertListEqual(r.shape, t.shape)
|
||||||
|
self.assertTrue(mx.allclose(r, t, atol=1e-5).item())
|
||||||
|
Loading…
Reference in New Issue
Block a user