mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Route to gather qmm only for many tokens per expert (#2082)
This commit is contained in:
parent
5de6d94a90
commit
3cde719eb7
@ -850,14 +850,14 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
int M = x.shape(-2);
|
int M = x.shape(-2);
|
||||||
int N = out.shape(-1);
|
int N = out.shape(-1);
|
||||||
int B = out.size() / M / N;
|
int B = out.size() / M / N;
|
||||||
|
int E = w.size() / w.shape(-1) / w.shape(-2);
|
||||||
int vector_limit = transpose_ ? get_qmv_batch_limit(K, N, d) : 4;
|
int vector_limit = transpose_ ? get_qmv_batch_limit(K, N, d) : 4;
|
||||||
|
|
||||||
// We are walking x in order and w is also in order so we can batch up the
|
// We are walking x in order and w is also in order so we can batch up the
|
||||||
// matmuls and reuse reading x and w.
|
// matmuls and reuse reading x and w.
|
||||||
//
|
//
|
||||||
// TODO: Tune 16 here a bit better. Maybe also choose it dynamically based
|
// TODO: Tune 16 and 8 here a bit better.
|
||||||
// on B and (w.size() / K / N).
|
if (M == 1 && B >= 16 && right_sorted_ == true && B / E >= 8) {
|
||||||
if (M == 1 && B >= 16 && right_sorted_ == true) {
|
|
||||||
gather_qmm_rhs(
|
gather_qmm_rhs(
|
||||||
x,
|
x,
|
||||||
w,
|
w,
|
||||||
|
Loading…
Reference in New Issue
Block a user