* dbrx

* format

* format

* comments

* change scores slightly

* remove inadvertant import
This commit is contained in:
Awni Hannun
2024-03-28 21:03:53 -07:00
committed by GitHub
parent 297a908e3d
commit b80adbcc3e
4 changed files with 259 additions and 3 deletions

View File

@@ -125,7 +125,7 @@ class MOE(nn.Module):
else:
y = []
for xt, st, it in zip(x, scores, inds.tolist()):
yt = mx.concatenate([self.mlp[e](xt)[:, None] for e in it], axis=-1)
yt = mx.stack([self.mlp[e](xt) for e in it], axis=-1)
yt = (yt * st).sum(axis=-1)
y.append(yt[None, :])
y = mx.concatenate(y)