mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
mixtral runs a bit faster
This commit is contained in:
@@ -119,7 +119,6 @@ class MOEFeedForward(nn.Module):
|
||||
self.gate = nn.Linear(args.dim, self.num_experts, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
|
||||
ne = self.num_experts_per_tok
|
||||
orig_shape = x.shape
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
@@ -128,23 +127,12 @@ class MOEFeedForward(nn.Module):
|
||||
inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]
|
||||
scores = mx.softmax(mx.take_along_axis(gates, inds, axis=-1), axis=-1)
|
||||
|
||||
# For batch:
|
||||
if x.shape[0] > 1:
|
||||
mx.eval(inds)
|
||||
inds = np.array(inds)
|
||||
y = mx.zeros((x.shape[0], ne, x.shape[-1]))
|
||||
for e, expert in enumerate(self.experts):
|
||||
idx1, idx2 = map(mx.array, np.where(inds == e))
|
||||
if idx1.size == 0:
|
||||
continue
|
||||
y[idx1, idx2] = expert(x[idx1])
|
||||
y = (y * scores[:, :, None]).sum(axis=1)
|
||||
|
||||
# For single:
|
||||
else:
|
||||
ys = [self.experts[e](x)[:, :, None] for e in inds.squeeze().tolist()]
|
||||
y = mx.concatenate(ys, axis=-1)
|
||||
y = (y * scores[:, None, 0]).sum(axis=-1)
|
||||
y = []
|
||||
for xt, st, it in zip(x, scores, inds.tolist()):
|
||||
yt = mx.concatenate([self.experts[e](xt)[:, None] for e in it], axis=-1)
|
||||
yt = (yt * st).sum(axis=-1)
|
||||
y.append(yt[None, :])
|
||||
y = mx.concatenate(y)
|
||||
|
||||
return y.reshape(orig_shape)
|
||||
|
||||
@@ -300,18 +288,9 @@ if __name__ == "__main__":
|
||||
|
||||
print(args.prompt, end="", flush=True)
|
||||
prompt = mx.array(tokenizer.encode(args.prompt))
|
||||
mx.eval(prompt)
|
||||
tokens = []
|
||||
import time
|
||||
tic = time.time()
|
||||
p = True
|
||||
for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
|
||||
tokens.append(token)
|
||||
if p:
|
||||
mx.eval(tokens)
|
||||
p = False
|
||||
prompt_time = time.time() - tic
|
||||
tic = time.time()
|
||||
|
||||
if (len(tokens) % 10) == 0:
|
||||
mx.eval(tokens)
|
||||
@@ -320,9 +299,5 @@ if __name__ == "__main__":
|
||||
tokens = []
|
||||
|
||||
mx.eval(tokens)
|
||||
tok_time = time.time() - tic
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
print(s, flush=True)
|
||||
print("------")
|
||||
print(f"Prompt time {prompt_time}")
|
||||
print(f"Token time {tok_time}")
|
||||
|
Reference in New Issue
Block a user