Route to gather qmm only for many tokens per expert (#2082)

This commit is contained in:
Angelos Katharopoulos 2025-04-17 14:53:08 -07:00 committed by GitHub
parent 5de6d94a90
commit 3cde719eb7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -850,14 +850,14 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
int M = x.shape(-2);
int N = out.shape(-1);
int B = out.size() / M / N;
int E = w.size() / w.shape(-1) / w.shape(-2);
int vector_limit = transpose_ ? get_qmv_batch_limit(K, N, d) : 4;
// We are walking x in order and w is also in order so we can batch up the
// matmuls and reuse reading x and w.
//
// TODO: Tune 16 here a bit better. Maybe also choose it dynamically based
// on B and (w.size() / K / N).
if (M == 1 && B >= 16 && right_sorted_ == true) {
// TODO: Tune 16 and 8 here a bit better.
if (M == 1 && B >= 16 && right_sorted_ == true && B / E >= 8) {
gather_qmm_rhs(
x,
w,