mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Route to gather qmm only for many tokens per expert (#2082)
This commit is contained in:
parent
5de6d94a90
commit
3cde719eb7
@ -850,14 +850,14 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int M = x.shape(-2);
|
||||
int N = out.shape(-1);
|
||||
int B = out.size() / M / N;
|
||||
int E = w.size() / w.shape(-1) / w.shape(-2);
|
||||
int vector_limit = transpose_ ? get_qmv_batch_limit(K, N, d) : 4;
|
||||
|
||||
// We are walking x in order and w is also in order so we can batch up the
|
||||
// matmuls and reuse reading x and w.
|
||||
//
|
||||
// TODO: Tune 16 here a bit better. Maybe also choose it dynamically based
|
||||
// on B and (w.size() / K / N).
|
||||
if (M == 1 && B >= 16 && right_sorted_ == true) {
|
||||
// TODO: Tune 16 and 8 here a bit better.
|
||||
if (M == 1 && B >= 16 && right_sorted_ == true && B / E >= 8) {
|
||||
gather_qmm_rhs(
|
||||
x,
|
||||
w,
|
||||
|
Loading…
Reference in New Issue
Block a user