Block sparse mm (#1058)

This commit is contained in:
Jagrit Digani 2024-05-02 14:03:58 -07:00 committed by GitHub
parent 17f57df797
commit f390957685
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 1323 additions and 75 deletions

View File

@ -32,6 +32,7 @@ Operations
bitwise_or
bitwise_xor
block_masked_mm
block_sparse_mm
broadcast_to
ceil
clip

View File

@ -32,6 +32,7 @@ DEFAULT(ArgReduce)
DEFAULT(ArgSort)
DEFAULT(AsStrided)
DEFAULT(BlockMaskedMM)
DEFAULT(BlockSparseMM)
DEFAULT(Broadcast)
DEFAULT(Ceil)
DEFAULT(Concatenate)

View File

@ -42,6 +42,7 @@ DEFAULT(AsType)
DEFAULT(AsStrided)
DEFAULT(Broadcast)
DEFAULT(BlockMaskedMM)
DEFAULT(BlockSparseMM)
DEFAULT_MULTI(DivMod)
DEFAULT(Ceil)
DEFAULT(Concatenate)

View File

@ -190,4 +190,91 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
}
}
void BlockSparseMM::eval(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[BlockSparseMM::eval] Currently only supports float32.");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
auto check_transpose = [](const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (stx == arr.shape(-1) && sty == 1) {
return std::make_tuple(false, stx, arr);
} else if (stx == 1 && sty == arr.shape(-2)) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
size_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
auto [a_transposed, lda, a] = check_transpose(a_pre);
auto [b_transposed, ldb, b] = check_transpose(b_pre);
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;
}
// Get batch dims
auto batch_size_out = out.size() / (M * N);
size_t matrix_stride_out = M * N;
auto get_batch_dims = [](const auto& v) {
return decltype(v){v.begin(), v.end() - 2};
};
auto& lhs_indices = inputs[2];
auto& rhs_indices = inputs[3];
std::vector<int> batch_shape = get_batch_dims(out.shape());
int batch_ndim = batch_shape.size();
std::vector<int> batch_shape_A = get_batch_dims(a.shape());
std::vector<size_t> batch_strides_A = get_batch_dims(a.strides());
std::vector<int> batch_shape_B = get_batch_dims(b.shape());
std::vector<size_t> batch_strides_B = get_batch_dims(b.strides());
const uint32_t* lhs_indices_ptr = lhs_indices.data<uint32_t>();
const uint32_t* rhs_indices_ptr = rhs_indices.data<uint32_t>();
for (int i = 0; i < batch_size_out; i++) {
// Get index
uint32_t indx_A = lhs_indices_ptr[elem_to_loc(i, lhs_indices)];
uint32_t indx_B = rhs_indices_ptr[elem_to_loc(i, rhs_indices)];
cblas_sgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
1.0f, // alpha
a.data<float>() + elem_to_loc(indx_A, batch_shape_A, batch_strides_A),
lda,
b.data<float>() + elem_to_loc(indx_B, batch_shape_B, batch_strides_B),
ldb,
0.0f, // beta
out.data<float>() + matrix_stride_out * i,
out.shape(-1) // ldc
);
}
}
} // namespace mlx::core

View File

