mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-23 18:11:17 +08:00
Block sparse mm (#1058)
This commit is contained in:
parent
17f57df797
commit
f390957685
@ -32,6 +32,7 @@ Operations
|
||||
bitwise_or
|
||||
bitwise_xor
|
||||
block_masked_mm
|
||||
block_sparse_mm
|
||||
broadcast_to
|
||||
ceil
|
||||
clip
|
||||
|
@ -32,6 +32,7 @@ DEFAULT(ArgReduce)
|
||||
DEFAULT(ArgSort)
|
||||
DEFAULT(AsStrided)
|
||||
DEFAULT(BlockMaskedMM)
|
||||
DEFAULT(BlockSparseMM)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT(Ceil)
|
||||
DEFAULT(Concatenate)
|
||||
|
@ -42,6 +42,7 @@ DEFAULT(AsType)
|
||||
DEFAULT(AsStrided)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT(BlockMaskedMM)
|
||||
DEFAULT(BlockSparseMM)
|
||||
DEFAULT_MULTI(DivMod)
|
||||
DEFAULT(Ceil)
|
||||
DEFAULT(Concatenate)
|
||||
|
@ -190,4 +190,91 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void BlockSparseMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
if (out.dtype() != float32) {
|
||||
throw std::runtime_error(
|
||||
"[BlockSparseMM::eval] Currently only supports float32.");
|
||||
}
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& a_pre = inputs[0];
|
||||
auto& b_pre = inputs[1];
|
||||
|
||||
auto check_transpose = [](const array& arr) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (stx == arr.shape(-1) && sty == 1) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && sty == arr.shape(-2)) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
size_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
|
||||
auto [a_transposed, lda, a] = check_transpose(a_pre);
|
||||
auto [b_transposed, ldb, b] = check_transpose(b_pre);
|
||||
|
||||
size_t M = a.shape(-2);
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
|
||||
if (M == 0 || N == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (K == 0) {
|
||||
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
|
||||
return;
|
||||
}
|
||||
|
||||
// Get batch dims
|
||||
auto batch_size_out = out.size() / (M * N);
|
||||
size_t matrix_stride_out = M * N;
|
||||
|
||||
auto get_batch_dims = [](const auto& v) {
|
||||
return decltype(v){v.begin(), v.end() - 2};
|
||||
};
|
||||
|
||||
auto& lhs_indices = inputs[2];
|
||||
auto& rhs_indices = inputs[3];
|
||||
|
||||
std::vector<int> batch_shape = get_batch_dims(out.shape());
|
||||
int batch_ndim = batch_shape.size();
|
||||
|
||||
std::vector<int> batch_shape_A = get_batch_dims(a.shape());
|
||||
std::vector<size_t> batch_strides_A = get_batch_dims(a.strides());
|
||||
std::vector<int> batch_shape_B = get_batch_dims(b.shape());
|
||||
std::vector<size_t> batch_strides_B = get_batch_dims(b.strides());
|
||||
|
||||
const uint32_t* lhs_indices_ptr = lhs_indices.data<uint32_t>();
|
||||
const uint32_t* rhs_indices_ptr = rhs_indices.data<uint32_t>();
|
||||
|
||||
for (int i = 0; i < batch_size_out; i++) {
|
||||
// Get index
|
||||
uint32_t indx_A = lhs_indices_ptr[elem_to_loc(i, lhs_indices)];
|
||||
uint32_t indx_B = rhs_indices_ptr[elem_to_loc(i, rhs_indices)];
|
||||
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
a_transposed ? CblasTrans : CblasNoTrans, // transA
|
||||
b_transposed ? CblasTrans : CblasNoTrans, // transB
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
1.0f, // alpha
|
||||
a.data<float>() + elem_to_loc(indx_A, batch_shape_A, batch_strides_A),
|
||||
lda,
|
||||
b.data<float>() + elem_to_loc(indx_B, batch_shape_B, batch_strides_B),
|
||||
ldb,
|
||||
0.0f, // beta
|
||||
out.data<float>() + matrix_stride_out * i,
|
||||
out.shape(-1) // ldc
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -7,6 +7,8 @@
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@ -33,22 +35,19 @@ struct GEMVKernel {
|
||||
// - We assume each thead group is launched with (BN, BM, 1) threads
|
||||
//
|
||||
// 1. A thread loads TN elements each from mat along TM contiguous rows
|
||||
// and the corresponding scalar from the vector
|
||||
// and the corresponding scalar from the vector
|
||||
// 2. The thread then multiplies and adds to accumulate its local result for
|
||||
// the block
|
||||
// the block
|
||||
// 3. At the end, each thread has accumulated results over all blocks across
|
||||
// the rows
|
||||
// These are then summed up across the threadgroup
|
||||
// the rows. These are then summed up across the threadgroup
|
||||
// 4. Each threadgroup writes its accumulated BN * TN outputs
|
||||
//
|
||||
// Edge case handling:
|
||||
// - The threadgroup with the largest tid will have blocks that exceed the
|
||||
// matrix
|
||||
// - The threadgroup with the largest tid has blocks that exceed the matrix
|
||||
// * The blocks that start outside the matrix are never read (thread results
|
||||
// remain zero)
|
||||
// remain zero)
|
||||
// * The last thread that partially overlaps with the matrix is shifted
|
||||
// inwards
|
||||
// such that the thread block fits exactly in the matrix
|
||||
// inwards such that the thread block fits exactly in the matrix
|
||||
|
||||
MLX_MTL_CONST short tgp_mem_size = BN * TN * 2;
|
||||
|
||||
@ -100,14 +99,14 @@ struct GEMVKernel {
|
||||
if (simd_gid == 0) {
|
||||
// Main load loop
|
||||
if (bn + TN <= in_vec_size) {
|
||||
#pragma clang loop unroll(full)
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
in_vec_block[tn] = in_vec[bn + tn];
|
||||
}
|
||||
|
||||
} else { // Edgecase
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
in_vec_block[tn] = bn + tn < in_vec_size ? in_vec[bn + tn] : 0;
|
||||
}
|
||||
@ -116,24 +115,24 @@ struct GEMVKernel {
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load for all rows
|
||||
#pragma clang loop unroll(full)
|
||||
// Load for all rows
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
v_coeff[tn] = in_vec_block[tn];
|
||||
}
|
||||
|
||||
// Per thread work loop
|
||||
#pragma clang loop unroll(full)
|
||||
// Per thread work loop
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
// Load for the row
|
||||
if (bn + TN <= in_vec_size) {
|
||||
#pragma clang loop unroll(full)
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[tm * marix_ld + bn + tn];
|
||||
}
|
||||
|
||||
} else { // Edgecase
|
||||
#pragma clang loop unroll(full)
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
int col_idx =
|
||||
(bn + tn) < in_vec_size ? (bn + tn) : (in_vec_size - 1);
|
||||
@ -142,21 +141,22 @@ struct GEMVKernel {
|
||||
}
|
||||
|
||||
// Accumulate results
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
result[tm] += inter[tn] * v_coeff[tn];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Simdgroup accumulations
|
||||
#pragma clang loop unroll(full)
|
||||
// Simdgroup accumulations
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
result[tm] = simd_sum(result[tm]);
|
||||
}
|
||||
|
||||
// Write outputs
|
||||
if (simd_lid == 0) {
|
||||
#pragma clang loop unroll(full)
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
if (kDoAxpby) {
|
||||
out_vec[out_row + tm] = static_cast<T>(alpha) * result[tm] +
|
||||
@ -187,22 +187,18 @@ struct GEMVTKernel {
|
||||
// - We assume each thead group is launched with (BN, BM, 1) threads
|
||||
//
|
||||
// 1. A thread loads TN elements each from mat along TM contiguous rows
|
||||
// and the corresponding scalar from the vector
|
||||
// 2. The thread then multiplies and adds to accumulate its local result for
|
||||
// the block
|
||||
// and the corresponding scalar from the vector
|
||||
// 2. The thread then accumulates its local result for the block
|
||||
// 3. At the end, each thread has accumulated results over all blocks across
|
||||
// the rows
|
||||
// These are then summed up across the threadgroup
|
||||
// the rows. These are then summed up across the threadgroup
|
||||
// 4. Each threadgroup writes its accumulated BN * TN outputs
|
||||
//
|
||||
// Edge case handling:
|
||||
// - The threadgroup with the largest tid will have blocks that exceed the
|
||||
// matrix
|
||||
// - The threadgroup with the largest tid has blocks that exceed the matrix
|
||||
// * The blocks that start outside the matrix are never read (thread results
|
||||
// remain zero)
|
||||
// remain zero)
|
||||
// * The last thread that partially overlaps with the matrix is shifted
|
||||
// inwards
|
||||
// such that the thread block fits exactly in the matrix
|
||||
// inwards such that the thread block fits exactly in the matrix
|
||||
|
||||
MLX_MTL_CONST short tgp_mem_size = BN * BM * TN;
|
||||
|
||||
@ -249,12 +245,12 @@ struct GEMVTKernel {
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
if (bm + TM <= in_vec_size) {
|
||||
#pragma clang loop unroll(full)
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
v_coeff[tm] = in_vec[bm + tm];
|
||||
}
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
||||
@ -281,7 +277,7 @@ struct GEMVTKernel {
|
||||
|
||||
// Threadgroup collection
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < TN; i++) {
|
||||
tgp_results[lid.y * TN + i] = result[i];
|
||||
}
|
||||
@ -290,15 +286,15 @@ struct GEMVTKernel {
|
||||
|
||||
// Threadgroup accumulation and writing out results
|
||||
if (lid.y == 0 && out_col < out_vec_size) {
|
||||
#pragma clang loop unroll(full)
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int i = 1; i < BM; i++) {
|
||||
#pragma clang loop unroll(full)
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < TN; j++) {
|
||||
result[j] += tgp_results[i * TN + j];
|
||||
}
|
||||
}
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < TN; j++) {
|
||||
if (kDoAxpby) {
|
||||
out_vec[out_col + j] = static_cast<T>(alpha) * result[j] +
|
||||
@ -408,20 +404,149 @@ template <
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_gemv(name, itype, bm, bn, tm, tn) \
|
||||
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 0) \
|
||||
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 1) \
|
||||
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 0) \
|
||||
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 1)
|
||||
// clang-format off
|
||||
#define instantiate_gemv(name, itype, bm, bn, tm, tn) \
|
||||
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 0) \
|
||||
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 1) \
|
||||
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 0) \
|
||||
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 1) // clang-format on
|
||||
|
||||
#define instantiate_gemv_blocks(name, itype) \
|
||||
instantiate_gemv(name, itype, 4, 32, 1, 4) instantiate_gemv( \
|
||||
name, itype, 4, 32, 4, 4) instantiate_gemv(name, itype, 8, 32, 4, 4)
|
||||
// clang-format off
|
||||
#define instantiate_gemv_blocks(name, itype) \
|
||||
instantiate_gemv(name, itype, 4, 32, 1, 4) \
|
||||
instantiate_gemv(name, itype, 4, 32, 4, 4) \
|
||||
instantiate_gemv(name, itype, 8, 32, 4, 4) // clang-format on
|
||||
|
||||
instantiate_gemv_blocks(float32, float);
|
||||
instantiate_gemv_blocks(float16, half);
|
||||
instantiate_gemv_blocks(bfloat16, bfloat16_t);
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN> /* Thread cols (in elements) */
|
||||
[[kernel, max_total_threads_per_threadgroup(BM* BN)]] void gemv_bs(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
const device T* bias [[buffer(2)]],
|
||||
device T* out_vec [[buffer(3)]],
|
||||
const constant int& in_vec_size [[buffer(4)]],
|
||||
const constant int& out_vec_size [[buffer(5)]],
|
||||
const constant int& marix_ld [[buffer(6)]],
|
||||
const constant float& alpha [[buffer(7)]],
|
||||
const constant float& beta [[buffer(8)]],
|
||||
const constant int& batch_ndim [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant size_t* index_batch_strides [[buffer(11)]],
|
||||
const constant int& vector_batch_ndim [[buffer(12)]],
|
||||
const constant int* vector_batch_shape [[buffer(13)]],
|
||||
const constant size_t* vector_batch_stride [[buffer(14)]],
|
||||
const constant int& matrix_batch_ndim [[buffer(15)]],
|
||||
const constant int* matrix_batch_shape [[buffer(16)]],
|
||||
const constant size_t* matrix_batch_stride [[buffer(17)]],
|
||||
const constant uint32_t* vec_indices [[buffer(18)]],
|
||||
const constant uint32_t* mat_indices [[buffer(19)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
using gemv_kernel = GEMVKernel<T, BM, BN, TM, TN, false>;
|
||||
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
|
||||
|
||||
uint32_t indx_vec;
|
||||
uint32_t indx_mat;
|
||||
|
||||
// Update batch offsets
|
||||
if (batch_ndim > 1) {
|
||||
const constant size_t* veci_bstrides = index_batch_strides;
|
||||
const constant size_t* mati_bstrides = index_batch_strides + batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim);
|
||||
|
||||
indx_vec = vec_indices[batch_offsets.x];
|
||||
indx_mat = mat_indices[batch_offsets.y];
|
||||
|
||||
} else {
|
||||
indx_vec = vec_indices[index_batch_strides[0] * tid.z];
|
||||
indx_mat = mat_indices[index_batch_strides[batch_ndim] * tid.z];
|
||||
}
|
||||
|
||||
if (vector_batch_ndim > 1) {
|
||||
in_vec += elem_to_loc(
|
||||
indx_vec, vector_batch_shape, vector_batch_stride, vector_batch_ndim);
|
||||
} else {
|
||||
in_vec += indx_vec * vector_batch_stride[0];
|
||||
}
|
||||
|
||||
if (matrix_batch_ndim > 1) {
|
||||
mat += elem_to_loc(
|
||||
indx_mat, matrix_batch_shape, matrix_batch_stride, matrix_batch_ndim);
|
||||
} else {
|
||||
mat += indx_mat * matrix_batch_stride[0];
|
||||
}
|
||||
|
||||
out_vec += tid.z * out_vec_size;
|
||||
|
||||
gemv_kernel::run(
|
||||
mat,
|
||||
in_vec,
|
||||
bias,
|
||||
out_vec,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
marix_ld,
|
||||
alpha,
|
||||
beta,
|
||||
batch_ndim, // Not used
|
||||
tgp_memory,
|
||||
tid,
|
||||
lid,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
#define instantiate_gemv_bs_helper(nm, itype, bm, bn, tm, tn) \
|
||||
template [[host_name("gemv_bs_" #nm "_bm" #bm "_bn" #bn "_tm" #tm \
|
||||
"_tn" #tn)]] [[kernel]] void \
|
||||
gemv_bs<itype, bm, bn, tm, tn>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* in_vec [[buffer(1)]], \
|
||||
const device itype* bias [[buffer(2)]], \
|
||||
device itype* out_vec [[buffer(3)]], \
|
||||
const constant int& in_vec_size [[buffer(4)]], \
|
||||
const constant int& out_vec_size [[buffer(5)]], \
|
||||
const constant int& marix_ld [[buffer(6)]], \
|
||||
const constant float& alpha [[buffer(7)]], \
|
||||
const constant float& beta [[buffer(8)]], \
|
||||
const constant int& batch_ndim [[buffer(9)]], \
|
||||
const constant int* batch_shape [[buffer(10)]], \
|
||||
const constant size_t* index_batch_strides [[buffer(11)]], \
|
||||
const constant int& vector_batch_ndim [[buffer(12)]], \
|
||||
const constant int* vector_batch_shape [[buffer(13)]], \
|
||||
const constant size_t* vector_batch_stride [[buffer(14)]], \
|
||||
const constant int& matrix_batch_ndim [[buffer(15)]], \
|
||||
const constant int* matrix_batch_shape [[buffer(16)]], \
|
||||
const constant size_t* matrix_batch_stride [[buffer(17)]], \
|
||||
const constant uint32_t* vec_indices [[buffer(18)]], \
|
||||
const constant uint32_t* mat_indices [[buffer(19)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemv_bs_blocks(name, itype) \
|
||||
instantiate_gemv_bs_helper(name, itype, 4, 32, 1, 4) \
|
||||
instantiate_gemv_bs_helper(name, itype, 4, 32, 4, 4) \
|
||||
instantiate_gemv_bs_helper(name, itype, 8, 32, 4, 4) // clang-format on
|
||||
|
||||
instantiate_gemv_bs_blocks(float32, float);
|
||||
instantiate_gemv_bs_blocks(float16, half);
|
||||
instantiate_gemv_bs_blocks(bfloat16, bfloat16_t);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Vector matrix multiplication
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@ -539,3 +664,133 @@ template <
|
||||
instantiate_gemv_t_blocks(float32, float);
|
||||
instantiate_gemv_t_blocks(float16, half);
|
||||
instantiate_gemv_t_blocks(bfloat16, bfloat16_t); // clang-format on
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN> /* Thread cols (in elements) */
|
||||
[[kernel, max_total_threads_per_threadgroup(BM* BN)]] void gemv_t_bs(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
const device T* bias [[buffer(2)]],
|
||||
device T* out_vec [[buffer(3)]],
|
||||
const constant int& in_vec_size [[buffer(4)]],
|
||||
const constant int& out_vec_size [[buffer(5)]],
|
||||
const constant int& marix_ld [[buffer(6)]],
|
||||
const constant float& alpha [[buffer(7)]],
|
||||
const constant float& beta [[buffer(8)]],
|
||||
const constant int& batch_ndim [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant size_t* index_batch_strides [[buffer(11)]],
|
||||
const constant int& vector_batch_ndim [[buffer(12)]],
|
||||
const constant int* vector_batch_shape [[buffer(13)]],
|
||||
const constant size_t* vector_batch_stride [[buffer(14)]],
|
||||
const constant int& matrix_batch_ndim [[buffer(15)]],
|
||||
const constant int* matrix_batch_shape [[buffer(16)]],
|
||||
const constant size_t* matrix_batch_stride [[buffer(17)]],
|
||||
const constant uint32_t* vec_indices [[buffer(18)]],
|
||||
const constant uint32_t* mat_indices [[buffer(19)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
using gemv_kernel = GEMVTKernel<T, BM, BN, TM, TN, false>;
|
||||
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
|
||||
|
||||
uint32_t indx_vec;
|
||||
uint32_t indx_mat;
|
||||
|
||||
// Update batch offsets
|
||||
if (batch_ndim > 1) {
|
||||
const constant size_t* veci_bstrides = index_batch_strides;
|
||||
const constant size_t* mati_bstrides = index_batch_strides + batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim);
|
||||
|
||||
indx_vec = vec_indices[batch_offsets.x];
|
||||
indx_mat = mat_indices[batch_offsets.y];
|
||||
|
||||
} else {
|
||||
indx_vec = vec_indices[index_batch_strides[0] * tid.z];
|
||||
indx_mat = mat_indices[index_batch_strides[batch_ndim] * tid.z];
|
||||
}
|
||||
|
||||
if (vector_batch_ndim > 1) {
|
||||
in_vec += elem_to_loc(
|
||||
indx_vec, vector_batch_shape, vector_batch_stride, vector_batch_ndim);
|
||||
} else {
|
||||
in_vec += indx_vec * vector_batch_stride[0];
|
||||
}
|
||||
|
||||
if (matrix_batch_ndim > 1) {
|
||||
mat += elem_to_loc(
|
||||
indx_mat, matrix_batch_shape, matrix_batch_stride, matrix_batch_ndim);
|
||||
} else {
|
||||
mat += indx_mat * matrix_batch_stride[0];
|
||||
}
|
||||
|
||||
out_vec += tid.z * out_vec_size;
|
||||
|
||||
gemv_kernel::run(
|
||||
mat,
|
||||
in_vec,
|
||||
bias,
|
||||
out_vec,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
marix_ld,
|
||||
alpha,
|
||||
beta,
|
||||
batch_ndim, // Not used,
|
||||
tgp_memory,
|
||||
tid,
|
||||
lid,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
#define instantiate_gemv_t_bs_helper(nm, itype, bm, bn, tm, tn) \
|
||||
template [[host_name("gemv_t_bs_" #nm "_bm" #bm "_bn" #bn "_tm" #tm \
|
||||
"_tn" #tn)]] [[kernel]] void \
|
||||
gemv_t_bs<itype, bm, bn, tm, tn>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* in_vec [[buffer(1)]], \
|
||||
const device itype* bias [[buffer(2)]], \
|
||||
device itype* out_vec [[buffer(3)]], \
|
||||
const constant int& in_vec_size [[buffer(4)]], \
|
||||
const constant int& out_vec_size [[buffer(5)]], \
|
||||
const constant int& marix_ld [[buffer(6)]], \
|
||||
const constant float& alpha [[buffer(7)]], \
|
||||
const constant float& beta [[buffer(8)]], \
|
||||
const constant int& batch_ndim [[buffer(9)]], \
|
||||
const constant int* batch_shape [[buffer(10)]], \
|
||||
const constant size_t* index_batch_strides [[buffer(11)]], \
|
||||
const constant int& vector_batch_ndim [[buffer(12)]], \
|
||||
const constant int* vector_batch_shape [[buffer(13)]], \
|
||||
const constant size_t* vector_batch_stride [[buffer(14)]], \
|
||||
const constant int& matrix_batch_ndim [[buffer(15)]], \
|
||||
const constant int* matrix_batch_shape [[buffer(16)]], \
|
||||
const constant size_t* matrix_batch_stride [[buffer(17)]], \
|
||||
const constant uint32_t* vec_indices [[buffer(18)]], \
|
||||
const constant uint32_t* mat_indices [[buffer(19)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemv_t_bs_blocks(name, itype) \
|
||||
instantiate_gemv_t_bs_helper(name, itype, 8, 8, 4, 1) \
|
||||
instantiate_gemv_t_bs_helper(name, itype, 8, 8, 4, 4) \
|
||||
instantiate_gemv_t_bs_helper(name, itype, 8, 16, 4, 4) \
|
||||
instantiate_gemv_t_bs_helper(name, itype, 8, 32, 4, 4) \
|
||||
instantiate_gemv_t_bs_helper(name, itype, 8, 64, 4, 4) \
|
||||
instantiate_gemv_t_bs_helper(name, itype, 8, 128, 4, 4) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
instantiate_gemv_t_bs_blocks(float32, float);
|
||||
instantiate_gemv_t_bs_blocks(float16, half);
|
||||
instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); // clang-format on
|
@ -0,0 +1,168 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
using namespace mlx::steel;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool MN_aligned,
|
||||
bool K_aligned>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void bs_gemm(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
device T* D [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant size_t* batch_strides [[buffer(7)]],
|
||||
const constant uint32_t* lhs_indices [[buffer(10)]],
|
||||
const constant uint32_t* rhs_indices [[buffer(11)]],
|
||||
const constant int* batch_shape_A [[buffer(12)]],
|
||||
const constant size_t* batch_strides_A [[buffer(13)]],
|
||||
const constant int* batch_shape_B [[buffer(14)]],
|
||||
const constant size_t* batch_strides_B [[buffer(15)]],
|
||||
const constant int2& operand_batch_ndim [[buffer(16)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
using gemm_kernel = GEMMKernel<
|
||||
T,
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
MN_aligned,
|
||||
K_aligned>;
|
||||
|
||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
|
||||
uint32_t indx_A;
|
||||
uint32_t indx_B;
|
||||
|
||||
// Adjust for batch
|
||||
if (params->batch_ndim > 1) {
|
||||
const constant size_t* A_bstrides = batch_strides;
|
||||
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
|
||||
|
||||
indx_A = lhs_indices[batch_offsets.x];
|
||||
indx_B = rhs_indices[batch_offsets.y];
|
||||
|
||||
} else {
|
||||
indx_A = lhs_indices[params->batch_stride_a * tid.z];
|
||||
indx_B = rhs_indices[params->batch_stride_b * tid.z];
|
||||
}
|
||||
|
||||
int batch_ndim_A = operand_batch_ndim.x;
|
||||
int batch_ndim_B = operand_batch_ndim.y;
|
||||
|
||||
if (batch_ndim_A > 1) {
|
||||
A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A);
|
||||
} else {
|
||||
A += indx_A * batch_strides_A[0];
|
||||
}
|
||||
|
||||
if (batch_ndim_B > 1) {
|
||||
B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B);
|
||||
} else {
|
||||
B += indx_B * batch_strides_B[0];
|
||||
}
|
||||
|
||||
D += params->batch_stride_d * tid.z;
|
||||
|
||||
gemm_kernel::run(
|
||||
A, B, D, params, As, Bs, simd_lane_id, simd_group_id, tid, lid);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernel initializations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_gemm( \
|
||||
tname, \
|
||||
trans_a, \
|
||||
trans_b, \
|
||||
iname, \
|
||||
itype, \
|
||||
oname, \
|
||||
otype, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
aname, \
|
||||
mn_aligned, \
|
||||
kname, \
|
||||
k_aligned) \
|
||||
template [[host_name("steel_block_sparse_gemm_" #tname "_" #iname "_" #oname \
|
||||
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \
|
||||
"_MN_" #aname "_K_" #kname)]] [[kernel]] void \
|
||||
bs_gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
|
||||
const device itype* A [[buffer(0)]], \
|
||||
const device itype* B [[buffer(1)]], \
|
||||
device itype* D [[buffer(3)]], \
|
||||
const constant GEMMParams* params [[buffer(4)]], \
|
||||
const constant int* batch_shape [[buffer(6)]], \
|
||||
const constant size_t* batch_strides [[buffer(7)]], \
|
||||
const constant uint32_t* lhs_indices [[buffer(10)]], \
|
||||
const constant uint32_t* rhs_indices [[buffer(11)]], \
|
||||
const constant int* batch_shape_A [[buffer(12)]], \
|
||||
const constant size_t* batch_strides_A [[buffer(13)]], \
|
||||
const constant int* batch_shape_B [[buffer(14)]], \
|
||||
const constant size_t* batch_strides_B [[buffer(15)]], \
|
||||
const constant int2& operand_batch_ndim [[buffer(16)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
||||
|
||||
instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on
|
@ -21,9 +21,9 @@ struct GEMMParams {
|
||||
const int tiles_n;
|
||||
const int tiles_m;
|
||||
|
||||
const int batch_stride_a;
|
||||
const int batch_stride_b;
|
||||
const int batch_stride_d;
|
||||
const size_t batch_stride_a;
|
||||
const size_t batch_stride_b;
|
||||
const size_t batch_stride_d;
|
||||
|
||||
const int swizzle_log;
|
||||
const int gemm_k_iterations_aligned;
|
||||
@ -54,7 +54,7 @@ struct GEMMAddMMParams {
|
||||
const int ldc;
|
||||
const int fdc;
|
||||
|
||||
const int batch_stride_c;
|
||||
const size_t batch_stride_c;
|
||||
|
||||
const float alpha;
|
||||
const float beta;
|
||||
|
@ -327,9 +327,9 @@ void steel_matmul_conv_groups(
|
||||
/* const int ldd = */ ldd,
|
||||
/* const int tiles_n = */ tn,
|
||||
/* const int tiles_m = */ tm,
|
||||
/* const int batch_stride_a = */ K,
|
||||
/* const int batch_stride_b = */ N * K,
|
||||
/* const int batch_stride_d = */ N,
|
||||
/* const size_t batch_stride_a = */ size_t(K),
|
||||
/* const size_t batch_stride_b = */ size_t(N) * K,
|
||||
/* const size_t batch_stride_d = */ size_t(N),
|
||||
/* const int swizzle_log = */ swizzle_log,
|
||||
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||
/* const int batch_ndim = */ 1};
|
||||
@ -405,7 +405,7 @@ void steel_matmul(
|
||||
}
|
||||
}
|
||||
|
||||
int matrix_stride_out = M * N;
|
||||
size_t matrix_stride_out = size_t(M) * N;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Split K specialization
|
||||
@ -550,9 +550,9 @@ void steel_matmul(
|
||||
/* const int ldd = */ N,
|
||||
/* const int tiles_n = */ tn,
|
||||
/* const int tiles_m = */ tm,
|
||||
/* const int batch_stride_a = */ int(A_batch_stride.back()),
|
||||
/* const int batch_stride_b = */ int(B_batch_stride.back()),
|
||||
/* const int batch_stride_d = */ matrix_stride_out,
|
||||
/* const size_t batch_stride_a = */ A_batch_stride.back(),
|
||||
/* const size_t batch_stride_b = */ B_batch_stride.back(),
|
||||
/* const size_t batch_stride_d = */ matrix_stride_out,
|
||||
/* const int swizzle_log = */ swizzle_log,
|
||||
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||
/* const int batch_ndim = */ int(batch_shape.size())};
|
||||
@ -645,7 +645,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
auto [batch_shape, A_batch_stride, B_batch_stride] = collapse_batches(a, b);
|
||||
|
||||
auto batch_size_out = out.size() / (M * N);
|
||||
auto batch_size_out = out.size() / (size_t(M) * size_t(N));
|
||||
|
||||
// Collapse batches into M if needed
|
||||
if (batch_size_out > 1 && !a_transposed && batch_shape.size() == 1 &&
|
||||
@ -853,7 +853,8 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto [batch_shape, A_batch_stride, B_batch_stride, C_batch_stride] =
|
||||
collapse_batches(a, b, c);
|
||||
|
||||
auto batch_size_out = out.size() / (M * N);
|
||||
size_t matrix_stride_out = size_t(M) * size_t(N);
|
||||
auto batch_size_out = out.size() / (matrix_stride_out);
|
||||
|
||||
// Collapse batches into M if needed
|
||||
if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 &&
|
||||
@ -869,8 +870,6 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
batch_shape = {1};
|
||||
}
|
||||
|
||||
int matrix_stride_out = M * N;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Gemv specialization
|
||||
|
||||
@ -1120,9 +1119,9 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
/* const int ldd = */ N,
|
||||
/* const int tiles_n = */ tn,
|
||||
/* const int tiles_m = */ tm,
|
||||
/* const int batch_stride_a = */ int(A_batch_stride.back()),
|
||||
/* const int batch_stride_b = */ int(B_batch_stride.back()),
|
||||
/* const int batch_stride_d = */ matrix_stride_out,
|
||||
/* const size_t batch_stride_a = */ A_batch_stride.back(),
|
||||
/* const size_t batch_stride_b = */ B_batch_stride.back(),
|
||||
/* const size_t batch_stride_d = */ matrix_stride_out,
|
||||
/* const int swizzle_log = */ swizzle_log,
|
||||
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||
/* const int batch_ndim = */ int(batch_shape.size())};
|
||||
@ -1130,7 +1129,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
GEMMAddMMParams params{
|
||||
/* const int ldc = */ ldc,
|
||||
/* const int fdc = */ fdc,
|
||||
/* const int batch_stride_c = */ int(C_batch_stride.back()),
|
||||
/* const size_t batch_stride_c = */ C_batch_stride.back(),
|
||||
/* const float alpha = */ alpha_,
|
||||
/* const float beta = */ beta_};
|
||||
|
||||
@ -1230,8 +1229,8 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& out_mask = inputs[2];
|
||||
|
||||
std::vector<int> batch_shape{1};
|
||||
int A_batch_str = 0;
|
||||
int B_batch_str = 0;
|
||||
size_t A_batch_str = 0;
|
||||
size_t B_batch_str = 0;
|
||||
|
||||
std::vector<size_t> batch_strides;
|
||||
|
||||
@ -1249,8 +1248,8 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
auto [bshape_c, bstrides_c] = collapse_contiguous_dims(bshape, bstrides);
|
||||
batch_shape = bshape_c;
|
||||
A_batch_str = int(bstrides_c[0].back());
|
||||
B_batch_str = int(bstrides_c[1].back());
|
||||
A_batch_str = bstrides_c[0].back();
|
||||
B_batch_str = bstrides_c[1].back();
|
||||
|
||||
for (auto& bstr : bstrides_c) {
|
||||
batch_strides.insert(batch_strides.end(), bstr.begin(), bstr.end());
|
||||
@ -1259,8 +1258,8 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
batch_strides = std::vector<size_t>(inputs.size(), 0);
|
||||
}
|
||||
|
||||
auto batch_size_out = out.size() / (M * N);
|
||||
int matrix_stride_out = M * N;
|
||||
size_t matrix_stride_out = size_t(M) * N;
|
||||
size_t batch_size_out = out.size() / (matrix_stride_out);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Regular kernel dispatch
|
||||
@ -1301,9 +1300,9 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
/* const int ldd = */ N,
|
||||
/* const int tiles_n = */ tn,
|
||||
/* const int tiles_m = */ tm,
|
||||
/* const int batch_stride_a = */ A_batch_str,
|
||||
/* const int batch_stride_b = */ B_batch_str,
|
||||
/* const int batch_stride_d = */ matrix_stride_out,
|
||||
/* const size_t batch_stride_a = */ A_batch_str,
|
||||
/* const size_t batch_stride_b = */ B_batch_str,
|
||||
/* const size_t batch_stride_d = */ matrix_stride_out,
|
||||
/* const int swizzle_log = */ swizzle_log,
|
||||
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||
/* const int batch_ndim = */ int(batch_shape.size())};
|
||||
@ -1355,4 +1354,314 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
|
||||
void BlockSparseMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
using namespace mlx::steel;
|
||||
// assert(inputs.size() == 2);
|
||||
if (!issubdtype(out.dtype(), floating)) {
|
||||
throw std::runtime_error(
|
||||
"[matmul] Does not yet support non-floating point types.");
|
||||
}
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto& a_pre = inputs[0];
|
||||
auto& b_pre = inputs[1];
|
||||
// Return 0s if either input is empty
|
||||
if (a_pre.size() == 0 || b_pre.size() == 0) {
|
||||
array zero = array(0, a_pre.dtype());
|
||||
copy_gpu(zero, out, CopyType::Scalar, s);
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
command_buffer->addCompletedHandler([zero](MTL::CommandBuffer*) {});
|
||||
return;
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Init checks and prep
|
||||
|
||||
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||
// the arrays
|
||||
std::vector<array> copies;
|
||||
auto check_transpose = [&copies, &s](const array& arr) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (sty == 1) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
size_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
|
||||
auto [transpose_a, a_cols, a] = check_transpose(a_pre);
|
||||
auto [transpose_b, b_cols, b] = check_transpose(b_pre);
|
||||
|
||||
int lda = a_cols;
|
||||
int ldb = b_cols;
|
||||
|
||||
int M = a.shape(-2);
|
||||
int N = b.shape(-1);
|
||||
int K = a.shape(-1);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Check and collapse batch dimensions
|
||||
|
||||
auto get_batch_dims = [](const auto& v) {
|
||||
return decltype(v){v.begin(), v.end() - 2};
|
||||
};
|
||||
|
||||
auto& lhs_indices = inputs[2];
|
||||
auto& rhs_indices = inputs[3];
|
||||
|
||||
std::vector<int> batch_shape = get_batch_dims(out.shape());
|
||||
std::vector<size_t> batch_strides;
|
||||
|
||||
batch_strides.insert(
|
||||
batch_strides.end(),
|
||||
lhs_indices.strides().begin(),
|
||||
lhs_indices.strides().end());
|
||||
size_t lhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
|
||||
|
||||
batch_strides.insert(
|
||||
batch_strides.end(),
|
||||
rhs_indices.strides().begin(),
|
||||
rhs_indices.strides().end());
|
||||
size_t rhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
|
||||
|
||||
int batch_ndim = batch_shape.size();
|
||||
|
||||
if (batch_ndim == 0) {
|
||||
batch_shape = {1};
|
||||
batch_strides = {0};
|
||||
}
|
||||
|
||||
int batch_ndim_A = a.ndim() - 2;
|
||||
int batch_ndim_B = b.ndim() - 2;
|
||||
std::vector<int> operand_batch_ndim = {batch_ndim_A, batch_ndim_B};
|
||||
|
||||
std::vector<int> batch_shape_A = get_batch_dims(a.shape());
|
||||
std::vector<size_t> batch_strides_A = get_batch_dims(a.strides());
|
||||
std::vector<int> batch_shape_B = get_batch_dims(b.shape());
|
||||
std::vector<size_t> batch_strides_B = get_batch_dims(b.strides());
|
||||
|
||||
if (batch_ndim_A == 0) {
|
||||
batch_shape_A = {1};
|
||||
batch_strides_A = {0};
|
||||
}
|
||||
|
||||
if (batch_ndim_B == 0) {
|
||||
batch_shape_B = {1};
|
||||
batch_strides_B = {0};
|
||||
}
|
||||
|
||||
size_t matrix_stride_out = size_t(M) * N;
|
||||
auto batch_size_out = out.size() / matrix_stride_out;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Gemv specialization
|
||||
|
||||
// Route to gemv if needed
|
||||
if (std::min(M, N) == 1) {
|
||||
// Collect problem info
|
||||
bool is_b_matrix = N != 1;
|
||||
|
||||
auto& mat = is_b_matrix ? b : a;
|
||||
auto& vec = is_b_matrix ? a : b;
|
||||
bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a;
|
||||
int in_vector_len = K;
|
||||
int out_vector_len = is_b_matrix ? N : M;
|
||||
|
||||
int mat_cols = transpose_mat ? out_vector_len : in_vector_len;
|
||||
int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
|
||||
int mat_ld = is_b_matrix ? b_cols : a_cols;
|
||||
|
||||
auto batch_strides_mat = is_b_matrix ? batch_strides_B : batch_strides_A;
|
||||
auto batch_strides_vec = is_b_matrix ? batch_strides_A : batch_strides_B;
|
||||
|
||||
auto batch_shape_mat = is_b_matrix ? batch_shape_B : batch_shape_A;
|
||||
auto batch_shape_vec = is_b_matrix ? batch_shape_A : batch_shape_B;
|
||||
|
||||
if (!is_b_matrix) {
|
||||
batch_strides = rhs_indices.strides();
|
||||
batch_strides.insert(
|
||||
batch_strides.end(),
|
||||
lhs_indices.strides().begin(),
|
||||
lhs_indices.strides().end());
|
||||
}
|
||||
|
||||
int batch_ndim = batch_shape.size();
|
||||
|
||||
// Determine dispatch kernel
|
||||
int tm = 4, tn = 4;
|
||||
int bm, bn, n_out_per_tgp;
|
||||
std::ostringstream kname;
|
||||
|
||||
if (transpose_mat) {
|
||||
bm = 8;
|
||||
bn = 8;
|
||||
if (out_vector_len >= 24576) {
|
||||
bn = 128;
|
||||
} else if (out_vector_len >= 16384) {
|
||||
bn = 64;
|
||||
} else if (out_vector_len >= 8192) {
|
||||
bn = 16;
|
||||
}
|
||||
|
||||
// Specialized kernel for very small outputs
|
||||
tn = out_vector_len < tn ? 1 : tn;
|
||||
|
||||
n_out_per_tgp = bn * tn;
|
||||
kname << "gemv_t_bs_" << type_to_name(out);
|
||||
|
||||
} else {
|
||||
bm = out_vector_len >= 4096 ? 8 : 4;
|
||||
bn = 32;
|
||||
|
||||
// Specialized kernel for very small outputs
|
||||
tm = out_vector_len < tm ? 1 : tm;
|
||||
|
||||
n_out_per_tgp = bm * tm;
|
||||
kname << "gemv_bs_" << type_to_name(out);
|
||||
}
|
||||
|
||||
kname << "_bm" << bm << "_bn" << bn << "_tm" << tm << "_tn" << tn;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
||||
MTL::Size group_dims = MTL::Size(bn, bm, 1);
|
||||
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
|
||||
|
||||
compute_encoder.set_input_array(mat, 0);
|
||||
compute_encoder.set_input_array(vec, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&mat_ld, sizeof(int), 6);
|
||||
|
||||
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
|
||||
set_vector_bytes(compute_encoder, batch_shape, 10);
|
||||
set_vector_bytes(compute_encoder, batch_strides, 11);
|
||||
|
||||
int batch_ndim_vec = batch_shape_vec.size();
|
||||
compute_encoder->setBytes(&batch_ndim_vec, sizeof(int), 12);
|
||||
set_vector_bytes(compute_encoder, batch_shape_vec, 13);
|
||||
set_vector_bytes(compute_encoder, batch_strides_vec, 14);
|
||||
|
||||
int batch_ndim_mat = batch_shape_mat.size();
|
||||
compute_encoder->setBytes(&batch_ndim_mat, sizeof(int), 15);
|
||||
set_vector_bytes(compute_encoder, batch_shape_mat, 16);
|
||||
set_vector_bytes(compute_encoder, batch_strides_mat, 17);
|
||||
|
||||
compute_encoder.set_input_array(lhs_indices, 18 + int(!is_b_matrix));
|
||||
compute_encoder.set_input_array(rhs_indices, 18 + int(is_b_matrix));
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
return;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Regular kernel dispatch
|
||||
|
||||
// Determine dispatch kernel
|
||||
int bm = 32, bn = 32, bk = 16;
|
||||
int wm = 2, wn = 2;
|
||||
|
||||
if ((size_t)batch_size_out * M * N >= 1ul << 20) {
|
||||
if (!transpose_a && transpose_b) {
|
||||
bm = 64;
|
||||
bn = (out.dtype() == float32) ? 64 : 32;
|
||||
bk = (out.dtype() == float32) ? 16 : 32;
|
||||
} else {
|
||||
bm = 64;
|
||||
bn = 64;
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare kernel name
|
||||
std::ostringstream kname;
|
||||
kname << "steel_block_sparse_gemm_" << (transpose_a ? 't' : 'n')
|
||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||
<< "_wm" << wm << "_wn" << wn << "_MN_"
|
||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
|
||||
<< ((K % bk == 0) ? "t" : "n") << "aligned";
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Use problem size to determine threadblock swizzle
|
||||
int tn = (N + bn - 1) / bn;
|
||||
int tm = (M + bm - 1) / bm;
|
||||
|
||||
// TODO: Explore device-based tuning for swizzle
|
||||
int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);
|
||||
|
||||
// Prepare steel matmul params
|
||||
GEMMParams params{
|
||||
/* const int M = */ M,
|
||||
/* const int N = */ N,
|
||||
/* const int K = */ K,
|
||||
/* const int lda = */ lda,
|
||||
/* const int ldb = */ ldb,
|
||||
/* const int ldd = */ N,
|
||||
/* const int tiles_n = */ tn,
|
||||
/* const int tiles_m = */ tm,
|
||||
/* const size_t batch_stride_a = */ lhs_indices_str,
|
||||
/* const size_t batch_stride_b = */ rhs_indices_str,
|
||||
/* const size_t batch_stride_d = */ matrix_stride_out,
|
||||
/* const int swizzle_log = */ swizzle_log,
|
||||
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||
/* const int batch_ndim = */ batch_ndim};
|
||||
|
||||
// Prepare launch grid params
|
||||
int tile = 1 << swizzle_log;
|
||||
tm = (tm + tile - 1) / tile;
|
||||
tn = tn * tile;
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
|
||||
|
||||
// Launch kernel
|
||||
compute_encoder.set_input_array(a, 0);
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 4);
|
||||
|
||||
set_vector_bytes(compute_encoder, batch_shape, 6);
|
||||
set_vector_bytes(compute_encoder, batch_strides, 7);
|
||||
|
||||
compute_encoder.set_input_array(lhs_indices, 10);
|
||||
compute_encoder.set_input_array(rhs_indices, 11);
|
||||
|
||||
set_vector_bytes(compute_encoder, batch_shape_A, 12);
|
||||
set_vector_bytes(compute_encoder, batch_strides_A, 13);
|
||||
set_vector_bytes(compute_encoder, batch_shape_B, 14);
|
||||
set_vector_bytes(compute_encoder, batch_strides_B, 15);
|
||||
set_vector_bytes(compute_encoder, operand_batch_ndim, 16);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
// Clear copies
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
return;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -32,6 +32,8 @@ NO_GPU(ArgSort)
|
||||
NO_GPU(AsType)
|
||||
NO_GPU(AsStrided)
|
||||
NO_GPU(BitwiseBinary)
|
||||
NO_GPU(BlockMaskedMM)
|
||||
NO_GPU(BlockSparseMM)
|
||||
NO_GPU(Broadcast)
|
||||
NO_GPU(Ceil)
|
||||
NO_GPU_MULTI(Compiled)
|
||||
@ -100,7 +102,6 @@ NO_GPU(Subtract)
|
||||
NO_GPU_MULTI(SVD)
|
||||
NO_GPU(Tan)
|
||||
NO_GPU(Tanh)
|
||||
NO_GPU(BlockMaskedMM)
|
||||
NO_GPU(Transpose)
|
||||
NO_GPU(Inverse)
|
||||
|
||||
|
118
mlx/ops.cpp
118
mlx/ops.cpp
@ -3785,6 +3785,124 @@ array block_masked_mm(
|
||||
return out;
|
||||
}
|
||||
|
||||
/** Compute matrix product with matrix-level gather */
|
||||
array block_sparse_mm(
|
||||
array a,
|
||||
array b,
|
||||
std::optional<array> lhs_indices_ /* = std::nullopt */,
|
||||
std::optional<array> rhs_indices_ /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
// If no indices, fall back to full matmul
|
||||
if (!lhs_indices_ && !rhs_indices_) {
|
||||
return matmul(a, b, s);
|
||||
}
|
||||
|
||||
// Do shape checks for operands
|
||||
int in_a_ndim = a.ndim();
|
||||
int in_b_ndim = b.ndim();
|
||||
|
||||
if (a.ndim() == 0 || b.ndim() == 0) {
|
||||
throw std::invalid_argument(
|
||||
"[block_sparse_mm] Got 0 dimension input. Inputs must "
|
||||
"have at least one dimension.");
|
||||
}
|
||||
|
||||
if (a.ndim() == 1) {
|
||||
// Insert a singleton dim in the beginning
|
||||
a = reshape(a, {1, -1}, s);
|
||||
}
|
||||
if (b.ndim() == 1) {
|
||||
// Insert a singleton dim at the end
|
||||
b = reshape(b, {-1, 1}, s);
|
||||
}
|
||||
|
||||
if (a.shape(-1) != b.shape(-2)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[block_sparse_mm] Last dimension of first input with shape "
|
||||
<< a.shape() << " must match second to last dimension of"
|
||||
<< " second input with shape " << b.shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// Type promotion
|
||||
auto out_type = result_type(a, b);
|
||||
if (!issubdtype(out_type, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[block_sparse_mm] Only real floating point types are supported but "
|
||||
<< a.dtype() << " and " << b.dtype()
|
||||
<< " were provided which results in " << out_type
|
||||
<< ", which is not a real floating point type.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
a = astype(a, out_type, s);
|
||||
b = astype(b, out_type, s);
|
||||
|
||||
// Handle broadcasting
|
||||
std::vector<int> bsx_a(a.shape().begin(), a.shape().end() - 2);
|
||||
std::vector<int> bsx_b(b.shape().begin(), b.shape().end() - 2);
|
||||
|
||||
auto indices_or_default = [&](const std::optional<array>& indices,
|
||||
const std::vector<int>& bsx_shape) {
|
||||
if (indices.has_value()) {
|
||||
return indices.value();
|
||||
} else {
|
||||
int n_batch = 1;
|
||||
for (auto& i : bsx_shape)
|
||||
n_batch *= i;
|
||||
return reshape(arange(n_batch, uint32, s), bsx_shape, s);
|
||||
}
|
||||
};
|
||||
|
||||
// Pull and broadcast indices
|
||||
array lhs_indices = indices_or_default(lhs_indices_, bsx_a);
|
||||
array rhs_indices = indices_or_default(rhs_indices_, bsx_b);
|
||||
|
||||
if (!issubdtype(lhs_indices.dtype(), integer)) {
|
||||
throw std::invalid_argument(
|
||||
"[block_sparse_mm] Got lhs_indices with invalid dtype. Indices must be integral.");
|
||||
}
|
||||
|
||||
if (!issubdtype(rhs_indices.dtype(), integer)) {
|
||||
throw std::invalid_argument(
|
||||
"[block_sparse_mm] Got rhs_indices with invalid dtype. Indices must be integral.");
|
||||
}
|
||||
|
||||
lhs_indices = astype(lhs_indices, uint32, s);
|
||||
rhs_indices = astype(rhs_indices, uint32, s);
|
||||
|
||||
int M = a.shape(-2);
|
||||
int N = b.shape(-1);
|
||||
int K = a.shape(-1);
|
||||
|
||||
auto out_bsx_shape =
|
||||
broadcast_shapes(lhs_indices.shape(), rhs_indices.shape());
|
||||
|
||||
lhs_indices = broadcast_to(lhs_indices, out_bsx_shape, s);
|
||||
rhs_indices = broadcast_to(rhs_indices, out_bsx_shape, s);
|
||||
|
||||
auto out_shape = out_bsx_shape;
|
||||
out_shape.push_back(M);
|
||||
out_shape.push_back(N);
|
||||
|
||||
// Caculate array
|
||||
auto out = array(
|
||||
out_shape,
|
||||
out_type,
|
||||
std::make_shared<BlockSparseMM>(to_stream(s)),
|
||||
{a, b, lhs_indices, rhs_indices});
|
||||
|
||||
// Remove the possibly inserted singleton dimensions
|
||||
if (in_a_ndim == 1 || in_b_ndim == 1) {
|
||||
out_shape.erase(
|
||||
out_shape.end() - ((in_a_ndim == 1) ? 2 : 1),
|
||||
out_shape.end() - ((in_b_ndim == 1) ? 0 : 1));
|
||||
out = reshape(out, out_shape, s);
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
array diagonal(
|
||||
const array& a,
|
||||
int offset /* = 0 */,
|
||||
|
@ -1183,6 +1183,14 @@ array block_masked_mm(
|
||||
std::optional<array> mask_rhs = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Compute matrix product with matrix-level gather */
|
||||
array block_sparse_mm(
|
||||
array a,
|
||||
array b,
|
||||
std::optional<array> lhs_indices = std::nullopt,
|
||||
std::optional<array> rhs_indices = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Extract a diagonal or construct a diagonal array */
|
||||
array diagonal(
|
||||
const array& a,
|
||||
|
@ -3357,7 +3357,61 @@ std::vector<array> BlockMaskedMM::vjp(
|
||||
|
||||
vjps.push_back(grad);
|
||||
} else {
|
||||
vjps.push_back(zeros_like(primals[arg], stream()));
|
||||
throw std::invalid_argument(
|
||||
"[BlockMaskedMM] Cannot calculate VJP with respect to masks.");
|
||||
}
|
||||
}
|
||||
return vjps;
|
||||
}
|
||||
|
||||
std::vector<array> BlockSparseMM::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
std::vector<array> vjps;
|
||||
auto& cotan = cotangents[0];
|
||||
|
||||
auto& lhs_indices = primals[2];
|
||||
auto& rhs_indices = primals[3];
|
||||
|
||||
int M = cotan.shape(-2);
|
||||
int N = cotan.shape(-1);
|
||||
int K = primals[0].shape(-1);
|
||||
|
||||
for (auto arg : argnums) {
|
||||
if (arg == 0) {
|
||||
// M X N * (K X N).T -> M X K
|
||||
auto base = zeros_like(primals[0], stream());
|
||||
auto bt = swapaxes(primals[1], -1, -2, stream());
|
||||
|
||||
auto base_shape = base.shape();
|
||||
base = reshape(base, {-1, M, K}, stream());
|
||||
|
||||
// g : (out_batch_shape) + (M, K)
|
||||
auto g = block_sparse_mm(cotan, bt, std::nullopt, rhs_indices, stream());
|
||||
g = expand_dims(g, -3, stream());
|
||||
auto gacc = scatter_add(base, lhs_indices, g, 0, stream());
|
||||
|
||||
vjps.push_back(reshape(gacc, base_shape, stream()));
|
||||
|
||||
} else if (arg == 1) {
|
||||
// (M X K).T * M X N -> K X N
|
||||
auto base = zeros_like(primals[1], stream());
|
||||
auto at = swapaxes(primals[0], -1, -2, stream());
|
||||
|
||||
auto base_shape = base.shape();
|
||||
base = reshape(base, {-1, K, N}, stream());
|
||||
|
||||
// g : (out_batch_shape) + (K, N)
|
||||
auto g = block_sparse_mm(at, cotan, lhs_indices, std::nullopt, stream());
|
||||
g = expand_dims(g, -3, stream());
|
||||
auto gacc = scatter_add(base, rhs_indices, g, 0, stream());
|
||||
|
||||
vjps.push_back(reshape(gacc, base_shape, stream()));
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[BlockSparseMM] Cannot calculate VJP with respect to indices.");
|
||||
}
|
||||
}
|
||||
return vjps;
|
||||
|
@ -485,6 +485,26 @@ class BlockMaskedMM : public UnaryPrimitive {
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class BlockSparseMM : public UnaryPrimitive {
|
||||
public:
|
||||
explicit BlockSparseMM(Stream stream) : UnaryPrimitive(stream) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
DEFINE_PRINT(BlockSparseMM)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Broadcast : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Broadcast(Stream stream, const std::vector<int>& shape)
|
||||
|
@ -3716,6 +3716,38 @@ void init_ops(nb::module_& m) {
|
||||
mask_lhs (array, optional): Boolean mask for a (default: ``None``)
|
||||
mask_rhs (array, optional): Boolean mask for b (default: ``None``)
|
||||
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"block_sparse_mm",
|
||||
&block_sparse_mm,
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
"lhs_indices"_a = nb::none(),
|
||||
"rhs_indices"_a = nb::none(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def block_sparse_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Matrix multiplication with matrix-level gather.
|
||||
|
||||
Performs a gather of the operands with the given indices followed by a (possibly batched) matrix multiplication of two arrays.
|
||||
This operation is more efficient than explicitly applying a :func:``take`` followed by a :func:``matmul``.
|
||||
|
||||
The indices ``lhs_indices`` and ``rhs_indices`` contain flat indices along the batch dimensions (i.e. all but the last two dimensions) of ``a`` and ``b`` respectively.
|
||||
|
||||
For ``a`` with shape ``(A1, A2, ..., AS, M, K)``,
|
||||
``lhs_indices`` contains indices from the range ``[0, A1 * A2 * ... * AS)``
|
||||
|
||||
For ``b`` with shape ``(B1, B2, ..., BS, M, K)``,
|
||||
``rhs_indices`` contains indices from the range ``[0, B1 * B2 * ... * BS)``
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
b (array): Input array.
|
||||
lhs_indices (array, optional): Integer indices for ``a`` (default: ``None``)
|
||||
rhs_indices (array, optional): Integer indices for ``b`` (default: ``None``)
|
||||
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"diagonal",
|
||||
|
@ -810,6 +810,199 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
|
||||
self.assertTrue(np.allclose(c_mx, c_np, atol=1e-5))
|
||||
|
||||
def test_block_sparse_matmul(self):
|
||||
def np_block_sparse_mm(a, b, lhs_indices=None, rhs_indices=None):
|
||||
a = a.reshape((-1, a.shape[-2], a.shape[-1]))
|
||||
b = b.reshape((-1, b.shape[-2], b.shape[-1]))
|
||||
lhs_indices = lhs_indices or np.arange(a.shape[0])
|
||||
rhs_indices = rhs_indices or np.arange(b.shape[0])
|
||||
a = a[lhs_indices, :, :]
|
||||
b = b[rhs_indices, :, :]
|
||||
out = a @ b
|
||||
return out
|
||||
|
||||
def test_shape(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
np_dtype=np.float32,
|
||||
batch_A=(),
|
||||
batch_B=(),
|
||||
lhs_indices=None,
|
||||
rhs_indices=None,
|
||||
):
|
||||
with self.subTest(
|
||||
M=M,
|
||||
N=N,
|
||||
K=K,
|
||||
np_dtype=np_dtype,
|
||||
batch_A=batch_A,
|
||||
batch_B=batch_B,
|
||||
lhs_indices=lhs_indices,
|
||||
rhs_indices=rhs_indices,
|
||||
):
|
||||
|
||||
a_np = np.random.normal(size=batch_A + (M, K)).astype(np_dtype)
|
||||
b_np = np.random.normal(size=batch_B + (K, N)).astype(np_dtype)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
|
||||
out_np = np_block_sparse_mm(a_np, b_np, lhs_indices, rhs_indices)
|
||||
|
||||
lhs_indices_mx = None if lhs_indices is None else mx.array(lhs_indices)
|
||||
rhs_indices_mx = None if rhs_indices is None else mx.array(rhs_indices)
|
||||
|
||||
out_mx = mx.block_sparse_mm(a_mx, b_mx, lhs_indices_mx, rhs_indices_mx)
|
||||
|
||||
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
|
||||
|
||||
inputs = (
|
||||
{
|
||||
"batch_A": (1,),
|
||||
"lhs_indices": (0,),
|
||||
"batch_B": (3,),
|
||||
"rhs_indices": (2, 1),
|
||||
},
|
||||
{
|
||||
"batch_A": (1,),
|
||||
"lhs_indices": None,
|
||||
"batch_B": (3,),
|
||||
"rhs_indices": (2, 1),
|
||||
},
|
||||
{
|
||||
"batch_A": (2,),
|
||||
"lhs_indices": None,
|
||||
"batch_B": (3,),
|
||||
"rhs_indices": (2, 1),
|
||||
},
|
||||
{
|
||||
"batch_A": (3,),
|
||||
"lhs_indices": (0, 2),
|
||||
"batch_B": (1,),
|
||||
"rhs_indices": (0,),
|
||||
},
|
||||
{
|
||||
"batch_A": (5,),
|
||||
"lhs_indices": (0, 2),
|
||||
"batch_B": (3,),
|
||||
"rhs_indices": (2, 1),
|
||||
},
|
||||
{
|
||||
"batch_A": (4, 2),
|
||||
"lhs_indices": (
|
||||
(7, 6),
|
||||
(5, 4),
|
||||
(1, 2),
|
||||
),
|
||||
"batch_B": (4, 1),
|
||||
"rhs_indices": ((2,), (0,), (1,)),
|
||||
},
|
||||
)
|
||||
|
||||
for kwargs in inputs:
|
||||
test_shape(32, 32, 32, **kwargs)
|
||||
test_shape(16, 1, 16, **kwargs)
|
||||
|
||||
# Add tests for broadcasting
|
||||
a_np = np.random.normal(size=(5, 32, 32)).astype(np.float32)
|
||||
b_np = np.random.normal(size=(3, 32, 32)).astype(np.float32)
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
|
||||
# Numpy
|
||||
a_np = a_np.reshape((5, 1, 32, 32))
|
||||
b_np = b_np.reshape((1, 3, 32, 32))
|
||||
|
||||
a_np = np.broadcast_to(a_np, (5, 4, 32, 32))
|
||||
b_np = np.broadcast_to(b_np, (2, 3, 32, 32)).swapaxes(1, 0)
|
||||
|
||||
lhs_indices = [0, 13, 12]
|
||||
rhs_indices = [0, 3, 5]
|
||||
|
||||
out_np = np_block_sparse_mm(a_np, b_np, lhs_indices, rhs_indices)
|
||||
|
||||
# MLX
|
||||
a_mx = a_mx.reshape((5, 1, 32, 32))
|
||||
b_mx = b_mx.reshape((1, 3, 32, 32))
|
||||
|
||||
a_mx = mx.broadcast_to(a_mx, (5, 4, 32, 32))
|
||||
b_mx = mx.broadcast_to(b_mx, (2, 3, 32, 32)).swapaxes(1, 0)
|
||||
|
||||
lhs_indices_mx = mx.array(lhs_indices)
|
||||
rhs_indices_mx = mx.array(rhs_indices)
|
||||
|
||||
out_mx = mx.block_sparse_mm(a_mx, b_mx, lhs_indices_mx, rhs_indices_mx)
|
||||
|
||||
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
|
||||
|
||||
# Gemv test
|
||||
a_np = np.random.normal(size=(5, 1, 32)).astype(np.float32)
|
||||
b_np = np.random.normal(size=(3, 16, 32)).astype(np.float32)
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
|
||||
lhs_indices = [3, 1]
|
||||
rhs_indices = [0, 2]
|
||||
|
||||
b_np_t = np.swapaxes(b_np, -1, -2)
|
||||
out_np = np_block_sparse_mm(a_np, b_np_t, lhs_indices, rhs_indices)
|
||||
|
||||
lhs_indices_mx = mx.array(lhs_indices)
|
||||
rhs_indices_mx = mx.array(rhs_indices)
|
||||
|
||||
b_mx_t = mx.swapaxes(b_mx, -1, -2)
|
||||
out_mx = mx.block_sparse_mm(a_mx, b_mx_t, lhs_indices_mx, rhs_indices_mx)
|
||||
|
||||
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5))
|
||||
|
||||
def test_block_sparse_matmul_grad(self):
|
||||
|
||||
lhs_indices = mx.array([[7, 6], [4, 1], [0, 2]], dtype=mx.uint32)
|
||||
rhs_indices = mx.array([[2], [0], [1]], dtype=mx.uint32)
|
||||
|
||||
def f_ref(a, b):
|
||||
lhs_indices_ = mx.broadcast_to(lhs_indices, (3, 2))
|
||||
rhs_indices_ = mx.broadcast_to(rhs_indices, (3, 2))
|
||||
M = a.shape[-2]
|
||||
N = b.shape[-2]
|
||||
K = a.shape[-1]
|
||||
|
||||
a = a.reshape((-1, M, K))
|
||||
b = b.reshape((-1, K, N))
|
||||
|
||||
a = mx.take(a, lhs_indices_, 0)
|
||||
b = mx.take(b, rhs_indices_, 0)
|
||||
|
||||
return a @ b
|
||||
|
||||
def f_test(a, b):
|
||||
return mx.block_sparse_mm(a, b, lhs_indices, rhs_indices)
|
||||
|
||||
a_mx = mx.random.normal((4, 2, 32, 32))
|
||||
b_mx = mx.random.normal((4, 1, 32, 32))
|
||||
|
||||
out_test = f_test(a_mx, b_mx)
|
||||
out_ref = f_ref(a_mx, b_mx)
|
||||
|
||||
self.assertTrue(mx.allclose(out_test, out_ref, atol=1e-5))
|
||||
|
||||
cotan = mx.ones_like(out_test)
|
||||
out_ref, dout_ref = mx.vjp(
|
||||
f_ref,
|
||||
[a_mx, b_mx],
|
||||
[cotan],
|
||||
)
|
||||
out_test, dout_test = mx.vjp(
|
||||
f_test,
|
||||
[a_mx, b_mx],
|
||||
[cotan],
|
||||
)
|
||||
|
||||
for r, t in zip(dout_ref, dout_test):
|
||||
self.assertEqual(r.shape, t.shape)
|
||||
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user