mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
mixtral runs a bit faster
This commit is contained in:
parent
e42682dced
commit
2ffd0da009
@ -2,7 +2,8 @@
|
||||
|
||||
An example of generating text with Mistral using MLX.
|
||||
|
||||
Mistral 7B is one of the top large language models in its size class. It is also fully open source with a permissive license[^1].
|
||||
Mistral 7B is one of the top large language models in its size class. It is
|
||||
also fully open source with a permissive license[^1].
|
||||
|
||||
### Setup
|
||||
|
||||
@ -25,6 +26,8 @@ Then, convert the weights with:
|
||||
python convert.py
|
||||
```
|
||||
|
||||
The conversion script will save the converted weights in the same location.
|
||||
|
||||
### Run
|
||||
|
||||
Once you've converted the weights to MLX format, you can generate text with
|
||||
@ -36,4 +39,6 @@ python mistral.py --prompt "It is a truth universally acknowledged," --temp 0
|
||||
|
||||
Run `python mistral.py --help` for more details.
|
||||
|
||||
[^1]: Refer to the [blog post](https://mistral.ai/news/announcing-mistral-7b/) and [github repository](https://github.com/mistralai/mistral-src) for more details.
|
||||
[^1]: Refer to the [blog post](https://mistral.ai/news/announcing-mistral-7b/)
|
||||
and [github repository](https://github.com/mistralai/mistral-src) for more
|
||||
details.
|
||||
|
@ -2,26 +2,23 @@
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import torch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert Mistral weights to MLX.")
|
||||
parser.add_argument(
|
||||
"--torch_model",
|
||||
"--model_path",
|
||||
type=str,
|
||||
default="mistral-7B-v0.1/consolidated.00.pth",
|
||||
help="The path to the torch model weights",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mlx_model",
|
||||
type=str,
|
||||
default="mistral-7B-v0.1/mlx_mistral_7b.npz",
|
||||
help="The path to store the mlx model weights",
|
||||
default="mistral-7B-v0.1/",
|
||||
help="The path to the Mistral model. The MLX weights will also be saved there.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
state = torch.load(args.torch_model)
|
||||
model_path = Path(args.model_path)
|
||||
state = torch.load(str(model_path / "consolidated.00.pth"))
|
||||
np.savez(
|
||||
args.mlx_model, **{k: v.to(torch.float16).numpy() for k, v in state.items()}
|
||||
str(model_path / "weights.npz"),
|
||||
**{k: v.to(torch.float16).numpy() for k, v in state.items()}
|
||||
)
|
||||
|
@ -196,7 +196,7 @@ def load_model(folder: str, dtype=mx.float16):
|
||||
config = json.loads(f.read())
|
||||
config.pop("sliding_window")
|
||||
model_args = ModelArgs(**config)
|
||||
weights = mx.load(str(model_path / "mlx_mistral_7b.npz"))
|
||||
weights = mx.load(str(model_path / "weights.npz"))
|
||||
weights = tree_unflatten(list(weights.items()))
|
||||
weights = tree_map(lambda p: p.astype(dtype), weights)
|
||||
model = Mistral(model_args)
|
||||
|
@ -33,12 +33,12 @@ Now from `mlx-exmaples/mixtral` conver the weights to NumPy so MLX can read them
|
||||
python convert.py --model_path mixtral-8x7b-32kseqlen/
|
||||
```
|
||||
|
||||
The conversion script will save the new weights in the same location.
|
||||
The conversion script will save the converted weights in the same location.
|
||||
|
||||
After that's done, if you want to clean some stuff up:
|
||||
|
||||
```
|
||||
rm mixtral-8x7b-32kseqlen/*.pth
|
||||
rm mixtral-8x7b-32kseqlen/*.pth*
|
||||
```
|
||||
|
||||
### Generate
|
||||
|
@ -119,7 +119,6 @@ class MOEFeedForward(nn.Module):
|
||||
self.gate = nn.Linear(args.dim, self.num_experts, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
|
||||
ne = self.num_experts_per_tok
|
||||
orig_shape = x.shape
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
@ -128,23 +127,12 @@ class MOEFeedForward(nn.Module):
|
||||
inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]
|
||||
scores = mx.softmax(mx.take_along_axis(gates, inds, axis=-1), axis=-1)
|
||||
|
||||
# For batch:
|
||||
if x.shape[0] > 1:
|
||||
mx.eval(inds)
|
||||
inds = np.array(inds)
|
||||
y = mx.zeros((x.shape[0], ne, x.shape[-1]))
|
||||
for e, expert in enumerate(self.experts):
|
||||
idx1, idx2 = map(mx.array, np.where(inds == e))
|
||||
if idx1.size == 0:
|
||||
continue
|
||||
y[idx1, idx2] = expert(x[idx1])
|
||||
y = (y * scores[:, :, None]).sum(axis=1)
|
||||
|
||||
# For single:
|
||||
else:
|
||||
ys = [self.experts[e](x)[:, :, None] for e in inds.squeeze().tolist()]
|
||||
y = mx.concatenate(ys, axis=-1)
|
||||
y = (y * scores[:, None, 0]).sum(axis=-1)
|
||||
y = []
|
||||
for xt, st, it in zip(x, scores, inds.tolist()):
|
||||
yt = mx.concatenate([self.experts[e](xt)[:, None] for e in it], axis=-1)
|
||||
yt = (yt * st).sum(axis=-1)
|
||||
y.append(yt[None, :])
|
||||
y = mx.concatenate(y)
|
||||
|
||||
return y.reshape(orig_shape)
|
||||
|
||||
@ -300,18 +288,9 @@ if __name__ == "__main__":
|
||||
|
||||
print(args.prompt, end="", flush=True)
|
||||
prompt = mx.array(tokenizer.encode(args.prompt))
|
||||
mx.eval(prompt)
|
||||
tokens = []
|
||||
import time
|
||||
tic = time.time()
|
||||
p = True
|
||||
for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
|
||||
tokens.append(token)
|
||||
if p:
|
||||
mx.eval(tokens)
|
||||
p = False
|
||||
prompt_time = time.time() - tic
|
||||
tic = time.time()
|
||||
|
||||
if (len(tokens) % 10) == 0:
|
||||
mx.eval(tokens)
|
||||
@ -320,9 +299,5 @@ if __name__ == "__main__":
|
||||
tokens = []
|
||||
|
||||
mx.eval(tokens)
|
||||
tok_time = time.time() - tic
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
print(s, flush=True)
|
||||
print("------")
|
||||
print(f"Prompt time {prompt_time}")
|
||||
print(f"Token time {tok_time}")
|
||||
|
Loading…
Reference in New Issue
Block a user