No copy gems (#801)

* Enable collapsing batch dims in gemm
* Update gemm to only make copies when neither of the last 2 axes are contiguous
* Update addmm to support gemv shapes
* Update addmm to support irregular batch strides
* Update tests
This commit is contained in:
Jagrit Digani 2024-03-12 13:13:41 -07:00 committed by GitHub
parent d0c544a868
commit 5ad133f8bb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 799 additions and 448 deletions

View File

@ -28,10 +28,12 @@ void explicit_gemm_conv_ND_gpu(
const array& wt, const array& wt,
array out, array out,
const MLXConvParams<N>& conv_params) { const MLXConvParams<N>& conv_params) {
// Get gemm shapes
int implicit_M = out.size() / conv_params.O;
int implicit_K = wt.size() / conv_params.O;
int implicit_N = conv_params.O;
// Prepare unfolding array // Prepare unfolding array
std::vector<int> unfolded_shape = { std::vector<int> unfolded_shape{implicit_M, implicit_K};
static_cast<int>(out.size() / conv_params.O),
static_cast<int>(wt.size() / conv_params.O)};
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {}); array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes())); in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
@ -59,20 +61,29 @@ void explicit_gemm_conv_ND_gpu(
compute_encoder->dispatchThreads(grid_dims, group_dims); compute_encoder->dispatchThreads(grid_dims, group_dims);
// Reshape weight
std::vector<int> wt_reshape{implicit_K, implicit_N};
std::vector<size_t> wt_restride{1, static_cast<size_t>(implicit_K)};
array wt_reshaped(wt_reshape, wt.dtype(), nullptr, {});
auto wt_flags = wt.flags();
wt_flags.row_contiguous = false;
wt_flags.col_contiguous = true;
wt_reshaped.copy_shared_buffer(wt, wt_restride, wt_flags, wt.data_size());
// Perform gemm // Perform gemm
std::vector<array> copies; std::vector<array> copies = {in_unfolded, wt_reshaped};
return steel_matmul( return steel_matmul(
s, s,
d, d,
/*a = */ in_unfolded, /*a = */ in_unfolded,
/*b = */ wt, /*b = */ wt_reshaped,
/*c = */ out, /*c = */ out,
/*M = */ unfolded_shape[0], /*M = */ implicit_M,
/*N = */ conv_params.O, /*N = */ implicit_N,
/*K = */ unfolded_shape[1], /*K = */ implicit_K,
/*batch_size_out = */ 1, /*batch_size_out = */ 1,
/*a_cols = */ unfolded_shape[1], /*a_cols = */ implicit_K,
/*b_cols = */ unfolded_shape[1], /*b_cols = */ implicit_K,
/*a_transposed = */ false, /*a_transposed = */ false,
/*b_transposed = */ true, /*b_transposed = */ true,
/*copies = */ copies); /*copies = */ copies);

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <metal_stdlib> #include <metal_stdlib>
#include <metal_simdgroup> #include <metal_simdgroup>
@ -22,7 +22,8 @@ template <
const int BM, /* Threadgroup rows (in threads) */ const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */ const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */ const int TM, /* Thread rows (in elements) */
const int TN > /* Thread cols (in elements) */ const int TN , /* Thread cols (in elements) */
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
struct GEMVKernel { struct GEMVKernel {
static_assert(BN == SIMD_SIZE, "gemv block must have a width of SIMD_SIZE"); static_assert(BN == SIMD_SIZE, "gemv block must have a width of SIMD_SIZE");
@ -48,11 +49,16 @@ struct GEMVKernel {
MLX_MTL_CONST short tgp_mem_size = BN * TN * 2; MLX_MTL_CONST short tgp_mem_size = BN * TN * 2;
static METAL_FUNC void run( static METAL_FUNC void run(
const device T* mat, const device T* mat [[buffer(0)]],
const device T* in_vec, const device T* in_vec [[buffer(1)]],
device T* out_vec, const device T* bias [[buffer(2)]],
const constant int& in_vec_size [[buffer(3)]], device T* out_vec [[buffer(3)]],
const constant int& out_vec_size [[buffer(4)]], const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& marix_ld [[buffer(6)]],
const constant float& alpha [[buffer(7)]],
const constant float& beta [[buffer(8)]],
const constant int& bias_stride [[buffer(14)]],
threadgroup T* tgp_memory [[threadgroup(0)]], threadgroup T* tgp_memory [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]], uint3 lid [[thread_position_in_threadgroup]],
@ -81,7 +87,7 @@ struct GEMVKernel {
out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM; out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM;
// Advance matrix // Advance matrix
mat += out_row * in_vec_size; mat += out_row * marix_ld;
// Loop over in_vec in blocks of BN * TN // Loop over in_vec in blocks of BN * TN
for(int bn = simd_lid * TN; bn < in_vec_size; bn += BN * TN) { for(int bn = simd_lid * TN; bn < in_vec_size; bn += BN * TN) {
@ -124,14 +130,14 @@ struct GEMVKernel {
if(bn + TN <= in_vec_size) { if(bn + TN <= in_vec_size) {
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
for(int tn = 0; tn < TN; tn++) { for(int tn = 0; tn < TN; tn++) {
inter[tn] = mat[tm * in_vec_size + bn + tn]; inter[tn] = mat[tm * marix_ld + bn + tn];
} }
} else { // Edgecase } else { // Edgecase
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
for(int tn = 0; tn < TN; tn++) { for(int tn = 0; tn < TN; tn++) {
int col_idx = (bn + tn) < in_vec_size ? (bn + tn) : (in_vec_size - 1); int col_idx = (bn + tn) < in_vec_size ? (bn + tn) : (in_vec_size - 1);
inter[tn] = mat[tm * in_vec_size + col_idx]; inter[tn] = mat[tm * marix_ld + col_idx];
} }
} }
@ -154,8 +160,14 @@ struct GEMVKernel {
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
for(int tm = 0; tm < TM; tm++) { for(int tm = 0; tm < TM; tm++) {
if(kDoAxpby) {
out_vec[out_row + tm] =
static_cast<T>(alpha) * result[tm] +
static_cast<T>(beta) * bias[(out_row + tm) * bias_stride];
} else {
out_vec[out_row + tm] = result[tm]; out_vec[out_row + tm] = result[tm];
} }
}
} }
@ -172,7 +184,8 @@ template <
const int BM, /* Threadgroup rows (in threads) */ const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */ const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */ const int TM, /* Thread rows (in elements) */
const int TN > /* Thread cols (in elements) */ const int TN, /* Thread cols (in elements) */
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
struct GEMVTKernel { struct GEMVTKernel {
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up // - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
@ -197,11 +210,16 @@ struct GEMVTKernel {
MLX_MTL_CONST short tgp_mem_size = BN * BM * TN; MLX_MTL_CONST short tgp_mem_size = BN * BM * TN;
static METAL_FUNC void run( static METAL_FUNC void run(
const device T* mat, const device T* mat [[buffer(0)]],
const device T* in_vec, const device T* in_vec [[buffer(1)]],
device T* out_vec, const device T* bias [[buffer(2)]],
const constant int& in_vec_size [[buffer(3)]], device T* out_vec [[buffer(3)]],
const constant int& out_vec_size [[buffer(4)]], const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& marix_ld [[buffer(6)]],
const constant float& alpha [[buffer(7)]],
const constant float& beta [[buffer(8)]],
const constant int& bias_stride [[buffer(14)]],
threadgroup T* tgp_memory [[threadgroup(0)]], threadgroup T* tgp_memory [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]], uint3 lid [[thread_position_in_threadgroup]],
@ -245,7 +263,7 @@ struct GEMVTKernel {
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
for(int tm = 0; tm < TM; tm++) { for(int tm = 0; tm < TM; tm++) {
for(int tn = 0; tn < TN; tn++) { for(int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * out_vec_size + out_col + tn]; inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
} }
for(int tn = 0; tn < TN; tn++) { for(int tn = 0; tn < TN; tn++) {
result[tn] += v_coeff[tm] * inter[tn]; result[tn] += v_coeff[tm] * inter[tn];
@ -257,7 +275,7 @@ struct GEMVTKernel {
v_coeff[tm] = in_vec[bm + tm]; v_coeff[tm] = in_vec[bm + tm];
for(int tn = 0; tn < TN; tn++) { for(int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * out_vec_size + out_col + tn]; inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
} }
for(int tn = 0; tn < TN; tn++) { for(int tn = 0; tn < TN; tn++) {
result[tn] += v_coeff[tm] * inter[tn]; result[tn] += v_coeff[tm] * inter[tn];
@ -292,13 +310,17 @@ struct GEMVTKernel {
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
for(int j = 0; j < TN; j++) { for(int j = 0; j < TN; j++) {
if(kDoAxpby) {
out_vec[out_col + j] =
static_cast<T>(alpha) * result[j] +
static_cast<T>(beta) * bias[(out_col + j) * bias_stride];
} else {
out_vec[out_col + j] = result[j]; out_vec[out_col + j] = result[j];
} }
} }
} }
}
}; };
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
@ -310,78 +332,64 @@ template <
const int BM, /* Threadgroup rows (in threads) */ const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */ const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */ const int TM, /* Thread rows (in elements) */
const int TN> /* Thread cols (in elements) */ const int TN, /* Thread cols (in elements) */
const bool kDoNCBatch, /* Batch ndim > 1 */
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv( [[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv(
const device T* mat [[buffer(0)]], const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]], const device T* in_vec [[buffer(1)]],
device T* out_vec [[buffer(2)]], const device T* bias [[buffer(2)]],
const constant int& in_vec_size [[buffer(3)]], device T* out_vec [[buffer(3)]],
const constant int& out_vec_size [[buffer(4)]], const constant int& in_vec_size [[buffer(4)]],
const constant int& vector_batch_stride [[buffer(5)]], const constant int& out_vec_size [[buffer(5)]],
const constant int& matrix_batch_stride [[buffer(6)]], const constant int& marix_ld [[buffer(6)]],
const constant float& alpha [[buffer(7)]],
const constant float& beta [[buffer(8)]],
const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant size_t* vector_batch_stride [[buffer(11)]],
const constant size_t* matrix_batch_stride [[buffer(12)]],
const constant size_t* bias_batch_stride [[buffer(13)]],
const constant int& bias_stride [[buffer(14)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]], uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVKernel<T, BM, BN, TM, TN>; using gemv_kernel = GEMVKernel<T, BM, BN, TM, TN, kDoAxpby>;
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size]; threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
// Update batch offsets // Update batch offsets
in_vec += tid.z * vector_batch_stride; if(kDoNCBatch) {
mat += tid.z * matrix_batch_stride; in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
out_vec += tid.z * out_vec_size; mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
gemv_kernel::run( if(kDoAxpby) {
mat, bias += elem_to_loc(tid.z, batch_shape, bias_batch_stride, batch_ndim);
in_vec, }
out_vec,
in_vec_size, } else {
out_vec_size, in_vec += tid.z * vector_batch_stride[0];
tgp_memory, mat += tid.z * matrix_batch_stride[0];
tid,
lid, if(kDoAxpby) {
simd_gid, bias += tid.z * bias_batch_stride[0];
simd_lid }
); }
}
template <
typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN> /* Thread cols (in elements) */
[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv_nc(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
device T* out_vec [[buffer(2)]],
const constant int& in_vec_size [[buffer(3)]],
const constant int& out_vec_size [[buffer(4)]],
const constant int& nc_dim [[buffer(5)]],
const device int* nc_shape [[buffer(6)]],
const device size_t* nc_strides_vec [[buffer(7)]],
const device size_t* nc_strides_mat [[buffer(8)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVKernel<T, BM, BN, TM, TN>;
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
// Update batch offsets
in_vec += elem_to_loc(tid.z, nc_shape, nc_strides_vec, nc_dim);
mat += elem_to_loc(tid.z, nc_shape, nc_strides_mat, nc_dim);
out_vec += tid.z * out_vec_size; out_vec += tid.z * out_vec_size;
gemv_kernel::run( gemv_kernel::run(
mat, mat,
in_vec, in_vec,
bias,
out_vec, out_vec,
in_vec_size, in_vec_size,
out_vec_size, out_vec_size,
marix_ld,
alpha,
beta,
bias_stride,
tgp_memory, tgp_memory,
tid, tid,
lid, lid,
@ -392,41 +400,34 @@ template <
} }
#define instantiate_gemv_c(name, itype, bm, bn, tm, tn) \ #define instantiate_gemv_helper(name, itype, bm, bn, tm, tn, nc, axpby) \
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \ template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc" #nc "_axpby" #axpby)]] \
[[kernel]] void gemv<itype, bm, bn, tm, tn>( \ [[kernel]] void gemv<itype, bm, bn, tm, tn, nc, axpby>( \
const device itype* mat [[buffer(0)]], \ const device itype* mat [[buffer(0)]], \
const device itype* vec [[buffer(1)]], \ const device itype* in_vec [[buffer(1)]], \
device itype* out [[buffer(2)]], \ const device itype* bias [[buffer(2)]], \
const constant int& in_vec_size [[buffer(3)]], \ device itype* out_vec [[buffer(3)]], \
const constant int& out_vec_size [[buffer(4)]], \ const constant int& in_vec_size [[buffer(4)]], \
const constant int& vector_batch_stride [[buffer(5)]], \ const constant int& out_vec_size [[buffer(5)]], \
const constant int& matrix_batch_stride [[buffer(6)]], \ const constant int& marix_ld [[buffer(6)]], \
uint3 tid [[threadgroup_position_in_grid]], \ const constant float& alpha [[buffer(7)]], \
uint3 lid [[thread_position_in_threadgroup]], \ const constant float& beta [[buffer(8)]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \ const constant int& batch_ndim [[buffer(9)]], \
uint simd_lid [[thread_index_in_simdgroup]]); const constant int* batch_shape [[buffer(10)]], \
const constant size_t* vector_batch_stride [[buffer(11)]], \
#define instantiate_gemv_nc(name, itype, bm, bn, tm, tn) \ const constant size_t* matrix_batch_stride [[buffer(12)]], \
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc")]] \ const constant size_t* bias_batch_stride [[buffer(13)]], \
[[kernel]] void gemv_nc<itype, bm, bn, tm, tn>( \ const constant int& bias_stride [[buffer(14)]], \
const device itype* mat [[buffer(0)]], \
const device itype* vec [[buffer(1)]], \
device itype* out [[buffer(2)]], \
const constant int& in_vec_size [[buffer(3)]], \
const constant int& out_vec_size [[buffer(4)]], \
const constant int& nc_dim [[buffer(5)]], \
const device int* nc_shape [[buffer(6)]], \
const device size_t* nc_strides_vec [[buffer(7)]], \
const device size_t* nc_strides_mat [[buffer(8)]], \
uint3 tid [[threadgroup_position_in_grid]], \ uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \ uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \ uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]); uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_gemv(name, itype, bm, bn, tm, tn) \ #define instantiate_gemv(name, itype, bm, bn, tm, tn) \
instantiate_gemv_c(name, itype, bm, bn, tm, tn) \ instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 0) \
instantiate_gemv_nc(name, itype, bm, bn, tm, tn) instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 1) \
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 0) \
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 1)
#define instantiate_gemv_blocks(name, itype) \ #define instantiate_gemv_blocks(name, itype) \
instantiate_gemv(name, itype, 4, 32, 1, 4) \ instantiate_gemv(name, itype, 4, 32, 1, 4) \
@ -446,77 +447,64 @@ template <
const int BM, /* Threadgroup rows (in threads) */ const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */ const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */ const int TM, /* Thread rows (in elements) */
const int TN> /* Thread cols (in elements) */ const int TN, /* Thread cols (in elements) */
const bool kDoNCBatch, /* Batch ndim > 1 */
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv_t( [[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv_t(
const device T* mat [[buffer(0)]], const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]], const device T* in_vec [[buffer(1)]],
device T* out_vec [[buffer(2)]], const device T* bias [[buffer(2)]],
const constant int& in_vec_size [[buffer(3)]], device T* out_vec [[buffer(3)]],
const constant int& out_vec_size [[buffer(4)]], const constant int& in_vec_size [[buffer(4)]],
const constant int& vector_batch_stride [[buffer(5)]], const constant int& out_vec_size [[buffer(5)]],
const constant int& matrix_batch_stride [[buffer(6)]], const constant int& marix_ld [[buffer(6)]],
const constant float& alpha [[buffer(7)]],
const constant float& beta [[buffer(8)]],
const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant size_t* vector_batch_stride [[buffer(11)]],
const constant size_t* matrix_batch_stride [[buffer(12)]],
const constant size_t* bias_batch_stride [[buffer(13)]],
const constant int& bias_stride [[buffer(14)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]], uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVTKernel<T, BM, BN, TM, TN>; using gemv_kernel = GEMVTKernel<T, BM, BN, TM, TN, kDoAxpby>;
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size]; threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
// Update batch offsets // Update batch offsets
in_vec += tid.z * vector_batch_stride; if(kDoNCBatch) {
mat += tid.z * matrix_batch_stride; in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
out_vec += tid.z * out_vec_size; mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
gemv_kernel::run( if(kDoAxpby) {
mat, bias += elem_to_loc(tid.z, batch_shape, bias_batch_stride, batch_ndim);
in_vec, }
out_vec,
in_vec_size, } else {
out_vec_size, in_vec += tid.z * vector_batch_stride[0];
tgp_memory, mat += tid.z * matrix_batch_stride[0];
tid,
lid, if(kDoAxpby) {
simd_gid, bias += tid.z * bias_batch_stride[0];
simd_lid }
); }
}
template <
typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN> /* Thread cols (in elements) */
[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv_t_nc(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
device T* out_vec [[buffer(2)]],
const constant int& in_vec_size [[buffer(3)]],
const constant int& out_vec_size [[buffer(4)]],
const constant int& nc_dim [[buffer(5)]],
const device int* nc_shape [[buffer(6)]],
const device size_t* nc_strides_vec [[buffer(7)]],
const device size_t* nc_strides_mat [[buffer(8)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVTKernel<T, BM, BN, TM, TN>;
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
// Update batch offsets
in_vec += elem_to_loc(tid.z, nc_shape, nc_strides_vec, nc_dim);
mat += elem_to_loc(tid.z, nc_shape, nc_strides_mat, nc_dim);
out_vec += tid.z * out_vec_size; out_vec += tid.z * out_vec_size;
gemv_kernel::run( gemv_kernel::run(
mat, mat,
in_vec, in_vec,
bias,
out_vec, out_vec,
in_vec_size, in_vec_size,
out_vec_size, out_vec_size,
marix_ld,
alpha,
beta,
bias_stride,
tgp_memory, tgp_memory,
tid, tid,
lid, lid,
@ -526,41 +514,34 @@ template <
} }
#define instantiate_gemv_t_c(name, itype, bm, bn, tm, tn) \ #define instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, nc, axpby) \
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \ template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc" #nc "_axpby" #axpby)]] \
[[kernel]] void gemv_t<itype, bm, bn, tm, tn>( \ [[kernel]] void gemv_t<itype, bm, bn, tm, tn, nc, axpby>( \
const device itype* mat [[buffer(0)]], \ const device itype* mat [[buffer(0)]], \
const device itype* vec [[buffer(1)]], \ const device itype* in_vec [[buffer(1)]], \
device itype* out [[buffer(2)]], \ const device itype* bias [[buffer(2)]], \
const constant int& in_vec_size [[buffer(3)]], \ device itype* out_vec [[buffer(3)]], \
const constant int& out_vec_size [[buffer(4)]], \ const constant int& in_vec_size [[buffer(4)]], \
const constant int& vector_batch_stride [[buffer(5)]], \ const constant int& out_vec_size [[buffer(5)]], \
const constant int& matrix_batch_stride [[buffer(6)]], \ const constant int& marix_ld [[buffer(6)]], \
uint3 tid [[threadgroup_position_in_grid]], \ const constant float& alpha [[buffer(7)]], \
uint3 lid [[thread_position_in_threadgroup]], \ const constant float& beta [[buffer(8)]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \ const constant int& batch_ndim [[buffer(9)]], \
uint simd_lid [[thread_index_in_simdgroup]]); const constant int* batch_shape [[buffer(10)]], \
const constant size_t* vector_batch_stride [[buffer(11)]], \
#define instantiate_gemv_t_nc(name, itype, bm, bn, tm, tn) \ const constant size_t* matrix_batch_stride [[buffer(12)]], \
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc")]] \ const constant size_t* bias_batch_stride [[buffer(13)]], \
[[kernel]] void gemv_t_nc<itype, bm, bn, tm, tn>( \ const constant int& bias_stride [[buffer(14)]], \
const device itype* mat [[buffer(0)]], \
const device itype* vec [[buffer(1)]], \
device itype* out [[buffer(2)]], \
const constant int& in_vec_size [[buffer(3)]], \
const constant int& out_vec_size [[buffer(4)]], \
const constant int& nc_dim [[buffer(5)]], \
const device int* nc_shape [[buffer(6)]], \
const device size_t* nc_strides_vec [[buffer(7)]], \
const device size_t* nc_strides_mat [[buffer(8)]], \
uint3 tid [[threadgroup_position_in_grid]], \ uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \ uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \ uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]); uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_gemv_t(name, itype, bm, bn, tm, tn) \ #define instantiate_gemv_t(name, itype, bm, bn, tm, tn) \
instantiate_gemv_t_c(name, itype, bm, bn, tm, tn) \ instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 0, 0) \
instantiate_gemv_t_nc(name, itype, bm, bn, tm, tn) instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 0, 1) \
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 1, 0) \
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 1, 1)
#define instantiate_gemv_t_blocks(name, itype) \ #define instantiate_gemv_t_blocks(name, itype) \
instantiate_gemv_t(name, itype, 8, 8, 4, 1) \ instantiate_gemv_t(name, itype, 8, 8, 4, 1) \

View File

@ -140,7 +140,7 @@ struct GEMMKernel {
static METAL_FUNC void run( static METAL_FUNC void run(
const device T* A [[buffer(0)]], const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]], const device T* B [[buffer(1)]],
device U* C [[buffer(2)]], device U* D [[buffer(2)]],
const constant GEMMParams* params [[buffer(3)]], const constant GEMMParams* params [[buffer(3)]],
threadgroup T* As [[threadgroup(0)]], threadgroup T* As [[threadgroup(0)]],
threadgroup T* Bs [[threadgroup(1)]], threadgroup T* Bs [[threadgroup(1)]],
@ -167,7 +167,7 @@ struct GEMMKernel {
A += transpose_a ? c_row : c_row * params->lda; A += transpose_a ? c_row : c_row * params->lda;
B += transpose_b ? c_col * params->ldb : c_col; B += transpose_b ? c_col * params->ldb : c_col;
C += c_row * params->ldc + c_col; D += c_row * params->ldd + c_col;
// Prepare threadgroup loading operations // Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
@ -214,7 +214,7 @@ struct GEMMKernel {
} }
// Store results to device memory // Store results to device memory
mma_op.store_result(C, params->ldc); mma_op.store_result(D, params->ldd);
return; return;
} }
@ -237,7 +237,7 @@ struct GEMMKernel {
tgp_bn, tgp_bn,
leftover_bk); leftover_bk);
mma_op.store_result(C, params->ldc); mma_op.store_result(D, params->ldd);
return; return;
} else if (tgp_bn == BN) { } else if (tgp_bn == BN) {
@ -252,7 +252,7 @@ struct GEMMKernel {
tgp_bn, tgp_bn,
leftover_bk); leftover_bk);
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm)); mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
return; return;
} else if (tgp_bm == BM) { } else if (tgp_bm == BM) {
@ -267,7 +267,7 @@ struct GEMMKernel {
tgp_bn, tgp_bn,
leftover_bk); leftover_bk);
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm)); mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
return; return;
} else { } else {
@ -282,7 +282,7 @@ struct GEMMKernel {
tgp_bn, tgp_bn,
leftover_bk); leftover_bk);
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm)); mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
return; return;
} }
} }

