mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 09:51:17 +08:00
Refactor gemv into a function
This commit is contained in:
parent
c8b4787e4e
commit
0c69f10d55
@ -470,6 +470,178 @@ void steel_matmul(
|
|||||||
copies);
|
copies);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <bool CHECK_AB = true>
|
||||||
|
void gemv_axbpy(
|
||||||
|
const Stream& s,
|
||||||
|
metal::Device& d,
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
const array& c,
|
||||||
|
array& out,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K,
|
||||||
|
int batch_size_out,
|
||||||
|
int lda,
|
||||||
|
int ldb,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
std::vector<array>& copies,
|
||||||
|
Shape batch_shape = {},
|
||||||
|
Strides A_batch_stride = {},
|
||||||
|
Strides B_batch_stride = {},
|
||||||
|
Strides C_batch_stride = {},
|
||||||
|
float alpha = 1.0f,
|
||||||
|
float beta = 0.0f) {
|
||||||
|
// Collect problem info
|
||||||
|
bool is_b_matrix = N != 1;
|
||||||
|
|
||||||
|
auto& mat = is_b_matrix ? b : a;
|
||||||
|
auto& vec = is_b_matrix ? a : b;
|
||||||
|
bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a;
|
||||||
|
int in_vector_len = K;
|
||||||
|
int out_vector_len = is_b_matrix ? N : M;
|
||||||
|
|
||||||
|
int mat_cols = transpose_mat ? out_vector_len : in_vector_len;
|
||||||
|
int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
|
||||||
|
int mat_ld = is_b_matrix ? ldb : lda;
|
||||||
|
|
||||||
|
auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride;
|
||||||
|
auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride;
|
||||||
|
|
||||||
|
int stride_mat = batch_strides_mat.back();
|
||||||
|
int stride_vec = batch_strides_vec.back();
|
||||||
|
|
||||||
|
// Determine if inputs have simple batching / broadcasting
|
||||||
|
bool contiguous_kernel = (batch_shape.size() == 1);
|
||||||
|
|
||||||
|
int batch_ndim = batch_shape.size();
|
||||||
|
|
||||||
|
// Determine dispatch kernel
|
||||||
|
int tm = 4, tn = 4;
|
||||||
|
int sm = 1, sn = 32;
|
||||||
|
int bm = 1, bn = 1;
|
||||||
|
int n_out_per_tgp;
|
||||||
|
std::ostringstream kname;
|
||||||
|
|
||||||
|
if (transpose_mat) {
|
||||||
|
if (in_vector_len >= 8192 && out_vector_len >= 2048) {
|
||||||
|
sm = 4;
|
||||||
|
sn = 8;
|
||||||
|
} else {
|
||||||
|
sm = 8;
|
||||||
|
sn = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (out_vector_len >= 2048) {
|
||||||
|
bn = 16;
|
||||||
|
} else if (out_vector_len >= 512) {
|
||||||
|
bn = 4;
|
||||||
|
} else {
|
||||||
|
bn = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Specialized kernel for very small outputs
|
||||||
|
tn = out_vector_len < tn ? 1 : tn;
|
||||||
|
|
||||||
|
n_out_per_tgp = bn * sn * tn;
|
||||||
|
kname << "gemv_t_" << type_to_name(out);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
bm = out_vector_len >= 4096 ? 8 : 4;
|
||||||
|
sn = 32;
|
||||||
|
|
||||||
|
// Specialized kernel for very small outputs
|
||||||
|
tm = out_vector_len < tm ? 1 : tm;
|
||||||
|
|
||||||
|
n_out_per_tgp = bm * sm * tm;
|
||||||
|
kname << "gemv_" << type_to_name(out);
|
||||||
|
}
|
||||||
|
|
||||||
|
const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f);
|
||||||
|
|
||||||
|
kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm"
|
||||||
|
<< tm << "_tn" << tn;
|
||||||
|
kname << "_nc" << !contiguous_kernel << "_axpby" << do_axpby;
|
||||||
|
|
||||||
|
// Encode and dispatch kernel
|
||||||
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
auto kernel = d.get_kernel(kname.str());
|
||||||
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
|
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
||||||
|
MTL::Size group_dims = MTL::Size(32, bn, bm);
|
||||||
|
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
|
||||||
|
|
||||||
|
compute_encoder.set_input_array(mat, 0);
|
||||||
|
compute_encoder.set_input_array(vec, 1);
|
||||||
|
compute_encoder.set_output_array(out, 3);
|
||||||
|
|
||||||
|
compute_encoder.set_bytes(in_vector_len, 4);
|
||||||
|
compute_encoder.set_bytes(out_vector_len, 5);
|
||||||
|
compute_encoder.set_bytes(mat_ld, 6);
|
||||||
|
|
||||||
|
compute_encoder.set_bytes(batch_ndim, 9);
|
||||||
|
compute_encoder.set_vector_bytes(batch_shape, 10);
|
||||||
|
compute_encoder.set_vector_bytes(batch_strides_vec, 11);
|
||||||
|
compute_encoder.set_vector_bytes(batch_strides_mat, 12);
|
||||||
|
|
||||||
|
if (do_axpby) {
|
||||||
|
compute_encoder.set_input_array(c, 2);
|
||||||
|
|
||||||
|
compute_encoder.set_bytes(alpha, 7);
|
||||||
|
compute_encoder.set_bytes(beta, 8);
|
||||||
|
|
||||||
|
compute_encoder.set_vector_bytes(C_batch_stride, 13);
|
||||||
|
|
||||||
|
int bias_stride = c.strides()[c.ndim() - 1];
|
||||||
|
compute_encoder.set_bytes(bias_stride, 14);
|
||||||
|
}
|
||||||
|
|
||||||
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
|
|
||||||
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void gemv(
|
||||||
|
const Stream& s,
|
||||||
|
metal::Device& d,
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K,
|
||||||
|
int batch_size_out,
|
||||||
|
int lda,
|
||||||
|
int ldb,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
std::vector<array>& copies,
|
||||||
|
Shape batch_shape = {},
|
||||||
|
Strides A_batch_stride = {},
|
||||||
|
Strides B_batch_stride = {}) {
|
||||||
|
return gemv_axbpy<false>(
|
||||||
|
/* const Stream& s = */ s,
|
||||||
|
/* metal::Device& d = */ d,
|
||||||
|
/* const array& a = */ a,
|
||||||
|
/* const array& b = */ b,
|
||||||
|
/* const array& c = */ b,
|
||||||
|
/* array& out = */ out,
|
||||||
|
/* int M = */ M,
|
||||||
|
/* int N = */ N,
|
||||||
|
/* int K = */ K,
|
||||||
|
/* int batch_size_out = */ batch_size_out,
|
||||||
|
/* int lda = */ lda,
|
||||||
|
/* int ldb = */ ldb,
|
||||||
|
/* bool transpose_a = */ transpose_a,
|
||||||
|
/* bool transpose_b = */ transpose_b,
|
||||||
|
/* std::vector<array>& copies = */ copies,
|
||||||
|
/* Shape batch_shape = */ batch_shape,
|
||||||
|
/* Strides A_batch_stride = */ A_batch_stride,
|
||||||
|
/* Strides B_batch_stride = */ B_batch_stride);
|
||||||
|
}
|
||||||
|
|
||||||
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
if (!issubdtype(out.dtype(), floating)) {
|
if (!issubdtype(out.dtype(), floating)) {
|
||||||
@ -528,102 +700,26 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
// Route to gemv if needed
|
// Route to gemv if needed
|
||||||
if (std::min(M, N) == 1) {
|
if (std::min(M, N) == 1) {
|
||||||
// Collect problem info
|
return gemv(
|
||||||
bool is_b_matrix = N != 1;
|
/* const Stream& s = */ s,
|
||||||
|
/* metal::Device& d = */ d,
|
||||||
auto& mat = is_b_matrix ? b : a;
|
/* const array& a = */ a,
|
||||||
auto& vec = is_b_matrix ? a : b;
|
/* const array& b = */ b,
|
||||||
bool transpose_mat = is_b_matrix ? !b_transposed : a_transposed;
|
/* array& out = */ out,
|
||||||
int in_vector_len = K;
|
/* int M = */ M,
|
||||||
int out_vector_len = is_b_matrix ? N : M;
|
/* int N = */ N,
|
||||||
|
/* int K = */ K,
|
||||||
int mat_cols = transpose_mat ? out_vector_len : in_vector_len;
|
/* int batch_size_out = */ batch_size_out,
|
||||||
int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
|
/* int lda = */ a_cols,
|
||||||
int mat_ld = is_b_matrix ? b_cols : a_cols;
|
/* int ldb = */ b_cols,
|
||||||
|
/* bool transpose_a = */ a_transposed,
|
||||||
auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride;
|
/* bool transpose_b = */ b_transposed,
|
||||||
auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride;
|
/* std::vector<array>& copies = */ copies,
|
||||||
|
/* Shape batch_shape = */ batch_shape,
|
||||||
int stride_mat = batch_strides_mat.back();
|
/* Strides A_batch_stride = */ A_batch_stride,
|
||||||
int stride_vec = batch_strides_vec.back();
|
/* Strides B_batch_stride = */ B_batch_stride);
|
||||||
|
|
||||||
// Determine if inputs have simple batching / broadcasting
|
|
||||||
bool contiguous_kernel = (batch_shape.size() == 1);
|
|
||||||
|
|
||||||
int batch_ndim = batch_shape.size();
|
|
||||||
|
|
||||||
// Determine dispatch kernel
|
|
||||||
int tm = 4, tn = 4;
|
|
||||||
int sm = 1, sn = 32;
|
|
||||||
int bm = 1, bn = 1;
|
|
||||||
int n_out_per_tgp;
|
|
||||||
std::ostringstream kname;
|
|
||||||
|
|
||||||
if (transpose_mat) {
|
|
||||||
if (in_vector_len >= 8192 && out_vector_len >= 2048) {
|
|
||||||
sm = 4;
|
|
||||||
sn = 8;
|
|
||||||
} else {
|
|
||||||
sm = 8;
|
|
||||||
sn = 4;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (out_vector_len >= 2048) {
|
|
||||||
bn = 16;
|
|
||||||
} else if (out_vector_len >= 512) {
|
|
||||||
bn = 4;
|
|
||||||
} else {
|
|
||||||
bn = 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Specialized kernel for very small outputs
|
|
||||||
tn = out_vector_len < tn ? 1 : tn;
|
|
||||||
|
|
||||||
n_out_per_tgp = bn * sn * tn;
|
|
||||||
kname << "gemv_t_" << type_to_name(out);
|
|
||||||
|
|
||||||
} else {
|
|
||||||
bm = out_vector_len >= 4096 ? 8 : 4;
|
|
||||||
sn = 32;
|
|
||||||
|
|
||||||
// Specialized kernel for very small outputs
|
|
||||||
tm = out_vector_len < tm ? 1 : tm;
|
|
||||||
|
|
||||||
n_out_per_tgp = bm * sm * tm;
|
|
||||||
kname << "gemv_" << type_to_name(out);
|
|
||||||
}
|
|
||||||
|
|
||||||
kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm"
|
|
||||||
<< tm << "_tn" << tn;
|
|
||||||
kname << "_nc" << !contiguous_kernel << "_axpby0";
|
|
||||||
|
|
||||||
// Encode and dispatch kernel
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
||||||
auto kernel = d.get_kernel(kname.str());
|
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
|
||||||
|
|
||||||
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
|
||||||
MTL::Size group_dims = MTL::Size(32, bn, bm);
|
|
||||||
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
|
|
||||||
|
|
||||||
compute_encoder.set_input_array(mat, 0);
|
|
||||||
compute_encoder.set_input_array(vec, 1);
|
|
||||||
compute_encoder.set_output_array(out, 3);
|
|
||||||
|
|
||||||
compute_encoder.set_bytes(in_vector_len, 4);
|
|
||||||
compute_encoder.set_bytes(out_vector_len, 5);
|
|
||||||
compute_encoder.set_bytes(mat_ld, 6);
|
|
||||||
|
|
||||||
compute_encoder.set_bytes(batch_ndim, 9);
|
|
||||||
compute_encoder.set_vector_bytes(batch_shape, 10);
|
|
||||||
compute_encoder.set_vector_bytes(batch_strides_vec, 11);
|
|
||||||
compute_encoder.set_vector_bytes(batch_strides_mat, 12);
|
|
||||||
|
|
||||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
|
||||||
|
|
||||||
d.add_temporaries(std::move(copies), s.index);
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
// Gemm specialization
|
// Gemm specialization
|
||||||
|
|
||||||
@ -641,7 +737,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
/* int ldb = */ b_cols,
|
/* int ldb = */ b_cols,
|
||||||
/* bool transpose_a = */ a_transposed,
|
/* bool transpose_a = */ a_transposed,
|
||||||
/* bool transpose_b = */ b_transposed,
|
/* bool transpose_b = */ b_transposed,
|
||||||
/* std::vector<array>& = */ copies,
|
/* std::vector<array>& copies = */ copies,
|
||||||
/* Shape batch_shape = */ batch_shape,
|
/* Shape batch_shape = */ batch_shape,
|
||||||
/* Strides A_batch_stride = */ A_batch_stride,
|
/* Strides A_batch_stride = */ A_batch_stride,
|
||||||
/* Strides B_batch_stride = */ B_batch_stride);
|
/* Strides B_batch_stride = */ B_batch_stride);
|
||||||
@ -726,109 +822,28 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
// Route to gemv if needed
|
// Route to gemv if needed
|
||||||
if (std::min(M, N) == 1) {
|
if (std::min(M, N) == 1) {
|
||||||
// Collect problem info
|
return gemv_axbpy(
|
||||||
bool is_b_matrix = N != 1;
|
/* const Stream& s = */ s,
|
||||||
|
/* metal::Device& d = */ d,
|
||||||
auto& mat = is_b_matrix ? b : a;
|
/* const array& a = */ a,
|
||||||
auto& vec = is_b_matrix ? a : b;
|
/* const array& b = */ b,
|
||||||
bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a;
|
/* const array& c = */ c,
|
||||||
int in_vector_len = K;
|
/* array& out = */ out,
|
||||||
int out_vector_len = is_b_matrix ? N : M;
|
/* int M = */ M,
|
||||||
|
/* int N = */ N,
|
||||||
int mat_cols = transpose_mat ? out_vector_len : in_vector_len;
|
/* int K = */ K,
|
||||||
int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
|
/* int batch_size_out = */ batch_size_out,
|
||||||
int mat_ld = is_b_matrix ? b_cols : a_cols;
|
/* int lda = */ lda,
|
||||||
|
/* int ldb = */ ldb,
|
||||||
auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride;
|
/* bool transpose_a = */ transpose_a,
|
||||||
auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride;
|
/* bool transpose_b = */ transpose_b,
|
||||||
|
/* std::vector<array>& copies = */ copies,
|
||||||
int stride_mat = batch_strides_mat.back();
|
/* Shape batch_shape = */ batch_shape,
|
||||||
int stride_vec = batch_strides_vec.back();
|
/* Strides A_batch_stride = */ A_batch_stride,
|
||||||
|
/* Strides B_batch_stride = */ B_batch_stride,
|
||||||
// Determine if inputs have simple batching / broadcasting
|
/* Strides C_batch_stride = */ C_batch_stride,
|
||||||
bool contiguous_kernel = (batch_shape.size() == 1);
|
/* float alpha = */ alpha_,
|
||||||
|
/* float beta = */ beta_);
|
||||||
int batch_ndim = batch_shape.size();
|
|
||||||
|
|
||||||
// Determine dispatch kernel
|
|
||||||
int tm = 4, tn = 4;
|
|
||||||
int sm = 1, sn = 32;
|
|
||||||
int bm = 1, bn = 1;
|
|
||||||
int n_out_per_tgp;
|
|
||||||
std::ostringstream kname;
|
|
||||||
|
|
||||||
if (transpose_mat) {
|
|
||||||
if (in_vector_len >= 8192 && out_vector_len >= 2048) {
|
|
||||||
sm = 4;
|
|
||||||
sn = 8;
|
|
||||||
} else {
|
|
||||||
sm = 8;
|
|
||||||
sn = 4;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (out_vector_len >= 2048) {
|
|
||||||
bn = 16;
|
|
||||||
} else if (out_vector_len >= 512) {
|
|
||||||
bn = 4;
|
|
||||||
} else {
|
|
||||||
bn = 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Specialized kernel for very small outputs
|
|
||||||
tn = out_vector_len < tn ? 1 : tn;
|
|
||||||
|
|
||||||
n_out_per_tgp = bn * sn * tn;
|
|
||||||
kname << "gemv_t_" << type_to_name(out);
|
|
||||||
|
|
||||||
} else {
|
|
||||||
bm = out_vector_len >= 4096 ? 8 : 4;
|
|
||||||
sn = 32;
|
|
||||||
|
|
||||||
// Specialized kernel for very small outputs
|
|
||||||
tm = out_vector_len < tm ? 1 : tm;
|
|
||||||
|
|
||||||
n_out_per_tgp = bm * sm * tm;
|
|
||||||
kname << "gemv_" << type_to_name(out);
|
|
||||||
}
|
|
||||||
|
|
||||||
kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm"
|
|
||||||
<< tm << "_tn" << tn;
|
|
||||||
kname << "_nc" << !contiguous_kernel << "_axpby1";
|
|
||||||
|
|
||||||
// Encode and dispatch kernel
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
||||||
auto kernel = d.get_kernel(kname.str());
|
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
|
||||||
|
|
||||||
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
|
||||||
MTL::Size group_dims = MTL::Size(32, bn, bm);
|
|
||||||
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
|
|
||||||
|
|
||||||
compute_encoder.set_input_array(mat, 0);
|
|
||||||
compute_encoder.set_input_array(vec, 1);
|
|
||||||
compute_encoder.set_input_array(c, 2);
|
|
||||||
compute_encoder.set_output_array(out, 3);
|
|
||||||
|
|
||||||
compute_encoder.set_bytes(in_vector_len, 4);
|
|
||||||
compute_encoder.set_bytes(out_vector_len, 5);
|
|
||||||
compute_encoder.set_bytes(mat_ld, 6);
|
|
||||||
|
|
||||||
compute_encoder.set_bytes(alpha_, 7);
|
|
||||||
compute_encoder.set_bytes(beta_, 8);
|
|
||||||
|
|
||||||
compute_encoder.set_bytes(batch_ndim, 9);
|
|
||||||
compute_encoder.set_vector_bytes(batch_shape, 10);
|
|
||||||
compute_encoder.set_vector_bytes(batch_strides_vec, 11);
|
|
||||||
compute_encoder.set_vector_bytes(batch_strides_mat, 12);
|
|
||||||
compute_encoder.set_vector_bytes(C_batch_stride, 13);
|
|
||||||
|
|
||||||
int bias_stride = c.strides()[c.ndim() - 1];
|
|
||||||
compute_encoder.set_bytes(bias_stride, 14);
|
|
||||||
|
|
||||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
|
||||||
|
|
||||||
d.add_temporaries(std::move(copies), s.index);
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
using namespace mlx::steel;
|
using namespace mlx::steel;
|
||||||
|
Loading…
Reference in New Issue
Block a user