Fix argpartition call in Mixtral and other MOES (#676)

* Update mixtral.py

* fix all moes

---------

Co-authored-by: yuhai-china <yuhai.china@gmail.com>
This commit is contained in:
Awni Hannun 2024-04-12 11:00:56 -07:00 committed by GitHub
parent 9c5554d8ee
commit d3f8e4aee9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 4 additions and 6 deletions

View File

@ -103,7 +103,7 @@ class MOEFeedForward(nn.Module):
x = x.reshape(-1, x.shape[-1]) x = x.reshape(-1, x.shape[-1])
gates = self.gate(x) gates = self.gate(x)
inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne] inds = mx.argpartition(-gates, kth=ne - 1, axis=-1)[:, :ne]
scores = mx.softmax( scores = mx.softmax(
mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32), mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32),
axis=-1, axis=-1,

View File

@ -143,7 +143,7 @@ class SparseMoeBlock(nn.Module):
gates = self.router(x) gates = self.router(x)
gates = mx.softmax(gates.astype(mx.float32), axis=-1) gates = mx.softmax(gates.astype(mx.float32), axis=-1)
inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]) inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne - 1, axis=-1)[:, :ne])
scores = mx.take_along_axis(gates, inds, axis=-1) scores = mx.take_along_axis(gates, inds, axis=-1)
scores = scores / mx.linalg.norm(scores, ord=1, axis=-1, keepdims=True) scores = scores / mx.linalg.norm(scores, ord=1, axis=-1, keepdims=True)
scores = scores.astype(x.dtype) scores = scores.astype(x.dtype)

View File

@ -133,9 +133,7 @@ class MixtralSparseMoeBlock(nn.Module):
gates = self.gate(x) gates = self.gate(x)
inds = mx.stop_gradient( inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne - 1, axis=-1)[:, :ne])
mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]
) # TODO remove it once we figure out how to fine tune TopK in MOE
scores = mx.softmax( scores = mx.softmax(
mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32), mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32),

View File

@ -106,7 +106,7 @@ class MOE(nn.Module):
x = x.reshape(-1, x.shape[-1]) x = x.reshape(-1, x.shape[-1])
gates = self.gate(x) gates = self.gate(x)
inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne, axis=-1))[:, :ne] inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne - 1, axis=-1))[:, :ne]
scores = mx.softmax( scores = mx.softmax(
mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32), mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32),
axis=-1, axis=-1,