Fix argpartition call in Mixtral and other MOES (#676)

* Update mixtral.py

* fix all moes

---------

Co-authored-by: yuhai-china <yuhai.china@gmail.com>
This commit is contained in:
Awni Hannun
2024-04-12 11:00:56 -07:00
committed by GitHub
parent 9c5554d8ee
commit d3f8e4aee9
4 changed files with 4 additions and 6 deletions

View File

@@ -133,9 +133,7 @@ class MixtralSparseMoeBlock(nn.Module):
gates = self.gate(x)
inds = mx.stop_gradient(
mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]
) # TODO remove it once we figure out how to fine tune TopK in MOE
inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne - 1, axis=-1)[:, :ne])
scores = mx.softmax(
mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32),