mixtral runs a bit faster

This commit is contained in:
Awni Hannun 2023-12-12 08:36:40 -08:00
parent e42682dced
commit 2ffd0da009
5 changed files with 24 additions and 47 deletions

View File

@ -2,7 +2,8 @@
An example of generating text with Mistral using MLX.
Mistral 7B is one of the top large language models in its size class. It is also fully open source with a permissive license[^1].
Mistral 7B is one of the top large language models in its size class. It is
also fully open source with a permissive license[^1].
### Setup
@ -25,6 +26,8 @@ Then, convert the weights with:
python convert.py
```
The conversion script will save the converted weights in the same location.
### Run
Once you've converted the weights to MLX format, you can generate text with
@ -36,4 +39,6 @@ python mistral.py --prompt "It is a truth universally acknowledged," --temp 0
Run `python mistral.py --help` for more details.
[^1]: Refer to the [blog post](https://mistral.ai/news/announcing-mistral-7b/) and [github repository](https://github.com/mistralai/mistral-src) for more details.
[^1]: Refer to the [blog post](https://mistral.ai/news/announcing-mistral-7b/)
and [github repository](https://github.com/mistralai/mistral-src) for more
details.

View File

@ -2,26 +2,23 @@
import argparse
import numpy as np
from pathlib import Path
import torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Mistral weights to MLX.")
parser.add_argument(
"--torch_model",
"--model_path",
type=str,
default="mistral-7B-v0.1/consolidated.00.pth",
help="The path to the torch model weights",
)
parser.add_argument(
"--mlx_model",
type=str,
default="mistral-7B-v0.1/mlx_mistral_7b.npz",
help="The path to store the mlx model weights",
default="mistral-7B-v0.1/",
help="The path to the Mistral model. The MLX weights will also be saved there.",
)
args = parser.parse_args()
state = torch.load(args.torch_model)
model_path = Path(args.model_path)
state = torch.load(str(model_path / "consolidated.00.pth"))
np.savez(
args.mlx_model, **{k: v.to(torch.float16).numpy() for k, v in state.items()}
str(model_path / "weights.npz"),
**{k: v.to(torch.float16).numpy() for k, v in state.items()}
)

View File

@ -196,7 +196,7 @@ def load_model(folder: str, dtype=mx.float16):
config = json.loads(f.read())
config.pop("sliding_window")
model_args = ModelArgs(**config)
weights = mx.load(str(model_path / "mlx_mistral_7b.npz"))
weights = mx.load(str(model_path / "weights.npz"))
weights = tree_unflatten(list(weights.items()))
weights = tree_map(lambda p: p.astype(dtype), weights)
model = Mistral(model_args)

View File

@ -33,12 +33,12 @@ Now from `mlx-exmaples/mixtral` conver the weights to NumPy so MLX can read them
python convert.py --model_path mixtral-8x7b-32kseqlen/
```
The conversion script will save the new weights in the same location.
The conversion script will save the converted weights in the same location.
After that's done, if you want to clean some stuff up:
```
rm mixtral-8x7b-32kseqlen/*.pth
rm mixtral-8x7b-32kseqlen/*.pth*
```
### Generate

View File

@ -119,7 +119,6 @@ class MOEFeedForward(nn.Module):
self.gate = nn.Linear(args.dim, self.num_experts, bias=False)
def __call__(self, x) -> mx.array:
ne = self.num_experts_per_tok
orig_shape = x.shape
x = x.reshape(-1, x.shape[-1])
@ -128,23 +127,12 @@ class MOEFeedForward(nn.Module):
inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]
scores = mx.softmax(mx.take_along_axis(gates, inds, axis=-1), axis=-1)
# For batch:
if x.shape[0] > 1:
mx.eval(inds)
inds = np.array(inds)
y = mx.zeros((x.shape[0], ne, x.shape[-1]))
for e, expert in enumerate(self.experts):
idx1, idx2 = map(mx.array, np.where(inds == e))
if idx1.size == 0:
continue
y[idx1, idx2] = expert(x[idx1])
y = (y * scores[:, :, None]).sum(axis=1)
# For single:
else:
ys = [self.experts[e](x)[:, :, None] for e in inds.squeeze().tolist()]
y = mx.concatenate(ys, axis=-1)
y = (y * scores[:, None, 0]).sum(axis=-1)
y = []
for xt, st, it in zip(x, scores, inds.tolist()):
yt = mx.concatenate([self.experts[e](xt)[:, None] for e in it], axis=-1)
yt = (yt * st).sum(axis=-1)
y.append(yt[None, :])
y = mx.concatenate(y)
return y.reshape(orig_shape)
@ -300,18 +288,9 @@ if __name__ == "__main__":
print(args.prompt, end="", flush=True)
prompt = mx.array(tokenizer.encode(args.prompt))
mx.eval(prompt)
tokens = []
import time
tic = time.time()
p = True
for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
tokens.append(token)
if p:
mx.eval(tokens)
p = False
prompt_time = time.time() - tic
tic = time.time()
if (len(tokens) % 10) == 0:
mx.eval(tokens)
@ -320,9 +299,5 @@ if __name__ == "__main__":
tokens = []
mx.eval(tokens)
tok_time = time.time() - tic
s = tokenizer.decode([t.item() for t in tokens])
print(s, flush=True)
print("------")
print(f"Prompt time {prompt_time}")
print(f"Token time {tok_time}")