mixtral runs a bit faster

This commit is contained in:
Awni Hannun
2023-12-12 08:36:40 -08:00
parent e42682dced
commit 2ffd0da009
5 changed files with 24 additions and 47 deletions

View File

@@ -119,7 +119,6 @@ class MOEFeedForward(nn.Module):
self.gate = nn.Linear(args.dim, self.num_experts, bias=False)
def __call__(self, x) -> mx.array:
ne = self.num_experts_per_tok
orig_shape = x.shape
x = x.reshape(-1, x.shape[-1])
@@ -128,23 +127,12 @@ class MOEFeedForward(nn.Module):
inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]
scores = mx.softmax(mx.take_along_axis(gates, inds, axis=-1), axis=-1)
# For batch:
if x.shape[0] > 1:
mx.eval(inds)
inds = np.array(inds)
y = mx.zeros((x.shape[0], ne, x.shape[-1]))
for e, expert in enumerate(self.experts):
idx1, idx2 = map(mx.array, np.where(inds == e))
if idx1.size == 0:
continue
y[idx1, idx2] = expert(x[idx1])
y = (y * scores[:, :, None]).sum(axis=1)
# For single:
else:
ys = [self.experts[e](x)[:, :, None] for e in inds.squeeze().tolist()]
y = mx.concatenate(ys, axis=-1)
y = (y * scores[:, None, 0]).sum(axis=-1)
y = []
for xt, st, it in zip(x, scores, inds.tolist()):
yt = mx.concatenate([self.experts[e](xt)[:, None] for e in it], axis=-1)
yt = (yt * st).sum(axis=-1)
y.append(yt[None, :])
y = mx.concatenate(y)
return y.reshape(orig_shape)
@@ -300,18 +288,9 @@ if __name__ == "__main__":
print(args.prompt, end="", flush=True)
prompt = mx.array(tokenizer.encode(args.prompt))
mx.eval(prompt)
tokens = []
import time
tic = time.time()
p = True
for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
tokens.append(token)
if p:
mx.eval(tokens)
p = False
prompt_time = time.time() - tic
tic = time.time()
if (len(tokens) % 10) == 0:
mx.eval(tokens)
@@ -320,9 +299,5 @@ if __name__ == "__main__":
tokens = []
mx.eval(tokens)
tok_time = time.time() - tic
s = tokenizer.decode([t.item() for t in tokens])
print(s, flush=True)
print("------")
print(f"Prompt time {prompt_time}")
print(f"Token time {tok_time}")