* dbrx

* format

* format

* comments

* change scores slightly

* remove inadvertant import
This commit is contained in:
Awni Hannun
2024-03-28 21:03:53 -07:00
committed by GitHub
parent 297a908e3d
commit b80adbcc3e
4 changed files with 259 additions and 3 deletions

View File

@@ -143,7 +143,6 @@ class MixtralSparseMoeBlock(nn.Module):
).astype(gates.dtype)
if self.training:
mx.eval(inds)
inds = np.array(inds)
y = mx.zeros((x.shape[0], ne, x.shape[-1]), x.dtype)
for e, expert in enumerate(self.experts):
@@ -156,7 +155,7 @@ class MixtralSparseMoeBlock(nn.Module):
else:
y = []
for xt, st, it in zip(x, scores, inds.tolist()):
yt = mx.concatenate([self.experts[e](xt)[:, None] for e in it], axis=-1)
yt = mx.stack([self.experts[e](xt) for e in it], axis=-1)
yt = (yt * st).sum(axis=-1)
y.append(yt[None, :])
y = mx.concatenate(y)