@ -7,6 +7,8 @@
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/utils.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
@ -33,22 +35,19 @@ struct GEMVKernel {
// - We assume each thead group is launched with (BN, BM, 1) threads
//
// 1. A thread loads TN elements each from mat along TM contiguous rows
// and the corresponding scalar from the vector
// and the corresponding scalar from the vector
// 2. The thread then multiplies and adds to accumulate its local result for
// the block
// the block
// 3. At the end, each thread has accumulated results over all blocks across
// the rows
// These are then summed up across the threadgroup
// the rows. These are then summed up across the threadgroup
// 4. Each threadgroup writes its accumulated BN * TN outputs
//
// Edge case handling:
// - The threadgroup with the largest tid will have blocks that exceed the
// matrix
// - The threadgroup with the largest tid has blocks that exceed the matrix
// * The blocks that start outside the matrix are never read (thread results
// remain zero)
// remain zero)
// * The last thread that partially overlaps with the matrix is shifted
// inwards
// such that the thread block fits exactly in the matrix
// inwards such that the thread block fits exactly in the matrix
MLX_MTL_CONST short tgp_mem_size = BN * TN * 2;
@ -100,14 +99,14 @@ struct GEMVKernel {
if (simd_gid == 0) {
// Main load loop
if (bn + TN <= in_vec_size) {
#pragma clang loop unroll(full)
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
in_vec_block[tn] = in_vec[bn + tn];
}
} else { // Edgecase
#pragma clang loop unroll(full)
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
in_vec_block[tn] = bn + tn < in_vec_size ? in_vec[bn + tn] : 0;
}
@ -116,24 +115,24 @@ struct GEMVKernel {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load for all rows
#pragma clang loop unroll(full)
// Load for all rows
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
v_coeff[tn] = in_vec_block[tn];
}
// Per thread work loop
#pragma clang loop unroll(full)
// Per thread work loop
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
// Load for the row
if (bn + TN <= in_vec_size) {
#pragma clang loop unroll(full)
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
inter[tn] = mat[tm * marix_ld + bn + tn];
}
} else { // Edgecase
#pragma clang loop unroll(full)
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
int col_idx =
(bn + tn) < in_vec_size ? (bn + tn) : (in_vec_size - 1);
@ -142,21 +141,22 @@ struct GEMVKernel {
}
// Accumulate results
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
result[tm] += inter[tn] * v_coeff[tn];
}
}
}
// Simdgroup accumulations
#pragma clang loop unroll(full)
// Simdgroup accumulations
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
result[tm] = simd_sum(result[tm]);
}
// Write outputs
if (simd_lid == 0) {
#pragma clang loop unroll(full)
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
if (kDoAxpby) {
out_vec[out_row + tm] = static_cast<T>(alpha) * result[tm] +
@ -187,22 +187,18 @@ struct GEMVTKernel {
// - We assume each thead group is launched with (BN, BM, 1) threads
//
// 1. A thread loads TN elements each from mat along TM contiguous rows
// and the corresponding scalar from the vector
// 2. The thread then multiplies and adds to accumulate its local result for
// the block
// and the corresponding scalar from the vector
// 2. The thread then accumulates its local result for the block
// 3. At the end, each thread has accumulated results over all blocks across
// the rows
// These are then summed up across the threadgroup
// the rows. These are then summed up across the threadgroup
// 4. Each threadgroup writes its accumulated BN * TN outputs
//
// Edge case handling:
// - The threadgroup with the largest tid will have blocks that exceed the
// matrix
// - The threadgroup with the largest tid has blocks that exceed the matrix
// * The blocks that start outside the matrix are never read (thread results
// remain zero)
// remain zero)
// * The last thread that partially overlaps with the matrix is shifted
// inwards
// such that the thread block fits exactly in the matrix
// inwards such that the thread block fits exactly in the matrix
MLX_MTL_CONST short tgp_mem_size = BN * BM * TN;
@ -249,12 +245,12 @@ struct GEMVTKernel {
threadgroup_barrier(mem_flags::mem_none);
if (bm + TM <= in_vec_size) {
#pragma clang loop unroll(full)
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
v_coeff[tm] = in_vec[bm + tm];
}
#pragma clang loop unroll(full)
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
for (int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
@ -281,7 +277,7 @@ struct GEMVTKernel {
// Threadgroup collection
#pragma clang loop unroll(full)
MLX_MTL_PRAGMA_UNROLL
for (int i = 0; i < TN; i++) {
tgp_results[lid.y * TN + i] = result[i];
}
@ -290,15 +286,15 @@ struct GEMVTKernel {
// Threadgroup accumulation and writing out results
if (lid.y == 0 && out_col < out_vec_size) {
#pragma clang loop unroll(full)
MLX_MTL_PRAGMA_UNROLL
for (int i = 1; i < BM; i++) {
#pragma clang loop unroll(full)
MLX_MTL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
result[j] += tgp_results[i * TN + j];
}
}
#pragma clang loop unroll(full)
MLX_MTL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
if (kDoAxpby) {
out_vec[out_col + j] = static_cast<T>(alpha) * result[j] +
@ -408,20 +404,149 @@ template <
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_gemv(name, itype, bm, bn, tm, tn) \
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 0) \
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 1) \
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 0) \
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 1)
// clang-format off
#define instantiate_gemv(name, itype, bm, bn, tm, tn) \
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 0) \
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 1) \
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 0) \
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 1) // clang-format on
#define instantiate_gemv_blocks(name, itype) \
instantiate_gemv(name, itype, 4, 32, 1, 4) instantiate_gemv( \
name, itype, 4, 32, 4, 4) instantiate_gemv(name, itype, 8, 32, 4, 4)
// clang-format off
#define instantiate_gemv_blocks(name, itype) \
instantiate_gemv(name, itype, 4, 32, 1, 4) \
instantiate_gemv(name, itype, 4, 32, 4, 4) \
instantiate_gemv(name, itype, 8, 32, 4, 4) // clang-format on
instantiate_gemv_blocks(float32, float);
instantiate_gemv_blocks(float16, half);
instantiate_gemv_blocks(bfloat16, bfloat16_t);
template <
typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN> /* Thread cols (in elements) */
[[kernel, max_total_threads_per_threadgroup(BM* BN)]] void gemv_bs(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
const device T* bias [[buffer(2)]],
device T* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& marix_ld [[buffer(6)]],
const constant float& alpha [[buffer(7)]],
const constant float& beta [[buffer(8)]],
const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant size_t* index_batch_strides [[buffer(11)]],
const constant int& vector_batch_ndim [[buffer(12)]],
const constant int* vector_batch_shape [[buffer(13)]],
const constant size_t* vector_batch_stride [[buffer(14)]],
const constant int& matrix_batch_ndim [[buffer(15)]],
const constant int* matrix_batch_shape [[buffer(16)]],
const constant size_t* matrix_batch_stride [[buffer(17)]],
const constant uint32_t* vec_indices [[buffer(18)]],
const constant uint32_t* mat_indices [[buffer(19)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVKernel<T, BM, BN, TM, TN, false>;
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
uint32_t indx_vec;
uint32_t indx_mat;
// Update batch offsets
if (batch_ndim > 1) {
const constant size_t* veci_bstrides = index_batch_strides;
const constant size_t* mati_bstrides = index_batch_strides + batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim);
indx_vec = vec_indices[batch_offsets.x];
indx_mat = mat_indices[batch_offsets.y];
} else {
indx_vec = vec_indices[index_batch_strides[0] * tid.z];
indx_mat = mat_indices[index_batch_strides[batch_ndim] * tid.z];
}
if (vector_batch_ndim > 1) {
in_vec += elem_to_loc(
indx_vec, vector_batch_shape, vector_batch_stride, vector_batch_ndim);
} else {
in_vec += indx_vec * vector_batch_stride[0];
}
if (matrix_batch_ndim > 1) {
mat += elem_to_loc(
indx_mat, matrix_batch_shape, matrix_batch_stride, matrix_batch_ndim);
} else {
mat += indx_mat * matrix_batch_stride[0];
}
out_vec += tid.z * out_vec_size;
gemv_kernel::run(
mat,
in_vec,
bias,
out_vec,
in_vec_size,
out_vec_size,
marix_ld,
alpha,
beta,
batch_ndim, // Not used
tgp_memory,
tid,
lid,
simd_gid,
simd_lid);
}
#define instantiate_gemv_bs_helper(nm, itype, bm, bn, tm, tn) \
template [[host_name("gemv_bs_" #nm "_bm" #bm "_bn" #bn "_tm" #tm \
"_tn" #tn)]] [[kernel]] void \
gemv_bs<itype, bm, bn, tm, tn>( \
const device itype* mat [[buffer(0)]], \
const device itype* in_vec [[buffer(1)]], \
const device itype* bias [[buffer(2)]], \
device itype* out_vec [[buffer(3)]], \
const constant int& in_vec_size [[buffer(4)]], \
const constant int& out_vec_size [[buffer(5)]], \
const constant int& marix_ld [[buffer(6)]], \
const constant float& alpha [[buffer(7)]], \
const constant float& beta [[buffer(8)]], \
const constant int& batch_ndim [[buffer(9)]], \
const constant int* batch_shape [[buffer(10)]], \
const constant size_t* index_batch_strides [[buffer(11)]], \
const constant int& vector_batch_ndim [[buffer(12)]], \
const constant int* vector_batch_shape [[buffer(13)]], \
const constant size_t* vector_batch_stride [[buffer(14)]], \
const constant int& matrix_batch_ndim [[buffer(15)]], \
const constant int* matrix_batch_shape [[buffer(16)]], \
const constant size_t* matrix_batch_stride [[buffer(17)]], \
const constant uint32_t* vec_indices [[buffer(18)]], \
const constant uint32_t* mat_indices [[buffer(19)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
// clang-format off
#define instantiate_gemv_bs_blocks(name, itype) \
instantiate_gemv_bs_helper(name, itype, 4, 32, 1, 4) \
instantiate_gemv_bs_helper(name, itype, 4, 32, 4, 4) \
instantiate_gemv_bs_helper(name, itype, 8, 32, 4, 4) // clang-format on
instantiate_gemv_bs_blocks(float32, float);
instantiate_gemv_bs_blocks(float16, half);
instantiate_gemv_bs_blocks(bfloat16, bfloat16_t);
///////////////////////////////////////////////////////////////////////////////
/// Vector matrix multiplication
///////////////////////////////////////////////////////////////////////////////
@ -538,4 +663,134 @@ template <
// clang-format off
instantiate_gemv_t_blocks(float32, float);
instantiate_gemv_t_blocks(float16, half);
instantiate_gemv_t_blocks(bfloat16, bfloat16_t); // clang-format on
instantiate_gemv_t_blocks(bfloat16, bfloat16_t); // clang-format on
template <
typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN> /* Thread cols (in elements) */
[[kernel, max_total_threads_per_threadgroup(BM* BN)]] void gemv_t_bs(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
const device T* bias [[buffer(2)]],
device T* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& marix_ld [[buffer(6)]],
const constant float& alpha [[buffer(7)]],
const constant float& beta [[buffer(8)]],
const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant size_t* index_batch_strides [[buffer(11)]],
const constant int& vector_batch_ndim [[buffer(12)]],
const constant int* vector_batch_shape [[buffer(13)]],
const constant size_t* vector_batch_stride [[buffer(14)]],
const constant int& matrix_batch_ndim [[buffer(15)]],
const constant int* matrix_batch_shape [[buffer(16)]],
const constant size_t* matrix_batch_stride [[buffer(17)]],
const constant uint32_t* vec_indices [[buffer(18)]],
const constant uint32_t* mat_indices [[buffer(19)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVTKernel<T, BM, BN, TM, TN, false>;
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
uint32_t indx_vec;
uint32_t indx_mat;
// Update batch offsets
if (batch_ndim > 1) {
const constant size_t* veci_bstrides = index_batch_strides;
const constant size_t* mati_bstrides = index_batch_strides + batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim);
indx_vec = vec_indices[batch_offsets.x];
indx_mat = mat_indices[batch_offsets.y];
} else {
indx_vec = vec_indices[index_batch_strides[0] * tid.z];
indx_mat = mat_indices[index_batch_strides[batch_ndim] * tid.z];
}
if (vector_batch_ndim > 1) {
in_vec += elem_to_loc(
indx_vec, vector_batch_shape, vector_batch_stride, vector_batch_ndim);
} else {
in_vec += indx_vec * vector_batch_stride[0];
}
if (matrix_batch_ndim > 1) {
mat += elem_to_loc(
indx_mat, matrix_batch_shape, matrix_batch_stride, matrix_batch_ndim);
} else {
mat += indx_mat * matrix_batch_stride[0];
}
out_vec += tid.z * out_vec_size;
gemv_kernel::run(
mat,
in_vec,
bias,
out_vec,
in_vec_size,
out_vec_size,
marix_ld,
alpha,
beta,
batch_ndim, // Not used,
tgp_memory,
tid,
lid,
simd_gid,
simd_lid);
}
#define instantiate_gemv_t_bs_helper(nm, itype, bm, bn, tm, tn) \
template [[host_name("gemv_t_bs_" #nm "_bm" #bm "_bn" #bn "_tm" #tm \
"_tn" #tn)]] [[kernel]] void \
gemv_t_bs<itype, bm, bn, tm, tn>( \
const device itype* mat [[buffer(0)]], \
const device itype* in_vec [[buffer(1)]], \
const device itype* bias [[buffer(2)]], \
device itype* out_vec [[buffer(3)]], \
const constant int& in_vec_size [[buffer(4)]], \
const constant int& out_vec_size [[buffer(5)]], \
const constant int& marix_ld [[buffer(6)]], \
const constant float& alpha [[buffer(7)]], \
const constant float& beta [[buffer(8)]], \
const constant int& batch_ndim [[buffer(9)]], \
const constant int* batch_shape [[buffer(10)]], \
const constant size_t* index_batch_strides [[buffer(11)]], \
const constant int& vector_batch_ndim [[buffer(12)]], \
const constant int* vector_batch_shape [[buffer(13)]], \
const constant size_t* vector_batch_stride [[buffer(14)]], \
const constant int& matrix_batch_ndim [[buffer(15)]], \
const constant int* matrix_batch_shape [[buffer(16)]], \
const constant size_t* matrix_batch_stride [[buffer(17)]], \
const constant uint32_t* vec_indices [[buffer(18)]], \
const constant uint32_t* mat_indices [[buffer(19)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
// clang-format off
#define instantiate_gemv_t_bs_blocks(name, itype) \
instantiate_gemv_t_bs_helper(name, itype, 8, 8, 4, 1) \
instantiate_gemv_t_bs_helper(name, itype, 8, 8, 4, 4) \
instantiate_gemv_t_bs_helper(name, itype, 8, 16, 4, 4) \
instantiate_gemv_t_bs_helper(name, itype, 8, 32, 4, 4) \
instantiate_gemv_t_bs_helper(name, itype, 8, 64, 4, 4) \
instantiate_gemv_t_bs_helper(name, itype, 8, 128, 4, 4) // clang-format on
// clang-format off
instantiate_gemv_t_bs_blocks(float32, float);
instantiate_gemv_t_bs_blocks(float16, half);
instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); // clang-format on

View File

@ -0,0 +1,168 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
using namespace mlx::steel;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void bs_gemm(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device T* D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
const constant uint32_t* lhs_indices [[buffer(10)]],
const constant uint32_t* rhs_indices [[buffer(11)]],
const constant int* batch_shape_A [[buffer(12)]],
const constant size_t* batch_strides_A [[buffer(13)]],
const constant int* batch_shape_B [[buffer(14)]],
const constant size_t* batch_strides_B [[buffer(15)]],
const constant int2& operand_batch_ndim [[buffer(16)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using gemm_kernel = GEMMKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
MN_aligned,
K_aligned>;
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
uint32_t indx_A;
uint32_t indx_B;
// Adjust for batch
if (params->batch_ndim > 1) {
const constant size_t* A_bstrides = batch_strides;
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
indx_A = lhs_indices[batch_offsets.x];
indx_B = rhs_indices[batch_offsets.y];
} else {
indx_A = lhs_indices[params->batch_stride_a * tid.z];
indx_B = rhs_indices[params->batch_stride_b * tid.z];
}
int batch_ndim_A = operand_batch_ndim.x;
int batch_ndim_B = operand_batch_ndim.y;
if (batch_ndim_A > 1) {
A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A);
} else {
A += indx_A * batch_strides_A[0];
}
if (batch_ndim_B > 1) {
B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B);
} else {
B += indx_B * batch_strides_B[0];
}
D += params->batch_stride_d * tid.z;
gemm_kernel::run(
A, B, D, params, As, Bs, simd_lane_id, simd_group_id, tid, lid);
}
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel initializations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_gemm( \
tname, \
trans_a, \
trans_b, \
iname, \
itype, \
oname, \
otype, \
bm, \
bn, \
bk, \
wm, \
wn, \
aname, \
mn_aligned, \
kname, \
k_aligned) \
template [[host_name("steel_block_sparse_gemm_" #tname "_" #iname "_" #oname \
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \
"_MN_" #aname "_K_" #kname)]] [[kernel]] void \
bs_gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
const device itype* A [[buffer(0)]], \
const device itype* B [[buffer(1)]], \
device itype* D [[buffer(3)]], \
const constant GEMMParams* params [[buffer(4)]], \
const constant int* batch_shape [[buffer(6)]], \
const constant size_t* batch_strides [[buffer(7)]], \
const constant uint32_t* lhs_indices [[buffer(10)]], \
const constant uint32_t* rhs_indices [[buffer(11)]], \
const constant int* batch_shape_A [[buffer(12)]], \
const constant size_t* batch_strides_A [[buffer(13)]], \
const constant int* batch_shape_B [[buffer(14)]], \
const constant size_t* batch_strides_B [[buffer(15)]], \
const constant int2& operand_batch_ndim [[buffer(16)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
// clang-format off
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on
// clang-format off
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
// clang-format off
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) // clang-format on
// clang-format off
instantiate_gemm_shapes_helper(float16, half, float16, half);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on

View File

@ -21,9 +21,9 @@ struct GEMMParams {
const int tiles_n;
const int tiles_m;
const int batch_stride_a;
const int batch_stride_b;
const int batch_stride_d;
const size_t batch_stride_a;
const size_t batch_stride_b;
const size_t batch_stride_d;
const int swizzle_log;
const int gemm_k_iterations_aligned;
@ -54,7 +54,7 @@ struct GEMMAddMMParams {
const int ldc;
const int fdc;
const int batch_stride_c;
const size_t batch_stride_c;
const float alpha;
const float beta;

View File

@ -327,9 +327,9 @@ void steel_matmul_conv_groups(
/* const int ldd = */ ldd,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const int batch_stride_a = */ K,
/* const int batch_stride_b = */ N * K,
/* const int batch_stride_d = */ N,
/* const size_t batch_stride_a = */ size_t(K),
/* const size_t batch_stride_b = */ size_t(N) * K,
/* const size_t batch_stride_d = */ size_t(N),
/* const int swizzle_log = */ swizzle_log,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ 1};
@ -405,7 +405,7 @@ void steel_matmul(
}
}
int matrix_stride_out = M * N;
size_t matrix_stride_out = size_t(M) * N;
/////////////////////////////////////////////////////////////////////////////
// Split K specialization
@ -550,9 +550,9 @@ void steel_matmul(
/* const int ldd = */ N,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const int batch_stride_a = */ int(A_batch_stride.back()),
/* const int batch_stride_b = */ int(B_batch_stride.back()),
/* const int batch_stride_d = */ matrix_stride_out,
/* const size_t batch_stride_a = */ A_batch_stride.back(),
/* const size_t batch_stride_b = */ B_batch_stride.back(),
/* const size_t batch_stride_d = */ matrix_stride_out,
/* const int swizzle_log = */ swizzle_log,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ int(batch_shape.size())};
@ -645,7 +645,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
auto [batch_shape, A_batch_stride, B_batch_stride] = collapse_batches(a, b);
auto batch_size_out = out.size() / (M * N);
auto batch_size_out = out.size() / (size_t(M) * size_t(N));
// Collapse batches into M if needed
if (batch_size_out > 1 && !a_transposed && batch_shape.size() == 1 &&
@ -853,7 +853,8 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto [batch_shape, A_batch_stride, B_batch_stride, C_batch_stride] =
collapse_batches(a, b, c);
auto batch_size_out = out.size() / (M * N);
size_t matrix_stride_out = size_t(M) * size_t(N);
auto batch_size_out = out.size() / (matrix_stride_out);
// Collapse batches into M if needed
if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 &&
@ -869,8 +870,6 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
batch_shape = {1};
}
int matrix_stride_out = M * N;
/////////////////////////////////////////////////////////////////////////////
// Gemv specialization
@ -1120,9 +1119,9 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
/* const int ldd = */ N,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const int batch_stride_a = */ int(A_batch_stride.back()),
/* const int batch_stride_b = */ int(B_batch_stride.back()),
/* const int batch_stride_d = */ matrix_stride_out,
/* const size_t batch_stride_a = */ A_batch_stride.back(),
/* const size_t batch_stride_b = */ B_batch_stride.back(),
/* const size_t batch_stride_d = */ matrix_stride_out,
/* const int swizzle_log = */ swizzle_log,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ int(batch_shape.size())};
@ -1130,7 +1129,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
GEMMAddMMParams params{
/* const int ldc = */ ldc,
/* const int fdc = */ fdc,
/* const int batch_stride_c = */ int(C_batch_stride.back()),
/* const size_t batch_stride_c = */ C_batch_stride.back(),
/* const float alpha = */ alpha_,
/* const float beta = */ beta_};
@ -1230,8 +1229,8 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& out_mask = inputs[2];
std::vector<int> batch_shape{1};
int A_batch_str = 0;
int B_batch_str = 0;
size_t A_batch_str = 0;
size_t B_batch_str = 0;
std::vector<size_t> batch_strides;
@ -1249,8 +1248,8 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto [bshape_c, bstrides_c] = collapse_contiguous_dims(bshape, bstrides);
batch_shape = bshape_c;
A_batch_str = int(bstrides_c[0].back());
B_batch_str = int(bstrides_c[1].back());
A_batch_str = bstrides_c[0].back();
B_batch_str = bstrides_c[1].back();
for (auto& bstr : bstrides_c) {
batch_strides.insert(batch_strides.end(), bstr.begin(), bstr.end());
@ -1259,8 +1258,8 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
batch_strides = std::vector<size_t>(inputs.size(), 0);
}
auto batch_size_out = out.size() / (M * N);
int matrix_stride_out = M * N;
size_t matrix_stride_out = size_t(M) * N;
size_t batch_size_out = out.size() / (matrix_stride_out);
/////////////////////////////////////////////////////////////////////////////
// Regular kernel dispatch
@ -1301,9 +1300,9 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
/* const int ldd = */ N,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const int batch_stride_a = */ A_batch_str,
/* const int batch_stride_b = */ B_batch_str,
/* const int batch_stride_d = */ matrix_stride_out,
/* const size_t batch_stride_a = */ A_batch_str,
/* const size_t batch_stride_b = */ B_batch_str,
/* const size_t batch_stride_d = */ matrix_stride_out,
/* const int swizzle_log = */ swizzle_log,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ int(batch_shape.size())};
@ -1355,4 +1354,314 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
return;
}
void BlockSparseMM::eval_gpu(const std::vector<array>& inputs, array& out) {
using namespace mlx::steel;
// assert(inputs.size() == 2);
if (!issubdtype(out.dtype(), floating)) {
throw std::runtime_error(
"[matmul] Does not yet support non-floating point types.");
}
auto& s = stream();
auto& d = metal::device(s.device);
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
// Return 0s if either input is empty
if (a_pre.size() == 0 || b_pre.size() == 0) {
array zero = array(0, a_pre.dtype());
copy_gpu(zero, out, CopyType::Scalar, s);
auto command_buffer = d.get_command_buffer(s.index);
command_buffer->addCompletedHandler([zero](MTL::CommandBuffer*) {});
return;
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
/////////////////////////////////////////////////////////////////////////////
// Init checks and prep
// Keep a vector with copies to be cleared in the completed buffer to release
// the arrays
std::vector<array> copies;
auto check_transpose = [&copies, &s](const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (sty == 1) {
return std::make_tuple(false, stx, arr);
} else if (stx == 1) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
size_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
auto [transpose_a, a_cols, a] = check_transpose(a_pre);
auto [transpose_b, b_cols, b] = check_transpose(b_pre);
int lda = a_cols;
int ldb = b_cols;
int M = a.shape(-2);
int N = b.shape(-1);
int K = a.shape(-1);
/////////////////////////////////////////////////////////////////////////////
// Check and collapse batch dimensions
auto get_batch_dims = [](const auto& v) {
return decltype(v){v.begin(), v.end() - 2};
};
auto& lhs_indices = inputs[2];
auto& rhs_indices = inputs[3];
std::vector<int> batch_shape = get_batch_dims(out.shape());
std::vector<size_t> batch_strides;
batch_strides.insert(
batch_strides.end(),
lhs_indices.strides().begin(),
lhs_indices.strides().end());
size_t lhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
batch_strides.insert(
batch_strides.end(),
rhs_indices.strides().begin(),
rhs_indices.strides().end());
size_t rhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
int batch_ndim = batch_shape.size();
if (batch_ndim == 0) {
batch_shape = {1};
batch_strides = {0};
}
int batch_ndim_A = a.ndim() - 2;
int batch_ndim_B = b.ndim() - 2;
std::vector<int> operand_batch_ndim = {batch_ndim_A, batch_ndim_B};
std::vector<int> batch_shape_A = get_batch_dims(a.shape());
std::vector<size_t> batch_strides_A = get_batch_dims(a.strides());
std::vector<int> batch_shape_B = get_batch_dims(b.shape());
std::vector<size_t> batch_strides_B = get_batch_dims(b.strides());
if (batch_ndim_A == 0) {
batch_shape_A = {1};
batch_strides_A = {0};
}
if (batch_ndim_B == 0) {
batch_shape_B = {1};
batch_strides_B = {0};
}
size_t matrix_stride_out = size_t(M) * N;
auto batch_size_out = out.size() / matrix_stride_out;
/////////////////////////////////////////////////////////////////////////////
// Gemv specialization
// Route to gemv if needed
if (std::min(M, N) == 1) {
// Collect problem info
bool is_b_matrix = N != 1;
auto& mat = is_b_matrix ? b : a;
auto& vec = is_b_matrix ? a : b;
bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a;
int in_vector_len = K;
int out_vector_len = is_b_matrix ? N : M;
int mat_cols = transpose_mat ? out_vector_len : in_vector_len;
int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
int mat_ld = is_b_matrix ? b_cols : a_cols;
auto batch_strides_mat = is_b_matrix ? batch_strides_B : batch_strides_A;
auto batch_strides_vec = is_b_matrix ? batch_strides_A : batch_strides_B;
auto batch_shape_mat = is_b_matrix ? batch_shape_B : batch_shape_A;
auto batch_shape_vec = is_b_matrix ? batch_shape_A : batch_shape_B;
if (!is_b_matrix) {
batch_strides = rhs_indices.strides();
batch_strides.insert(
batch_strides.end(),
lhs_indices.strides().begin(),
lhs_indices.strides().end());
}
int batch_ndim = batch_shape.size();
// Determine dispatch kernel
int tm = 4, tn = 4;
int bm, bn, n_out_per_tgp;
std::ostringstream kname;
if (transpose_mat) {
bm = 8;
bn = 8;
if (out_vector_len >= 24576) {
bn = 128;
} else if (out_vector_len >= 16384) {
bn = 64;
} else if (out_vector_len >= 8192) {
bn = 16;
}
// Specialized kernel for very small outputs
tn = out_vector_len < tn ? 1 : tn;
n_out_per_tgp = bn * tn;
kname << "gemv_t_bs_" << type_to_name(out);
} else {
bm = out_vector_len >= 4096 ? 8 : 4;
bn = 32;
// Specialized kernel for very small outputs
tm = out_vector_len < tm ? 1 : tm;
n_out_per_tgp = bm * tm;
kname << "gemv_bs_" << type_to_name(out);
}
kname << "_bm" << bm << "_bn" << bn << "_tm" << tm << "_tn" << tn;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
MTL::Size group_dims = MTL::Size(bn, bm, 1);
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
compute_encoder.set_input_array(mat, 0);
compute_encoder.set_input_array(vec, 1);
compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
compute_encoder->setBytes(&mat_ld, sizeof(int), 6);
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
set_vector_bytes(compute_encoder, batch_shape, 10);
set_vector_bytes(compute_encoder, batch_strides, 11);
int batch_ndim_vec = batch_shape_vec.size();
compute_encoder->setBytes(&batch_ndim_vec, sizeof(int), 12);
set_vector_bytes(compute_encoder, batch_shape_vec, 13);
set_vector_bytes(compute_encoder, batch_strides_vec, 14);
int batch_ndim_mat = batch_shape_mat.size();
compute_encoder->setBytes(&batch_ndim_mat, sizeof(int), 15);
set_vector_bytes(compute_encoder, batch_shape_mat, 16);
set_vector_bytes(compute_encoder, batch_strides_mat, 17);
compute_encoder.set_input_array(lhs_indices, 18 + int(!is_b_matrix));
compute_encoder.set_input_array(rhs_indices, 18 + int(is_b_matrix));
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
return;
}
/////////////////////////////////////////////////////////////////////////////
// Regular kernel dispatch
// Determine dispatch kernel
int bm = 32, bn = 32, bk = 16;
int wm = 2, wn = 2;
if ((size_t)batch_size_out * M * N >= 1ul << 20) {
if (!transpose_a && transpose_b) {
bm = 64;
bn = (out.dtype() == float32) ? 64 : 32;
bk = (out.dtype() == float32) ? 16 : 32;
} else {
bm = 64;
bn = 64;
}
}
// Prepare kernel name
std::ostringstream kname;
kname << "steel_block_sparse_gemm_" << (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn << "_MN_"
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
<< ((K % bk == 0) ? "t" : "n") << "aligned";
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
// Use problem size to determine threadblock swizzle
int tn = (N + bn - 1) / bn;
int tm = (M + bm - 1) / bm;
// TODO: Explore device-based tuning for swizzle
int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);
// Prepare steel matmul params
GEMMParams params{
/* const int M = */ M,
/* const int N = */ N,
/* const int K = */ K,
/* const int lda = */ lda,
/* const int ldb = */ ldb,
/* const int ldd = */ N,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const size_t batch_stride_a = */ lhs_indices_str,
/* const size_t batch_stride_b = */ rhs_indices_str,
/* const size_t batch_stride_d = */ matrix_stride_out,
/* const int swizzle_log = */ swizzle_log,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ batch_ndim};
// Prepare launch grid params
int tile = 1 << swizzle_log;
tm = (tm + tile - 1) / tile;
tn = tn * tile;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
// Launch kernel
compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&params, sizeof(GEMMParams), 4);
set_vector_bytes(compute_encoder, batch_shape, 6);
set_vector_bytes(compute_encoder, batch_strides, 7);
compute_encoder.set_input_array(lhs_indices, 10);
compute_encoder.set_input_array(rhs_indices, 11);
set_vector_bytes(compute_encoder, batch_shape_A, 12);
set_vector_bytes(compute_encoder, batch_strides_A, 13);
set_vector_bytes(compute_encoder, batch_shape_B, 14);
set_vector_bytes(compute_encoder, batch_strides_B, 15);
set_vector_bytes(compute_encoder, operand_batch_ndim, 16);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
// Clear copies
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
return;
}
} // namespace mlx::core

View File

@ -32,6 +32,8 @@ NO_GPU(ArgSort)
NO_GPU(AsType)
NO_GPU(AsStrided)
NO_GPU(BitwiseBinary)
NO_GPU(BlockMaskedMM)
NO_GPU(BlockSparseMM)
NO_GPU(Broadcast)
NO_GPU(Ceil)
NO_GPU_MULTI(Compiled)
@ -100,7 +102,6 @@ NO_GPU(Subtract)
NO_GPU_MULTI(SVD)
NO_GPU(Tan)
NO_GPU(Tanh)
NO_GPU(BlockMaskedMM)
NO_GPU(Transpose)
NO_GPU(Inverse)

View File

@ -3785,6 +3785,124 @@ array block_masked_mm(
return out;
}
/** Compute matrix product with matrix-level gather */
array block_sparse_mm(
array a,
array b,
std::optional<array> lhs_indices_ /* = std::nullopt */,
std::optional<array> rhs_indices_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
// If no indices, fall back to full matmul
if (!lhs_indices_ && !rhs_indices_) {
return matmul(a, b, s);
}
// Do shape checks for operands
int in_a_ndim = a.ndim();
int in_b_ndim = b.ndim();
if (a.ndim() == 0 || b.ndim() == 0) {
throw std::invalid_argument(
"[block_sparse_mm] Got 0 dimension input. Inputs must "
"have at least one dimension.");
}
if (a.ndim() == 1) {
// Insert a singleton dim in the beginning
a = reshape(a, {1, -1}, s);
}
if (b.ndim() == 1) {
// Insert a singleton dim at the end
b = reshape(b, {-1, 1}, s);
}
if (a.shape(-1) != b.shape(-2)) {
std::ostringstream msg;
msg << "[block_sparse_mm] Last dimension of first input with shape "
<< a.shape() << " must match second to last dimension of"
<< " second input with shape " << b.shape() << ".";
throw std::invalid_argument(msg.str());
}
// Type promotion
auto out_type = result_type(a, b);
if (!issubdtype(out_type, floating)) {
std::ostringstream msg;
msg << "[block_sparse_mm] Only real floating point types are supported but "
<< a.dtype() << " and " << b.dtype()
<< " were provided which results in " << out_type
<< ", which is not a real floating point type.";
throw std::invalid_argument(msg.str());
}
a = astype(a, out_type, s);
b = astype(b, out_type, s);
// Handle broadcasting
std::vector<int> bsx_a(a.shape().begin(), a.shape().end() - 2);
std::vector<int> bsx_b(b.shape().begin(), b.shape().end() - 2);
auto indices_or_default = [&](const std::optional<array>& indices,
const std::vector<int>& bsx_shape) {
if (indices.has_value()) {
return indices.value();
} else {
int n_batch = 1;
for (auto& i : bsx_shape)
n_batch *= i;
return reshape(arange(n_batch, uint32, s), bsx_shape, s);
}
};
// Pull and broadcast indices
array lhs_indices = indices_or_default(lhs_indices_, bsx_a);
array rhs_indices = indices_or_default(rhs_indices_, bsx_b);
if (!issubdtype(lhs_indices.dtype(), integer)) {
throw std::invalid_argument(
"[block_sparse_mm] Got lhs_indices with invalid dtype. Indices must be integral.");
}
if (!issubdtype(rhs_indices.dtype(), integer)) {
throw std::invalid_argument(
"[block_sparse_mm] Got rhs_indices with invalid dtype. Indices must be integral.");
}
lhs_indices = astype(lhs_indices, uint32, s);
rhs_indices = astype(rhs_indices, uint32, s);
int M = a.shape(-2);
int N = b.shape(-1);
int K = a.shape(-1);
auto out_bsx_shape =
broadcast_shapes(lhs_indices.shape(), rhs_indices.shape());
lhs_indices = broadcast_to(lhs_indices, out_bsx_shape, s);
rhs_indices = broadcast_to(rhs_indices, out_bsx_shape, s);
auto out_shape = out_bsx_shape;
out_shape.push_back(M);
out_shape.push_back(N);
// Caculate array
auto out = array(
out_shape,
out_type,
std::make_shared<BlockSparseMM>(to_stream(s)),
{a, b, lhs_indices, rhs_indices});
// Remove the possibly inserted singleton dimensions
if (in_a_ndim == 1 || in_b_ndim == 1) {
out_shape.erase(
out_shape.end() - ((in_a_ndim == 1) ? 2 : 1),
out_shape.end() - ((in_b_ndim == 1) ? 0 : 1));
out = reshape(out, out_shape, s);
}
return out;
}
array diagonal(
const array& a,
int offset /* = 0 */,

View File

@ -1183,6 +1183,14 @@ array block_masked_mm(
std::optional<array> mask_rhs = std::nullopt,
StreamOrDevice s = {});
/** Compute matrix product with matrix-level gather */
array block_sparse_mm(
array a,
array b,
std::optional<array> lhs_indices = std::nullopt,
std::optional<array> rhs_indices = std::nullopt,
StreamOrDevice s = {});
/** Extract a diagonal or construct a diagonal array */
array diagonal(
const array& a,

View File

@ -3357,7 +3357,61 @@ std::vector<array> BlockMaskedMM::vjp(
vjps.push_back(grad);
} else {
vjps.push_back(zeros_like(primals[arg], stream()));
throw std::invalid_argument(
"[BlockMaskedMM] Cannot calculate VJP with respect to masks.");
}
}
return vjps;
}
std::vector<array> BlockSparseMM::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
std::vector<array> vjps;
auto& cotan = cotangents[0];
auto& lhs_indices = primals[2];
auto& rhs_indices = primals[3];
int M = cotan.shape(-2);
int N = cotan.shape(-1);
int K = primals[0].shape(-1);
for (auto arg : argnums) {
if (arg == 0) {
// M X N * (K X N).T -> M X K
auto base = zeros_like(primals[0], stream());
auto bt = swapaxes(primals[1], -1, -2, stream());
auto base_shape = base.shape();
base = reshape(base, {-1, M, K}, stream());
// g : (out_batch_shape) + (M, K)
auto g = block_sparse_mm(cotan, bt, std::nullopt, rhs_indices, stream());
g = expand_dims(g, -3, stream());
auto gacc = scatter_add(base, lhs_indices, g, 0, stream());
vjps.push_back(reshape(gacc, base_shape, stream()));
} else if (arg == 1) {
// (M X K).T * M X N -> K X N
auto base = zeros_like(primals[1], stream());
auto at = swapaxes(primals[0], -1, -2, stream());
auto base_shape = base.shape();
base = reshape(base, {-1, K, N}, stream());
// g : (out_batch_shape) + (K, N)
auto g = block_sparse_mm(at, cotan, lhs_indices, std::nullopt, stream());
g = expand_dims(g, -3, stream());
auto gacc = scatter_add(base, rhs_indices, g, 0, stream());
vjps.push_back(reshape(gacc, base_shape, stream()));
} else {
throw std::invalid_argument(
"[BlockSparseMM] Cannot calculate VJP with respect to indices.");
}
}
return vjps;

View File

@ -485,6 +485,26 @@ class BlockMaskedMM : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};
class BlockSparseMM : public UnaryPrimitive {
public:
explicit BlockSparseMM(Stream stream) : UnaryPrimitive(stream) {};
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
DEFINE_PRINT(BlockSparseMM)
DEFINE_DEFAULT_IS_EQUIVALENT()
private:
void eval(const std::vector<array>& inputs, array& out);
};
class Broadcast : public UnaryPrimitive {
public:
explicit Broadcast(Stream stream, const std::vector<int>& shape)

View File

@ -3716,6 +3716,38 @@ void init_ops(nb::module_& m) {
mask_lhs (array, optional): Boolean mask for a (default: ``None``)
mask_rhs (array, optional): Boolean mask for b (default: ``None``)
)pbdoc");
m.def(
"block_sparse_mm",
&block_sparse_mm,
nb::arg(),
nb::arg(),
"lhs_indices"_a = nb::none(),
"rhs_indices"_a = nb::none(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def block_sparse_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Matrix multiplication with matrix-level gather.
Performs a gather of the operands with the given indices followed by a (possibly batched) matrix multiplication of two arrays.
This operation is more efficient than explicitly applying a :func:``take`` followed by a :func:``matmul``.
The indices ``lhs_indices`` and ``rhs_indices`` contain flat indices along the batch dimensions (i.e. all but the last two dimensions) of ``a`` and ``b`` respectively.
For ``a`` with shape ``(A1, A2, ..., AS, M, K)``,
``lhs_indices`` contains indices from the range ``[0, A1 * A2 * ... * AS)``
For ``b`` with shape ``(B1, B2, ..., BS, M, K)``,
``rhs_indices`` contains indices from the range ``[0, B1 * B2 * ... * BS)``
Args:
a (array): Input array.
b (array): Input array.
lhs_indices (array, optional): Integer indices for ``a`` (default: ``None``)
rhs_indices (array, optional): Integer indices for ``b`` (default: ``None``)
)pbdoc");
m.def(
"diagonal",

View File

@ -810,6 +810,199 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(c_mx, c_np, atol=1e-5))
def test_block_sparse_matmul(self):
def np_block_sparse_mm(a, b, lhs_indices=None, rhs_indices=None):
a = a.reshape((-1, a.shape[-2], a.shape[-1]))
b = b.reshape((-1, b.shape[-2], b.shape[-1]))
lhs_indices = lhs_indices or np.arange(a.shape[0])
rhs_indices = rhs_indices or np.arange(b.shape[0])
a = a[lhs_indices, :, :]
b = b[rhs_indices, :, :]
out = a @ b
return out
def test_shape(
M,
N,
K,
np_dtype=np.float32,
batch_A=(),
batch_B=(),
lhs_indices=None,
rhs_indices=None,
):
with self.subTest(
M=M,
N=N,
K=K,
np_dtype=np_dtype,
batch_A=batch_A,
batch_B=batch_B,
lhs_indices=lhs_indices,
rhs_indices=rhs_indices,
):
a_np = np.random.normal(size=batch_A + (M, K)).astype(np_dtype)
b_np = np.random.normal(size=batch_B + (K, N)).astype(np_dtype)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
out_np = np_block_sparse_mm(a_np, b_np, lhs_indices, rhs_indices)
lhs_indices_mx = None if lhs_indices is None else mx.array(lhs_indices)
rhs_indices_mx = None if rhs_indices is None else mx.array(rhs_indices)
out_mx = mx.block_sparse_mm(a_mx, b_mx, lhs_indices_mx, rhs_indices_mx)
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
inputs = (
{
"batch_A": (1,),
"lhs_indices": (0,),
"batch_B": (3,),
"rhs_indices": (2, 1),
},
{
"batch_A": (1,),
"lhs_indices": None,
"batch_B": (3,),
"rhs_indices": (2, 1),
},
{
"batch_A": (2,),
"lhs_indices": None,
"batch_B": (3,),
"rhs_indices": (2, 1),
},
{
"batch_A": (3,),
"lhs_indices": (0, 2),
"batch_B": (1,),
"rhs_indices": (0,),
},
{
"batch_A": (5,),
"lhs_indices": (0, 2),
"batch_B": (3,),
"rhs_indices": (2, 1),
},
{
"batch_A": (4, 2),
"lhs_indices": (
(7, 6),
(5, 4),
(1, 2),
),
"batch_B": (4, 1),
"rhs_indices": ((2,), (0,), (1,)),
},
)
for kwargs in inputs:
test_shape(32, 32, 32, **kwargs)
test_shape(16, 1, 16, **kwargs)
# Add tests for broadcasting
a_np = np.random.normal(size=(5, 32, 32)).astype(np.float32)
b_np = np.random.normal(size=(3, 32, 32)).astype(np.float32)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
# Numpy
a_np = a_np.reshape((5, 1, 32, 32))
b_np = b_np.reshape((1, 3, 32, 32))
a_np = np.broadcast_to(a_np, (5, 4, 32, 32))
b_np = np.broadcast_to(b_np, (2, 3, 32, 32)).swapaxes(1, 0)
lhs_indices = [0, 13, 12]
rhs_indices = [0, 3, 5]
out_np = np_block_sparse_mm(a_np, b_np, lhs_indices, rhs_indices)
# MLX
a_mx = a_mx.reshape((5, 1, 32, 32))
b_mx = b_mx.reshape((1, 3, 32, 32))
a_mx = mx.broadcast_to(a_mx, (5, 4, 32, 32))
b_mx = mx.broadcast_to(b_mx, (2, 3, 32, 32)).swapaxes(1, 0)
lhs_indices_mx = mx.array(lhs_indices)
rhs_indices_mx = mx.array(rhs_indices)
out_mx = mx.block_sparse_mm(a_mx, b_mx, lhs_indices_mx, rhs_indices_mx)
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
# Gemv test
a_np = np.random.normal(size=(5, 1, 32)).astype(np.float32)
b_np = np.random.normal(size=(3, 16, 32)).astype(np.float32)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
lhs_indices = [3, 1]
rhs_indices = [0, 2]
b_np_t = np.swapaxes(b_np, -1, -2)
out_np = np_block_sparse_mm(a_np, b_np_t, lhs_indices, rhs_indices)
lhs_indices_mx = mx.array(lhs_indices)
rhs_indices_mx = mx.array(rhs_indices)
b_mx_t = mx.swapaxes(b_mx, -1, -2)
out_mx = mx.block_sparse_mm(a_mx, b_mx_t, lhs_indices_mx, rhs_indices_mx)
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
def test_block_sparse_matmul_grad(self):
lhs_indices = mx.array([[7, 6], [4, 1], [0, 2]], dtype=mx.uint32)
rhs_indices = mx.array([[2], [0], [1]], dtype=mx.uint32)
def f_ref(a, b):
lhs_indices_ = mx.broadcast_to(lhs_indices, (3, 2))
rhs_indices_ = mx.broadcast_to(rhs_indices, (3, 2))
M = a.shape[-2]
N = b.shape[-2]
K = a.shape[-1]
a = a.reshape((-1, M, K))
b = b.reshape((-1, K, N))
a = mx.take(a, lhs_indices_, 0)
b = mx.take(b, rhs_indices_, 0)
return a @ b
def f_test(a, b):
return mx.block_sparse_mm(a, b, lhs_indices, rhs_indices)
a_mx = mx.random.normal((4, 2, 32, 32))
b_mx = mx.random.normal((4, 1, 32, 32))
out_test = f_test(a_mx, b_mx)
out_ref = f_ref(a_mx, b_mx)
self.assertTrue(mx.allclose(out_test, out_ref, atol=1e-5))
cotan = mx.ones_like(out_test)
out_ref, dout_ref = mx.vjp(
f_ref,
[a_mx, b_mx],
[cotan],
)
out_test, dout_test = mx.vjp(
f_test,
[a_mx, b_mx],
[cotan],
)
for r, t in zip(dout_ref, dout_test):
self.assertEqual(r.shape, t.shape)
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
if __name__ == "__main__":
unittest.main()