diff --git a/llms/mixtral/mixtral.py b/llms/mixtral/mixtral.py index d5e07926..a40f33e0 100644 --- a/llms/mixtral/mixtral.py +++ b/llms/mixtral/mixtral.py @@ -103,7 +103,7 @@ class MOEFeedForward(nn.Module): x = x.reshape(-1, x.shape[-1]) gates = self.gate(x) - inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne] + inds = mx.argpartition(-gates, kth=ne - 1, axis=-1)[:, :ne] scores = mx.softmax( mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32), axis=-1, diff --git a/llms/mlx_lm/models/dbrx.py b/llms/mlx_lm/models/dbrx.py index d0ac2e16..e2b362bb 100644 --- a/llms/mlx_lm/models/dbrx.py +++ b/llms/mlx_lm/models/dbrx.py @@ -143,7 +143,7 @@ class SparseMoeBlock(nn.Module): gates = self.router(x) gates = mx.softmax(gates.astype(mx.float32), axis=-1) - inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]) + inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne - 1, axis=-1)[:, :ne]) scores = mx.take_along_axis(gates, inds, axis=-1) scores = scores / mx.linalg.norm(scores, ord=1, axis=-1, keepdims=True) scores = scores.astype(x.dtype) diff --git a/llms/mlx_lm/models/mixtral.py b/llms/mlx_lm/models/mixtral.py index 6e8b7324..d11a7507 100644 --- a/llms/mlx_lm/models/mixtral.py +++ b/llms/mlx_lm/models/mixtral.py @@ -133,9 +133,7 @@ class MixtralSparseMoeBlock(nn.Module): gates = self.gate(x) - inds = mx.stop_gradient( - mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne] - ) # TODO remove it once we figure out how to fine tune TopK in MOE + inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne - 1, axis=-1)[:, :ne]) scores = mx.softmax( mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32), diff --git a/llms/mlx_lm/models/phixtral.py b/llms/mlx_lm/models/phixtral.py index 341e8984..fa4a24a4 100644 --- a/llms/mlx_lm/models/phixtral.py +++ b/llms/mlx_lm/models/phixtral.py @@ -106,7 +106,7 @@ class MOE(nn.Module): x = x.reshape(-1, x.shape[-1]) gates = self.gate(x) - inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne, axis=-1))[:, :ne] + inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne - 1, axis=-1))[:, :ne] scores = mx.softmax( mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32), axis=-1,