mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 11:14:32 +08:00
Block sparse mm (#1058)
This commit is contained in:
@@ -7,6 +7,8 @@
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -33,22 +35,19 @@ struct GEMVKernel {
|
||||
// - We assume each thead group is launched with (BN, BM, 1) threads
|
||||
//
|
||||
// 1. A thread loads TN elements each from mat along TM contiguous rows
|
||||
// and the corresponding scalar from the vector
|
||||
// and the corresponding scalar from the vector
|
||||
// 2. The thread then multiplies and adds to accumulate its local result for
|
||||
// the block
|
||||
// the block
|
||||
// 3. At the end, each thread has accumulated results over all blocks across
|
||||
// the rows
|
||||
// These are then summed up across the threadgroup
|
||||
// the rows. These are then summed up across the threadgroup
|
||||
// 4. Each threadgroup writes its accumulated BN * TN outputs
|
||||
//
|
||||
// Edge case handling:
|
||||
// - The threadgroup with the largest tid will have blocks that exceed the
|
||||
// matrix
|
||||
// - The threadgroup with the largest tid has blocks that exceed the matrix
|
||||
// * The blocks that start outside the matrix are never read (thread results
|
||||
// remain zero)
|
||||
// remain zero)
|
||||
// * The last thread that partially overlaps with the matrix is shifted
|
||||
// inwards
|
||||
// such that the thread block fits exactly in the matrix
|
||||
// inwards such that the thread block fits exactly in the matrix
|
||||
|
||||
MLX_MTL_CONST short tgp_mem_size = BN * TN * 2;
|
||||
|
||||
@@ -100,14 +99,14 @@ struct GEMVKernel {
|
||||
if (simd_gid == 0) {
|
||||
// Main load loop
|
||||
if (bn + TN <= in_vec_size) {
|
||||
#pragma clang loop unroll(full)
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
in_vec_block[tn] = in_vec[bn + tn];
|
||||
}
|
||||
|
||||
} else { // Edgecase
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
in_vec_block[tn] = bn + tn < in_vec_size ? in_vec[bn + tn] : 0;
|
||||
}
|
||||
@@ -116,24 +115,24 @@ struct GEMVKernel {
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load for all rows
|
||||
#pragma clang loop unroll(full)
|
||||
// Load for all rows
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
v_coeff[tn] = in_vec_block[tn];
|
||||
}
|
||||
|
||||
// Per thread work loop
|
||||
#pragma clang loop unroll(full)
|
||||
// Per thread work loop
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
// Load for the row
|
||||
if (bn + TN <= in_vec_size) {
|
||||
#pragma clang loop unroll(full)
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[tm * marix_ld + bn + tn];
|
||||
}
|
||||
|
||||
} else { // Edgecase
|
||||
#pragma clang loop unroll(full)
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
int col_idx =
|
||||
(bn + tn) < in_vec_size ? (bn + tn) : (in_vec_size - 1);
|
||||
@@ -142,21 +141,22 @@ struct GEMVKernel {
|
||||
}
|
||||
|
||||
// Accumulate results
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
result[tm] += inter[tn] * v_coeff[tn];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Simdgroup accumulations
|
||||
#pragma clang loop unroll(full)
|
||||
// Simdgroup accumulations
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
result[tm] = simd_sum(result[tm]);
|
||||
}
|
||||
|
||||
// Write outputs
|
||||
if (simd_lid == 0) {
|
||||
#pragma clang loop unroll(full)
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
if (kDoAxpby) {
|
||||
out_vec[out_row + tm] = static_cast<T>(alpha) * result[tm] +
|
||||
@@ -187,22 +187,18 @@ struct GEMVTKernel {
|
||||
// - We assume each thead group is launched with (BN, BM, 1) threads
|
||||
//
|
||||
// 1. A thread loads TN elements each from mat along TM contiguous rows
|
||||
// and the corresponding scalar from the vector
|
||||
// 2. The thread then multiplies and adds to accumulate its local result for
|
||||
// the block
|
||||
// and the corresponding scalar from the vector
|
||||
// 2. The thread then accumulates its local result for the block
|
||||
// 3. At the end, each thread has accumulated results over all blocks across
|
||||
// the rows
|
||||
// These are then summed up across the threadgroup
|
||||
// the rows. These are then summed up across the threadgroup
|
||||
// 4. Each threadgroup writes its accumulated BN * TN outputs
|
||||
//
|
||||
// Edge case handling:
|
||||
// - The threadgroup with the largest tid will have blocks that exceed the
|
||||
// matrix
|
||||
// - The threadgroup with the largest tid has blocks that exceed the matrix
|
||||
// * The blocks that start outside the matrix are never read (thread results
|
||||
// remain zero)
|
||||
// remain zero)
|
||||
// * The last thread that partially overlaps with the matrix is shifted
|
||||
// inwards
|
||||
// such that the thread block fits exactly in the matrix
|
||||
// inwards such that the thread block fits exactly in the matrix
|
||||
|
||||
MLX_MTL_CONST short tgp_mem_size = BN * BM * TN;
|
||||
|
||||
@@ -249,12 +245,12 @@ struct GEMVTKernel {
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
if (bm + TM <= in_vec_size) {
|
||||
#pragma clang loop unroll(full)
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
v_coeff[tm] = in_vec[bm + tm];
|
||||
}
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int tm = 0; tm < TM; tm++) {
|
||||
for (int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
||||
@@ -281,7 +277,7 @@ struct GEMVTKernel {
|
||||
|
||||
// Threadgroup collection
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < TN; i++) {
|
||||
tgp_results[lid.y * TN + i] = result[i];
|
||||
}
|
||||
@@ -290,15 +286,15 @@ struct GEMVTKernel {
|
||||
|
||||
// Threadgroup accumulation and writing out results
|
||||
if (lid.y == 0 && out_col < out_vec_size) {
|
||||
#pragma clang loop unroll(full)
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int i = 1; i < BM; i++) {
|
||||
#pragma clang loop unroll(full)
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < TN; j++) {
|
||||
result[j] += tgp_results[i * TN + j];
|
||||
}
|
||||
}
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
MLX_MTL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < TN; j++) {
|
||||
if (kDoAxpby) {
|
||||
out_vec[out_col + j] = static_cast<T>(alpha) * result[j] +
|
||||
@@ -408,20 +404,149 @@ template <
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_gemv(name, itype, bm, bn, tm, tn) \
|
||||
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 0) \
|
||||
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 1) \
|
||||
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 0) \
|
||||
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 1)
|
||||
// clang-format off
|
||||
#define instantiate_gemv(name, itype, bm, bn, tm, tn) \
|
||||
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 0) \
|
||||
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 1) \
|
||||
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 0) \
|
||||
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 1) // clang-format on
|
||||
|
||||
#define instantiate_gemv_blocks(name, itype) \
|
||||
instantiate_gemv(name, itype, 4, 32, 1, 4) instantiate_gemv( \
|
||||
name, itype, 4, 32, 4, 4) instantiate_gemv(name, itype, 8, 32, 4, 4)
|
||||
// clang-format off
|
||||
#define instantiate_gemv_blocks(name, itype) \
|
||||
instantiate_gemv(name, itype, 4, 32, 1, 4) \
|
||||
instantiate_gemv(name, itype, 4, 32, 4, 4) \
|
||||
instantiate_gemv(name, itype, 8, 32, 4, 4) // clang-format on
|
||||
|
||||
instantiate_gemv_blocks(float32, float);
|
||||
instantiate_gemv_blocks(float16, half);
|
||||
instantiate_gemv_blocks(bfloat16, bfloat16_t);
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN> /* Thread cols (in elements) */
|
||||
[[kernel, max_total_threads_per_threadgroup(BM* BN)]] void gemv_bs(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
const device T* bias [[buffer(2)]],
|
||||
device T* out_vec [[buffer(3)]],
|
||||
const constant int& in_vec_size [[buffer(4)]],
|
||||
const constant int& out_vec_size [[buffer(5)]],
|
||||
const constant int& marix_ld [[buffer(6)]],
|
||||
const constant float& alpha [[buffer(7)]],
|
||||
const constant float& beta [[buffer(8)]],
|
||||
const constant int& batch_ndim [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant size_t* index_batch_strides [[buffer(11)]],
|
||||
const constant int& vector_batch_ndim [[buffer(12)]],
|
||||
const constant int* vector_batch_shape [[buffer(13)]],
|
||||
const constant size_t* vector_batch_stride [[buffer(14)]],
|
||||
const constant int& matrix_batch_ndim [[buffer(15)]],
|
||||
const constant int* matrix_batch_shape [[buffer(16)]],
|
||||
const constant size_t* matrix_batch_stride [[buffer(17)]],
|
||||
const constant uint32_t* vec_indices [[buffer(18)]],
|
||||
const constant uint32_t* mat_indices [[buffer(19)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
using gemv_kernel = GEMVKernel<T, BM, BN, TM, TN, false>;
|
||||
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
|
||||
|
||||
uint32_t indx_vec;
|
||||
uint32_t indx_mat;
|
||||
|
||||
// Update batch offsets
|
||||
if (batch_ndim > 1) {
|
||||
const constant size_t* veci_bstrides = index_batch_strides;
|
||||
const constant size_t* mati_bstrides = index_batch_strides + batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim);
|
||||
|
||||
indx_vec = vec_indices[batch_offsets.x];
|
||||
indx_mat = mat_indices[batch_offsets.y];
|
||||
|
||||
} else {
|
||||
indx_vec = vec_indices[index_batch_strides[0] * tid.z];
|
||||
indx_mat = mat_indices[index_batch_strides[batch_ndim] * tid.z];
|
||||
}
|
||||
|
||||
if (vector_batch_ndim > 1) {
|
||||
in_vec += elem_to_loc(
|
||||
indx_vec, vector_batch_shape, vector_batch_stride, vector_batch_ndim);
|
||||
} else {
|
||||
in_vec += indx_vec * vector_batch_stride[0];
|
||||
}
|
||||
|
||||
if (matrix_batch_ndim > 1) {
|
||||
mat += elem_to_loc(
|
||||
indx_mat, matrix_batch_shape, matrix_batch_stride, matrix_batch_ndim);
|
||||
} else {
|
||||
mat += indx_mat * matrix_batch_stride[0];
|
||||
}
|
||||
|
||||
out_vec += tid.z * out_vec_size;
|
||||
|
||||
gemv_kernel::run(
|
||||
mat,
|
||||
in_vec,
|
||||
bias,
|
||||
out_vec,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
marix_ld,
|
||||
alpha,
|
||||
beta,
|
||||
batch_ndim, // Not used
|
||||
tgp_memory,
|
||||
tid,
|
||||
lid,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
#define instantiate_gemv_bs_helper(nm, itype, bm, bn, tm, tn) \
|
||||
template [[host_name("gemv_bs_" #nm "_bm" #bm "_bn" #bn "_tm" #tm \
|
||||
"_tn" #tn)]] [[kernel]] void \
|
||||
gemv_bs<itype, bm, bn, tm, tn>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* in_vec [[buffer(1)]], \
|
||||
const device itype* bias [[buffer(2)]], \
|
||||
device itype* out_vec [[buffer(3)]], \
|
||||
const constant int& in_vec_size [[buffer(4)]], \
|
||||
const constant int& out_vec_size [[buffer(5)]], \
|
||||
const constant int& marix_ld [[buffer(6)]], \
|
||||
const constant float& alpha [[buffer(7)]], \
|
||||
const constant float& beta [[buffer(8)]], \
|
||||
const constant int& batch_ndim [[buffer(9)]], \
|
||||
const constant int* batch_shape [[buffer(10)]], \
|
||||
const constant size_t* index_batch_strides [[buffer(11)]], \
|
||||
const constant int& vector_batch_ndim [[buffer(12)]], \
|
||||
const constant int* vector_batch_shape [[buffer(13)]], \
|
||||
const constant size_t* vector_batch_stride [[buffer(14)]], \
|
||||
const constant int& matrix_batch_ndim [[buffer(15)]], \
|
||||
const constant int* matrix_batch_shape [[buffer(16)]], \
|
||||
const constant size_t* matrix_batch_stride [[buffer(17)]], \
|
||||
const constant uint32_t* vec_indices [[buffer(18)]], \
|
||||
const constant uint32_t* mat_indices [[buffer(19)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemv_bs_blocks(name, itype) \
|
||||
instantiate_gemv_bs_helper(name, itype, 4, 32, 1, 4) \
|
||||
instantiate_gemv_bs_helper(name, itype, 4, 32, 4, 4) \
|
||||
instantiate_gemv_bs_helper(name, itype, 8, 32, 4, 4) // clang-format on
|
||||
|
||||
instantiate_gemv_bs_blocks(float32, float);
|
||||
instantiate_gemv_bs_blocks(float16, half);
|
||||
instantiate_gemv_bs_blocks(bfloat16, bfloat16_t);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Vector matrix multiplication
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -538,4 +663,134 @@ template <
|
||||
// clang-format off
|
||||
instantiate_gemv_t_blocks(float32, float);
|
||||
instantiate_gemv_t_blocks(float16, half);
|
||||
instantiate_gemv_t_blocks(bfloat16, bfloat16_t); // clang-format on
|
||||
instantiate_gemv_t_blocks(bfloat16, bfloat16_t); // clang-format on
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN> /* Thread cols (in elements) */
|
||||
[[kernel, max_total_threads_per_threadgroup(BM* BN)]] void gemv_t_bs(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
const device T* bias [[buffer(2)]],
|
||||
device T* out_vec [[buffer(3)]],
|
||||
const constant int& in_vec_size [[buffer(4)]],
|
||||
const constant int& out_vec_size [[buffer(5)]],
|
||||
const constant int& marix_ld [[buffer(6)]],
|
||||
const constant float& alpha [[buffer(7)]],
|
||||
const constant float& beta [[buffer(8)]],
|
||||
const constant int& batch_ndim [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant size_t* index_batch_strides [[buffer(11)]],
|
||||
const constant int& vector_batch_ndim [[buffer(12)]],
|
||||
const constant int* vector_batch_shape [[buffer(13)]],
|
||||
const constant size_t* vector_batch_stride [[buffer(14)]],
|
||||
const constant int& matrix_batch_ndim [[buffer(15)]],
|
||||
const constant int* matrix_batch_shape [[buffer(16)]],
|
||||
const constant size_t* matrix_batch_stride [[buffer(17)]],
|
||||
const constant uint32_t* vec_indices [[buffer(18)]],
|
||||
const constant uint32_t* mat_indices [[buffer(19)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
using gemv_kernel = GEMVTKernel<T, BM, BN, TM, TN, false>;
|
||||
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
|
||||
|
||||
uint32_t indx_vec;
|
||||
uint32_t indx_mat;
|
||||
|
||||
// Update batch offsets
|
||||
if (batch_ndim > 1) {
|
||||
const constant size_t* veci_bstrides = index_batch_strides;
|
||||
const constant size_t* mati_bstrides = index_batch_strides + batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim);
|
||||
|
||||
indx_vec = vec_indices[batch_offsets.x];
|
||||
indx_mat = mat_indices[batch_offsets.y];
|
||||
|
||||
} else {
|
||||
indx_vec = vec_indices[index_batch_strides[0] * tid.z];
|
||||
indx_mat = mat_indices[index_batch_strides[batch_ndim] * tid.z];
|
||||
}
|
||||
|
||||
if (vector_batch_ndim > 1) {
|
||||
in_vec += elem_to_loc(
|
||||
indx_vec, vector_batch_shape, vector_batch_stride, vector_batch_ndim);
|
||||
} else {
|
||||
in_vec += indx_vec * vector_batch_stride[0];
|
||||
}
|
||||
|
||||
if (matrix_batch_ndim > 1) {
|
||||
mat += elem_to_loc(
|
||||
indx_mat, matrix_batch_shape, matrix_batch_stride, matrix_batch_ndim);
|
||||
} else {
|
||||
mat += indx_mat * matrix_batch_stride[0];
|
||||
}
|
||||
|
||||
out_vec += tid.z * out_vec_size;
|
||||
|
||||
gemv_kernel::run(
|
||||
mat,
|
||||
in_vec,
|
||||
bias,
|
||||
out_vec,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
marix_ld,
|
||||
alpha,
|
||||
beta,
|
||||
batch_ndim, // Not used,
|
||||
tgp_memory,
|
||||
tid,
|
||||
lid,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
#define instantiate_gemv_t_bs_helper(nm, itype, bm, bn, tm, tn) \
|
||||
template [[host_name("gemv_t_bs_" #nm "_bm" #bm "_bn" #bn "_tm" #tm \
|
||||
"_tn" #tn)]] [[kernel]] void \
|
||||
gemv_t_bs<itype, bm, bn, tm, tn>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* in_vec [[buffer(1)]], \
|
||||
const device itype* bias [[buffer(2)]], \
|
||||
device itype* out_vec [[buffer(3)]], \
|
||||
const constant int& in_vec_size [[buffer(4)]], \
|
||||
const constant int& out_vec_size [[buffer(5)]], \
|
||||
const constant int& marix_ld [[buffer(6)]], \
|
||||
const constant float& alpha [[buffer(7)]], \
|
||||
const constant float& beta [[buffer(8)]], \
|
||||
const constant int& batch_ndim [[buffer(9)]], \
|
||||
const constant int* batch_shape [[buffer(10)]], \
|
||||
const constant size_t* index_batch_strides [[buffer(11)]], \
|
||||
const constant int& vector_batch_ndim [[buffer(12)]], \
|
||||
const constant int* vector_batch_shape [[buffer(13)]], \
|
||||
const constant size_t* vector_batch_stride [[buffer(14)]], \
|
||||
const constant int& matrix_batch_ndim [[buffer(15)]], \
|
||||
const constant int* matrix_batch_shape [[buffer(16)]], \
|
||||
const constant size_t* matrix_batch_stride [[buffer(17)]], \
|
||||
const constant uint32_t* vec_indices [[buffer(18)]], \
|
||||
const constant uint32_t* mat_indices [[buffer(19)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemv_t_bs_blocks(name, itype) \
|
||||
instantiate_gemv_t_bs_helper(name, itype, 8, 8, 4, 1) \
|
||||
instantiate_gemv_t_bs_helper(name, itype, 8, 8, 4, 4) \
|
||||
instantiate_gemv_t_bs_helper(name, itype, 8, 16, 4, 4) \
|
||||
instantiate_gemv_t_bs_helper(name, itype, 8, 32, 4, 4) \
|
||||
instantiate_gemv_t_bs_helper(name, itype, 8, 64, 4, 4) \
|
||||
instantiate_gemv_t_bs_helper(name, itype, 8, 128, 4, 4) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
instantiate_gemv_t_bs_blocks(float32, float);
|
||||
instantiate_gemv_t_bs_blocks(float16, half);
|
||||
instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); // clang-format on
|
@@ -0,0 +1,168 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
using namespace mlx::steel;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool MN_aligned,
|
||||
bool K_aligned>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void bs_gemm(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
device T* D [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant size_t* batch_strides [[buffer(7)]],
|
||||
const constant uint32_t* lhs_indices [[buffer(10)]],
|
||||
const constant uint32_t* rhs_indices [[buffer(11)]],
|
||||
const constant int* batch_shape_A [[buffer(12)]],
|
||||
const constant size_t* batch_strides_A [[buffer(13)]],
|
||||
const constant int* batch_shape_B [[buffer(14)]],
|
||||
const constant size_t* batch_strides_B [[buffer(15)]],
|
||||
const constant int2& operand_batch_ndim [[buffer(16)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
using gemm_kernel = GEMMKernel<
|
||||
T,
|
||||
T,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
MN_aligned,
|
||||
K_aligned>;
|
||||
|
||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
|
||||
uint32_t indx_A;
|
||||
uint32_t indx_B;
|
||||
|
||||
// Adjust for batch
|
||||
if (params->batch_ndim > 1) {
|
||||
const constant size_t* A_bstrides = batch_strides;
|
||||
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
|
||||
|
||||
indx_A = lhs_indices[batch_offsets.x];
|
||||
indx_B = rhs_indices[batch_offsets.y];
|
||||
|
||||
} else {
|
||||
indx_A = lhs_indices[params->batch_stride_a * tid.z];
|
||||
indx_B = rhs_indices[params->batch_stride_b * tid.z];
|
||||
}
|
||||
|
||||
int batch_ndim_A = operand_batch_ndim.x;
|
||||
int batch_ndim_B = operand_batch_ndim.y;
|
||||
|
||||
if (batch_ndim_A > 1) {
|
||||
A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A);
|
||||
} else {
|
||||
A += indx_A * batch_strides_A[0];
|
||||
}
|
||||
|
||||
if (batch_ndim_B > 1) {
|
||||
B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B);
|
||||
} else {
|
||||
B += indx_B * batch_strides_B[0];
|
||||
}
|
||||
|
||||
D += params->batch_stride_d * tid.z;
|
||||
|
||||
gemm_kernel::run(
|
||||
A, B, D, params, As, Bs, simd_lane_id, simd_group_id, tid, lid);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernel initializations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_gemm( \
|
||||
tname, \
|
||||
trans_a, \
|
||||
trans_b, \
|
||||
iname, \
|
||||
itype, \
|
||||
oname, \
|
||||
otype, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
aname, \
|
||||
mn_aligned, \
|
||||
kname, \
|
||||
k_aligned) \
|
||||
template [[host_name("steel_block_sparse_gemm_" #tname "_" #iname "_" #oname \
|
||||
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \
|
||||
"_MN_" #aname "_K_" #kname)]] [[kernel]] void \
|
||||
bs_gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
|
||||
const device itype* A [[buffer(0)]], \
|
||||
const device itype* B [[buffer(1)]], \
|
||||
device itype* D [[buffer(3)]], \
|
||||
const constant GEMMParams* params [[buffer(4)]], \
|
||||
const constant int* batch_shape [[buffer(6)]], \
|
||||
const constant size_t* batch_strides [[buffer(7)]], \
|
||||
const constant uint32_t* lhs_indices [[buffer(10)]], \
|
||||
const constant uint32_t* rhs_indices [[buffer(11)]], \
|
||||
const constant int* batch_shape_A [[buffer(12)]], \
|
||||
const constant size_t* batch_strides_A [[buffer(13)]], \
|
||||
const constant int* batch_shape_B [[buffer(14)]], \
|
||||
const constant size_t* batch_strides_B [[buffer(15)]], \
|
||||
const constant int2& operand_batch_ndim [[buffer(16)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
||||
|
||||
instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on
|
@@ -21,9 +21,9 @@ struct GEMMParams {
|
||||
const int tiles_n;
|
||||
const int tiles_m;
|
||||
|
||||
const int batch_stride_a;
|
||||
const int batch_stride_b;
|
||||
const int batch_stride_d;
|
||||
const size_t batch_stride_a;
|
||||
const size_t batch_stride_b;
|
||||
const size_t batch_stride_d;
|
||||
|
||||
const int swizzle_log;
|
||||
const int gemm_k_iterations_aligned;
|
||||
@@ -54,7 +54,7 @@ struct GEMMAddMMParams {
|
||||
const int ldc;
|
||||
const int fdc;
|
||||
|
||||
const int batch_stride_c;
|
||||
const size_t batch_stride_c;
|
||||
|
||||
const float alpha;
|
||||
const float beta;
|
||||
|
@@ -327,9 +327,9 @@ void steel_matmul_conv_groups(
|
||||
/* const int ldd = */ ldd,
|
||||
/* const int tiles_n = */ tn,
|
||||
/* const int tiles_m = */ tm,
|
||||
/* const int batch_stride_a = */ K,
|
||||
/* const int batch_stride_b = */ N * K,
|
||||
/* const int batch_stride_d = */ N,
|
||||
/* const size_t batch_stride_a = */ size_t(K),
|
||||
/* const size_t batch_stride_b = */ size_t(N) * K,
|
||||
/* const size_t batch_stride_d = */ size_t(N),
|
||||
/* const int swizzle_log = */ swizzle_log,
|
||||
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||
/* const int batch_ndim = */ 1};
|
||||
@@ -405,7 +405,7 @@ void steel_matmul(
|
||||
}
|
||||
}
|
||||
|
||||
int matrix_stride_out = M * N;
|
||||
size_t matrix_stride_out = size_t(M) * N;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Split K specialization
|
||||
@@ -550,9 +550,9 @@ void steel_matmul(
|
||||
/* const int ldd = */ N,
|
||||
/* const int tiles_n = */ tn,
|
||||
/* const int tiles_m = */ tm,
|
||||
/* const int batch_stride_a = */ int(A_batch_stride.back()),
|
||||
/* const int batch_stride_b = */ int(B_batch_stride.back()),
|
||||
/* const int batch_stride_d = */ matrix_stride_out,
|
||||
/* const size_t batch_stride_a = */ A_batch_stride.back(),
|
||||
/* const size_t batch_stride_b = */ B_batch_stride.back(),
|
||||
/* const size_t batch_stride_d = */ matrix_stride_out,
|
||||
/* const int swizzle_log = */ swizzle_log,
|
||||
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||
/* const int batch_ndim = */ int(batch_shape.size())};
|
||||
@@ -645,7 +645,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
auto [batch_shape, A_batch_stride, B_batch_stride] = collapse_batches(a, b);
|
||||
|
||||
auto batch_size_out = out.size() / (M * N);
|
||||
auto batch_size_out = out.size() / (size_t(M) * size_t(N));
|
||||
|
||||
// Collapse batches into M if needed
|
||||
if (batch_size_out > 1 && !a_transposed && batch_shape.size() == 1 &&
|
||||
@@ -853,7 +853,8 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto [batch_shape, A_batch_stride, B_batch_stride, C_batch_stride] =
|
||||
collapse_batches(a, b, c);
|
||||
|
||||
auto batch_size_out = out.size() / (M * N);
|
||||
size_t matrix_stride_out = size_t(M) * size_t(N);
|
||||
auto batch_size_out = out.size() / (matrix_stride_out);
|
||||
|
||||
// Collapse batches into M if needed
|
||||
if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 &&
|
||||
@@ -869,8 +870,6 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
batch_shape = {1};
|
||||
}
|
||||
|
||||
int matrix_stride_out = M * N;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Gemv specialization
|
||||
|
||||
@@ -1120,9 +1119,9 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
/* const int ldd = */ N,
|
||||
/* const int tiles_n = */ tn,
|
||||
/* const int tiles_m = */ tm,
|
||||
/* const int batch_stride_a = */ int(A_batch_stride.back()),
|
||||
/* const int batch_stride_b = */ int(B_batch_stride.back()),
|
||||
/* const int batch_stride_d = */ matrix_stride_out,
|
||||
/* const size_t batch_stride_a = */ A_batch_stride.back(),
|
||||
/* const size_t batch_stride_b = */ B_batch_stride.back(),
|
||||
/* const size_t batch_stride_d = */ matrix_stride_out,
|
||||
/* const int swizzle_log = */ swizzle_log,
|
||||
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||
/* const int batch_ndim = */ int(batch_shape.size())};
|
||||
@@ -1130,7 +1129,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
GEMMAddMMParams params{
|
||||
/* const int ldc = */ ldc,
|
||||
/* const int fdc = */ fdc,
|
||||
/* const int batch_stride_c = */ int(C_batch_stride.back()),
|
||||
/* const size_t batch_stride_c = */ C_batch_stride.back(),
|
||||
/* const float alpha = */ alpha_,
|
||||
/* const float beta = */ beta_};
|
||||
|
||||
@@ -1230,8 +1229,8 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& out_mask = inputs[2];
|
||||
|
||||
std::vector<int> batch_shape{1};
|
||||
int A_batch_str = 0;
|
||||
int B_batch_str = 0;
|
||||
size_t A_batch_str = 0;
|
||||
size_t B_batch_str = 0;
|
||||
|
||||
std::vector<size_t> batch_strides;
|
||||
|
||||
@@ -1249,8 +1248,8 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
auto [bshape_c, bstrides_c] = collapse_contiguous_dims(bshape, bstrides);
|
||||
batch_shape = bshape_c;
|
||||
A_batch_str = int(bstrides_c[0].back());
|
||||
B_batch_str = int(bstrides_c[1].back());
|
||||
A_batch_str = bstrides_c[0].back();
|
||||
B_batch_str = bstrides_c[1].back();
|
||||
|
||||
for (auto& bstr : bstrides_c) {
|
||||
batch_strides.insert(batch_strides.end(), bstr.begin(), bstr.end());
|
||||
@@ -1259,8 +1258,8 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
batch_strides = std::vector<size_t>(inputs.size(), 0);
|
||||
}
|
||||
|
||||
auto batch_size_out = out.size() / (M * N);
|
||||
int matrix_stride_out = M * N;
|
||||
size_t matrix_stride_out = size_t(M) * N;
|
||||
size_t batch_size_out = out.size() / (matrix_stride_out);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Regular kernel dispatch
|
||||
@@ -1301,9 +1300,9 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
/* const int ldd = */ N,
|
||||
/* const int tiles_n = */ tn,
|
||||
/* const int tiles_m = */ tm,
|
||||
/* const int batch_stride_a = */ A_batch_str,
|
||||
/* const int batch_stride_b = */ B_batch_str,
|
||||
/* const int batch_stride_d = */ matrix_stride_out,
|
||||
/* const size_t batch_stride_a = */ A_batch_str,
|
||||
/* const size_t batch_stride_b = */ B_batch_str,
|
||||
/* const size_t batch_stride_d = */ matrix_stride_out,
|
||||
/* const int swizzle_log = */ swizzle_log,
|
||||
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||
/* const int batch_ndim = */ int(batch_shape.size())};
|
||||
@@ -1355,4 +1354,314 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
|
||||
void BlockSparseMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
using namespace mlx::steel;
|
||||
// assert(inputs.size() == 2);
|
||||
if (!issubdtype(out.dtype(), floating)) {
|
||||
throw std::runtime_error(
|
||||
"[matmul] Does not yet support non-floating point types.");
|
||||
}
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto& a_pre = inputs[0];
|
||||
auto& b_pre = inputs[1];
|
||||
// Return 0s if either input is empty
|
||||
if (a_pre.size() == 0 || b_pre.size() == 0) {
|
||||
array zero = array(0, a_pre.dtype());
|
||||
copy_gpu(zero, out, CopyType::Scalar, s);
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
command_buffer->addCompletedHandler([zero](MTL::CommandBuffer*) {});
|
||||
return;
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Init checks and prep
|
||||
|
||||
// Keep a vector with copies to be cleared in the completed buffer to release
|
||||
// the arrays
|
||||
std::vector<array> copies;
|
||||
auto check_transpose = [&copies, &s](const array& arr) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (sty == 1) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
size_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
|
||||
auto [transpose_a, a_cols, a] = check_transpose(a_pre);
|
||||
auto [transpose_b, b_cols, b] = check_transpose(b_pre);
|
||||
|
||||
int lda = a_cols;
|
||||
int ldb = b_cols;
|
||||
|
||||
int M = a.shape(-2);
|
||||
int N = b.shape(-1);
|
||||
int K = a.shape(-1);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Check and collapse batch dimensions
|
||||
|
||||
auto get_batch_dims = [](const auto& v) {
|
||||
return decltype(v){v.begin(), v.end() - 2};
|
||||
};
|
||||
|
||||
auto& lhs_indices = inputs[2];
|
||||
auto& rhs_indices = inputs[3];
|
||||
|
||||
std::vector<int> batch_shape = get_batch_dims(out.shape());
|
||||
std::vector<size_t> batch_strides;
|
||||
|
||||
batch_strides.insert(
|
||||
batch_strides.end(),
|
||||
lhs_indices.strides().begin(),
|
||||
lhs_indices.strides().end());
|
||||
size_t lhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
|
||||
|
||||
batch_strides.insert(
|
||||
batch_strides.end(),
|
||||
rhs_indices.strides().begin(),
|
||||
rhs_indices.strides().end());
|
||||
size_t rhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
|
||||
|
||||
int batch_ndim = batch_shape.size();
|
||||
|
||||
if (batch_ndim == 0) {
|
||||
batch_shape = {1};
|
||||
batch_strides = {0};
|
||||
}
|
||||
|
||||
int batch_ndim_A = a.ndim() - 2;
|
||||
int batch_ndim_B = b.ndim() - 2;
|
||||
std::vector<int> operand_batch_ndim = {batch_ndim_A, batch_ndim_B};
|
||||
|
||||
std::vector<int> batch_shape_A = get_batch_dims(a.shape());
|
||||
std::vector<size_t> batch_strides_A = get_batch_dims(a.strides());
|
||||
std::vector<int> batch_shape_B = get_batch_dims(b.shape());
|
||||
std::vector<size_t> batch_strides_B = get_batch_dims(b.strides());
|
||||
|
||||
if (batch_ndim_A == 0) {
|
||||
batch_shape_A = {1};
|
||||
batch_strides_A = {0};
|
||||
}
|
||||
|
||||
if (batch_ndim_B == 0) {
|
||||
batch_shape_B = {1};
|
||||
batch_strides_B = {0};
|
||||
}
|
||||
|
||||
size_t matrix_stride_out = size_t(M) * N;
|
||||
auto batch_size_out = out.size() / matrix_stride_out;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Gemv specialization
|
||||
|
||||
// Route to gemv if needed
|
||||
if (std::min(M, N) == 1) {
|
||||
// Collect problem info
|
||||
bool is_b_matrix = N != 1;
|
||||
|
||||
auto& mat = is_b_matrix ? b : a;
|
||||
auto& vec = is_b_matrix ? a : b;
|
||||
bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a;
|
||||
int in_vector_len = K;
|
||||
int out_vector_len = is_b_matrix ? N : M;
|
||||
|
||||
int mat_cols = transpose_mat ? out_vector_len : in_vector_len;
|
||||
int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
|
||||
int mat_ld = is_b_matrix ? b_cols : a_cols;
|
||||
|
||||
auto batch_strides_mat = is_b_matrix ? batch_strides_B : batch_strides_A;
|
||||
auto batch_strides_vec = is_b_matrix ? batch_strides_A : batch_strides_B;
|
||||
|
||||
auto batch_shape_mat = is_b_matrix ? batch_shape_B : batch_shape_A;
|
||||
auto batch_shape_vec = is_b_matrix ? batch_shape_A : batch_shape_B;
|
||||
|
||||
if (!is_b_matrix) {
|
||||
batch_strides = rhs_indices.strides();
|
||||
batch_strides.insert(
|
||||
batch_strides.end(),
|
||||
lhs_indices.strides().begin(),
|
||||
lhs_indices.strides().end());
|
||||
}
|
||||
|
||||
int batch_ndim = batch_shape.size();
|
||||
|
||||
// Determine dispatch kernel
|
||||
int tm = 4, tn = 4;
|
||||
int bm, bn, n_out_per_tgp;
|
||||
std::ostringstream kname;
|
||||
|
||||
if (transpose_mat) {
|
||||
bm = 8;
|
||||
bn = 8;
|
||||
if (out_vector_len >= 24576) {
|
||||
bn = 128;
|
||||
} else if (out_vector_len >= 16384) {
|
||||
bn = 64;
|
||||
} else if (out_vector_len >= 8192) {
|
||||
bn = 16;
|
||||
}
|
||||
|
||||
// Specialized kernel for very small outputs
|
||||
tn = out_vector_len < tn ? 1 : tn;
|
||||
|
||||
n_out_per_tgp = bn * tn;
|
||||
kname << "gemv_t_bs_" << type_to_name(out);
|
||||
|
||||
} else {
|
||||
bm = out_vector_len >= 4096 ? 8 : 4;
|
||||
bn = 32;
|
||||
|
||||
// Specialized kernel for very small outputs
|
||||
tm = out_vector_len < tm ? 1 : tm;
|
||||
|
||||
n_out_per_tgp = bm * tm;
|
||||
kname << "gemv_bs_" << type_to_name(out);
|
||||
}
|
||||
|
||||
kname << "_bm" << bm << "_bn" << bn << "_tm" << tm << "_tn" << tn;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
||||
MTL::Size group_dims = MTL::Size(bn, bm, 1);
|
||||
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
|
||||
|
||||
compute_encoder.set_input_array(mat, 0);
|
||||
compute_encoder.set_input_array(vec, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&mat_ld, sizeof(int), 6);
|
||||
|
||||
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
|
||||
set_vector_bytes(compute_encoder, batch_shape, 10);
|
||||
set_vector_bytes(compute_encoder, batch_strides, 11);
|
||||
|
||||
int batch_ndim_vec = batch_shape_vec.size();
|
||||
compute_encoder->setBytes(&batch_ndim_vec, sizeof(int), 12);
|
||||
set_vector_bytes(compute_encoder, batch_shape_vec, 13);
|
||||
set_vector_bytes(compute_encoder, batch_strides_vec, 14);
|
||||
|
||||
int batch_ndim_mat = batch_shape_mat.size();
|
||||
compute_encoder->setBytes(&batch_ndim_mat, sizeof(int), 15);
|
||||
set_vector_bytes(compute_encoder, batch_shape_mat, 16);
|
||||
set_vector_bytes(compute_encoder, batch_strides_mat, 17);
|
||||
|
||||
compute_encoder.set_input_array(lhs_indices, 18 + int(!is_b_matrix));
|
||||
compute_encoder.set_input_array(rhs_indices, 18 + int(is_b_matrix));
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
return;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Regular kernel dispatch
|
||||
|
||||
// Determine dispatch kernel
|
||||
int bm = 32, bn = 32, bk = 16;
|
||||
int wm = 2, wn = 2;
|
||||
|
||||
if ((size_t)batch_size_out * M * N >= 1ul << 20) {
|
||||
if (!transpose_a && transpose_b) {
|
||||
bm = 64;
|
||||
bn = (out.dtype() == float32) ? 64 : 32;
|
||||
bk = (out.dtype() == float32) ? 16 : 32;
|
||||
} else {
|
||||
bm = 64;
|
||||
bn = 64;
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare kernel name
|
||||
std::ostringstream kname;
|
||||
kname << "steel_block_sparse_gemm_" << (transpose_a ? 't' : 'n')
|
||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||
<< "_wm" << wm << "_wn" << wn << "_MN_"
|
||||
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
|
||||
<< ((K % bk == 0) ? "t" : "n") << "aligned";
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Use problem size to determine threadblock swizzle
|
||||
int tn = (N + bn - 1) / bn;
|
||||
int tm = (M + bm - 1) / bm;
|
||||
|
||||
// TODO: Explore device-based tuning for swizzle
|
||||
int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);
|
||||
|
||||
// Prepare steel matmul params
|
||||
GEMMParams params{
|
||||
/* const int M = */ M,
|
||||
/* const int N = */ N,
|
||||
/* const int K = */ K,
|
||||
/* const int lda = */ lda,
|
||||
/* const int ldb = */ ldb,
|
||||
/* const int ldd = */ N,
|
||||
/* const int tiles_n = */ tn,
|
||||
/* const int tiles_m = */ tm,
|
||||
/* const size_t batch_stride_a = */ lhs_indices_str,
|
||||
/* const size_t batch_stride_b = */ rhs_indices_str,
|
||||
/* const size_t batch_stride_d = */ matrix_stride_out,
|
||||
/* const int swizzle_log = */ swizzle_log,
|
||||
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||
/* const int batch_ndim = */ batch_ndim};
|
||||
|
||||
// Prepare launch grid params
|
||||
int tile = 1 << swizzle_log;
|
||||
tm = (tm + tile - 1) / tile;
|
||||
tn = tn * tile;
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
|
||||
|
||||
// Launch kernel
|
||||
compute_encoder.set_input_array(a, 0);
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 4);
|
||||
|
||||
set_vector_bytes(compute_encoder, batch_shape, 6);
|
||||
set_vector_bytes(compute_encoder, batch_strides, 7);
|
||||
|
||||
compute_encoder.set_input_array(lhs_indices, 10);
|
||||
compute_encoder.set_input_array(rhs_indices, 11);
|
||||
|
||||
set_vector_bytes(compute_encoder, batch_shape_A, 12);
|
||||
set_vector_bytes(compute_encoder, batch_strides_A, 13);
|
||||
set_vector_bytes(compute_encoder, batch_shape_B, 14);
|
||||
set_vector_bytes(compute_encoder, batch_strides_B, 15);
|
||||
set_vector_bytes(compute_encoder, operand_batch_ndim, 16);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
// Clear copies
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
return;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
Reference in New Issue
Block a user