mixtral runs a bit faster

This commit is contained in:
Awni Hannun
2023-12-12 08:36:40 -08:00
parent e42682dced
commit 2ffd0da009
5 changed files with 24 additions and 47 deletions

View File

@@ -33,12 +33,12 @@ Now from `mlx-exmaples/mixtral` conver the weights to NumPy so MLX can read them
python convert.py --model_path mixtral-8x7b-32kseqlen/
```
The conversion script will save the new weights in the same location.
The conversion script will save the converted weights in the same location.
After that's done, if you want to clean some stuff up:
```
rm mixtral-8x7b-32kseqlen/*.pth
rm mixtral-8x7b-32kseqlen/*.pth*
```
### Generate

View File

@@ -119,7 +119,6 @@ class MOEFeedForward(nn.Module):
self.gate = nn.Linear(args.dim, self.num_experts, bias=False)
def __call__(self, x) -> mx.array:
ne = self.num_experts_per_tok
orig_shape = x.shape
x = x.reshape(-1, x.shape[-1])
@@ -128,23 +127,12 @@ class MOEFeedForward(nn.Module):
inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]
scores = mx.softmax(mx.take_along_axis(gates, inds, axis=-1), axis=-1)
# For batch:
if x.shape[0] > 1:
mx.eval(inds)
inds = np.array(inds)
y = mx.zeros((x.shape[0], ne, x.shape[-1]))
for e, expert in enumerate(self.experts):
idx1, idx2 = map(mx.array, np.where(inds == e))
if idx1.size == 0:
continue
y[idx1, idx2] = expert(x[idx1])
y = (y * scores[:, :, None]).sum(axis=1)
# For single:
else:
ys = [self.experts[e](x)[:, :, None] for e in inds.squeeze().tolist()]
y = mx.concatenate(ys, axis=-1)
y = (y * scores[:, None, 0]).sum(axis=-1)
y = []
for xt, st, it in zip(x, scores, inds.tolist()):
yt = mx.concatenate([self.experts[e](xt)[:, None] for e in it], axis=-1)
yt = (yt * st).sum(axis=-1)
y.append(yt[None, :])
y = mx.concatenate(y)
return y.reshape(orig_shape)
@@ -300,18 +288,9 @@ if __name__ == "__main__":
print(args.prompt, end="", flush=True)
prompt = mx.array(tokenizer.encode(args.prompt))
mx.eval(prompt)
tokens = []
import time
tic = time.time()
p = True
for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
tokens.append(token)
if p:
mx.eval(tokens)
p = False
prompt_time = time.time() - tic
tic = time.time()
if (len(tokens) % 10) == 0:
mx.eval(tokens)
@@ -320,9 +299,5 @@ if __name__ == "__main__":
tokens = []
mx.eval(tokens)
tok_time = time.time() - tic
s = tokenizer.decode([t.item() for t in tokens])
print(s, flush=True)
print("------")
print(f"Prompt time {prompt_time}")
print(f"Token time {tok_time}")