Refactor gemv into a function

This commit is contained in:
Jagrit Digani 2025-06-10 16:53:23 -07:00
parent c8b4787e4e
commit 0c69f10d55

View File

@ -470,6 +470,178 @@ void steel_matmul(
copies);
}
template <bool CHECK_AB = true>
void gemv_axbpy(
const Stream& s,
metal::Device& d,
const array& a,
const array& b,
const array& c,
array& out,
int M,
int N,
int K,
int batch_size_out,
int lda,
int ldb,
bool transpose_a,
bool transpose_b,
std::vector<array>& copies,
Shape batch_shape = {},
Strides A_batch_stride = {},
Strides B_batch_stride = {},
Strides C_batch_stride = {},
float alpha = 1.0f,
float beta = 0.0f) {
// Collect problem info
bool is_b_matrix = N != 1;
auto& mat = is_b_matrix ? b : a;
auto& vec = is_b_matrix ? a : b;
bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a;
int in_vector_len = K;
int out_vector_len = is_b_matrix ? N : M;
int mat_cols = transpose_mat ? out_vector_len : in_vector_len;
int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
int mat_ld = is_b_matrix ? ldb : lda;
auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride;
auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride;
int stride_mat = batch_strides_mat.back();
int stride_vec = batch_strides_vec.back();
// Determine if inputs have simple batching / broadcasting
bool contiguous_kernel = (batch_shape.size() == 1);
int batch_ndim = batch_shape.size();
// Determine dispatch kernel
int tm = 4, tn = 4;
int sm = 1, sn = 32;
int bm = 1, bn = 1;
int n_out_per_tgp;
std::ostringstream kname;
if (transpose_mat) {
if (in_vector_len >= 8192 && out_vector_len >= 2048) {
sm = 4;
sn = 8;
} else {
sm = 8;
sn = 4;
}
if (out_vector_len >= 2048) {
bn = 16;
} else if (out_vector_len >= 512) {
bn = 4;
} else {
bn = 2;
}
// Specialized kernel for very small outputs
tn = out_vector_len < tn ? 1 : tn;
n_out_per_tgp = bn * sn * tn;
kname << "gemv_t_" << type_to_name(out);
} else {
bm = out_vector_len >= 4096 ? 8 : 4;
sn = 32;
// Specialized kernel for very small outputs
tm = out_vector_len < tm ? 1 : tm;
n_out_per_tgp = bm * sm * tm;
kname << "gemv_" << type_to_name(out);
}
const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f);
kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm"
<< tm << "_tn" << tn;
kname << "_nc" << !contiguous_kernel << "_axpby" << do_axpby;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel);
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
MTL::Size group_dims = MTL::Size(32, bn, bm);
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
compute_encoder.set_input_array(mat, 0);
compute_encoder.set_input_array(vec, 1);
compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(in_vector_len, 4);
compute_encoder.set_bytes(out_vector_len, 5);
compute_encoder.set_bytes(mat_ld, 6);
compute_encoder.set_bytes(batch_ndim, 9);
compute_encoder.set_vector_bytes(batch_shape, 10);
compute_encoder.set_vector_bytes(batch_strides_vec, 11);
compute_encoder.set_vector_bytes(batch_strides_mat, 12);
if (do_axpby) {
compute_encoder.set_input_array(c, 2);
compute_encoder.set_bytes(alpha, 7);
compute_encoder.set_bytes(beta, 8);
compute_encoder.set_vector_bytes(C_batch_stride, 13);
int bias_stride = c.strides()[c.ndim() - 1];
compute_encoder.set_bytes(bias_stride, 14);
}
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index);
}
inline void gemv(
const Stream& s,
metal::Device& d,
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
int batch_size_out,
int lda,
int ldb,
bool transpose_a,
bool transpose_b,
std::vector<array>& copies,
Shape batch_shape = {},
Strides A_batch_stride = {},
Strides B_batch_stride = {}) {
return gemv_axbpy<false>(
/* const Stream& s = */ s,
/* metal::Device& d = */ d,
/* const array& a = */ a,
/* const array& b = */ b,
/* const array& c = */ b,
/* array& out = */ out,
/* int M = */ M,
/* int N = */ N,
/* int K = */ K,
/* int batch_size_out = */ batch_size_out,
/* int lda = */ lda,
/* int ldb = */ ldb,
/* bool transpose_a = */ transpose_a,
/* bool transpose_b = */ transpose_b,
/* std::vector<array>& copies = */ copies,
/* Shape batch_shape = */ batch_shape,
/* Strides A_batch_stride = */ A_batch_stride,
/* Strides B_batch_stride = */ B_batch_stride);
}
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (!issubdtype(out.dtype(), floating)) {
@ -528,102 +700,26 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
// Route to gemv if needed
if (std::min(M, N) == 1) {
// Collect problem info
bool is_b_matrix = N != 1;
auto& mat = is_b_matrix ? b : a;
auto& vec = is_b_matrix ? a : b;
bool transpose_mat = is_b_matrix ? !b_transposed : a_transposed;
int in_vector_len = K;
int out_vector_len = is_b_matrix ? N : M;
int mat_cols = transpose_mat ? out_vector_len : in_vector_len;
int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
int mat_ld = is_b_matrix ? b_cols : a_cols;
auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride;
auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride;
int stride_mat = batch_strides_mat.back();
int stride_vec = batch_strides_vec.back();
// Determine if inputs have simple batching / broadcasting
bool contiguous_kernel = (batch_shape.size() == 1);
int batch_ndim = batch_shape.size();
// Determine dispatch kernel
int tm = 4, tn = 4;
int sm = 1, sn = 32;
int bm = 1, bn = 1;
int n_out_per_tgp;
std::ostringstream kname;
if (transpose_mat) {
if (in_vector_len >= 8192 && out_vector_len >= 2048) {
sm = 4;
sn = 8;
} else {
sm = 8;
sn = 4;
}
if (out_vector_len >= 2048) {
bn = 16;
} else if (out_vector_len >= 512) {
bn = 4;
} else {
bn = 2;
}
// Specialized kernel for very small outputs
tn = out_vector_len < tn ? 1 : tn;
n_out_per_tgp = bn * sn * tn;
kname << "gemv_t_" << type_to_name(out);
} else {
bm = out_vector_len >= 4096 ? 8 : 4;
sn = 32;
// Specialized kernel for very small outputs
tm = out_vector_len < tm ? 1 : tm;
n_out_per_tgp = bm * sm * tm;
kname << "gemv_" << type_to_name(out);
}
kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm"
<< tm << "_tn" << tn;
kname << "_nc" << !contiguous_kernel << "_axpby0";
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel);
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
MTL::Size group_dims = MTL::Size(32, bn, bm);
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
compute_encoder.set_input_array(mat, 0);
compute_encoder.set_input_array(vec, 1);
compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(in_vector_len, 4);
compute_encoder.set_bytes(out_vector_len, 5);
compute_encoder.set_bytes(mat_ld, 6);
compute_encoder.set_bytes(batch_ndim, 9);
compute_encoder.set_vector_bytes(batch_shape, 10);
compute_encoder.set_vector_bytes(batch_strides_vec, 11);
compute_encoder.set_vector_bytes(batch_strides_mat, 12);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index);
return;
return gemv(
/* const Stream& s = */ s,
/* metal::Device& d = */ d,
/* const array& a = */ a,
/* const array& b = */ b,
/* array& out = */ out,
/* int M = */ M,
/* int N = */ N,
/* int K = */ K,
/* int batch_size_out = */ batch_size_out,
/* int lda = */ a_cols,
/* int ldb = */ b_cols,
/* bool transpose_a = */ a_transposed,
/* bool transpose_b = */ b_transposed,
/* std::vector<array>& copies = */ copies,
/* Shape batch_shape = */ batch_shape,
/* Strides A_batch_stride = */ A_batch_stride,
/* Strides B_batch_stride = */ B_batch_stride);
}
/////////////////////////////////////////////////////////////////////////////
// Gemm specialization
@ -641,7 +737,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
/* int ldb = */ b_cols,
/* bool transpose_a = */ a_transposed,
/* bool transpose_b = */ b_transposed,
/* std::vector<array>& = */ copies,
/* std::vector<array>& copies = */ copies,
/* Shape batch_shape = */ batch_shape,
/* Strides A_batch_stride = */ A_batch_stride,
/* Strides B_batch_stride = */ B_batch_stride);
@ -726,109 +822,28 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Route to gemv if needed
if (std::min(M, N) == 1) {
// Collect problem info
bool is_b_matrix = N != 1;
auto& mat = is_b_matrix ? b : a;
auto& vec = is_b_matrix ? a : b;
bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a;
int in_vector_len = K;
int out_vector_len = is_b_matrix ? N : M;
int mat_cols = transpose_mat ? out_vector_len : in_vector_len;
int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
int mat_ld = is_b_matrix ? b_cols : a_cols;
auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride;
auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride;
int stride_mat = batch_strides_mat.back();
int stride_vec = batch_strides_vec.back();
// Determine if inputs have simple batching / broadcasting
bool contiguous_kernel = (batch_shape.size() == 1);
int batch_ndim = batch_shape.size();
// Determine dispatch kernel
int tm = 4, tn = 4;
int sm = 1, sn = 32;
int bm = 1, bn = 1;
int n_out_per_tgp;
std::ostringstream kname;
if (transpose_mat) {
if (in_vector_len >= 8192 && out_vector_len >= 2048) {
sm = 4;
sn = 8;
} else {
sm = 8;
sn = 4;
}
if (out_vector_len >= 2048) {
bn = 16;
} else if (out_vector_len >= 512) {
bn = 4;
} else {
bn = 2;
}
// Specialized kernel for very small outputs
tn = out_vector_len < tn ? 1 : tn;
n_out_per_tgp = bn * sn * tn;
kname << "gemv_t_" << type_to_name(out);
} else {
bm = out_vector_len >= 4096 ? 8 : 4;
sn = 32;
// Specialized kernel for very small outputs
tm = out_vector_len < tm ? 1 : tm;
n_out_per_tgp = bm * sm * tm;
kname << "gemv_" << type_to_name(out);
}
kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm"
<< tm << "_tn" << tn;
kname << "_nc" << !contiguous_kernel << "_axpby1";
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder.set_compute_pipeline_state(kernel);
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
MTL::Size group_dims = MTL::Size(32, bn, bm);
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
compute_encoder.set_input_array(mat, 0);
compute_encoder.set_input_array(vec, 1);
compute_encoder.set_input_array(c, 2);
compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(in_vector_len, 4);
compute_encoder.set_bytes(out_vector_len, 5);
compute_encoder.set_bytes(mat_ld, 6);
compute_encoder.set_bytes(alpha_, 7);
compute_encoder.set_bytes(beta_, 8);
compute_encoder.set_bytes(batch_ndim, 9);
compute_encoder.set_vector_bytes(batch_shape, 10);
compute_encoder.set_vector_bytes(batch_strides_vec, 11);
compute_encoder.set_vector_bytes(batch_strides_mat, 12);
compute_encoder.set_vector_bytes(C_batch_stride, 13);
int bias_stride = c.strides()[c.ndim() - 1];
compute_encoder.set_bytes(bias_stride, 14);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index);
return;
return gemv_axbpy(
/* const Stream& s = */ s,
/* metal::Device& d = */ d,
/* const array& a = */ a,
/* const array& b = */ b,
/* const array& c = */ c,
/* array& out = */ out,
/* int M = */ M,
/* int N = */ N,
/* int K = */ K,
/* int batch_size_out = */ batch_size_out,
/* int lda = */ lda,
/* int ldb = */ ldb,
/* bool transpose_a = */ transpose_a,
/* bool transpose_b = */ transpose_b,
/* std::vector<array>& copies = */ copies,
/* Shape batch_shape = */ batch_shape,
/* Strides A_batch_stride = */ A_batch_stride,
/* Strides B_batch_stride = */ B_batch_stride,
/* Strides C_batch_stride = */ C_batch_stride,
/* float alpha = */ alpha_,
/* float beta = */ beta_);
}
using namespace mlx::steel;