mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
DBRX (#628)
* dbrx * format * format * comments * change scores slightly * remove inadvertant import
This commit is contained in:
@@ -143,7 +143,6 @@ class MixtralSparseMoeBlock(nn.Module):
|
||||
).astype(gates.dtype)
|
||||
|
||||
if self.training:
|
||||
mx.eval(inds)
|
||||
inds = np.array(inds)
|
||||
y = mx.zeros((x.shape[0], ne, x.shape[-1]), x.dtype)
|
||||
for e, expert in enumerate(self.experts):
|
||||
@@ -156,7 +155,7 @@ class MixtralSparseMoeBlock(nn.Module):
|
||||
else:
|
||||
y = []
|
||||
for xt, st, it in zip(x, scores, inds.tolist()):
|
||||
yt = mx.concatenate([self.experts[e](xt)[:, None] for e in it], axis=-1)
|
||||
yt = mx.stack([self.experts[e](xt) for e in it], axis=-1)
|
||||
yt = (yt * st).sum(axis=-1)
|
||||
y.append(yt[None, :])
|
||||
y = mx.concatenate(y)
|
||||
|
Reference in New Issue
Block a user