View File

@ -1,6 +1,7 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
using namespace metal; using namespace metal;
@ -23,8 +24,10 @@ template <typename T,
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm( [[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm(
const device T *A [[buffer(0)]], const device T *A [[buffer(0)]],
const device T *B [[buffer(1)]], const device T *B [[buffer(1)]],
device T *C [[buffer(2)]], device T *D [[buffer(3)]],
const constant GEMMParams* params [[buffer(3)]], const constant GEMMParams* params [[buffer(4)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
@ -36,12 +39,25 @@ template <typename T,
threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Adjust for batch // Adjust for batch
if(params->batch_ndim > 1) {
const constant size_t* A_bstrides = batch_strides;
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
A += batch_offsets.x;
B += batch_offsets.y;
} else {
A += params->batch_stride_a * tid.z; A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z; B += params->batch_stride_b * tid.z;
C += params->batch_stride_c * tid.z; }
D += params->batch_stride_d * tid.z;
gemm_kernel::run( gemm_kernel::run(
A, B, C, A, B, D,
params, params,
As, Bs, As, Bs,
simd_lane_id, simd_group_id, tid, lid simd_lane_id, simd_group_id, tid, lid
@ -57,8 +73,10 @@ template <typename T,
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \ [[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
const device itype *A [[buffer(0)]], \ const device itype *A [[buffer(0)]], \
const device itype *B [[buffer(1)]], \ const device itype *B [[buffer(1)]], \
device itype *C [[buffer(2)]], \ device itype *D [[buffer(3)]], \
const constant GEMMParams* params [[buffer(3)]], \ const constant GEMMParams* params [[buffer(4)]], \
const constant int* batch_shape [[buffer(6)]], \
const constant size_t* batch_strides [[buffer(7)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \ uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \ uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \ uint3 tid [[threadgroup_position_in_grid]], \

View File

@ -27,7 +27,10 @@ template <typename T,
const device T *B [[buffer(1)]], const device T *B [[buffer(1)]],
const device T *C [[buffer(2)]], const device T *C [[buffer(2)]],
device T *D [[buffer(3)]], device T *D [[buffer(3)]],
const constant GEMMAddMMParams* params [[buffer(4)]], const constant GEMMParams* params [[buffer(4)]],
const constant GEMMAddMMParams* addmm_params [[buffer(5)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
@ -50,9 +53,24 @@ template <typename T,
threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Adjust for batch // Adjust for batch
if(params->batch_ndim > 1) {
const constant size_t* A_bstrides = batch_strides;
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
const constant size_t* C_bstrides = B_bstrides + params->batch_ndim;
ulong3 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, C_bstrides, params->batch_ndim);
A += batch_offsets.x;
B += batch_offsets.y;
C += batch_offsets.z;
} else {
A += params->batch_stride_a * tid.z; A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z; B += params->batch_stride_b * tid.z;
C += params->batch_stride_c * tid.z; C += addmm_params->batch_stride_c * tid.z;
}
D += params->batch_stride_d * tid.z; D += params->batch_stride_d * tid.z;
const int tid_y = ((tid.y) << params->swizzle_log) + const int tid_y = ((tid.y) << params->swizzle_log) +
@ -71,9 +89,10 @@ template <typename T,
A += transpose_a ? c_row : c_row * params->lda; A += transpose_a ? c_row : c_row * params->lda;
B += transpose_b ? c_col * params->ldb : c_col; B += transpose_b ? c_col * params->ldb : c_col;
C += c_row * params->ldc + c_col * params->fdc;
D += c_row * params->ldd + c_col; D += c_row * params->ldd + c_col;
C += c_row * addmm_params->ldc + c_col * addmm_params->fdc;
// Prepare threadgroup loading operations // Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
@ -83,7 +102,7 @@ template <typename T,
int gemm_k_iterations = params->gemm_k_iterations_aligned; int gemm_k_iterations = params->gemm_k_iterations_aligned;
const Epilogue epilogue_op(params->alpha, params->beta); const Epilogue epilogue_op(addmm_params->alpha, addmm_params->beta);
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop // MNK aligned loop
@ -121,7 +140,7 @@ template <typename T,
} }
// Store results to device memory // Store results to device memory
mma_op.store_result(D, params->ldd, C, params->ldc, params->fdc, epilogue_op); mma_op.store_result(D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op);
return; return;
} }
@ -145,7 +164,7 @@ template <typename T,
leftover_bk, leftover_bk,
LoopAlignment<true, true, K_aligned>{}); LoopAlignment<true, true, K_aligned>{});
mma_op.store_result(D, params->ldd, C, params->ldc, params->fdc, epilogue_op); mma_op.store_result(D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op);
return; return;
} else if (tgp_bn == BN) { } else if (tgp_bn == BN) {
@ -163,7 +182,7 @@ template <typename T,
return mma_op.store_result_safe( return mma_op.store_result_safe(
D, params->ldd, D, params->ldd,
C, params->ldc, params->fdc, C, addmm_params->ldc, addmm_params->fdc,
short2(tgp_bn, tgp_bm), short2(tgp_bn, tgp_bm),
epilogue_op); epilogue_op);
@ -182,7 +201,7 @@ template <typename T,
return mma_op.store_result_safe( return mma_op.store_result_safe(
D, params->ldd, D, params->ldd,
C, params->ldc, params->fdc, C, addmm_params->ldc, addmm_params->fdc,
short2(tgp_bn, tgp_bm), short2(tgp_bn, tgp_bm),
epilogue_op); epilogue_op);
@ -201,7 +220,7 @@ template <typename T,
return mma_op.store_result_safe( return mma_op.store_result_safe(
D, params->ldd, D, params->ldd,
C, params->ldc, params->fdc, C, addmm_params->ldc, addmm_params->fdc,
short2(tgp_bn, tgp_bm), short2(tgp_bn, tgp_bm),
epilogue_op); epilogue_op);
} }
@ -219,7 +238,10 @@ template <typename T,
const device itype *B [[buffer(1)]], \ const device itype *B [[buffer(1)]], \
const device itype *C [[buffer(2)]], \ const device itype *C [[buffer(2)]], \
device itype *D [[buffer(3)]], \ device itype *D [[buffer(3)]], \
const constant GEMMAddMMParams* params [[buffer(4)]], \ const constant GEMMParams* gemm_params [[buffer(4)]], \
const constant GEMMAddMMParams* params [[buffer(5)]], \
const constant int* batch_shape [[buffer(6)]], \
const constant size_t* batch_strides [[buffer(7)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \ uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \ uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \ uint3 tid [[threadgroup_position_in_grid]], \

View File

@ -144,9 +144,9 @@ struct BlockMMA {
} }
/* Store results from simdgroup_matrix results into device memory */ /* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(device U* C, const int ldc) const { METAL_FUNC void store_result(device U* D, const int ldd) const {
// Adjust for simdgroup and thread location // Adjust for simdgroup and thread location
C += (sm + tm) * ldc + tn + sn; D += (sm + tm) * ldd + tn + sn;
// Loop over all simdgroup tiles // Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL STEEL_PRAGMA_UNROLL
@ -155,22 +155,22 @@ struct BlockMMA {
for (short j = 0; j < TN; j++) { for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C // Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements(); thread const auto& accum = results[i * TN + j].thread_elements();
int offset = (i * TM_stride) * ldc + (j * TN_stride); int offset = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue // Apply epilogue
U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])}; U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])};
// Write out C // Write out D
C[offset] = outs[0]; D[offset] = outs[0];
C[offset + 1] = outs[1]; D[offset + 1] = outs[1];
} }
} }
} }
METAL_FUNC void METAL_FUNC void
store_result_safe(device U* C, const int ldc, short2 dst_tile_dims) const { store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) const {
// Adjust for simdgroup and thread location // Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn); D += (sm + tm) * ldd + (tn + sn);
dst_tile_dims -= short2(tn + sn, sm + tm); dst_tile_dims -= short2(tn + sn, sm + tm);
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
@ -183,15 +183,15 @@ struct BlockMMA {
for (int j = 0; j < TN; j++) { for (int j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C // Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements(); thread const auto& accum = results[i * TN + j].thread_elements();
int offset = (i * TM_stride) * ldc + (j * TN_stride); int offset = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue and output C // Apply epilogue and output C
if (j * TN_stride < dst_tile_dims.x) { if (j * TN_stride < dst_tile_dims.x) {
C[offset] = Epilogue::apply(accum[0]); D[offset] = Epilogue::apply(accum[0]);
} }
if (j * TN_stride + 1 < dst_tile_dims.x) { if (j * TN_stride + 1 < dst_tile_dims.x) {
C[offset + 1] = Epilogue::apply(accum[1]); D[offset + 1] = Epilogue::apply(accum[1]);
} }
} }
} }

View File

@ -16,17 +16,19 @@ struct GEMMParams {
const int lda; const int lda;
const int ldb; const int ldb;
const int ldc; const int ldd;
const int tiles_n; const int tiles_n;
const int tiles_m; const int tiles_m;
const int batch_stride_a; const int batch_stride_a;
const int batch_stride_b; const int batch_stride_b;
const int batch_stride_c; const int batch_stride_d;
const int swizzle_log; const int swizzle_log;
const int gemm_k_iterations_aligned; const int gemm_k_iterations_aligned;
const int batch_ndim;
}; };
struct GEMMSpiltKParams { struct GEMMSpiltKParams {
@ -49,30 +51,13 @@ struct GEMMSpiltKParams {
}; };
struct GEMMAddMMParams { struct GEMMAddMMParams {
const int M;
const int N;
const int K;
const int lda;
const int ldb;
const int ldc; const int ldc;
const int ldd; const int fdc;
const int tiles_n;
const int tiles_m;
const int batch_stride_a;
const int batch_stride_b;
const int batch_stride_c; const int batch_stride_c;
const int batch_stride_d;
const int swizzle_log;
const int gemm_k_iterations_aligned;
const float alpha; const float alpha;
const float beta; const float beta;
const int fdc;
}; };
} // namespace steel } // namespace steel

View File

@ -6,3 +6,40 @@
#define STEEL_CONST static constant constexpr const #define STEEL_CONST static constant constexpr const
#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") #define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
METAL_FUNC ulong2 elem_to_loc_broadcast(
uint elem,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
int ndim) {
ulong loc_a{0};
ulong loc_b{0};
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
int pos_in_dim = (elem % shape[i]);
elem /= shape[i];
loc_a += pos_in_dim * a_strides[i];
loc_b += pos_in_dim * b_strides[i];
}
return ulong2(loc_a, loc_b);
}
METAL_FUNC ulong3 elem_to_loc_broadcast(
uint elem,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const size_t* c_strides,
int ndim) {
ulong loc_a{0};
ulong loc_b{0};
ulong loc_c{0};
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
int pos_in_dim = (elem % shape[i]);
elem /= shape[i];
loc_a += pos_in_dim * a_strides[i];
loc_b += pos_in_dim * b_strides[i];
loc_c += pos_in_dim * c_strides[i];
}
return ulong3(loc_a, loc_b, loc_c);
}

View File

@ -191,6 +191,70 @@ inline void mps_matmul(
}); });
} }
inline auto collapse_batches(const array& a, const array& b) {
// Get and check the shape for the batched dims
std::vector<int> A_bshape{a.shape().begin(), a.shape().end() - 2};
std::vector<int> B_bshape{b.shape().begin(), b.shape().end() - 2};
if (A_bshape != B_bshape) {
std::ostringstream msg;
msg << "[matmul] Got matrices with incorrectly broadcasted shapes: "
<< "A " << a.shape() << ", B " << b.shape() << ".";
throw std::runtime_error(msg.str());
}
std::vector<size_t> A_bstride{a.strides().begin(), a.strides().end() - 2};
std::vector<size_t> B_bstride{b.strides().begin(), b.strides().end() - 2};
auto [batch_shape, batch_strides] =
collapse_contiguous_dims(A_bshape, {A_bstride, B_bstride});
auto A_batch_stride = batch_strides[0];
auto B_batch_stride = batch_strides[1];
if (batch_shape.empty()) {
batch_shape.push_back(1);
A_batch_stride.push_back(0);
B_batch_stride.push_back(0);
}
return std::make_tuple(batch_shape, A_batch_stride, B_batch_stride);
}
inline auto collapse_batches(const array& a, const array& b, const array& c) {
// Get and check the shape for the batched dims
std::vector<int> A_bshape{a.shape().begin(), a.shape().end() - 2};
std::vector<int> B_bshape{b.shape().begin(), b.shape().end() - 2};
std::vector<int> C_bshape{c.shape().begin(), c.shape().end() - 2};
if (A_bshape != B_bshape || A_bshape != C_bshape) {
std::ostringstream msg;
msg << "[addmm] Got matrices with incorrectly broadcasted shapes: "
<< "A " << a.shape() << ", B " << b.shape() << ", B " << c.shape()
<< ".";
throw std::runtime_error(msg.str());
}
std::vector<size_t> A_bstride{a.strides().begin(), a.strides().end() - 2};
std::vector<size_t> B_bstride{b.strides().begin(), b.strides().end() - 2};
std::vector<size_t> C_bstride{c.strides().begin(), c.strides().end() - 2};
auto [batch_shape, batch_strides] =
collapse_contiguous_dims(A_bshape, {A_bstride, B_bstride, C_bstride});
auto A_batch_stride = batch_strides[0];
auto B_batch_stride = batch_strides[1];
auto C_batch_stride = batch_strides[2];
if (batch_shape.empty()) {
batch_shape.push_back(1);
A_batch_stride.push_back(0);
B_batch_stride.push_back(0);
C_batch_stride.push_back(0);
}
return std::make_tuple(
batch_shape, A_batch_stride, B_batch_stride, C_batch_stride);
}
} // namespace } // namespace
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
@ -211,22 +275,33 @@ void steel_matmul(
int ldb, int ldb,
bool transpose_a, bool transpose_a,
bool transpose_b, bool transpose_b,
std::vector<array>& copies) { std::vector<array>& copies,
std::vector<int> batch_shape /* = {} */,
std::vector<size_t> A_batch_stride /* = {} */,
std::vector<size_t> B_batch_stride /* = {} */) {
using namespace mlx::steel; using namespace mlx::steel;
// Coalesce (B, M, K) X (K, N) to (B*M, K) X (K, N) if (batch_shape.empty()) {
if (batch_size_out > 1 && !transpose_a && /////////////////////////////////////////////////////////////////////////////
a.data_size() == batch_size_out * M * K && b.size() == K * N) { // Check and collapse batch dimensions
M = M * batch_size_out; auto [batch_shape_, A_bstride_, B_bstride_] = collapse_batches(a, b);
batch_shape = batch_shape_;
A_batch_stride = A_bstride_;
B_batch_stride = B_bstride_;
// Collapse batches into M if needed
if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 &&
a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K &&
B_batch_stride.back() == 0) {
M *= batch_shape.back();
batch_size_out = 1; batch_size_out = 1;
A_batch_stride = {0};
B_batch_stride = {0};
batch_shape = {1};
}
} }
// Account for batch sizes and basic broadcasting
int batch_size_a = a.data_size() / (M * K);
int batch_size_b = b.data_size() / (K * N);
int matrix_stride_a = (batch_size_a == 1) ? 0 : M * K;
int matrix_stride_b = (batch_size_b == 1) ? 0 : K * N;
int matrix_stride_out = M * N; int matrix_stride_out = M * N;
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
@ -269,18 +344,18 @@ void steel_matmul(
int tm = (M + bm - 1) / bm; int tm = (M + bm - 1) / bm;
GEMMSpiltKParams params{ GEMMSpiltKParams params{
M, /* const int M = */ M,
N, /* const int N = */ N,
K, /* const int K = */ K,
lda, /* const int lda = */ lda,
ldb, /* const int ldb = */ ldb,
N, /* const int ldc = */ N,
tn, /* const int tiles_n = */ tn,
tm, /* const int tiles_m = */ tm,
split_k_partitions, /* const int split_k_partitions = */ split_k_partitions,
split_k_partition_stride, /* const int split_k_partition_stride = */ split_k_partition_stride,
split_k_partition_size, /* const int split_k_partition_size = */ split_k_partition_size,
gemm_k_iterations}; /* const int gemm_k_iterations_aligned = */ gemm_k_iterations};
MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions); MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions);
@ -364,19 +439,20 @@ void steel_matmul(
// Prepare steel matmul params // Prepare steel matmul params
GEMMParams params{ GEMMParams params{
M, /* const int M = */ M,
N, /* const int N = */ N,
K, /* const int K = */ K,
lda, /* const int lda = */ lda,
ldb, /* const int ldb = */ ldb,
N, /* const int ldd = */ N,
tn, /* const int tiles_n = */ tn,
tm, /* const int tiles_m = */ tm,
matrix_stride_a, /* const int batch_stride_a = */ int(A_batch_stride.back()),
matrix_stride_b, /* const int batch_stride_b = */ int(B_batch_stride.back()),
matrix_stride_out, /* const int batch_stride_d = */ matrix_stride_out,
swizzle_log, /* const int swizzle_log = */ swizzle_log,
(K / bk)}; /* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ int(batch_shape.size())};
// Prepare launch grid params // Prepare launch grid params
int tile = 1 << swizzle_log; int tile = 1 << swizzle_log;
@ -386,37 +462,25 @@ void steel_matmul(
MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out); MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
// Launch only 1 kernel in the case of simple batching / broadcasting std::vector<size_t> batch_strides = A_batch_stride;
if (batch_size_out == std::max(batch_size_a, batch_size_b) && batch_strides.insert(
(batch_size_a == batch_size_b || batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
std::min(batch_size_a, batch_size_b) == 1)) {
// Launch kernel
set_array_buffer(compute_encoder, a, 0); set_array_buffer(compute_encoder, a, 0);
set_array_buffer(compute_encoder, b, 1); set_array_buffer(compute_encoder, b, 1);
set_array_buffer(compute_encoder, out, 2); set_array_buffer(compute_encoder, out, 3);
compute_encoder->setBytes(&params, sizeof(GEMMParams), 4);
compute_encoder->setBytes(
batch_shape.data(), sizeof(int) * batch_shape.size(), 6);
compute_encoder->setBytes(
batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7);
compute_encoder->setBytes(&params, sizeof(GEMMParams), 3);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims); compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
} else { // Otherwise launch kernels with set offsets
MTL::Size grid_dims_single = MTL::Size(tn, tm, 1);
for (int i = 0; i < batch_size_out; ++i) {
auto a_off = elem_to_loc(M * K * i, a.shape(), a.strides());
auto b_off = elem_to_loc(K * N * i, b.shape(), b.strides());
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto b_buf = static_cast<const MTL::Buffer*>(b.buffer().ptr());
auto out_buf = static_cast<const MTL::Buffer*>(out.buffer().ptr());
compute_encoder->setBuffer(a_buf, a_off * a.itemsize(), 0);
compute_encoder->setBuffer(b_buf, b_off * b.itemsize(), 1);
compute_encoder->setBuffer(out_buf, i * M * N * out.itemsize(), 2);
compute_encoder->setBytes(&params, sizeof(GEMMParams), 3);
compute_encoder->dispatchThreadgroups(grid_dims_single, group_dims);
}
}
// Clear copies
d.get_command_buffer(s.index)->addCompletedHandler( d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); }); [copies](MTL::CommandBuffer*) mutable { copies.clear(); });
return; return;
@ -453,9 +517,9 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
auto check_transpose = [&copies, &s](const array& arr) { auto check_transpose = [&copies, &s](const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2]; auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1]; auto sty = arr.strides()[arr.ndim() - 1];
if (stx == arr.shape(-1) && sty == 1) { if (sty == 1) {
return std::make_tuple(false, stx, arr); return std::make_tuple(false, stx, arr);
} else if (stx == 1 && sty == arr.shape(-2)) { } else if (stx == 1) {
return std::make_tuple(true, sty, arr); return std::make_tuple(true, sty, arr);
} else { } else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
@ -473,8 +537,25 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
int N = b.shape(-1); int N = b.shape(-1);
int K = a.shape(-1); int K = a.shape(-1);
/////////////////////////////////////////////////////////////////////////////
// Check and collapse batch dimensions
auto [batch_shape, A_batch_stride, B_batch_stride] = collapse_batches(a, b);
auto batch_size_out = out.size() / (M * N); auto batch_size_out = out.size() / (M * N);
// Collapse batches into M if needed
if (batch_size_out > 1 && !a_transposed && batch_shape.size() == 1 &&
a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K &&
B_batch_stride.back() == 0) {
M *= batch_shape.back();
batch_size_out = 1;
A_batch_stride = {0};
B_batch_stride = {0};
batch_shape = {1};
}
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// Gemv specialization // Gemv specialization
@ -491,20 +572,18 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
int mat_cols = transpose_mat ? out_vector_len : in_vector_len; int mat_cols = transpose_mat ? out_vector_len : in_vector_len;
int mat_rows = transpose_mat ? in_vector_len : out_vector_len; int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
int mat_ld = is_b_matrix ? b_cols : a_cols;
int batch_size_mat = mat.data_size() / (mat_cols * mat_rows); auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride;
int stride_mat = batch_size_mat == 1 ? 0 : mat_cols * mat_rows; auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride;
int batch_size_vec = vec.data_size() / in_vector_len; int stride_mat = batch_strides_mat.back();
int stride_vec = batch_size_vec == 1 ? 0 : in_vector_len; int stride_vec = batch_strides_vec.back();
// Determine if inputs have simple batching / broadcasting // Determine if inputs have simple batching / broadcasting
bool contiguous_kernel = bool contiguous_kernel = (batch_shape.size() == 1);
(batch_size_out == std::max(batch_size_mat, batch_size_vec) &&
(batch_size_mat == batch_size_vec ||
std::min(batch_size_mat, batch_size_vec) == 1));
int nc_dim = out.ndim() - 2; int batch_ndim = batch_shape.size();
// Determine dispatch kernel // Determine dispatch kernel
int tm = 4, tn = 4; int tm = 4, tn = 4;
@ -540,10 +619,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
kname << "_bm" << bm << "_bn" << bn << "_tm" << tm << "_tn" << tn; kname << "_bm" << bm << "_bn" << bn << "_tm" << tm << "_tn" << tn;
kname << "_nc" << !contiguous_kernel << "_axpby0";
if (!contiguous_kernel) {
kname << "_nc";
}
// Encode and dispatch kernel // Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index); auto compute_encoder = d.get_command_encoder(s.index);
@ -556,25 +632,18 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
set_array_buffer(compute_encoder, mat, 0); set_array_buffer(compute_encoder, mat, 0);
set_array_buffer(compute_encoder, vec, 1); set_array_buffer(compute_encoder, vec, 1);
set_array_buffer(compute_encoder, out, 2); set_array_buffer(compute_encoder, out, 3);
compute_encoder->setBytes(&in_vector_len, sizeof(int), 3); compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
compute_encoder->setBytes(&out_vector_len, sizeof(int), 4); compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
compute_encoder->setBytes(&mat_ld, sizeof(int), 6);
if (contiguous_kernel) { compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
compute_encoder->setBytes(&stride_vec, sizeof(int), 5); compute_encoder->setBytes(batch_shape.data(), batch_ndim * sizeof(int), 10);
compute_encoder->setBytes(&stride_mat, sizeof(int), 6);
} else {
// In case of complex broadcasting, we consider the shape[:-2] and
// strides [:-2] to determine the location of a batch
// nc_dim = out.ndim() - 2
compute_encoder->setBytes(&nc_dim, sizeof(int), 5);
compute_encoder->setBytes(out.shape().data(), nc_dim * sizeof(int), 6);
compute_encoder->setBytes( compute_encoder->setBytes(
vec.strides().data(), nc_dim * sizeof(size_t), 7); batch_strides_vec.data(), batch_ndim * sizeof(size_t), 11);
compute_encoder->setBytes( compute_encoder->setBytes(
mat.strides().data(), nc_dim * sizeof(size_t), 8); batch_strides_mat.data(), batch_ndim * sizeof(size_t), 12);
}
compute_encoder->dispatchThreadgroups(grid_dims, group_dims); compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
@ -606,20 +675,23 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
return steel_matmul( return steel_matmul(
s, /* const Stream& s = */ s,
d, /* metal::Device& d = */ d,
a, /* const array& a = */ a,
b, /* const array& b = */ b,
out, /* array& out = */ out,
M, /* int M = */ M,
N, /* int N = */ N,
K, /* int K = */ K,
batch_size_out, /* int batch_size_out = */ batch_size_out,
a_cols, /* int lda = */ a_cols,
b_cols, /* int ldb = */ b_cols,
a_transposed, /* bool transpose_a = */ a_transposed,
b_transposed, /* bool transpose_b = */ b_transposed,
copies); /* std::vector<array>& = */ copies,
/* std::vector<int> batch_shape = */ batch_shape,
/* std::vector<size_t> A_batch_stride = */ A_batch_stride,
/* std::vector<size_t> B_batch_stride = */ B_batch_stride);
} }
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) { void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
@ -645,9 +717,9 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto check_transpose = [&copies, &s](const array& arr) { auto check_transpose = [&copies, &s](const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2]; auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1]; auto sty = arr.strides()[arr.ndim() - 1];
if (stx == arr.shape(-1) && sty == 1) { if (sty == 1) {
return std::make_tuple(false, stx, arr); return std::make_tuple(false, stx, arr);
} else if (stx == 1 && sty == arr.shape(-2)) { } else if (stx == 1) {
return std::make_tuple(true, sty, arr); return std::make_tuple(true, sty, arr);
} else { } else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
@ -665,33 +737,151 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
int N = b.shape(-1); int N = b.shape(-1);
int K = a.shape(-1); int K = a.shape(-1);
auto batch_size_out = out.size() / (M * N);
array c = c_pre; array c = c_pre;
int ldc = c.strides()[c.ndim() - 2]; int ldc = c.strides()[c.ndim() - 2];
int fdc = c.strides()[c.ndim() - 1]; int fdc = c.strides()[c.ndim() - 1];
int matrix_stride_c = c.ndim() <= 2 ? 0 : c.strides()[c.ndim() - 3];
int lda = a_cols; int lda = a_cols;
int ldb = b_cols; int ldb = b_cols;
int ldd = N;
/////////////////////////////////////////////////////////////////////////////
// Check and collapse batch dimensions
auto [batch_shape, A_batch_stride, B_batch_stride, C_batch_stride] =
collapse_batches(a, b, c);
auto batch_size_out = out.size() / (M * N);
// Collapse batches into M if needed
if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 &&
a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K &&
C_batch_stride.back() == M * c.strides()[c.ndim() - 2] &&
B_batch_stride.back() == 0) {
M *= batch_shape.back();
batch_size_out = 1;
A_batch_stride = {0};
B_batch_stride = {0};
C_batch_stride = {0};
batch_shape = {1};
}
int matrix_stride_out = M * N;
/////////////////////////////////////////////////////////////////////////////
// Gemv specialization
// Route to gemv if needed
if (std::min(M, N) == 1) {
// Collect problem info
bool is_b_matrix = N != 1;
auto& mat = is_b_matrix ? b : a;
auto& vec = is_b_matrix ? a : b;
bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a;
int in_vector_len = K;
int out_vector_len = is_b_matrix ? N : M;
int mat_cols = transpose_mat ? out_vector_len : in_vector_len;
int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
int mat_ld = is_b_matrix ? b_cols : a_cols;
auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride;
auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride;
int stride_mat = batch_strides_mat.back();
int stride_vec = batch_strides_vec.back();
// Determine if inputs have simple batching / broadcasting
bool contiguous_kernel = (batch_shape.size() == 1);
int batch_ndim = batch_shape.size();
// Determine dispatch kernel
int tm = 4, tn = 4;
int bm, bn, n_out_per_tgp;
std::ostringstream kname;
if (transpose_mat) {
bm = 8;
bn = 8;
if (out_vector_len >= 24576) {
bn = 128;
} else if (out_vector_len >= 16384) {
bn = 64;
} else if (out_vector_len >= 8192) {
bn = 16;
}
// Specialized kernel for very small outputs
tn = out_vector_len < tn ? 1 : tn;
n_out_per_tgp = bn * tn;
kname << "gemv_t_" << type_to_name(out);
} else {
bm = out_vector_len >= 4096 ? 8 : 4;
bn = 32;
// Specialized kernel for very small outputs
tm = out_vector_len < tm ? 1 : tm;
n_out_per_tgp = bm * tm;
kname << "gemv_" << type_to_name(out);
}
kname << "_bm" << bm << "_bn" << bn << "_tm" << tm << "_tn" << tn;
kname << "_nc" << !contiguous_kernel << "_axpby1";
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
MTL::Size group_dims = MTL::Size(bn, bm, 1);
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
set_array_buffer(compute_encoder, mat, 0);
set_array_buffer(compute_encoder, vec, 1);
set_array_buffer(compute_encoder, c, 2);
set_array_buffer(compute_encoder, out, 3);
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
compute_encoder->setBytes(&mat_ld, sizeof(int), 6);
compute_encoder->setBytes(&alpha_, sizeof(float), 7);
compute_encoder->setBytes(&beta_, sizeof(float), 8);
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
compute_encoder->setBytes(batch_shape.data(), batch_ndim * sizeof(int), 10);
compute_encoder->setBytes(
batch_strides_vec.data(), batch_ndim * sizeof(size_t), 11);
compute_encoder->setBytes(
batch_strides_mat.data(), batch_ndim * sizeof(size_t), 12);
compute_encoder->setBytes(
C_batch_stride.data(), batch_ndim * sizeof(size_t), 13);
int bias_stride = c.strides()[c.ndim() - 1];
compute_encoder->setBytes(&bias_stride, sizeof(int), 14);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
return;
}
using namespace mlx::steel; using namespace mlx::steel;
// Account for batch sizes and basic broadcasting /////////////////////////////////////////////////////////////////////////////
int batch_size_a = a.data_size() / (M * K); // Split K specialization
int batch_size_b = b.data_size() / (K * N);
int matrix_stride_a = (batch_size_a == 1) ? 0 : M * K;
int matrix_stride_b = (batch_size_b == 1) ? 0 : K * N;
int matrix_stride_out = M * N;
int _tm = M / 16; int _tm = M / 16;
int _tn = N / 16; int _tn = N / 16;
int _tk = K / 16; int _tk = K / 16;
/////////////////////////////////////////////////////////////////////////////
// Split K specialization
if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) { if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) {
int bm = M < 40 ? 16 : 32; int bm = M < 40 ? 16 : 32;
int bn = N < 40 ? 16 : 32; int bn = N < 40 ? 16 : 32;
@ -817,25 +1007,29 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// TODO: Explore device-based tuning for swizzle // TODO: Explore device-based tuning for swizzle
int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2); int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);
// Prepare steel matmul params
GEMMParams gemm_params{
/* const int M = */ M,
/* const int N = */ N,
/* const int K = */ K,
/* const int lda = */ lda,
/* const int ldb = */ ldb,
/* const int ldd = */ N,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const int batch_stride_a = */ int(A_batch_stride.back()),
/* const int batch_stride_b = */ int(B_batch_stride.back()),
/* const int batch_stride_d = */ matrix_stride_out,
/* const int swizzle_log = */ swizzle_log,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ int(batch_shape.size())};
GEMMAddMMParams params{ GEMMAddMMParams params{
M, /* const int ldc = */ ldc,
N, /* const int fdc = */ fdc,
K, /* const int batch_stride_c = */ int(C_batch_stride.back()),
lda, /* const float alpha = */ alpha_,
ldb, /* const float beta = */ beta_};
ldc,
N,
tn,
tm,
matrix_stride_a,
matrix_stride_b,
matrix_stride_c,
matrix_stride_out,
swizzle_log,
(K / bk),
alpha_,
beta_,
fdc};
int tile = 1 << swizzle_log; int tile = 1 << swizzle_log;
tm = (tm + tile - 1) / tile; tm = (tm + tile - 1) / tile;
@ -844,40 +1038,27 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out); MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
// Launch only 1 kernel in the case of simple batching / broadcasting std::vector<size_t> batch_strides = A_batch_stride;
if (batch_size_out == std::max(batch_size_a, batch_size_b) && batch_strides.insert(
(batch_size_a == batch_size_b || batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
std::min(batch_size_a, batch_size_b) == 1)) { batch_strides.insert(
batch_strides.end(), C_batch_stride.begin(), C_batch_stride.end());
// Launch kernel
set_array_buffer(compute_encoder, a, 0); set_array_buffer(compute_encoder, a, 0);
set_array_buffer(compute_encoder, b, 1); set_array_buffer(compute_encoder, b, 1);
set_array_buffer(compute_encoder, c, 2); set_array_buffer(compute_encoder, c, 2);
set_array_buffer(compute_encoder, out, 3); set_array_buffer(compute_encoder, out, 3);
compute_encoder->setBytes(&params, sizeof(GEMMAddMMParams), 4); compute_encoder->setBytes(&gemm_params, sizeof(GEMMParams), 4);
compute_encoder->setBytes(&params, sizeof(GEMMAddMMParams), 5);
compute_encoder->setBytes(
batch_shape.data(), sizeof(int) * batch_shape.size(), 6);
compute_encoder->setBytes(
batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims); compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
} else { // Otherwise launch kernels with set offsets
MTL::Size grid_dims_single = MTL::Size(tn, tm, 1);
for (int i = 0; i < batch_size_out; ++i) {
auto a_off = elem_to_loc(M * K * i, a.shape(), a.strides());
auto b_off = elem_to_loc(K * N * i, b.shape(), b.strides());
auto c_off = elem_to_loc(M * N * i, c.shape(), c.strides());
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto b_buf = static_cast<const MTL::Buffer*>(b.buffer().ptr());
auto c_buf = static_cast<const MTL::Buffer*>(c.buffer().ptr());
auto out_buf = static_cast<const MTL::Buffer*>(out.buffer().ptr());
compute_encoder->setBuffer(a_buf, a_off * a.itemsize(), 0);
compute_encoder->setBuffer(b_buf, b_off * b.itemsize(), 1);
compute_encoder->setBuffer(c_buf, c_off * c.itemsize(), 2);
compute_encoder->setBuffer(out_buf, i * M * N * out.itemsize(), 3);
compute_encoder->setBytes(&params, sizeof(GEMMAddMMParams), 4);
compute_encoder->dispatchThreadgroups(grid_dims_single, group_dims);
}
}
d.get_command_buffer(s.index)->addCompletedHandler( d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); }); [copies](MTL::CommandBuffer*) mutable { copies.clear(); });

