mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-20 10:27:41 +08:00
No copy gems (#801)
* Enable collapsing batch dims in gemm * Update gemm to only make copies when neither of the last 2 axes are contiguous * Update addmm to support gemv shapes * Update addmm to support irregular batch strides * Update tests
This commit is contained in:
parent
d0c544a868
commit
5ad133f8bb
@ -28,10 +28,12 @@ void explicit_gemm_conv_ND_gpu(
|
||||
const array& wt,
|
||||
array out,
|
||||
const MLXConvParams<N>& conv_params) {
|
||||
// Get gemm shapes
|
||||
int implicit_M = out.size() / conv_params.O;
|
||||
int implicit_K = wt.size() / conv_params.O;
|
||||
int implicit_N = conv_params.O;
|
||||
// Prepare unfolding array
|
||||
std::vector<int> unfolded_shape = {
|
||||
static_cast<int>(out.size() / conv_params.O),
|
||||
static_cast<int>(wt.size() / conv_params.O)};
|
||||
std::vector<int> unfolded_shape{implicit_M, implicit_K};
|
||||
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
|
||||
|
||||
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
|
||||
@ -59,20 +61,29 @@ void explicit_gemm_conv_ND_gpu(
|
||||
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
// Reshape weight
|
||||
std::vector<int> wt_reshape{implicit_K, implicit_N};
|
||||
std::vector<size_t> wt_restride{1, static_cast<size_t>(implicit_K)};
|
||||
array wt_reshaped(wt_reshape, wt.dtype(), nullptr, {});
|
||||
auto wt_flags = wt.flags();
|
||||
wt_flags.row_contiguous = false;
|
||||
wt_flags.col_contiguous = true;
|
||||
wt_reshaped.copy_shared_buffer(wt, wt_restride, wt_flags, wt.data_size());
|
||||
|
||||
// Perform gemm
|
||||
std::vector<array> copies;
|
||||
std::vector<array> copies = {in_unfolded, wt_reshaped};
|
||||
return steel_matmul(
|
||||
s,
|
||||
d,
|
||||
/*a = */ in_unfolded,
|
||||
/*b = */ wt,
|
||||
/*b = */ wt_reshaped,
|
||||
/*c = */ out,
|
||||
/*M = */ unfolded_shape[0],
|
||||
/*N = */ conv_params.O,
|
||||
/*K = */ unfolded_shape[1],
|
||||
/*M = */ implicit_M,
|
||||
/*N = */ implicit_N,
|
||||
/*K = */ implicit_K,
|
||||
/*batch_size_out = */ 1,
|
||||
/*a_cols = */ unfolded_shape[1],
|
||||
/*b_cols = */ unfolded_shape[1],
|
||||
/*a_cols = */ implicit_K,
|
||||
/*b_cols = */ implicit_K,
|
||||
/*a_transposed = */ false,
|
||||
/*b_transposed = */ true,
|
||||
/*copies = */ copies);
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <metal_stdlib>
|
||||
#include <metal_simdgroup>
|
||||
@ -22,7 +22,8 @@ template <
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN > /* Thread cols (in elements) */
|
||||
const int TN , /* Thread cols (in elements) */
|
||||
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
|
||||
struct GEMVKernel {
|
||||
|
||||
static_assert(BN == SIMD_SIZE, "gemv block must have a width of SIMD_SIZE");
|
||||
@ -48,11 +49,16 @@ struct GEMVKernel {
|
||||
MLX_MTL_CONST short tgp_mem_size = BN * TN * 2;
|
||||
|
||||
static METAL_FUNC void run(
|
||||
const device T* mat,
|
||||
const device T* in_vec,
|
||||
device T* out_vec,
|
||||
const constant int& in_vec_size [[buffer(3)]],
|
||||
const constant int& out_vec_size [[buffer(4)]],
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
const device T* bias [[buffer(2)]],
|
||||
device T* out_vec [[buffer(3)]],
|
||||
const constant int& in_vec_size [[buffer(4)]],
|
||||
const constant int& out_vec_size [[buffer(5)]],
|
||||
const constant int& marix_ld [[buffer(6)]],
|
||||
const constant float& alpha [[buffer(7)]],
|
||||
const constant float& beta [[buffer(8)]],
|
||||
const constant int& bias_stride [[buffer(14)]],
|
||||
threadgroup T* tgp_memory [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
@ -81,7 +87,7 @@ struct GEMVKernel {
|
||||
out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM;
|
||||
|
||||
// Advance matrix
|
||||
mat += out_row * in_vec_size;
|
||||
mat += out_row * marix_ld;
|
||||
|
||||
// Loop over in_vec in blocks of BN * TN
|
||||
for(int bn = simd_lid * TN; bn < in_vec_size; bn += BN * TN) {
|
||||
@ -124,14 +130,14 @@ struct GEMVKernel {
|
||||
if(bn + TN <= in_vec_size) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[tm * in_vec_size + bn + tn];
|
||||
inter[tn] = mat[tm * marix_ld + bn + tn];
|
||||
}
|
||||
|
||||
} else { // Edgecase
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
int col_idx = (bn + tn) < in_vec_size ? (bn + tn) : (in_vec_size - 1);
|
||||
inter[tn] = mat[tm * in_vec_size + col_idx];
|
||||
inter[tn] = mat[tm * marix_ld + col_idx];
|
||||
}
|
||||
}
|
||||
|
||||
@ -154,7 +160,13 @@ struct GEMVKernel {
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tm = 0; tm < TM; tm++) {
|
||||
out_vec[out_row + tm] = result[tm];
|
||||
if(kDoAxpby) {
|
||||
out_vec[out_row + tm] =
|
||||
static_cast<T>(alpha) * result[tm] +
|
||||
static_cast<T>(beta) * bias[(out_row + tm) * bias_stride];
|
||||
} else {
|
||||
out_vec[out_row + tm] = result[tm];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@ -172,7 +184,8 @@ template <
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN > /* Thread cols (in elements) */
|
||||
const int TN, /* Thread cols (in elements) */
|
||||
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
|
||||
struct GEMVTKernel {
|
||||
|
||||
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
|
||||
@ -197,11 +210,16 @@ struct GEMVTKernel {
|
||||
MLX_MTL_CONST short tgp_mem_size = BN * BM * TN;
|
||||
|
||||
static METAL_FUNC void run(
|
||||
const device T* mat,
|
||||
const device T* in_vec,
|
||||
device T* out_vec,
|
||||
const constant int& in_vec_size [[buffer(3)]],
|
||||
const constant int& out_vec_size [[buffer(4)]],
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
const device T* bias [[buffer(2)]],
|
||||
device T* out_vec [[buffer(3)]],
|
||||
const constant int& in_vec_size [[buffer(4)]],
|
||||
const constant int& out_vec_size [[buffer(5)]],
|
||||
const constant int& marix_ld [[buffer(6)]],
|
||||
const constant float& alpha [[buffer(7)]],
|
||||
const constant float& beta [[buffer(8)]],
|
||||
const constant int& bias_stride [[buffer(14)]],
|
||||
threadgroup T* tgp_memory [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
@ -245,7 +263,7 @@ struct GEMVTKernel {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tm = 0; tm < TM; tm++) {
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[(bm + tm) * out_vec_size + out_col + tn];
|
||||
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
||||
}
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
result[tn] += v_coeff[tm] * inter[tn];
|
||||
@ -257,7 +275,7 @@ struct GEMVTKernel {
|
||||
v_coeff[tm] = in_vec[bm + tm];
|
||||
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[(bm + tm) * out_vec_size + out_col + tn];
|
||||
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
||||
}
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
result[tn] += v_coeff[tm] * inter[tn];
|
||||
@ -292,13 +310,17 @@ struct GEMVTKernel {
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int j = 0; j < TN; j++) {
|
||||
out_vec[out_col + j] = result[j];
|
||||
|
||||
if(kDoAxpby) {
|
||||
out_vec[out_col + j] =
|
||||
static_cast<T>(alpha) * result[j] +
|
||||
static_cast<T>(beta) * bias[(out_col + j) * bias_stride];
|
||||
} else {
|
||||
out_vec[out_col + j] = result[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@ -310,78 +332,64 @@ template <
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN> /* Thread cols (in elements) */
|
||||
const int TN, /* Thread cols (in elements) */
|
||||
const bool kDoNCBatch, /* Batch ndim > 1 */
|
||||
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
|
||||
[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
device T* out_vec [[buffer(2)]],
|
||||
const constant int& in_vec_size [[buffer(3)]],
|
||||
const constant int& out_vec_size [[buffer(4)]],
|
||||
const constant int& vector_batch_stride [[buffer(5)]],
|
||||
const constant int& matrix_batch_stride [[buffer(6)]],
|
||||
const device T* bias [[buffer(2)]],
|
||||
device T* out_vec [[buffer(3)]],
|
||||
const constant int& in_vec_size [[buffer(4)]],
|
||||
const constant int& out_vec_size [[buffer(5)]],
|
||||
const constant int& marix_ld [[buffer(6)]],
|
||||
const constant float& alpha [[buffer(7)]],
|
||||
const constant float& beta [[buffer(8)]],
|
||||
const constant int& batch_ndim [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]],
|
||||
const constant size_t* bias_batch_stride [[buffer(13)]],
|
||||
const constant int& bias_stride [[buffer(14)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
using gemv_kernel = GEMVKernel<T, BM, BN, TM, TN>;
|
||||
using gemv_kernel = GEMVKernel<T, BM, BN, TM, TN, kDoAxpby>;
|
||||
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
|
||||
|
||||
// Update batch offsets
|
||||
in_vec += tid.z * vector_batch_stride;
|
||||
mat += tid.z * matrix_batch_stride;
|
||||
out_vec += tid.z * out_vec_size;
|
||||
|
||||
gemv_kernel::run(
|
||||
mat,
|
||||
in_vec,
|
||||
out_vec,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
tgp_memory,
|
||||
tid,
|
||||
lid,
|
||||
simd_gid,
|
||||
simd_lid
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN> /* Thread cols (in elements) */
|
||||
[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv_nc(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
device T* out_vec [[buffer(2)]],
|
||||
const constant int& in_vec_size [[buffer(3)]],
|
||||
const constant int& out_vec_size [[buffer(4)]],
|
||||
const constant int& nc_dim [[buffer(5)]],
|
||||
const device int* nc_shape [[buffer(6)]],
|
||||
const device size_t* nc_strides_vec [[buffer(7)]],
|
||||
const device size_t* nc_strides_mat [[buffer(8)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
using gemv_kernel = GEMVKernel<T, BM, BN, TM, TN>;
|
||||
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
|
||||
|
||||
// Update batch offsets
|
||||
in_vec += elem_to_loc(tid.z, nc_shape, nc_strides_vec, nc_dim);
|
||||
mat += elem_to_loc(tid.z, nc_shape, nc_strides_mat, nc_dim);
|
||||
if(kDoNCBatch) {
|
||||
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
|
||||
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
|
||||
|
||||
if(kDoAxpby) {
|
||||
bias += elem_to_loc(tid.z, batch_shape, bias_batch_stride, batch_ndim);
|
||||
}
|
||||
|
||||
} else {
|
||||
in_vec += tid.z * vector_batch_stride[0];
|
||||
mat += tid.z * matrix_batch_stride[0];
|
||||
|
||||
if(kDoAxpby) {
|
||||
bias += tid.z * bias_batch_stride[0];
|
||||
}
|
||||
}
|
||||
|
||||
out_vec += tid.z * out_vec_size;
|
||||
|
||||
gemv_kernel::run(
|
||||
mat,
|
||||
in_vec,
|
||||
bias,
|
||||
out_vec,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
marix_ld,
|
||||
alpha,
|
||||
beta,
|
||||
bias_stride,
|
||||
tgp_memory,
|
||||
tid,
|
||||
lid,
|
||||
@ -392,41 +400,34 @@ template <
|
||||
}
|
||||
|
||||
|
||||
#define instantiate_gemv_c(name, itype, bm, bn, tm, tn) \
|
||||
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \
|
||||
[[kernel]] void gemv<itype, bm, bn, tm, tn>( \
|
||||
#define instantiate_gemv_helper(name, itype, bm, bn, tm, tn, nc, axpby) \
|
||||
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc" #nc "_axpby" #axpby)]] \
|
||||
[[kernel]] void gemv<itype, bm, bn, tm, tn, nc, axpby>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* vec [[buffer(1)]], \
|
||||
device itype* out [[buffer(2)]], \
|
||||
const constant int& in_vec_size [[buffer(3)]], \
|
||||
const constant int& out_vec_size [[buffer(4)]], \
|
||||
const constant int& vector_batch_stride [[buffer(5)]], \
|
||||
const constant int& matrix_batch_stride [[buffer(6)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_gemv_nc(name, itype, bm, bn, tm, tn) \
|
||||
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc")]] \
|
||||
[[kernel]] void gemv_nc<itype, bm, bn, tm, tn>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* vec [[buffer(1)]], \
|
||||
device itype* out [[buffer(2)]], \
|
||||
const constant int& in_vec_size [[buffer(3)]], \
|
||||
const constant int& out_vec_size [[buffer(4)]], \
|
||||
const constant int& nc_dim [[buffer(5)]], \
|
||||
const device int* nc_shape [[buffer(6)]], \
|
||||
const device size_t* nc_strides_vec [[buffer(7)]], \
|
||||
const device size_t* nc_strides_mat [[buffer(8)]], \
|
||||
const device itype* in_vec [[buffer(1)]], \
|
||||
const device itype* bias [[buffer(2)]], \
|
||||
device itype* out_vec [[buffer(3)]], \
|
||||
const constant int& in_vec_size [[buffer(4)]], \
|
||||
const constant int& out_vec_size [[buffer(5)]], \
|
||||
const constant int& marix_ld [[buffer(6)]], \
|
||||
const constant float& alpha [[buffer(7)]], \
|
||||
const constant float& beta [[buffer(8)]], \
|
||||
const constant int& batch_ndim [[buffer(9)]], \
|
||||
const constant int* batch_shape [[buffer(10)]], \
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]], \
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]], \
|
||||
const constant size_t* bias_batch_stride [[buffer(13)]], \
|
||||
const constant int& bias_stride [[buffer(14)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_gemv(name, itype, bm, bn, tm, tn) \
|
||||
instantiate_gemv_c(name, itype, bm, bn, tm, tn) \
|
||||
instantiate_gemv_nc(name, itype, bm, bn, tm, tn)
|
||||
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 0) \
|
||||
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 1) \
|
||||
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 0) \
|
||||
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 1)
|
||||
|
||||
#define instantiate_gemv_blocks(name, itype) \
|
||||
instantiate_gemv(name, itype, 4, 32, 1, 4) \
|
||||
@ -446,77 +447,64 @@ template <
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN> /* Thread cols (in elements) */
|
||||
const int TN, /* Thread cols (in elements) */
|
||||
const bool kDoNCBatch, /* Batch ndim > 1 */
|
||||
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
|
||||
[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv_t(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
device T* out_vec [[buffer(2)]],
|
||||
const constant int& in_vec_size [[buffer(3)]],
|
||||
const constant int& out_vec_size [[buffer(4)]],
|
||||
const constant int& vector_batch_stride [[buffer(5)]],
|
||||
const constant int& matrix_batch_stride [[buffer(6)]],
|
||||
const device T* bias [[buffer(2)]],
|
||||
device T* out_vec [[buffer(3)]],
|
||||
const constant int& in_vec_size [[buffer(4)]],
|
||||
const constant int& out_vec_size [[buffer(5)]],
|
||||
const constant int& marix_ld [[buffer(6)]],
|
||||
const constant float& alpha [[buffer(7)]],
|
||||
const constant float& beta [[buffer(8)]],
|
||||
const constant int& batch_ndim [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]],
|
||||
const constant size_t* bias_batch_stride [[buffer(13)]],
|
||||
const constant int& bias_stride [[buffer(14)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
using gemv_kernel = GEMVTKernel<T, BM, BN, TM, TN>;
|
||||
using gemv_kernel = GEMVTKernel<T, BM, BN, TM, TN, kDoAxpby>;
|
||||
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
|
||||
|
||||
// Update batch offsets
|
||||
in_vec += tid.z * vector_batch_stride;
|
||||
mat += tid.z * matrix_batch_stride;
|
||||
out_vec += tid.z * out_vec_size;
|
||||
|
||||
gemv_kernel::run(
|
||||
mat,
|
||||
in_vec,
|
||||
out_vec,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
tgp_memory,
|
||||
tid,
|
||||
lid,
|
||||
simd_gid,
|
||||
simd_lid
|
||||
);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN> /* Thread cols (in elements) */
|
||||
[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv_t_nc(
|
||||
const device T* mat [[buffer(0)]],
|
||||
const device T* in_vec [[buffer(1)]],
|
||||
device T* out_vec [[buffer(2)]],
|
||||
const constant int& in_vec_size [[buffer(3)]],
|
||||
const constant int& out_vec_size [[buffer(4)]],
|
||||
const constant int& nc_dim [[buffer(5)]],
|
||||
const device int* nc_shape [[buffer(6)]],
|
||||
const device size_t* nc_strides_vec [[buffer(7)]],
|
||||
const device size_t* nc_strides_mat [[buffer(8)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
using gemv_kernel = GEMVTKernel<T, BM, BN, TM, TN>;
|
||||
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
|
||||
|
||||
// Update batch offsets
|
||||
in_vec += elem_to_loc(tid.z, nc_shape, nc_strides_vec, nc_dim);
|
||||
mat += elem_to_loc(tid.z, nc_shape, nc_strides_mat, nc_dim);
|
||||
if(kDoNCBatch) {
|
||||
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
|
||||
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
|
||||
|
||||
if(kDoAxpby) {
|
||||
bias += elem_to_loc(tid.z, batch_shape, bias_batch_stride, batch_ndim);
|
||||
}
|
||||
|
||||
} else {
|
||||
in_vec += tid.z * vector_batch_stride[0];
|
||||
mat += tid.z * matrix_batch_stride[0];
|
||||
|
||||
if(kDoAxpby) {
|
||||
bias += tid.z * bias_batch_stride[0];
|
||||
}
|
||||
}
|
||||
|
||||
out_vec += tid.z * out_vec_size;
|
||||
|
||||
gemv_kernel::run(
|
||||
mat,
|
||||
in_vec,
|
||||
bias,
|
||||
out_vec,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
marix_ld,
|
||||
alpha,
|
||||
beta,
|
||||
bias_stride,
|
||||
tgp_memory,
|
||||
tid,
|
||||
lid,
|
||||
@ -526,41 +514,34 @@ template <
|
||||
|
||||
}
|
||||
|
||||
#define instantiate_gemv_t_c(name, itype, bm, bn, tm, tn) \
|
||||
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \
|
||||
[[kernel]] void gemv_t<itype, bm, bn, tm, tn>( \
|
||||
#define instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, nc, axpby) \
|
||||
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc" #nc "_axpby" #axpby)]] \
|
||||
[[kernel]] void gemv_t<itype, bm, bn, tm, tn, nc, axpby>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* vec [[buffer(1)]], \
|
||||
device itype* out [[buffer(2)]], \
|
||||
const constant int& in_vec_size [[buffer(3)]], \
|
||||
const constant int& out_vec_size [[buffer(4)]], \
|
||||
const constant int& vector_batch_stride [[buffer(5)]], \
|
||||
const constant int& matrix_batch_stride [[buffer(6)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_gemv_t_nc(name, itype, bm, bn, tm, tn) \
|
||||
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc")]] \
|
||||
[[kernel]] void gemv_t_nc<itype, bm, bn, tm, tn>( \
|
||||
const device itype* mat [[buffer(0)]], \
|
||||
const device itype* vec [[buffer(1)]], \
|
||||
device itype* out [[buffer(2)]], \
|
||||
const constant int& in_vec_size [[buffer(3)]], \
|
||||
const constant int& out_vec_size [[buffer(4)]], \
|
||||
const constant int& nc_dim [[buffer(5)]], \
|
||||
const device int* nc_shape [[buffer(6)]], \
|
||||
const device size_t* nc_strides_vec [[buffer(7)]], \
|
||||
const device size_t* nc_strides_mat [[buffer(8)]], \
|
||||
const device itype* in_vec [[buffer(1)]], \
|
||||
const device itype* bias [[buffer(2)]], \
|
||||
device itype* out_vec [[buffer(3)]], \
|
||||
const constant int& in_vec_size [[buffer(4)]], \
|
||||
const constant int& out_vec_size [[buffer(5)]], \
|
||||
const constant int& marix_ld [[buffer(6)]], \
|
||||
const constant float& alpha [[buffer(7)]], \
|
||||
const constant float& beta [[buffer(8)]], \
|
||||
const constant int& batch_ndim [[buffer(9)]], \
|
||||
const constant int* batch_shape [[buffer(10)]], \
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]], \
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]], \
|
||||
const constant size_t* bias_batch_stride [[buffer(13)]], \
|
||||
const constant int& bias_stride [[buffer(14)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_gemv_t(name, itype, bm, bn, tm, tn) \
|
||||
instantiate_gemv_t_c(name, itype, bm, bn, tm, tn) \
|
||||
instantiate_gemv_t_nc(name, itype, bm, bn, tm, tn)
|
||||
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 0, 0) \
|
||||
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 0, 1) \
|
||||
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 1, 0) \
|
||||
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 1, 1)
|
||||
|
||||
#define instantiate_gemv_t_blocks(name, itype) \
|
||||
instantiate_gemv_t(name, itype, 8, 8, 4, 1) \
|
||||
|
@ -140,7 +140,7 @@ struct GEMMKernel {
|
||||
static METAL_FUNC void run(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
device U* C [[buffer(2)]],
|
||||
device U* D [[buffer(2)]],
|
||||
const constant GEMMParams* params [[buffer(3)]],
|
||||
threadgroup T* As [[threadgroup(0)]],
|
||||
threadgroup T* Bs [[threadgroup(1)]],
|
||||
@ -167,7 +167,7 @@ struct GEMMKernel {
|
||||
|
||||
A += transpose_a ? c_row : c_row * params->lda;
|
||||
B += transpose_b ? c_col * params->ldb : c_col;
|
||||
C += c_row * params->ldc + c_col;
|
||||
D += c_row * params->ldd + c_col;
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
@ -214,7 +214,7 @@ struct GEMMKernel {
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
mma_op.store_result(C, params->ldc);
|
||||
mma_op.store_result(D, params->ldd);
|
||||
return;
|
||||
|
||||
}
|
||||
@ -237,7 +237,7 @@ struct GEMMKernel {
|
||||
tgp_bn,
|
||||
leftover_bk);
|
||||
|
||||
mma_op.store_result(C, params->ldc);
|
||||
mma_op.store_result(D, params->ldd);
|
||||
return;
|
||||
|
||||
} else if (tgp_bn == BN) {
|
||||
@ -252,7 +252,7 @@ struct GEMMKernel {
|
||||
tgp_bn,
|
||||
leftover_bk);
|
||||
|
||||
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
||||
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
return;
|
||||
|
||||
} else if (tgp_bm == BM) {
|
||||
@ -267,7 +267,7 @@ struct GEMMKernel {
|
||||
tgp_bn,
|
||||
leftover_bk);
|
||||
|
||||
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
||||
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
return;
|
||||
|
||||
} else {
|
||||
@ -282,7 +282,7 @@ struct GEMMKernel {
|
||||
tgp_bn,
|
||||
leftover_bk);
|
||||
|
||||
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
||||
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
|
||||
using namespace metal;
|
||||
@ -23,8 +24,10 @@ template <typename T,
|
||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm(
|
||||
const device T *A [[buffer(0)]],
|
||||
const device T *B [[buffer(1)]],
|
||||
device T *C [[buffer(2)]],
|
||||
const constant GEMMParams* params [[buffer(3)]],
|
||||
device T *D [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant size_t* batch_strides [[buffer(7)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
@ -36,12 +39,25 @@ template <typename T,
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
|
||||
// Adjust for batch
|
||||
A += params->batch_stride_a * tid.z;
|
||||
B += params->batch_stride_b * tid.z;
|
||||
C += params->batch_stride_c * tid.z;
|
||||
if(params->batch_ndim > 1) {
|
||||
const constant size_t* A_bstrides = batch_strides;
|
||||
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
|
||||
|
||||
A += batch_offsets.x;
|
||||
B += batch_offsets.y;
|
||||
|
||||
} else {
|
||||
A += params->batch_stride_a * tid.z;
|
||||
B += params->batch_stride_b * tid.z;
|
||||
}
|
||||
|
||||
D += params->batch_stride_d * tid.z;
|
||||
|
||||
gemm_kernel::run(
|
||||
A, B, C,
|
||||
A, B, D,
|
||||
params,
|
||||
As, Bs,
|
||||
simd_lane_id, simd_group_id, tid, lid
|
||||
@ -57,8 +73,10 @@ template <typename T,
|
||||
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
|
||||
const device itype *A [[buffer(0)]], \
|
||||
const device itype *B [[buffer(1)]], \
|
||||
device itype *C [[buffer(2)]], \
|
||||
const constant GEMMParams* params [[buffer(3)]], \
|
||||
device itype *D [[buffer(3)]], \
|
||||
const constant GEMMParams* params [[buffer(4)]], \
|
||||
const constant int* batch_shape [[buffer(6)]], \
|
||||
const constant size_t* batch_strides [[buffer(7)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
|
@ -27,7 +27,10 @@ template <typename T,
|
||||
const device T *B [[buffer(1)]],
|
||||
const device T *C [[buffer(2)]],
|
||||
device T *D [[buffer(3)]],
|
||||
const constant GEMMAddMMParams* params [[buffer(4)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant GEMMAddMMParams* addmm_params [[buffer(5)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant size_t* batch_strides [[buffer(7)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
@ -50,9 +53,24 @@ template <typename T,
|
||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
||||
|
||||
// Adjust for batch
|
||||
A += params->batch_stride_a * tid.z;
|
||||
B += params->batch_stride_b * tid.z;
|
||||
C += params->batch_stride_c * tid.z;
|
||||
if(params->batch_ndim > 1) {
|
||||
const constant size_t* A_bstrides = batch_strides;
|
||||
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
|
||||
const constant size_t* C_bstrides = B_bstrides + params->batch_ndim;
|
||||
|
||||
ulong3 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, A_bstrides, B_bstrides, C_bstrides, params->batch_ndim);
|
||||
|
||||
A += batch_offsets.x;
|
||||
B += batch_offsets.y;
|
||||
C += batch_offsets.z;
|
||||
|
||||
} else {
|
||||
A += params->batch_stride_a * tid.z;
|
||||
B += params->batch_stride_b * tid.z;
|
||||
C += addmm_params->batch_stride_c * tid.z;
|
||||
}
|
||||
|
||||
D += params->batch_stride_d * tid.z;
|
||||
|
||||
const int tid_y = ((tid.y) << params->swizzle_log) +
|
||||
@ -71,9 +89,10 @@ template <typename T,
|
||||
|
||||
A += transpose_a ? c_row : c_row * params->lda;
|
||||
B += transpose_b ? c_col * params->ldb : c_col;
|
||||
C += c_row * params->ldc + c_col * params->fdc;
|
||||
D += c_row * params->ldd + c_col;
|
||||
|
||||
C += c_row * addmm_params->ldc + c_col * addmm_params->fdc;
|
||||
|
||||
// Prepare threadgroup loading operations
|
||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
||||
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
||||
@ -83,7 +102,7 @@ template <typename T,
|
||||
|
||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
||||
|
||||
const Epilogue epilogue_op(params->alpha, params->beta);
|
||||
const Epilogue epilogue_op(addmm_params->alpha, addmm_params->beta);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MNK aligned loop
|
||||
@ -121,7 +140,7 @@ template <typename T,
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
mma_op.store_result(D, params->ldd, C, params->ldc, params->fdc, epilogue_op);
|
||||
mma_op.store_result(D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op);
|
||||
return;
|
||||
|
||||
}
|
||||
@ -145,7 +164,7 @@ template <typename T,
|
||||
leftover_bk,
|
||||
LoopAlignment<true, true, K_aligned>{});
|
||||
|
||||
mma_op.store_result(D, params->ldd, C, params->ldc, params->fdc, epilogue_op);
|
||||
mma_op.store_result(D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op);
|
||||
return;
|
||||
|
||||
} else if (tgp_bn == BN) {
|
||||
@ -163,7 +182,7 @@ template <typename T,
|
||||
|
||||
return mma_op.store_result_safe(
|
||||
D, params->ldd,
|
||||
C, params->ldc, params->fdc,
|
||||
C, addmm_params->ldc, addmm_params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op);
|
||||
|
||||
@ -182,7 +201,7 @@ template <typename T,
|
||||
|
||||
return mma_op.store_result_safe(
|
||||
D, params->ldd,
|
||||
C, params->ldc, params->fdc,
|
||||
C, addmm_params->ldc, addmm_params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op);
|
||||
|
||||
@ -201,7 +220,7 @@ template <typename T,
|
||||
|
||||
return mma_op.store_result_safe(
|
||||
D, params->ldd,
|
||||
C, params->ldc, params->fdc,
|
||||
C, addmm_params->ldc, addmm_params->fdc,
|
||||
short2(tgp_bn, tgp_bm),
|
||||
epilogue_op);
|
||||
}
|
||||
@ -219,7 +238,10 @@ template <typename T,
|
||||
const device itype *B [[buffer(1)]], \
|
||||
const device itype *C [[buffer(2)]], \
|
||||
device itype *D [[buffer(3)]], \
|
||||
const constant GEMMAddMMParams* params [[buffer(4)]], \
|
||||
const constant GEMMParams* gemm_params [[buffer(4)]], \
|
||||
const constant GEMMAddMMParams* params [[buffer(5)]], \
|
||||
const constant int* batch_shape [[buffer(6)]], \
|
||||
const constant size_t* batch_strides [[buffer(7)]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
|
@ -144,9 +144,9 @@ struct BlockMMA {
|
||||
}
|
||||
|
||||
/* Store results from simdgroup_matrix results into device memory */
|
||||
METAL_FUNC void store_result(device U* C, const int ldc) const {
|
||||
METAL_FUNC void store_result(device U* D, const int ldd) const {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm + tm) * ldc + tn + sn;
|
||||
D += (sm + tm) * ldd + tn + sn;
|
||||
|
||||
// Loop over all simdgroup tiles
|
||||
STEEL_PRAGMA_UNROLL
|
||||
@ -155,22 +155,22 @@ struct BlockMMA {
|
||||
for (short j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum = results[i * TN + j].thread_elements();
|
||||
int offset = (i * TM_stride) * ldc + (j * TN_stride);
|
||||
int offset = (i * TM_stride) * ldd + (j * TN_stride);
|
||||
|
||||
// Apply epilogue
|
||||
U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])};
|
||||
|
||||
// Write out C
|
||||
C[offset] = outs[0];
|
||||
C[offset + 1] = outs[1];
|
||||
// Write out D
|
||||
D[offset] = outs[0];
|
||||
D[offset + 1] = outs[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void
|
||||
store_result_safe(device U* C, const int ldc, short2 dst_tile_dims) const {
|
||||
store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) const {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm + tm) * ldc + (tn + sn);
|
||||
D += (sm + tm) * ldd + (tn + sn);
|
||||
dst_tile_dims -= short2(tn + sn, sm + tm);
|
||||
|
||||
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
||||
@ -183,15 +183,15 @@ struct BlockMMA {
|
||||
for (int j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in C
|
||||
thread const auto& accum = results[i * TN + j].thread_elements();
|
||||
int offset = (i * TM_stride) * ldc + (j * TN_stride);
|
||||
int offset = (i * TM_stride) * ldd + (j * TN_stride);
|
||||
|
||||
// Apply epilogue and output C
|
||||
if (j * TN_stride < dst_tile_dims.x) {
|
||||
C[offset] = Epilogue::apply(accum[0]);
|
||||
D[offset] = Epilogue::apply(accum[0]);
|
||||
}
|
||||
|
||||
if (j * TN_stride + 1 < dst_tile_dims.x) {
|
||||
C[offset + 1] = Epilogue::apply(accum[1]);
|
||||
D[offset + 1] = Epilogue::apply(accum[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -16,17 +16,19 @@ struct GEMMParams {
|
||||
|
||||
const int lda;
|
||||
const int ldb;
|
||||
const int ldc;
|
||||
const int ldd;
|
||||
|
||||
const int tiles_n;
|
||||
const int tiles_m;
|
||||
|
||||
const int batch_stride_a;
|
||||
const int batch_stride_b;
|
||||
const int batch_stride_c;
|
||||
const int batch_stride_d;
|
||||
|
||||
const int swizzle_log;
|
||||
const int gemm_k_iterations_aligned;
|
||||
|
||||
const int batch_ndim;
|
||||
};
|
||||
|
||||
struct GEMMSpiltKParams {
|
||||
@ -49,30 +51,13 @@ struct GEMMSpiltKParams {
|
||||
};
|
||||
|
||||
struct GEMMAddMMParams {
|
||||
const int M;
|
||||
const int N;
|
||||
const int K;
|
||||
|
||||
const int lda;
|
||||
const int ldb;
|
||||
const int ldc;
|
||||
const int ldd;
|
||||
const int fdc;
|
||||
|
||||
const int tiles_n;
|
||||
const int tiles_m;
|
||||
|
||||
const int batch_stride_a;
|
||||
const int batch_stride_b;
|
||||
const int batch_stride_c;
|
||||
const int batch_stride_d;
|
||||
|
||||
const int swizzle_log;
|
||||
const int gemm_k_iterations_aligned;
|
||||
|
||||
const float alpha;
|
||||
const float beta;
|
||||
|
||||
const int fdc;
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
|
@ -5,4 +5,41 @@
|
||||
#include <metal_stdlib>
|
||||
|
||||
#define STEEL_CONST static constant constexpr const
|
||||
#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
|
||||
#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
|
||||
|
||||
METAL_FUNC ulong2 elem_to_loc_broadcast(
|
||||
uint elem,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
int ndim) {
|
||||
ulong loc_a{0};
|
||||
ulong loc_b{0};
|
||||
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||
int pos_in_dim = (elem % shape[i]);
|
||||
elem /= shape[i];
|
||||
loc_a += pos_in_dim * a_strides[i];
|
||||
loc_b += pos_in_dim * b_strides[i];
|
||||
}
|
||||
return ulong2(loc_a, loc_b);
|
||||
}
|
||||
|
||||
METAL_FUNC ulong3 elem_to_loc_broadcast(
|
||||
uint elem,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const size_t* c_strides,
|
||||
int ndim) {
|
||||
ulong loc_a{0};
|
||||
ulong loc_b{0};
|
||||
ulong loc_c{0};
|
||||
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||
int pos_in_dim = (elem % shape[i]);
|
||||
elem /= shape[i];
|
||||
loc_a += pos_in_dim * a_strides[i];
|
||||
loc_b += pos_in_dim * b_strides[i];
|
||||
loc_c += pos_in_dim * c_strides[i];
|
||||
}
|
||||
return ulong3(loc_a, loc_b, loc_c);
|
||||
}
|
@ -191,6 +191,70 @@ inline void mps_matmul(
|
||||
});
|
||||
}
|
||||
|
||||
inline auto collapse_batches(const array& a, const array& b) {
|
||||
// Get and check the shape for the batched dims
|
||||
std::vector<int> A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||
std::vector<int> B_bshape{b.shape().begin(), b.shape().end() - 2};
|
||||
if (A_bshape != B_bshape) {
|
||||
std::ostringstream msg;
|
||||
msg << "[matmul] Got matrices with incorrectly broadcasted shapes: "
|
||||
<< "A " << a.shape() << ", B " << b.shape() << ".";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
std::vector<size_t> A_bstride{a.strides().begin(), a.strides().end() - 2};
|
||||
std::vector<size_t> B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||
|
||||
auto [batch_shape, batch_strides] =
|
||||
collapse_contiguous_dims(A_bshape, {A_bstride, B_bstride});
|
||||
|
||||
auto A_batch_stride = batch_strides[0];
|
||||
auto B_batch_stride = batch_strides[1];
|
||||
|
||||
if (batch_shape.empty()) {
|
||||
batch_shape.push_back(1);
|
||||
A_batch_stride.push_back(0);
|
||||
B_batch_stride.push_back(0);
|
||||
}
|
||||
|
||||
return std::make_tuple(batch_shape, A_batch_stride, B_batch_stride);
|
||||
}
|
||||
|
||||
inline auto collapse_batches(const array& a, const array& b, const array& c) {
|
||||
// Get and check the shape for the batched dims
|
||||
std::vector<int> A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||
std::vector<int> B_bshape{b.shape().begin(), b.shape().end() - 2};
|
||||
std::vector<int> C_bshape{c.shape().begin(), c.shape().end() - 2};
|
||||
if (A_bshape != B_bshape || A_bshape != C_bshape) {
|
||||
std::ostringstream msg;
|
||||
msg << "[addmm] Got matrices with incorrectly broadcasted shapes: "
|
||||
<< "A " << a.shape() << ", B " << b.shape() << ", B " << c.shape()
|
||||
<< ".";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
std::vector<size_t> A_bstride{a.strides().begin(), a.strides().end() - 2};
|
||||
std::vector<size_t> B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||
std::vector<size_t> C_bstride{c.strides().begin(), c.strides().end() - 2};
|
||||
|
||||
auto [batch_shape, batch_strides] =
|
||||
collapse_contiguous_dims(A_bshape, {A_bstride, B_bstride, C_bstride});
|
||||
|
||||
auto A_batch_stride = batch_strides[0];
|
||||
auto B_batch_stride = batch_strides[1];
|
||||
auto C_batch_stride = batch_strides[2];
|
||||
|
||||
if (batch_shape.empty()) {
|
||||
batch_shape.push_back(1);
|
||||
A_batch_stride.push_back(0);
|
||||
B_batch_stride.push_back(0);
|
||||
C_batch_stride.push_back(0);
|
||||
}
|
||||
|
||||
return std::make_tuple(
|
||||
batch_shape, A_batch_stride, B_batch_stride, C_batch_stride);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@ -211,22 +275,33 @@ void steel_matmul(
|
||||
int ldb,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
std::vector<array>& copies) {
|
||||
std::vector<array>& copies,
|
||||
std::vector<int> batch_shape /* = {} */,
|
||||
std::vector<size_t> A_batch_stride /* = {} */,
|
||||
std::vector<size_t> B_batch_stride /* = {} */) {
|
||||
using namespace mlx::steel;
|
||||
|
||||
// Coalesce (B, M, K) X (K, N) to (B*M, K) X (K, N)
|
||||
if (batch_size_out > 1 && !transpose_a &&
|
||||
a.data_size() == batch_size_out * M * K && b.size() == K * N) {
|
||||
M = M * batch_size_out;
|
||||
batch_size_out = 1;
|
||||
if (batch_shape.empty()) {
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Check and collapse batch dimensions
|
||||
auto [batch_shape_, A_bstride_, B_bstride_] = collapse_batches(a, b);
|
||||
|
||||
batch_shape = batch_shape_;
|
||||
A_batch_stride = A_bstride_;
|
||||
B_batch_stride = B_bstride_;
|
||||
// Collapse batches into M if needed
|
||||
if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 &&
|
||||
a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K &&
|
||||
B_batch_stride.back() == 0) {
|
||||
M *= batch_shape.back();
|
||||
batch_size_out = 1;
|
||||
|
||||
A_batch_stride = {0};
|
||||
B_batch_stride = {0};
|
||||
batch_shape = {1};
|
||||
}
|
||||
}
|
||||
|
||||
// Account for batch sizes and basic broadcasting
|
||||
int batch_size_a = a.data_size() / (M * K);
|
||||
int batch_size_b = b.data_size() / (K * N);
|
||||
|
||||
int matrix_stride_a = (batch_size_a == 1) ? 0 : M * K;
|
||||
int matrix_stride_b = (batch_size_b == 1) ? 0 : K * N;
|
||||
int matrix_stride_out = M * N;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
@ -269,18 +344,18 @@ void steel_matmul(
|
||||
int tm = (M + bm - 1) / bm;
|
||||
|
||||
GEMMSpiltKParams params{
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
lda,
|
||||
ldb,
|
||||
N,
|
||||
tn,
|
||||
tm,
|
||||
split_k_partitions,
|
||||
split_k_partition_stride,
|
||||
split_k_partition_size,
|
||||
gemm_k_iterations};
|
||||
/* const int M = */ M,
|
||||
/* const int N = */ N,
|
||||
/* const int K = */ K,
|
||||
/* const int lda = */ lda,
|
||||
/* const int ldb = */ ldb,
|
||||
/* const int ldc = */ N,
|
||||
/* const int tiles_n = */ tn,
|
||||
/* const int tiles_m = */ tm,
|
||||
/* const int split_k_partitions = */ split_k_partitions,
|
||||
/* const int split_k_partition_stride = */ split_k_partition_stride,
|
||||
/* const int split_k_partition_size = */ split_k_partition_size,
|
||||
/* const int gemm_k_iterations_aligned = */ gemm_k_iterations};
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions);
|
||||
@ -364,19 +439,20 @@ void steel_matmul(
|
||||
|
||||
// Prepare steel matmul params
|
||||
GEMMParams params{
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
lda,
|
||||
ldb,
|
||||
N,
|
||||
tn,
|
||||
tm,
|
||||
matrix_stride_a,
|
||||
matrix_stride_b,
|
||||
matrix_stride_out,
|
||||
swizzle_log,
|
||||
(K / bk)};
|
||||
/* const int M = */ M,
|
||||
/* const int N = */ N,
|
||||
/* const int K = */ K,
|
||||
/* const int lda = */ lda,
|
||||
/* const int ldb = */ ldb,
|
||||
/* const int ldd = */ N,
|
||||
/* const int tiles_n = */ tn,
|
||||
/* const int tiles_m = */ tm,
|
||||
/* const int batch_stride_a = */ int(A_batch_stride.back()),
|
||||
/* const int batch_stride_b = */ int(B_batch_stride.back()),
|
||||
/* const int batch_stride_d = */ matrix_stride_out,
|
||||
/* const int swizzle_log = */ swizzle_log,
|
||||
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||
/* const int batch_ndim = */ int(batch_shape.size())};
|
||||
|
||||
// Prepare launch grid params
|
||||
int tile = 1 << swizzle_log;
|
||||
@ -386,37 +462,25 @@ void steel_matmul(
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
|
||||
|
||||
// Launch only 1 kernel in the case of simple batching / broadcasting
|
||||
if (batch_size_out == std::max(batch_size_a, batch_size_b) &&
|
||||
(batch_size_a == batch_size_b ||
|
||||
std::min(batch_size_a, batch_size_b) == 1)) {
|
||||
set_array_buffer(compute_encoder, a, 0);
|
||||
set_array_buffer(compute_encoder, b, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
std::vector<size_t> batch_strides = A_batch_stride;
|
||||
batch_strides.insert(
|
||||
batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 3);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
} else { // Otherwise launch kernels with set offsets
|
||||
// Launch kernel
|
||||
set_array_buffer(compute_encoder, a, 0);
|
||||
set_array_buffer(compute_encoder, b, 1);
|
||||
set_array_buffer(compute_encoder, out, 3);
|
||||
|
||||
MTL::Size grid_dims_single = MTL::Size(tn, tm, 1);
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 4);
|
||||
|
||||
for (int i = 0; i < batch_size_out; ++i) {
|
||||
auto a_off = elem_to_loc(M * K * i, a.shape(), a.strides());
|
||||
auto b_off = elem_to_loc(K * N * i, b.shape(), b.strides());
|
||||
compute_encoder->setBytes(
|
||||
batch_shape.data(), sizeof(int) * batch_shape.size(), 6);
|
||||
compute_encoder->setBytes(
|
||||
batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7);
|
||||
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto b_buf = static_cast<const MTL::Buffer*>(b.buffer().ptr());
|
||||
auto out_buf = static_cast<const MTL::Buffer*>(out.buffer().ptr());
|
||||
|
||||
compute_encoder->setBuffer(a_buf, a_off * a.itemsize(), 0);
|
||||
compute_encoder->setBuffer(b_buf, b_off * b.itemsize(), 1);
|
||||
compute_encoder->setBuffer(out_buf, i * M * N * out.itemsize(), 2);
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 3);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims_single, group_dims);
|
||||
}
|
||||
}
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
// Clear copies
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
return;
|
||||
@ -453,9 +517,9 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto check_transpose = [&copies, &s](const array& arr) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (stx == arr.shape(-1) && sty == 1) {
|
||||
if (sty == 1) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && sty == arr.shape(-2)) {
|
||||
} else if (stx == 1) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
@ -473,8 +537,25 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int N = b.shape(-1);
|
||||
int K = a.shape(-1);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Check and collapse batch dimensions
|
||||
|
||||
auto [batch_shape, A_batch_stride, B_batch_stride] = collapse_batches(a, b);
|
||||
|
||||
auto batch_size_out = out.size() / (M * N);
|
||||
|
||||
// Collapse batches into M if needed
|
||||
if (batch_size_out > 1 && !a_transposed && batch_shape.size() == 1 &&
|
||||
a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K &&
|
||||
B_batch_stride.back() == 0) {
|
||||
M *= batch_shape.back();
|
||||
batch_size_out = 1;
|
||||
|
||||
A_batch_stride = {0};
|
||||
B_batch_stride = {0};
|
||||
batch_shape = {1};
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Gemv specialization
|
||||
|
||||
@ -491,20 +572,18 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
int mat_cols = transpose_mat ? out_vector_len : in_vector_len;
|
||||
int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
|
||||
int mat_ld = is_b_matrix ? b_cols : a_cols;
|
||||
|
||||
int batch_size_mat = mat.data_size() / (mat_cols * mat_rows);
|
||||
int stride_mat = batch_size_mat == 1 ? 0 : mat_cols * mat_rows;
|
||||
auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride;
|
||||
auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride;
|
||||
|
||||
int batch_size_vec = vec.data_size() / in_vector_len;
|
||||
int stride_vec = batch_size_vec == 1 ? 0 : in_vector_len;
|
||||
int stride_mat = batch_strides_mat.back();
|
||||
int stride_vec = batch_strides_vec.back();
|
||||
|
||||
// Determine if inputs have simple batching / broadcasting
|
||||
bool contiguous_kernel =
|
||||
(batch_size_out == std::max(batch_size_mat, batch_size_vec) &&
|
||||
(batch_size_mat == batch_size_vec ||
|
||||
std::min(batch_size_mat, batch_size_vec) == 1));
|
||||
bool contiguous_kernel = (batch_shape.size() == 1);
|
||||
|
||||
int nc_dim = out.ndim() - 2;
|
||||
int batch_ndim = batch_shape.size();
|
||||
|
||||
// Determine dispatch kernel
|
||||
int tm = 4, tn = 4;
|
||||
@ -540,10 +619,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
kname << "_bm" << bm << "_bn" << bn << "_tm" << tm << "_tn" << tn;
|
||||
|
||||
if (!contiguous_kernel) {
|
||||
kname << "_nc";
|
||||
}
|
||||
kname << "_nc" << !contiguous_kernel << "_axpby0";
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
@ -556,25 +632,18 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
set_array_buffer(compute_encoder, mat, 0);
|
||||
set_array_buffer(compute_encoder, vec, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
set_array_buffer(compute_encoder, out, 3);
|
||||
|
||||
compute_encoder->setBytes(&in_vector_len, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&out_vector_len, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&mat_ld, sizeof(int), 6);
|
||||
|
||||
if (contiguous_kernel) {
|
||||
compute_encoder->setBytes(&stride_vec, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&stride_mat, sizeof(int), 6);
|
||||
} else {
|
||||
// In case of complex broadcasting, we consider the shape[:-2] and
|
||||
// strides [:-2] to determine the location of a batch
|
||||
// nc_dim = out.ndim() - 2
|
||||
compute_encoder->setBytes(&nc_dim, sizeof(int), 5);
|
||||
compute_encoder->setBytes(out.shape().data(), nc_dim * sizeof(int), 6);
|
||||
compute_encoder->setBytes(
|
||||
vec.strides().data(), nc_dim * sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(
|
||||
mat.strides().data(), nc_dim * sizeof(size_t), 8);
|
||||
}
|
||||
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
|
||||
compute_encoder->setBytes(batch_shape.data(), batch_ndim * sizeof(int), 10);
|
||||
compute_encoder->setBytes(
|
||||
batch_strides_vec.data(), batch_ndim * sizeof(size_t), 11);
|
||||
compute_encoder->setBytes(
|
||||
batch_strides_mat.data(), batch_ndim * sizeof(size_t), 12);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
@ -606,20 +675,23 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
return steel_matmul(
|
||||
s,
|
||||
d,
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
batch_size_out,
|
||||
a_cols,
|
||||
b_cols,
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
copies);
|
||||
/* const Stream& s = */ s,
|
||||
/* metal::Device& d = */ d,
|
||||
/* const array& a = */ a,
|
||||
/* const array& b = */ b,
|
||||
/* array& out = */ out,
|
||||
/* int M = */ M,
|
||||
/* int N = */ N,
|
||||
/* int K = */ K,
|
||||
/* int batch_size_out = */ batch_size_out,
|
||||
/* int lda = */ a_cols,
|
||||
/* int ldb = */ b_cols,
|
||||
/* bool transpose_a = */ a_transposed,
|
||||
/* bool transpose_b = */ b_transposed,
|
||||
/* std::vector<array>& = */ copies,
|
||||
/* std::vector<int> batch_shape = */ batch_shape,
|
||||
/* std::vector<size_t> A_batch_stride = */ A_batch_stride,
|
||||
/* std::vector<size_t> B_batch_stride = */ B_batch_stride);
|
||||
}
|
||||
|
||||
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@ -645,9 +717,9 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto check_transpose = [&copies, &s](const array& arr) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (stx == arr.shape(-1) && sty == 1) {
|
||||
if (sty == 1) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && sty == arr.shape(-2)) {
|
||||
} else if (stx == 1) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
@ -665,33 +737,151 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int N = b.shape(-1);
|
||||
int K = a.shape(-1);
|
||||
|
||||
auto batch_size_out = out.size() / (M * N);
|
||||
|
||||
array c = c_pre;
|
||||
int ldc = c.strides()[c.ndim() - 2];
|
||||
int fdc = c.strides()[c.ndim() - 1];
|
||||
int matrix_stride_c = c.ndim() <= 2 ? 0 : c.strides()[c.ndim() - 3];
|
||||
|
||||
int lda = a_cols;
|
||||
int ldb = b_cols;
|
||||
int ldd = N;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Check and collapse batch dimensions
|
||||
auto [batch_shape, A_batch_stride, B_batch_stride, C_batch_stride] =
|
||||
collapse_batches(a, b, c);
|
||||
|
||||
auto batch_size_out = out.size() / (M * N);
|
||||
|
||||
// Collapse batches into M if needed
|
||||
if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 &&
|
||||
a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K &&
|
||||
C_batch_stride.back() == M * c.strides()[c.ndim() - 2] &&
|
||||
B_batch_stride.back() == 0) {
|
||||
M *= batch_shape.back();
|
||||
batch_size_out = 1;
|
||||
|
||||
A_batch_stride = {0};
|
||||
B_batch_stride = {0};
|
||||
C_batch_stride = {0};
|
||||
batch_shape = {1};
|
||||
}
|
||||
|
||||
int matrix_stride_out = M * N;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Gemv specialization
|
||||
|
||||
// Route to gemv if needed
|
||||
if (std::min(M, N) == 1) {
|
||||
// Collect problem info
|
||||
bool is_b_matrix = N != 1;
|
||||
|
||||
auto& mat = is_b_matrix ? b : a;
|
||||
auto& vec = is_b_matrix ? a : b;
|
||||
bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a;
|
||||
int in_vector_len = K;
|
||||
int out_vector_len = is_b_matrix ? N : M;
|
||||
|
||||
int mat_cols = transpose_mat ? out_vector_len : in_vector_len;
|
||||
int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
|
||||
int mat_ld = is_b_matrix ? b_cols : a_cols;
|
||||
|
||||
auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride;
|
||||
auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride;
|
||||
|
||||
int stride_mat = batch_strides_mat.back();
|
||||
int stride_vec = batch_strides_vec.back();
|
||||
|
||||
// Determine if inputs have simple batching / broadcasting
|
||||
bool contiguous_kernel = (batch_shape.size() == 1);
|
||||
|
||||
int batch_ndim = batch_shape.size();
|
||||
|
||||
// Determine dispatch kernel
|
||||
int tm = 4, tn = 4;
|
||||
int bm, bn, n_out_per_tgp;
|
||||
std::ostringstream kname;
|
||||
|
||||
if (transpose_mat) {
|
||||
bm = 8;
|
||||
bn = 8;
|
||||
if (out_vector_len >= 24576) {
|
||||
bn = 128;
|
||||
} else if (out_vector_len >= 16384) {
|
||||
bn = 64;
|
||||
} else if (out_vector_len >= 8192) {
|
||||
bn = 16;
|
||||
}
|
||||
|
||||
// Specialized kernel for very small outputs
|
||||
tn = out_vector_len < tn ? 1 : tn;
|
||||
|
||||
n_out_per_tgp = bn * tn;
|
||||
kname << "gemv_t_" << type_to_name(out);
|
||||
|
||||
} else {
|
||||
bm = out_vector_len >= 4096 ? 8 : 4;
|
||||
bn = 32;
|
||||
|
||||
// Specialized kernel for very small outputs
|
||||
tm = out_vector_len < tm ? 1 : tm;
|
||||
|
||||
n_out_per_tgp = bm * tm;
|
||||
kname << "gemv_" << type_to_name(out);
|
||||
}
|
||||
|
||||
kname << "_bm" << bm << "_bn" << bn << "_tm" << tm << "_tn" << tn;
|
||||
kname << "_nc" << !contiguous_kernel << "_axpby1";
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
||||
MTL::Size group_dims = MTL::Size(bn, bm, 1);
|
||||
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
|
||||
|
||||
set_array_buffer(compute_encoder, mat, 0);
|
||||
set_array_buffer(compute_encoder, vec, 1);
|
||||
set_array_buffer(compute_encoder, c, 2);
|
||||
set_array_buffer(compute_encoder, out, 3);
|
||||
|
||||
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&mat_ld, sizeof(int), 6);
|
||||
|
||||
compute_encoder->setBytes(&alpha_, sizeof(float), 7);
|
||||
compute_encoder->setBytes(&beta_, sizeof(float), 8);
|
||||
|
||||
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
|
||||
compute_encoder->setBytes(batch_shape.data(), batch_ndim * sizeof(int), 10);
|
||||
compute_encoder->setBytes(
|
||||
batch_strides_vec.data(), batch_ndim * sizeof(size_t), 11);
|
||||
compute_encoder->setBytes(
|
||||
batch_strides_mat.data(), batch_ndim * sizeof(size_t), 12);
|
||||
compute_encoder->setBytes(
|
||||
C_batch_stride.data(), batch_ndim * sizeof(size_t), 13);
|
||||
|
||||
int bias_stride = c.strides()[c.ndim() - 1];
|
||||
compute_encoder->setBytes(&bias_stride, sizeof(int), 14);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
return;
|
||||
}
|
||||
|
||||
using namespace mlx::steel;
|
||||
|
||||
// Account for batch sizes and basic broadcasting
|
||||
int batch_size_a = a.data_size() / (M * K);
|
||||
int batch_size_b = b.data_size() / (K * N);
|
||||
|
||||
int matrix_stride_a = (batch_size_a == 1) ? 0 : M * K;
|
||||
int matrix_stride_b = (batch_size_b == 1) ? 0 : K * N;
|
||||
int matrix_stride_out = M * N;
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Split K specialization
|
||||
|
||||
int _tm = M / 16;
|
||||
int _tn = N / 16;
|
||||
int _tk = K / 16;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Split K specialization
|
||||
|
||||
if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) {
|
||||
int bm = M < 40 ? 16 : 32;
|
||||
int bn = N < 40 ? 16 : 32;
|
||||
@ -817,25 +1007,29 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// TODO: Explore device-based tuning for swizzle
|
||||
int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);
|
||||
|
||||
// Prepare steel matmul params
|
||||
GEMMParams gemm_params{
|
||||
/* const int M = */ M,
|
||||
/* const int N = */ N,
|
||||
/* const int K = */ K,
|
||||
/* const int lda = */ lda,
|
||||
/* const int ldb = */ ldb,
|
||||
/* const int ldd = */ N,
|
||||
/* const int tiles_n = */ tn,
|
||||
/* const int tiles_m = */ tm,
|
||||
/* const int batch_stride_a = */ int(A_batch_stride.back()),
|
||||
/* const int batch_stride_b = */ int(B_batch_stride.back()),
|
||||
/* const int batch_stride_d = */ matrix_stride_out,
|
||||
/* const int swizzle_log = */ swizzle_log,
|
||||
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
||||
/* const int batch_ndim = */ int(batch_shape.size())};
|
||||
|
||||
GEMMAddMMParams params{
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
N,
|
||||
tn,
|
||||
tm,
|
||||
matrix_stride_a,
|
||||
matrix_stride_b,
|
||||
matrix_stride_c,
|
||||
matrix_stride_out,
|
||||
swizzle_log,
|
||||
(K / bk),
|
||||
alpha_,
|
||||
beta_,
|
||||
fdc};
|
||||
/* const int ldc = */ ldc,
|
||||
/* const int fdc = */ fdc,
|
||||
/* const int batch_stride_c = */ int(C_batch_stride.back()),
|
||||
/* const float alpha = */ alpha_,
|
||||
/* const float beta = */ beta_};
|
||||
|
||||
int tile = 1 << swizzle_log;
|
||||
tm = (tm + tile - 1) / tile;
|
||||
@ -844,40 +1038,27 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
|
||||
|
||||
// Launch only 1 kernel in the case of simple batching / broadcasting
|
||||
if (batch_size_out == std::max(batch_size_a, batch_size_b) &&
|
||||
(batch_size_a == batch_size_b ||
|
||||
std::min(batch_size_a, batch_size_b) == 1)) {
|
||||
set_array_buffer(compute_encoder, a, 0);
|
||||
set_array_buffer(compute_encoder, b, 1);
|
||||
set_array_buffer(compute_encoder, c, 2);
|
||||
set_array_buffer(compute_encoder, out, 3);
|
||||
std::vector<size_t> batch_strides = A_batch_stride;
|
||||
batch_strides.insert(
|
||||
batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
|
||||
batch_strides.insert(
|
||||
batch_strides.end(), C_batch_stride.begin(), C_batch_stride.end());
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMAddMMParams), 4);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
} else { // Otherwise launch kernels with set offsets
|
||||
// Launch kernel
|
||||
set_array_buffer(compute_encoder, a, 0);
|
||||
set_array_buffer(compute_encoder, b, 1);
|
||||
set_array_buffer(compute_encoder, c, 2);
|
||||
set_array_buffer(compute_encoder, out, 3);
|
||||
|
||||
MTL::Size grid_dims_single = MTL::Size(tn, tm, 1);
|
||||
compute_encoder->setBytes(&gemm_params, sizeof(GEMMParams), 4);
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMAddMMParams), 5);
|
||||
|
||||
for (int i = 0; i < batch_size_out; ++i) {
|
||||
auto a_off = elem_to_loc(M * K * i, a.shape(), a.strides());
|
||||
auto b_off = elem_to_loc(K * N * i, b.shape(), b.strides());
|
||||
auto c_off = elem_to_loc(M * N * i, c.shape(), c.strides());
|
||||
compute_encoder->setBytes(
|
||||
batch_shape.data(), sizeof(int) * batch_shape.size(), 6);
|
||||
compute_encoder->setBytes(
|
||||
batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7);
|
||||
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto b_buf = static_cast<const MTL::Buffer*>(b.buffer().ptr());
|
||||
auto c_buf = static_cast<const MTL::Buffer*>(c.buffer().ptr());
|
||||
auto out_buf = static_cast<const MTL::Buffer*>(out.buffer().ptr());
|
||||
|
||||
compute_encoder->setBuffer(a_buf, a_off * a.itemsize(), 0);
|
||||
compute_encoder->setBuffer(b_buf, b_off * b.itemsize(), 1);
|
||||
compute_encoder->setBuffer(c_buf, c_off * c.itemsize(), 2);
|
||||
compute_encoder->setBuffer(out_buf, i * M * N * out.itemsize(), 3);
|
||||
|
||||
compute_encoder->setBytes(¶ms, sizeof(GEMMAddMMParams), 4);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims_single, group_dims);
|
||||
}
|
||||
}
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
|
@ -26,6 +26,9 @@ void steel_matmul(
|
||||
int ldb,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
std::vector<array>& copies);
|
||||
std::vector<array>& copies,
|
||||
std::vector<int> batch_shape = {},
|
||||
std::vector<size_t> A_batch_stride = {},
|
||||
std::vector<size_t> B_batch_stride = {});
|
||||
|
||||
} // namespace mlx::core
|
57
mlx/ops.cpp
57
mlx/ops.cpp
@ -3315,17 +3315,8 @@ array addmm(
|
||||
const float& alpha /* = 1.f */,
|
||||
const float& beta /* = 1.f */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
// Divert in the case of vector-matrix multiplication
|
||||
// TODO: Add the needed specializtion
|
||||
if (a.ndim() == 1 || b.ndim() == 1) {
|
||||
array X = matmul(a, b, s);
|
||||
array alpha_arr = array(alpha, X.dtype());
|
||||
array aX = multiply(alpha_arr, X, s);
|
||||
|
||||
array beta_arr = array(beta, c.dtype());
|
||||
array bY = multiply(beta_arr, c, s);
|
||||
return add(aX, bY, s);
|
||||
}
|
||||
int in_a_ndim = a.ndim();
|
||||
int in_b_ndim = b.ndim();
|
||||
|
||||
if (a.ndim() == 0 || b.ndim() == 0) {
|
||||
throw std::invalid_argument(
|
||||
@ -3333,6 +3324,15 @@ array addmm(
|
||||
"have at least one dimension.");
|
||||
}
|
||||
|
||||
if (a.ndim() == 1) {
|
||||
// Insert a singleton dim in the beginning
|
||||
a = reshape(a, {1, -1}, s);
|
||||
}
|
||||
if (b.ndim() == 1) {
|
||||
// Insert a singleton dim at the end
|
||||
b = reshape(b, {-1, 1}, s);
|
||||
}
|
||||
|
||||
if (a.shape(-1) != b.shape(-2)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[addmm] Last dimension of first input with shape " << a.shape()
|
||||
@ -3361,7 +3361,13 @@ array addmm(
|
||||
std::vector<int> out_shape = a.shape();
|
||||
a = reshape(a, {-1, out_shape.back()}, s);
|
||||
out_shape.back() = b.shape(-1);
|
||||
|
||||
if (in_b_ndim == 1) {
|
||||
out_shape.pop_back();
|
||||
}
|
||||
|
||||
c = broadcast_to(c, {a.shape(0), b.shape(1)}, s);
|
||||
|
||||
auto out = array(
|
||||
{a.shape(0), b.shape(1)},
|
||||
out_type,
|
||||
@ -3389,15 +3395,42 @@ array addmm(
|
||||
auto out_shape = a.shape();
|
||||
out_shape.back() = b.shape(-1);
|
||||
|
||||
auto c_broadcast_shape = broadcast_shapes(c.shape(), out_shape);
|
||||
auto out_shape_adjusted = out_shape;
|
||||
|
||||
if (in_a_ndim == 1 || in_b_ndim == 1) {
|
||||
out_shape_adjusted.erase(
|
||||
out_shape_adjusted.end() - ((in_a_ndim == 1) ? 2 : 1),
|
||||
out_shape_adjusted.end() - ((in_b_ndim == 1) ? 0 : 1));
|
||||
}
|
||||
|
||||
auto c_broadcast_shape = broadcast_shapes(c.shape(), out_shape_adjusted);
|
||||
c = broadcast_to(c, c_broadcast_shape, s);
|
||||
|
||||
if (in_a_ndim == 1 || in_b_ndim == 1) {
|
||||
auto c_reshape = c.shape();
|
||||
if (in_b_ndim == 1) {
|
||||
c_reshape.push_back(1);
|
||||
}
|
||||
|
||||
if (in_a_ndim == 1) {
|
||||
c_reshape.push_back(c_reshape.back());
|
||||
c_reshape[c_reshape.size() - 2] = 1;
|
||||
}
|
||||
|
||||
c = reshape(c, c_reshape, s);
|
||||
}
|
||||
|
||||
auto out = array(
|
||||
out_shape,
|
||||
out_type,
|
||||
std::make_unique<AddMM>(to_stream(s), alpha, beta),
|
||||
{a, b, c});
|
||||
|
||||
// Remove the possibly inserted singleton dimensions
|
||||
if (in_a_ndim == 1 || in_b_ndim == 1) {
|
||||
out = reshape(out, out_shape_adjusted, s);
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
|
@ -393,6 +393,77 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
mlx_vec_f=lambda vec_mlx: mx.broadcast_to(vec_mlx, shape_vec),
|
||||
)
|
||||
|
||||
def test_matrix_vector_attn(self):
|
||||
# Multi-query style attention check
|
||||
for dtype in self.dtypes:
|
||||
# fmt: off
|
||||
for (B, D, n_kv_heads, factor, qsl, ksl) in (
|
||||
(1, 16, 8, 4, 1, 256),
|
||||
(1, 16, 8, 4, 32, 256),
|
||||
(1, 16, 8, 4, 256, 1),
|
||||
(4, 16, 8, 4, 1, 256),
|
||||
(4, 16, 8, 4, 256, 1),
|
||||
):
|
||||
# fmt: on
|
||||
with self.subTest(
|
||||
B=B, # Batch size
|
||||
D=D, # Dimension of mm
|
||||
n_kv_heads=n_kv_heads, # key-value heads
|
||||
factor=factor, # factor to get query heads
|
||||
qsl=qsl, # Query sequence length
|
||||
ksl=ksl, # Key sequence length
|
||||
dtype=dtype # Data type
|
||||
):
|
||||
|
||||
np_dtype = getattr(np, dtype)
|
||||
|
||||
# Fix shapes for kqv
|
||||
n_q_heads = n_kv_heads * factor
|
||||
Dk = D * n_kv_heads
|
||||
Dq = D * n_q_heads
|
||||
scale = 1. / math.sqrt(Dk)
|
||||
|
||||
shape_queries = (B, qsl, Dq)
|
||||
shape_keys = (B, ksl, Dk)
|
||||
shape_values = (B, ksl, Dk)
|
||||
|
||||
# Prepare numpy arrays
|
||||
q_np = np.random.uniform(-scale, scale, size=shape_queries).astype(np_dtype)
|
||||
k_np = np.random.uniform(-scale, scale, size=shape_keys).astype(np_dtype)
|
||||
v_np = np.random.uniform(-scale, scale, size=shape_values).astype(np_dtype)
|
||||
|
||||
# Rearrange to move heads up
|
||||
q_np_reshape = q_np.reshape(B, qsl, n_kv_heads, factor, -1).transpose(0, 2, 3, 1, 4)
|
||||
k_np_reshape = k_np.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 4, 1)
|
||||
v_np_reshape = v_np.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 1, 4)
|
||||
|
||||
# Do attn style matmul
|
||||
s_np = q_np_reshape @ k_np_reshape
|
||||
o_np = s_np @ v_np_reshape
|
||||
o_np = o_np.transpose(0, 3, 1, 2, 4).reshape(B, qsl, -1)
|
||||
|
||||
# Test mlx
|
||||
q_mx = mx.array(q_np)
|
||||
k_mx = mx.array(k_np)
|
||||
v_mx = mx.array(v_np)
|
||||
|
||||
# Rearrange to move heads up
|
||||
q_mx_reshape = q_mx.reshape(B, qsl, n_kv_heads, factor, -1).transpose(0, 2, 3, 1, 4)
|
||||
k_mx_reshape = k_mx.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 4, 1)
|
||||
v_mx_reshape = v_mx.reshape(B, ksl, n_kv_heads, 1, -1).transpose(0, 2, 3, 1, 4)
|
||||
|
||||
# Do attn style matmul
|
||||
s_mx = q_mx_reshape @ k_mx_reshape
|
||||
o_mx = (s_mx @ v_mx_reshape)
|
||||
o_mx = o_mx.transpose(0, 3, 1, 2, 4).reshape(B, qsl, -1)
|
||||
|
||||
# Check against np
|
||||
self.assertListEqual(list(s_np.shape), list(s_mx.shape))
|
||||
self.assertTrue(np.allclose(s_np, s_mx, atol=1e-4))
|
||||
|
||||
self.assertListEqual(list(o_np.shape), list(o_mx.shape))
|
||||
self.assertTrue(np.allclose(o_np, o_mx, atol=1e-4))
|
||||
|
||||
def test_matrix_vector_edgecases(self):
|
||||
for dtype in self.dtypes:
|
||||
with self.subTest(dtype=dtype):
|
||||
@ -503,16 +574,29 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||
|
||||
# Matmul with vector
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 128)).astype(np.float32)
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
|
||||
for c_shape in ((1,), (128,), (32, 128)):
|
||||
c_npy = np.ones(c_shape).astype(np.float32)
|
||||
c_mlx = mx.array(c_npy)
|
||||
|
||||
d_npy = alpha * (a_npy @ b_npy) + beta * c_npy
|
||||
d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta)
|
||||
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||
|
||||
# Matmul with vector
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
|
||||
for c_shape in (
|
||||
(1,),
|
||||
(32, 128),
|
||||
):
|
||||
for c_shape in ((1,), (32, 128)):
|
||||
c_npy = np.ones(c_shape).astype(np.float32)
|
||||
c_mlx = mx.array(c_npy)
|
||||
|
||||
@ -564,16 +648,12 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
out_ref, dout_ref = mx.vjp(
|
||||
f_ref,
|
||||
[c, a, b],
|
||||
[
|
||||
cotan,
|
||||
],
|
||||
[cotan],
|
||||
)
|
||||
out_test, dout_test = mx.vjp(
|
||||
f_test,
|
||||
[c, a, b],
|
||||
[
|
||||
cotan,
|
||||
],
|
||||
[cotan],
|
||||
)
|
||||
|
||||
self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-4).item())
|
||||
|
Loading…
Reference in New Issue
Block a user