mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Refactor gemv into a function
This commit is contained in:
parent
c8b4787e4e
commit
0c69f10d55
@ -470,6 +470,178 @@ void steel_matmul(
|
||||
copies);
|
||||
}
|
||||
|
||||
template <bool CHECK_AB = true>
|
||||
void gemv_axbpy(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
array& out,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int batch_size_out,
|
||||
int lda,
|
||||
int ldb,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
std::vector<array>& copies,
|
||||
Shape batch_shape = {},
|
||||
Strides A_batch_stride = {},
|
||||
Strides B_batch_stride = {},
|
||||
Strides C_batch_stride = {},
|
||||
float alpha = 1.0f,
|
||||
float beta = 0.0f) {
|
||||
// Collect problem info
|
||||
bool is_b_matrix = N != 1;
|
||||
|
||||
auto& mat = is_b_matrix ? b : a;
|
||||
auto& vec = is_b_matrix ? a : b;
|
||||
bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a;
|
||||
int in_vector_len = K;
|
||||
int out_vector_len = is_b_matrix ? N : M;
|
||||
|
||||
int mat_cols = transpose_mat ? out_vector_len : in_vector_len;
|
||||
int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
|
||||
int mat_ld = is_b_matrix ? ldb : lda;
|
||||
|
||||
auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride;
|
||||
auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride;
|
||||
|
||||
int stride_mat = batch_strides_mat.back();
|
||||
int stride_vec = batch_strides_vec.back();
|
||||
|
||||
// Determine if inputs have simple batching / broadcasting
|
||||
bool contiguous_kernel = (batch_shape.size() == 1);
|
||||
|
||||
int batch_ndim = batch_shape.size();
|
||||
|
||||
// Determine dispatch kernel
|
||||
int tm = 4, tn = 4;
|
||||
int sm = 1, sn = 32;
|
||||
int bm = 1, bn = 1;
|
||||
int n_out_per_tgp;
|
||||
std::ostringstream kname;
|
||||
|
||||
if (transpose_mat) {
|
||||
if (in_vector_len >= 8192 && out_vector_len >= 2048) {
|
||||
sm = 4;
|
||||
sn = 8;
|
||||
} else {
|
||||
sm = 8;
|
||||
sn = 4;
|
||||
}
|
||||
|
||||
if (out_vector_len >= 2048) {
|
||||
bn = 16;
|
||||
} else if (out_vector_len >= 512) {
|
||||
bn = 4;
|
||||
} else {
|
||||
bn = 2;
|
||||
}
|
||||
|
||||
// Specialized kernel for very small outputs
|
||||
tn = out_vector_len < tn ? 1 : tn;
|
||||
|
||||
n_out_per_tgp = bn * sn * tn;
|
||||
kname << "gemv_t_" << type_to_name(out);
|
||||
|
||||
} else {
|
||||
bm = out_vector_len >= 4096 ? 8 : 4;
|
||||
sn = 32;
|
||||
|
||||
// Specialized kernel for very small outputs
|
||||
tm = out_vector_len < tm ? 1 : tm;
|
||||
|
||||
n_out_per_tgp = bm * sm * tm;
|
||||
kname << "gemv_" << type_to_name(out);
|
||||
}
|
||||
|
||||
const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f);
|
||||
|
||||
kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm"
|
||||
<< tm << "_tn" << tn;
|
||||
kname << "_nc" << !contiguous_kernel << "_axpby" << do_axpby;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
||||
MTL::Size group_dims = MTL::Size(32, bn, bm);
|
||||
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
|
||||
|
||||
compute_encoder.set_input_array(mat, 0);
|
||||
compute_encoder.set_input_array(vec, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder.set_bytes(in_vector_len, 4);
|
||||
compute_encoder.set_bytes(out_vector_len, 5);
|
||||
compute_encoder.set_bytes(mat_ld, 6);
|
||||
|
||||
compute_encoder.set_bytes(batch_ndim, 9);
|
||||
compute_encoder.set_vector_bytes(batch_shape, 10);
|
||||
compute_encoder.set_vector_bytes(batch_strides_vec, 11);
|
||||
compute_encoder.set_vector_bytes(batch_strides_mat, 12);
|
||||
|
||||
if (do_axpby) {
|
||||
compute_encoder.set_input_array(c, 2);
|
||||
|
||||
compute_encoder.set_bytes(alpha, 7);
|
||||
compute_encoder.set_bytes(beta, 8);
|
||||
|
||||
compute_encoder.set_vector_bytes(C_batch_stride, 13);
|
||||
|
||||
int bias_stride = c.strides()[c.ndim() - 1];
|
||||
compute_encoder.set_bytes(bias_stride, 14);
|
||||
}
|
||||
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
inline void gemv(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int batch_size_out,
|
||||
int lda,
|
||||
int ldb,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
std::vector<array>& copies,
|
||||
Shape batch_shape = {},
|
||||
Strides A_batch_stride = {},
|
||||
Strides B_batch_stride = {}) {
|
||||
return gemv_axbpy<false>(
|
||||
/* const Stream& s = */ s,
|
||||
/* metal::Device& d = */ d,
|
||||
/* const array& a = */ a,
|
||||
/* const array& b = */ b,
|
||||
/* const array& c = */ b,
|
||||
/* array& out = */ out,
|
||||
/* int M = */ M,
|
||||
/* int N = */ N,
|
||||
/* int K = */ K,
|
||||
/* int batch_size_out = */ batch_size_out,
|
||||
/* int lda = */ lda,
|
||||
/* int ldb = */ ldb,
|
||||
/* bool transpose_a = */ transpose_a,
|
||||
/* bool transpose_b = */ transpose_b,
|
||||
/* std::vector<array>& copies = */ copies,
|
||||
/* Shape batch_shape = */ batch_shape,
|
||||
/* Strides A_batch_stride = */ A_batch_stride,
|
||||
/* Strides B_batch_stride = */ B_batch_stride);
|
||||
}
|
||||
|
||||
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
if (!issubdtype(out.dtype(), floating)) {
|
||||
@ -528,102 +700,26 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Route to gemv if needed
|
||||
if (std::min(M, N) == 1) {
|
||||
// Collect problem info
|
||||
bool is_b_matrix = N != 1;
|
||||
|
||||
auto& mat = is_b_matrix ? b : a;
|
||||
auto& vec = is_b_matrix ? a : b;
|
||||
bool transpose_mat = is_b_matrix ? !b_transposed : a_transposed;
|
||||
int in_vector_len = K;
|
||||
int out_vector_len = is_b_matrix ? N : M;
|
||||
|
||||
int mat_cols = transpose_mat ? out_vector_len : in_vector_len;
|
||||
int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
|
||||
int mat_ld = is_b_matrix ? b_cols : a_cols;
|
||||
|
||||
auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride;
|
||||
auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride;
|
||||
|
||||
int stride_mat = batch_strides_mat.back();
|
||||
int stride_vec = batch_strides_vec.back();
|
||||
|
||||
// Determine if inputs have simple batching / broadcasting
|
||||
bool contiguous_kernel = (batch_shape.size() == 1);
|
||||
|
||||
int batch_ndim = batch_shape.size();
|
||||
|
||||
// Determine dispatch kernel
|
||||
int tm = 4, tn = 4;
|
||||
int sm = 1, sn = 32;
|
||||
int bm = 1, bn = 1;
|
||||
int n_out_per_tgp;
|
||||
std::ostringstream kname;
|
||||
|
||||
if (transpose_mat) {
|
||||
if (in_vector_len >= 8192 && out_vector_len >= 2048) {
|
||||
sm = 4;
|
||||
sn = 8;
|
||||
} else {
|
||||
sm = 8;
|
||||
sn = 4;
|
||||
}
|
||||
|
||||
if (out_vector_len >= 2048) {
|
||||
bn = 16;
|
||||
} else if (out_vector_len >= 512) {
|
||||
bn = 4;
|
||||
} else {
|
||||
bn = 2;
|
||||
}
|
||||
|
||||
// Specialized kernel for very small outputs
|
||||
tn = out_vector_len < tn ? 1 : tn;
|
||||
|
||||
n_out_per_tgp = bn * sn * tn;
|
||||
kname << "gemv_t_" << type_to_name(out);
|
||||
|
||||
} else {
|
||||
bm = out_vector_len >= 4096 ? 8 : 4;
|
||||
sn = 32;
|
||||
|
||||
// Specialized kernel for very small outputs
|
||||
tm = out_vector_len < tm ? 1 : tm;
|
||||
|
||||
n_out_per_tgp = bm * sm * tm;
|
||||
kname << "gemv_" << type_to_name(out);
|
||||
}
|
||||
|
||||
kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm"
|
||||
<< tm << "_tn" << tn;
|
||||
kname << "_nc" << !contiguous_kernel << "_axpby0";
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
||||
MTL::Size group_dims = MTL::Size(32, bn, bm);
|
||||
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
|
||||
|
||||
compute_encoder.set_input_array(mat, 0);
|
||||
compute_encoder.set_input_array(vec, 1);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder.set_bytes(in_vector_len, 4);
|
||||
compute_encoder.set_bytes(out_vector_len, 5);
|
||||
compute_encoder.set_bytes(mat_ld, 6);
|
||||
|
||||
compute_encoder.set_bytes(batch_ndim, 9);
|
||||
compute_encoder.set_vector_bytes(batch_shape, 10);
|
||||
compute_encoder.set_vector_bytes(batch_strides_vec, 11);
|
||||
compute_encoder.set_vector_bytes(batch_strides_mat, 12);
|
||||
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
return;
|
||||
return gemv(
|
||||
/* const Stream& s = */ s,
|
||||
/* metal::Device& d = */ d,
|
||||
/* const array& a = */ a,
|
||||
/* const array& b = */ b,
|
||||
/* array& out = */ out,
|
||||
/* int M = */ M,
|
||||
/* int N = */ N,
|
||||
/* int K = */ K,
|
||||
/* int batch_size_out = */ batch_size_out,
|
||||
/* int lda = */ a_cols,
|
||||
/* int ldb = */ b_cols,
|
||||
/* bool transpose_a = */ a_transposed,
|
||||
/* bool transpose_b = */ b_transposed,
|
||||
/* std::vector<array>& copies = */ copies,
|
||||
/* Shape batch_shape = */ batch_shape,
|
||||
/* Strides A_batch_stride = */ A_batch_stride,
|
||||
/* Strides B_batch_stride = */ B_batch_stride);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Gemm specialization
|
||||
|
||||
@ -641,7 +737,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
/* int ldb = */ b_cols,
|
||||
/* bool transpose_a = */ a_transposed,
|
||||
/* bool transpose_b = */ b_transposed,
|
||||
/* std::vector<array>& = */ copies,
|
||||
/* std::vector<array>& copies = */ copies,
|
||||
/* Shape batch_shape = */ batch_shape,
|
||||
/* Strides A_batch_stride = */ A_batch_stride,
|
||||
/* Strides B_batch_stride = */ B_batch_stride);
|
||||
@ -726,109 +822,28 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Route to gemv if needed
|
||||
if (std::min(M, N) == 1) {
|
||||
// Collect problem info
|
||||
bool is_b_matrix = N != 1;
|
||||
|
||||
auto& mat = is_b_matrix ? b : a;
|
||||
auto& vec = is_b_matrix ? a : b;
|
||||
bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a;
|
||||
int in_vector_len = K;
|
||||
int out_vector_len = is_b_matrix ? N : M;
|
||||
|
||||
int mat_cols = transpose_mat ? out_vector_len : in_vector_len;
|
||||
int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
|
||||
int mat_ld = is_b_matrix ? b_cols : a_cols;
|
||||
|
||||
auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride;
|
||||
auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride;
|
||||
|
||||
int stride_mat = batch_strides_mat.back();
|
||||
int stride_vec = batch_strides_vec.back();
|
||||
|
||||
// Determine if inputs have simple batching / broadcasting
|
||||
bool contiguous_kernel = (batch_shape.size() == 1);
|
||||
|
||||
int batch_ndim = batch_shape.size();
|
||||
|
||||
// Determine dispatch kernel
|
||||
int tm = 4, tn = 4;
|
||||
int sm = 1, sn = 32;
|
||||
int bm = 1, bn = 1;
|
||||
int n_out_per_tgp;
|
||||
std::ostringstream kname;
|
||||
|
||||
if (transpose_mat) {
|
||||
if (in_vector_len >= 8192 && out_vector_len >= 2048) {
|
||||
sm = 4;
|
||||
sn = 8;
|
||||
} else {
|
||||
sm = 8;
|
||||
sn = 4;
|
||||
}
|
||||
|
||||
if (out_vector_len >= 2048) {
|
||||
bn = 16;
|
||||
} else if (out_vector_len >= 512) {
|
||||
bn = 4;
|
||||
} else {
|
||||
bn = 2;
|
||||
}
|
||||
|
||||
// Specialized kernel for very small outputs
|
||||
tn = out_vector_len < tn ? 1 : tn;
|
||||
|
||||
n_out_per_tgp = bn * sn * tn;
|
||||
kname << "gemv_t_" << type_to_name(out);
|
||||
|
||||
} else {
|
||||
bm = out_vector_len >= 4096 ? 8 : 4;
|
||||
sn = 32;
|
||||
|
||||
// Specialized kernel for very small outputs
|
||||
tm = out_vector_len < tm ? 1 : tm;
|
||||
|
||||
n_out_per_tgp = bm * sm * tm;
|
||||
kname << "gemv_" << type_to_name(out);
|
||||
}
|
||||
|
||||
kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm"
|
||||
<< tm << "_tn" << tn;
|
||||
kname << "_nc" << !contiguous_kernel << "_axpby1";
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
||||
MTL::Size group_dims = MTL::Size(32, bn, bm);
|
||||
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
|
||||
|
||||
compute_encoder.set_input_array(mat, 0);
|
||||
compute_encoder.set_input_array(vec, 1);
|
||||
compute_encoder.set_input_array(c, 2);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
compute_encoder.set_bytes(in_vector_len, 4);
|
||||
compute_encoder.set_bytes(out_vector_len, 5);
|
||||
compute_encoder.set_bytes(mat_ld, 6);
|
||||
|
||||
compute_encoder.set_bytes(alpha_, 7);
|
||||
compute_encoder.set_bytes(beta_, 8);
|
||||
|
||||
compute_encoder.set_bytes(batch_ndim, 9);
|
||||
compute_encoder.set_vector_bytes(batch_shape, 10);
|
||||
compute_encoder.set_vector_bytes(batch_strides_vec, 11);
|
||||
compute_encoder.set_vector_bytes(batch_strides_mat, 12);
|
||||
compute_encoder.set_vector_bytes(C_batch_stride, 13);
|
||||
|
||||
int bias_stride = c.strides()[c.ndim() - 1];
|
||||
compute_encoder.set_bytes(bias_stride, 14);
|
||||
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
return;
|
||||
return gemv_axbpy(
|
||||
/* const Stream& s = */ s,
|
||||
/* metal::Device& d = */ d,
|
||||
/* const array& a = */ a,
|
||||
/* const array& b = */ b,
|
||||
/* const array& c = */ c,
|
||||
/* array& out = */ out,
|
||||
/* int M = */ M,
|
||||
/* int N = */ N,
|
||||
/* int K = */ K,
|
||||
/* int batch_size_out = */ batch_size_out,
|
||||
/* int lda = */ lda,
|
||||
/* int ldb = */ ldb,
|
||||
/* bool transpose_a = */ transpose_a,
|
||||
/* bool transpose_b = */ transpose_b,
|
||||
/* std::vector<array>& copies = */ copies,
|
||||
/* Shape batch_shape = */ batch_shape,
|
||||
/* Strides A_batch_stride = */ A_batch_stride,
|
||||
/* Strides B_batch_stride = */ B_batch_stride,
|
||||
/* Strides C_batch_stride = */ C_batch_stride,
|
||||
/* float alpha = */ alpha_,
|
||||
/* float beta = */ beta_);
|
||||
}
|
||||
|
||||
using namespace mlx::steel;
|
||||
|
Loading…
Reference in New Issue
Block a user