View File

@ -26,6 +26,9 @@ void steel_matmul(
int ldb, int ldb,
bool transpose_a, bool transpose_a,
bool transpose_b, bool transpose_b,
std::vector<array>& copies); std::vector<array>& copies,
std::vector<int> batch_shape = {},
std::vector<size_t> A_batch_stride = {},
std::vector<size_t> B_batch_stride = {});
} // namespace mlx::core } // namespace mlx::core

View File

@ -3315,17 +3315,8 @@ array addmm(
const float& alpha /* = 1.f */, const float& alpha /* = 1.f */,
const float& beta /* = 1.f */, const float& beta /* = 1.f */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
// Divert in the case of vector-matrix multiplication int in_a_ndim = a.ndim();
// TODO: Add the needed specializtion int in_b_ndim = b.ndim();
if (a.ndim() == 1 || b.ndim() == 1) {
array X = matmul(a, b, s);
array alpha_arr = array(alpha, X.dtype());
array aX = multiply(alpha_arr, X, s);
array beta_arr = array(beta, c.dtype());
array bY = multiply(beta_arr, c, s);
return add(aX, bY, s);
}
if (a.ndim() == 0 || b.ndim() == 0) { if (a.ndim() == 0 || b.ndim() == 0) {
throw std::invalid_argument( throw std::invalid_argument(
@ -3333,6 +3324,15 @@ array addmm(
"have at least one dimension."); "have at least one dimension.");
} }
if (a.ndim() == 1) {
// Insert a singleton dim in the beginning
a = reshape(a, {1, -1}, s);
}
if (b.ndim() == 1) {
// Insert a singleton dim at the end
b = reshape(b, {-1, 1}, s);
}
if (a.shape(-1) != b.shape(-2)) { if (a.shape(-1) != b.shape(-2)) {
std::ostringstream msg; std::ostringstream msg;
msg << "[addmm] Last dimension of first input with shape " << a.shape() msg << "[addmm] Last dimension of first input with shape " << a.shape()
@ -3361,7 +3361,13 @@ array addmm(
std::vector<int> out_shape = a.shape(); std::vector<int> out_shape = a.shape();
a = reshape(a, {-1, out_shape.back()}, s); a = reshape(a, {-1, out_shape.back()}, s);
out_shape.back() = b.shape(-1); out_shape.back() = b.shape(-1);
if (in_b_ndim == 1) {
out_shape.pop_back();
}
c = broadcast_to(c, {a.shape(0), b.shape(1)}, s); c = broadcast_to(c, {a.shape(0), b.shape(1)}, s);
auto out = array( auto out = array(
{a.shape(0), b.shape(1)}, {a.shape(0), b.shape(1)},
out_type, out_type,
@ -3389,15 +3395,42 @@ array addmm(
auto out_shape = a.shape(); auto out_shape = a.shape();
out_shape.back() = b.shape(-1); out_shape.back() = b.shape(-1);
auto c_broadcast_shape = broadcast_shapes(c.shape(), out_shape); auto out_shape_adjusted = out_shape;
if (in_a_ndim == 1 || in_b_ndim == 1) {
out_shape_adjusted.erase(
out_shape_adjusted.end() - ((in_a_ndim == 1) ? 2 : 1),
out_shape_adjusted.end() - ((in_b_ndim == 1) ? 0 : 1));
}
auto c_broadcast_shape = broadcast_shapes(c.shape(), out_shape_adjusted);
c = broadcast_to(c, c_broadcast_shape, s); c = broadcast_to(c, c_broadcast_shape, s);
if (in_a_ndim == 1 || in_b_ndim == 1) {
auto c_reshape = c.shape();
if (in_b_ndim == 1) {
c_reshape.push_back(1);
}
if (in_a_ndim == 1) {
c_reshape.push_back(c_reshape.back());
c_reshape[c_reshape.size() - 2] = 1;
}
c = reshape(c, c_reshape, s);
}
auto out = array( auto out = array(
out_shape, out_shape,
out_type, out_type,
std::make_unique<AddMM>(to_stream(s), alpha, beta), std::make_unique<AddMM>(to_stream(s), alpha, beta),
{a, b, c}); {a, b, c});
// Remove the possibly inserted singleton dimensions
if (in_a_ndim == 1 || in_b_ndim == 1) {
out = reshape(out, out_shape_adjusted, s);
}
return out; return out;
} }

View File

@ -393,6 +393,77 @@ class TestBlas(mlx_tests.MLXTestCase):
mlx_vec_f=lambda vec_mlx: mx.broadcast_to(vec_mlx, shape_vec), mlx_vec_f=lambda vec_mlx: mx.broadcast_to(vec_mlx, shape_vec),
) )
def test_matrix_vector_attn(self):
# Multi-query style attention check
for dtype in self.dtypes:
# fmt: off
for (B, D, n_kv_heads, factor, qsl, ksl) in (
(1, 16, 8, 4, 1, 256),
(1, 16, 8, 4, 32, 256),
(1, 16, 8, 4, 256, 1),
(4, 16, 8, 4, 1, 256),
(4, 16, 8, 4, 256, 1),
):
# fmt: on
with self.subTest(
B=B, # Batch size
D=D, # Dimension of mm
n_kv_heads=n_kv_heads, # key-value heads
factor=factor, # factor to get query heads
qsl=qsl, # Query sequence length
ksl=ksl, # Key sequence length
dtype=dtype # Data type
):
np_dtype = getattr(np, dtype)
# Fix shapes for kqv
n_q_heads = n_kv_heads * factor
Dk = D * n_kv_heads
Dq = D * n_q_heads
scale = 1. / math.sqrt(Dk)
shape_queries = (B, qsl, Dq)
shape_keys = (B, ksl, Dk)
shape_values = (B, ksl, Dk)
# Prepare numpy arrays
q_np = np.random.uniform(-scale, scale, size=shape_queries).astype(np_dtype)
k_np = np.random.uniform(-scale, scale, size=shape_keys).astype(np_dtype)
v_np = np.random.uniform(-scale, scale, size=shape_values).astype(np_dtype)
# Rearrange to move heads up
q_np_reshape = q_np.reshape(B, qsl, n_kv_heads, factor, -1).transpose(0, 2, 3, 1, 4)
k_np_reshape = k_np.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 4, 1)
v_np_reshape = v_np.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 1, 4)
# Do attn style matmul
s_np = q_np_reshape @ k_np_reshape
o_np = s_np @ v_np_reshape
o_np = o_np.transpose(0, 3, 1, 2, 4).reshape(B, qsl, -1)
# Test mlx
q_mx = mx.array(q_np)
k_mx = mx.array(k_np)
v_mx = mx.array(v_np)
# Rearrange to move heads up
q_mx_reshape = q_mx.reshape(B, qsl, n_kv_heads, factor, -1).transpose(0, 2, 3, 1, 4)
k_mx_reshape = k_mx.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 4, 1)
v_mx_reshape = v_mx.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 1, 4)
# Do attn style matmul
s_mx = q_mx_reshape @ k_mx_reshape
o_mx = (s_mx @ v_mx_reshape)
o_mx = o_mx.transpose(0, 3, 1, 2, 4).reshape(B, qsl, -1)
# Check against np
self.assertListEqual(list(s_np.shape), list(s_mx.shape))
self.assertTrue(np.allclose(s_np, s_mx, atol=1e-4))
self.assertListEqual(list(o_np.shape), list(o_mx.shape))
self.assertTrue(np.allclose(o_np, o_mx, atol=1e-4))
def test_matrix_vector_edgecases(self): def test_matrix_vector_edgecases(self):
for dtype in self.dtypes: for dtype in self.dtypes:
with self.subTest(dtype=dtype): with self.subTest(dtype=dtype):
@ -503,16 +574,29 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Matmul with vector
a_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 128)).astype(np.float32)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
for c_shape in ((1,), (128,), (32, 128)):
c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy)
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Matmul with vector # Matmul with vector
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32) b_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
a_mlx = mx.array(a_npy) a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy) b_mlx = mx.array(b_npy)
for c_shape in ( for c_shape in ((1,), (32, 128)):
(1,),
(32, 128),
):
c_npy = np.ones(c_shape).astype(np.float32) c_npy = np.ones(c_shape).astype(np.float32)
c_mlx = mx.array(c_npy) c_mlx = mx.array(c_npy)
@ -564,16 +648,12 @@ class TestBlas(mlx_tests.MLXTestCase):
out_ref, dout_ref = mx.vjp( out_ref, dout_ref = mx.vjp(
f_ref, f_ref,
[c, a, b], [c, a, b],
[ [cotan],
cotan,
],
) )
out_test, dout_test = mx.vjp( out_test, dout_test = mx.vjp(
f_test, f_test,
[c, a, b], [c, a, b],
[ [cotan],
cotan,
],
) )
self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-4).item()) self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-4).item())