mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Update GEMM (#424)
* Organize and collect metal subroutine templates and elements in `metal/kernels/steel/` * Update gemm elements for better performance * Add split-K specialization for gemm * Add `addmm` primitive, op and bindings for fused matmul and bias addition * Update tests and benchmarks as needed
This commit is contained in:
@@ -8,6 +8,7 @@
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/steel/host.h"
|
||||
#include "mlx/backend/metal/matmul.h"
|
||||
#include "mlx/backend/metal/mps/gemm.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
@@ -16,6 +17,10 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MPS Matmul fallback
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace {
|
||||
|
||||
bool use_mps() {
|
||||
@@ -46,7 +51,9 @@ inline void mps_matmul(
|
||||
int ldb,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
std::vector<array>& copies) {
|
||||
std::vector<array>& copies,
|
||||
float alpha = 1.0f,
|
||||
float beta = 0.0f) {
|
||||
MPS::DataType mps_dtype = MPS::DataTypeFloat32;
|
||||
|
||||
if (out.dtype() == float16) {
|
||||
@@ -121,7 +128,7 @@ inline void mps_matmul(
|
||||
auto out_mat = MPS::Matrix::alloc()->init(out_buf, out_desc);
|
||||
|
||||
auto kernel = MPS::MatrixMultiplication::alloc()->init(
|
||||
d.mtl_device(), transpose_a, transpose_b, M, N, K, 1.0, 0.0);
|
||||
d.mtl_device(), transpose_a, transpose_b, M, N, K, alpha, beta);
|
||||
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
kernel->setBatchSize(batch_size_out);
|
||||
@@ -162,7 +169,7 @@ inline void mps_matmul(
|
||||
auto out_mat = MPS::Matrix::alloc()->init(out_buf, out_desc);
|
||||
|
||||
auto kernel = MPS::MatrixMultiplication::alloc()->init(
|
||||
d.mtl_device(), transpose_a, transpose_b, M, N, K, 1.0, 0.0);
|
||||
d.mtl_device(), transpose_a, transpose_b, M, N, K, alpha, beta);
|
||||
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
for (int i = 0; i < batch_size_out; ++i) {
|
||||
@@ -186,7 +193,11 @@ inline void mps_matmul(
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlx_matmul(
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Steel matmul fallback
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void steel_matmul(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& a,
|
||||
@@ -201,6 +212,15 @@ void mlx_matmul(
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
std::vector<array>& copies) {
|
||||
using namespace mlx::steel;
|
||||
|
||||
// Coalesce (B, M, K) X (K, N) to (B*M, K) X (K, N)
|
||||
if (batch_size_out > 1 && !transpose_a &&
|
||||
a.data_size() == batch_size_out * M * K && b.size() == K * N) {
|
||||
M = M * batch_size_out;
|
||||
batch_size_out = 1;
|
||||
}
|
||||
|
||||
// Account for batch sizes and basic broadcasting
|
||||
int batch_size_a = a.data_size() / (M * K);
|
||||
int batch_size_b = b.data_size() / (K * N);
|
||||
@@ -209,11 +229,108 @@ void mlx_matmul(
|
||||
int matrix_stride_b = (batch_size_b == 1) ? 0 : K * N;
|
||||
int matrix_stride_out = M * N;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Split K specialization
|
||||
|
||||
int _tm = M / 16;
|
||||
int _tn = N / 16;
|
||||
int _tk = K / 16;
|
||||
|
||||
if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) {
|
||||
int bm = M < 40 ? 16 : 32;
|
||||
int bn = N < 40 ? 16 : 32;
|
||||
int bk = 16;
|
||||
int wm = 2, wn = 2;
|
||||
|
||||
int split_k_partitions =
|
||||
_tk < 16 ? 2 : (_tk < 32 ? 4 : (_tk < 64 ? 8 : 16));
|
||||
int split_k_partition_stride = M * N;
|
||||
int gemm_k_iterations = (K / bk) / split_k_partitions;
|
||||
int split_k_partition_size = gemm_k_iterations * bk;
|
||||
|
||||
array C_split({split_k_partitions, M, N}, float32, nullptr, {});
|
||||
C_split.set_data(allocator::malloc_or_wait(C_split.nbytes()));
|
||||
copies.push_back(C_split);
|
||||
|
||||
std::ostringstream kname;
|
||||
kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n')
|
||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||
<< type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||
<< "_wm" << wm << "_wn" << wn << "_MN_"
|
||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned"
|
||||
<< "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned";
|
||||
|
||||
// Encode and dispatch gemm kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int tn = (N + bn - 1) / bn;
|
||||
int tm = (M + bm - 1) / bm;
|
||||
|
||||
GEMMSpiltKParams params{
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
lda,
|
||||
ldb,
|
||||
N,
|
||||
tn,
|
||||
tm,
|
||||
split_k_partitions,
|
||||
split_k_partition_stride,
|
||||
split_k_partition_size,
|
||||
gemm_k_iterations};
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions);
|
||||
|
||||
set_array_buffer(compute_encoder, a, 0);
|
||||
set_array_buffer(compute_encoder, b, 1);
|
||||
set_array_buffer(compute_encoder, C_split, 2);
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMSpiltKParams), 3);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
// Do accum kernel
|
||||
{
|
||||
auto c_split_buf =
|
||||
static_cast<const MTL::Resource*>(C_split.buffer().ptr());
|
||||
const class MTL::Resource* const resources[1] = {c_split_buf};
|
||||
compute_encoder->memoryBarrier(resources, 1);
|
||||
|
||||
auto kernel = d.get_kernel(
|
||||
"steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
|
||||
type_to_name(C_split));
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Set the arguments for the kernel
|
||||
set_array_buffer(compute_encoder, C_split, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(&split_k_partitions, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&split_k_partition_stride, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&N, sizeof(int), 4);
|
||||
|
||||
// Launch enough thread groups for each output
|
||||
MTL::Size grid_dims = MTL::Size(N, M, 1);
|
||||
MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1);
|
||||
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
return;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Regular kernel dispatch
|
||||
|
||||
// Determine dispatch kernel
|
||||
int bm = 32, bn = 32, bk = 16;
|
||||
int wm = 2, wn = 2;
|
||||
|
||||
if ((size_t)batch_size_out * M * N >= 2ul << 20) {
|
||||
if ((size_t)batch_size_out * M * N >= 1ul << 20) {
|
||||
if (!transpose_a && transpose_b) {
|
||||
bm = 64;
|
||||
bn = (out.dtype() == float32) ? 64 : 32;
|
||||
@@ -224,10 +341,12 @@ void mlx_matmul(
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare kernel name
|
||||
std::ostringstream kname;
|
||||
kname << "gemm_" << (transpose_a ? 't' : 'n') << (transpose_b ? 't' : 'n')
|
||||
<< "_" << type_to_name(a) << "_" << type_to_name(out) << "_bm" << bm
|
||||
<< "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn << "_MN_"
|
||||
kname << "steel_gemm_" << (transpose_a ? 't' : 'n')
|
||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||
<< "_wm" << wm << "_wn" << wn << "_MN_"
|
||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned"
|
||||
<< "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned";
|
||||
|
||||
@@ -236,34 +355,55 @@ void mlx_matmul(
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Use problem size to determine threadblock swizzle
|
||||
int tn = (N + bn - 1) / bn;
|
||||
int tm = (M + bm - 1) / bm;
|
||||
|
||||
// TODO: Explore device-based tuning for swizzle
|
||||
int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);
|
||||
|
||||
// Prepare steel matmul params
|
||||
GEMMParams params{
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
lda,
|
||||
ldb,
|
||||
N,
|
||||
tn,
|
||||
tm,
|
||||
matrix_stride_a,
|
||||
matrix_stride_b,
|
||||
matrix_stride_out,
|
||||
swizzle_log,
|
||||
(K / bk)};
|
||||
|
||||
// Prepare launch grid params
|
||||
int tile = 1 << swizzle_log;
|
||||
tm = (tm + tile - 1) / tile;
|
||||
tn = tn * tile;
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
|
||||
|
||||
// Launch only 1 kernel in the case of simple batching / broadcasting
|
||||
if (batch_size_out == std::max(batch_size_a, batch_size_b) &&
|
||||
(batch_size_a == batch_size_b ||
|
||||
std::min(batch_size_a, batch_size_b) == 1)) {
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims =
|
||||
MTL::Size((N + bn - 1) / bn, (M + bm - 1) / bm, batch_size_out);
|
||||
|
||||
set_array_buffer(compute_encoder, a, 0);
|
||||
set_array_buffer(compute_encoder, b, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
|
||||
compute_encoder->setBytes(&M, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&N, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&K, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&matrix_stride_a, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&matrix_stride_b, sizeof(int), 7);
|
||||
compute_encoder->setBytes(&matrix_stride_out, sizeof(int), 8);
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 3);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
} else { // Other launch kernels with set offsets
|
||||
} else { // Otherwise launch kernels with set offsets
|
||||
|
||||
MTL::Size grid_dims_single = MTL::Size(tn, tm, 1);
|
||||
|
||||
for (int i = 0; i < batch_size_out; ++i) {
|
||||
auto a_off = elem_to_loc(M * K * i, a.shape(), a.strides());
|
||||
auto b_off = elem_to_loc(K * N * i, b.shape(), b.strides());
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size((N + bn - 1) / bn, (M + bm - 1) / bm, 1);
|
||||
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto b_buf = static_cast<const MTL::Buffer*>(b.buffer().ptr());
|
||||
auto out_buf = static_cast<const MTL::Buffer*>(out.buffer().ptr());
|
||||
@@ -272,13 +412,8 @@ void mlx_matmul(
|
||||
compute_encoder->setBuffer(b_buf, b_off * b.itemsize(), 1);
|
||||
compute_encoder->setBuffer(out_buf, i * M * N * out.itemsize(), 2);
|
||||
|
||||
compute_encoder->setBytes(&M, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&N, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&K, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&matrix_stride_a, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&matrix_stride_b, sizeof(int), 7);
|
||||
compute_encoder->setBytes(&matrix_stride_out, sizeof(int), 8);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 3);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims_single, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -300,6 +435,9 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& a_pre = inputs[0];
|
||||
auto& b_pre = inputs[1];
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Init checks and prep
|
||||
|
||||
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||
// the arrays
|
||||
std::vector<array> copies;
|
||||
@@ -328,6 +466,9 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
auto batch_size_out = out.size() / (M * N);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Gemv specialization
|
||||
|
||||
// Route to gemv if needed
|
||||
if (std::min(M, N) == 1) {
|
||||
// Collect problem info
|
||||
@@ -433,10 +574,13 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
|
||||
d.end_encoding(s.index);
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Gemm specialization
|
||||
|
||||
if (use_mps()) {
|
||||
mps_matmul(
|
||||
d.end_encoding(s.index);
|
||||
|
||||
return mps_matmul(
|
||||
s,
|
||||
d,
|
||||
a,
|
||||
@@ -451,10 +595,9 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
copies);
|
||||
return;
|
||||
}
|
||||
|
||||
mlx_matmul(
|
||||
return steel_matmul(
|
||||
s,
|
||||
d,
|
||||
a,
|
||||
@@ -471,4 +614,266 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
copies);
|
||||
}
|
||||
|
||||
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
if (!is_floating_point(out.dtype())) {
|
||||
throw std::runtime_error(
|
||||
"[matmul] Does not yet support non-floating point types.");
|
||||
}
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto& a_pre = inputs[0];
|
||||
auto& b_pre = inputs[1];
|
||||
auto& c_pre = inputs[2];
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Init checks and prep
|
||||
|
||||
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||
// the arrays
|
||||
std::vector<array> copies;
|
||||
auto check_transpose = [&copies, &s](const array& arr) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (stx == arr.shape(-1) && sty == 1) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && sty == arr.shape(-2)) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
size_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
|
||||
auto [transpose_a, a_cols, a] = check_transpose(a_pre);
|
||||
auto [transpose_b, b_cols, b] = check_transpose(b_pre);
|
||||
|
||||
int M = a.shape(-2);
|
||||
int N = b.shape(-1);
|
||||
int K = a.shape(-1);
|
||||
|
||||
auto batch_size_out = out.size() / (M * N);
|
||||
|
||||
array c = c_pre;
|
||||
int ldc = c.strides()[c.ndim() - 2];
|
||||
int fdc = c.strides()[c.ndim() - 1];
|
||||
int matrix_stride_c = c.ndim() <= 2 ? 0 : c.strides()[c.ndim() - 3];
|
||||
|
||||
int lda = a_cols;
|
||||
int ldb = b_cols;
|
||||
|
||||
using namespace mlx::steel;
|
||||
|
||||
// Account for batch sizes and basic broadcasting
|
||||
int batch_size_a = a.data_size() / (M * K);
|
||||
int batch_size_b = b.data_size() / (K * N);
|
||||
|
||||
int matrix_stride_a = (batch_size_a == 1) ? 0 : M * K;
|
||||
int matrix_stride_b = (batch_size_b == 1) ? 0 : K * N;
|
||||
int matrix_stride_out = M * N;
|
||||
|
||||
int _tm = M / 16;
|
||||
int _tn = N / 16;
|
||||
int _tk = K / 16;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Split K specialization
|
||||
|
||||
if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) {
|
||||
int bm = M < 40 ? 16 : 32;
|
||||
int bn = N < 40 ? 16 : 32;
|
||||
int bk = 16;
|
||||
int wm = 2, wn = 2;
|
||||
|
||||
int split_k_partitions =
|
||||
_tk < 16 ? 2 : (_tk < 32 ? 4 : (_tk < 64 ? 8 : 16));
|
||||
int split_k_partition_stride = M * N;
|
||||
int gemm_k_iterations = (K / bk) / split_k_partitions;
|
||||
int split_k_partition_size = gemm_k_iterations * bk;
|
||||
|
||||
array C_split({split_k_partitions, M, N}, float32, nullptr, {});
|
||||
C_split.set_data(allocator::malloc_or_wait(C_split.nbytes()));
|
||||
copies.push_back(C_split);
|
||||
|
||||
std::ostringstream kname;
|
||||
kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n')
|
||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||
<< type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||
<< "_wm" << wm << "_wn" << wn << "_MN_"
|
||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned"
|
||||
<< "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned";
|
||||
|
||||
// Encode and dispatch gemm kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int tn = (N + bn - 1) / bn;
|
||||
int tm = (M + bm - 1) / bm;
|
||||
|
||||
GEMMSpiltKParams params{
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
lda,
|
||||
ldb,
|
||||
N,
|
||||
tn,
|
||||
tm,
|
||||
split_k_partitions,
|
||||
split_k_partition_stride,
|
||||
split_k_partition_size,
|
||||
gemm_k_iterations};
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions);
|
||||
|
||||
set_array_buffer(compute_encoder, a, 0);
|
||||
set_array_buffer(compute_encoder, b, 1);
|
||||
set_array_buffer(compute_encoder, C_split, 2);
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMSpiltKParams), 3);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
// Do accum kernel
|
||||
{
|
||||
auto kernel = d.get_kernel(
|
||||
"steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
|
||||
type_to_name(C_split) + "_axpby");
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Set the arguments for the kernel
|
||||
set_array_buffer(compute_encoder, C_split, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(&split_k_partitions, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&split_k_partition_stride, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&N, sizeof(int), 4);
|
||||
set_array_buffer(compute_encoder, c, 5);
|
||||
compute_encoder->setBytes(&ldc, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&fdc, sizeof(int), 7);
|
||||
compute_encoder->setBytes(&alpha_, sizeof(float), 8);
|
||||
compute_encoder->setBytes(&beta_, sizeof(float), 9);
|
||||
|
||||
// Launch enough thread groups for each output
|
||||
MTL::Size grid_dims = MTL::Size(N, M, 1);
|
||||
MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1);
|
||||
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
return;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Regular addmm dispatch
|
||||
|
||||
// Determine dispatch kernel
|
||||
int bm = 32, bn = 32, bk = 16;
|
||||
int wm = 2, wn = 2;
|
||||
|
||||
if ((size_t)batch_size_out * M * N >= 1ul << 20) {
|
||||
if (!transpose_a && transpose_b) {
|
||||
bm = 64;
|
||||
bn = (out.dtype() == float32) ? 64 : 32;
|
||||
bk = (out.dtype() == float32) ? 16 : 32;
|
||||
} else {
|
||||
bm = 64;
|
||||
bn = 64;
|
||||
}
|
||||
}
|
||||
|
||||
std::ostringstream kname;
|
||||
kname << "steel_addmm_" << (transpose_a ? 't' : 'n')
|
||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||
<< "_wm" << wm << "_wn" << wn << "_MN_"
|
||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned"
|
||||
<< "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned"
|
||||
<< ((alpha_ == 1. && beta_ == 1.) ? "_add" : "_axpby");
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int tn = (N + bn - 1) / bn;
|
||||
int tm = (M + bm - 1) / bm;
|
||||
|
||||
// TODO: Explore device-based tuning for swizzle
|
||||
int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);
|
||||
|
||||
GEMMAddMMParams params{
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
N,
|
||||
tn,
|
||||
tm,
|
||||
matrix_stride_a,
|
||||
matrix_stride_b,
|
||||
matrix_stride_c,
|
||||
matrix_stride_out,
|
||||
swizzle_log,
|
||||
(K / bk),
|
||||
alpha_,
|
||||
beta_,
|
||||
fdc};
|
||||
|
||||
int tile = 1 << swizzle_log;
|
||||
tm = (tm + tile - 1) / tile;
|
||||
tn = tn * tile;
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
|
||||
|
||||
// Launch only 1 kernel in the case of simple batching / broadcasting
|
||||
if (batch_size_out == std::max(batch_size_a, batch_size_b) &&
|
||||
(batch_size_a == batch_size_b ||
|
||||
std::min(batch_size_a, batch_size_b) == 1)) {
|
||||
set_array_buffer(compute_encoder, a, 0);
|
||||
set_array_buffer(compute_encoder, b, 1);
|
||||
set_array_buffer(compute_encoder, c, 2);
|
||||
set_array_buffer(compute_encoder, out, 3);
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMAddMMParams), 4);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
} else { // Otherwise launch kernels with set offsets
|
||||
|
||||
MTL::Size grid_dims_single = MTL::Size(tn, tm, 1);
|
||||
|
||||
for (int i = 0; i < batch_size_out; ++i) {
|
||||
auto a_off = elem_to_loc(M * K * i, a.shape(), a.strides());
|
||||
auto b_off = elem_to_loc(K * N * i, b.shape(), b.strides());
|
||||
auto c_off = elem_to_loc(M * N * i, c.shape(), c.strides());
|
||||
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto b_buf = static_cast<const MTL::Buffer*>(b.buffer().ptr());
|
||||
auto c_buf = static_cast<const MTL::Buffer*>(c.buffer().ptr());
|
||||
auto out_buf = static_cast<const MTL::Buffer*>(out.buffer().ptr());
|
||||
|
||||
compute_encoder->setBuffer(a_buf, a_off * a.itemsize(), 0);
|
||||
compute_encoder->setBuffer(b_buf, b_off * b.itemsize(), 1);
|
||||
compute_encoder->setBuffer(c_buf, c_off * c.itemsize(), 2);
|
||||
compute_encoder->setBuffer(out_buf, i * M * N * out.itemsize(), 3);
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMAddMMParams), 4);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims_single, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
return;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
Reference in New Issue
Block a user