mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
mixtral runs a bit faster
This commit is contained in:
parent
e42682dced
commit
2ffd0da009
@ -2,7 +2,8 @@
|
|||||||
|
|
||||||
An example of generating text with Mistral using MLX.
|
An example of generating text with Mistral using MLX.
|
||||||
|
|
||||||
Mistral 7B is one of the top large language models in its size class. It is also fully open source with a permissive license[^1].
|
Mistral 7B is one of the top large language models in its size class. It is
|
||||||
|
also fully open source with a permissive license[^1].
|
||||||
|
|
||||||
### Setup
|
### Setup
|
||||||
|
|
||||||
@ -25,6 +26,8 @@ Then, convert the weights with:
|
|||||||
python convert.py
|
python convert.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
The conversion script will save the converted weights in the same location.
|
||||||
|
|
||||||
### Run
|
### Run
|
||||||
|
|
||||||
Once you've converted the weights to MLX format, you can generate text with
|
Once you've converted the weights to MLX format, you can generate text with
|
||||||
@ -36,4 +39,6 @@ python mistral.py --prompt "It is a truth universally acknowledged," --temp 0
|
|||||||
|
|
||||||
Run `python mistral.py --help` for more details.
|
Run `python mistral.py --help` for more details.
|
||||||
|
|
||||||
[^1]: Refer to the [blog post](https://mistral.ai/news/announcing-mistral-7b/) and [github repository](https://github.com/mistralai/mistral-src) for more details.
|
[^1]: Refer to the [blog post](https://mistral.ai/news/announcing-mistral-7b/)
|
||||||
|
and [github repository](https://github.com/mistralai/mistral-src) for more
|
||||||
|
details.
|
||||||
|
@ -2,26 +2,23 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Convert Mistral weights to MLX.")
|
parser = argparse.ArgumentParser(description="Convert Mistral weights to MLX.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--torch_model",
|
"--model_path",
|
||||||
type=str,
|
type=str,
|
||||||
default="mistral-7B-v0.1/consolidated.00.pth",
|
default="mistral-7B-v0.1/",
|
||||||
help="The path to the torch model weights",
|
help="The path to the Mistral model. The MLX weights will also be saved there.",
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--mlx_model",
|
|
||||||
type=str,
|
|
||||||
default="mistral-7B-v0.1/mlx_mistral_7b.npz",
|
|
||||||
help="The path to store the mlx model weights",
|
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
state = torch.load(args.torch_model)
|
model_path = Path(args.model_path)
|
||||||
|
state = torch.load(str(model_path / "consolidated.00.pth"))
|
||||||
np.savez(
|
np.savez(
|
||||||
args.mlx_model, **{k: v.to(torch.float16).numpy() for k, v in state.items()}
|
str(model_path / "weights.npz"),
|
||||||
|
**{k: v.to(torch.float16).numpy() for k, v in state.items()}
|
||||||
)
|
)
|
||||||
|
@ -196,7 +196,7 @@ def load_model(folder: str, dtype=mx.float16):
|
|||||||
config = json.loads(f.read())
|
config = json.loads(f.read())
|
||||||
config.pop("sliding_window")
|
config.pop("sliding_window")
|
||||||
model_args = ModelArgs(**config)
|
model_args = ModelArgs(**config)
|
||||||
weights = mx.load(str(model_path / "mlx_mistral_7b.npz"))
|
weights = mx.load(str(model_path / "weights.npz"))
|
||||||
weights = tree_unflatten(list(weights.items()))
|
weights = tree_unflatten(list(weights.items()))
|
||||||
weights = tree_map(lambda p: p.astype(dtype), weights)
|
weights = tree_map(lambda p: p.astype(dtype), weights)
|
||||||
model = Mistral(model_args)
|
model = Mistral(model_args)
|
||||||
|
@ -33,12 +33,12 @@ Now from `mlx-exmaples/mixtral` conver the weights to NumPy so MLX can read them
|
|||||||
python convert.py --model_path mixtral-8x7b-32kseqlen/
|
python convert.py --model_path mixtral-8x7b-32kseqlen/
|
||||||
```
|
```
|
||||||
|
|
||||||
The conversion script will save the new weights in the same location.
|
The conversion script will save the converted weights in the same location.
|
||||||
|
|
||||||
After that's done, if you want to clean some stuff up:
|
After that's done, if you want to clean some stuff up:
|
||||||
|
|
||||||
```
|
```
|
||||||
rm mixtral-8x7b-32kseqlen/*.pth
|
rm mixtral-8x7b-32kseqlen/*.pth*
|
||||||
```
|
```
|
||||||
|
|
||||||
### Generate
|
### Generate
|
||||||
|
@ -119,7 +119,6 @@ class MOEFeedForward(nn.Module):
|
|||||||
self.gate = nn.Linear(args.dim, self.num_experts, bias=False)
|
self.gate = nn.Linear(args.dim, self.num_experts, bias=False)
|
||||||
|
|
||||||
def __call__(self, x) -> mx.array:
|
def __call__(self, x) -> mx.array:
|
||||||
|
|
||||||
ne = self.num_experts_per_tok
|
ne = self.num_experts_per_tok
|
||||||
orig_shape = x.shape
|
orig_shape = x.shape
|
||||||
x = x.reshape(-1, x.shape[-1])
|
x = x.reshape(-1, x.shape[-1])
|
||||||
@ -128,23 +127,12 @@ class MOEFeedForward(nn.Module):
|
|||||||
inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]
|
inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]
|
||||||
scores = mx.softmax(mx.take_along_axis(gates, inds, axis=-1), axis=-1)
|
scores = mx.softmax(mx.take_along_axis(gates, inds, axis=-1), axis=-1)
|
||||||
|
|
||||||
# For batch:
|
y = []
|
||||||
if x.shape[0] > 1:
|
for xt, st, it in zip(x, scores, inds.tolist()):
|
||||||
mx.eval(inds)
|
yt = mx.concatenate([self.experts[e](xt)[:, None] for e in it], axis=-1)
|
||||||
inds = np.array(inds)
|
yt = (yt * st).sum(axis=-1)
|
||||||
y = mx.zeros((x.shape[0], ne, x.shape[-1]))
|
y.append(yt[None, :])
|
||||||
for e, expert in enumerate(self.experts):
|
y = mx.concatenate(y)
|
||||||
idx1, idx2 = map(mx.array, np.where(inds == e))
|
|
||||||
if idx1.size == 0:
|
|
||||||
continue
|
|
||||||
y[idx1, idx2] = expert(x[idx1])
|
|
||||||
y = (y * scores[:, :, None]).sum(axis=1)
|
|
||||||
|
|
||||||
# For single:
|
|
||||||
else:
|
|
||||||
ys = [self.experts[e](x)[:, :, None] for e in inds.squeeze().tolist()]
|
|
||||||
y = mx.concatenate(ys, axis=-1)
|
|
||||||
y = (y * scores[:, None, 0]).sum(axis=-1)
|
|
||||||
|
|
||||||
return y.reshape(orig_shape)
|
return y.reshape(orig_shape)
|
||||||
|
|
||||||
@ -300,18 +288,9 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
print(args.prompt, end="", flush=True)
|
print(args.prompt, end="", flush=True)
|
||||||
prompt = mx.array(tokenizer.encode(args.prompt))
|
prompt = mx.array(tokenizer.encode(args.prompt))
|
||||||
mx.eval(prompt)
|
|
||||||
tokens = []
|
tokens = []
|
||||||
import time
|
|
||||||
tic = time.time()
|
|
||||||
p = True
|
|
||||||
for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
|
for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
|
||||||
tokens.append(token)
|
tokens.append(token)
|
||||||
if p:
|
|
||||||
mx.eval(tokens)
|
|
||||||
p = False
|
|
||||||
prompt_time = time.time() - tic
|
|
||||||
tic = time.time()
|
|
||||||
|
|
||||||
if (len(tokens) % 10) == 0:
|
if (len(tokens) % 10) == 0:
|
||||||
mx.eval(tokens)
|
mx.eval(tokens)
|
||||||
@ -320,9 +299,5 @@ if __name__ == "__main__":
|
|||||||
tokens = []
|
tokens = []
|
||||||
|
|
||||||
mx.eval(tokens)
|
mx.eval(tokens)
|
||||||
tok_time = time.time() - tic
|
|
||||||
s = tokenizer.decode([t.item() for t in tokens])
|
s = tokenizer.decode([t.item() for t in tokens])
|
||||||
print(s, flush=True)
|
print(s, flush=True)
|
||||||
print("------")
|
|
||||||
print(f"Prompt time {prompt_time}")
|
|
||||||
print(f"Token time {tok_time}")
|
|
||||||
|
Loading…
Reference in New Issue
Block a user