mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
DBRX (#628)
* dbrx * format * format * comments * change scores slightly * remove inadvertant import
This commit is contained in:
@@ -125,7 +125,7 @@ class MOE(nn.Module):
|
||||
else:
|
||||
y = []
|
||||
for xt, st, it in zip(x, scores, inds.tolist()):
|
||||
yt = mx.concatenate([self.mlp[e](xt)[:, None] for e in it], axis=-1)
|
||||
yt = mx.stack([self.mlp[e](xt) for e in it], axis=-1)
|
||||
yt = (yt * st).sum(axis=-1)
|
||||
y.append(yt[None, :])
|
||||
y = mx.concatenate(y)
|
||||
|
||||
Reference in New Issue
Block a user