mixtral runs a bit faster

This commit is contained in:
Awni Hannun 2023-12-12 08:36:40 -08:00
parent e42682dced
commit 2ffd0da009
5 changed files with 24 additions and 47 deletions

View File

@ -2,7 +2,8 @@
An example of generating text with Mistral using MLX. An example of generating text with Mistral using MLX.
Mistral 7B is one of the top large language models in its size class. It is also fully open source with a permissive license[^1]. Mistral 7B is one of the top large language models in its size class. It is
also fully open source with a permissive license[^1].
### Setup ### Setup
@ -25,6 +26,8 @@ Then, convert the weights with:
python convert.py python convert.py
``` ```
The conversion script will save the converted weights in the same location.
### Run ### Run
Once you've converted the weights to MLX format, you can generate text with Once you've converted the weights to MLX format, you can generate text with
@ -36,4 +39,6 @@ python mistral.py --prompt "It is a truth universally acknowledged," --temp 0
Run `python mistral.py --help` for more details. Run `python mistral.py --help` for more details.
[^1]: Refer to the [blog post](https://mistral.ai/news/announcing-mistral-7b/) and [github repository](https://github.com/mistralai/mistral-src) for more details. [^1]: Refer to the [blog post](https://mistral.ai/news/announcing-mistral-7b/)
and [github repository](https://github.com/mistralai/mistral-src) for more
details.

View File

@ -2,26 +2,23 @@
import argparse import argparse
import numpy as np import numpy as np
from pathlib import Path
import torch import torch
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Mistral weights to MLX.") parser = argparse.ArgumentParser(description="Convert Mistral weights to MLX.")
parser.add_argument( parser.add_argument(
"--torch_model", "--model_path",
type=str, type=str,
default="mistral-7B-v0.1/consolidated.00.pth", default="mistral-7B-v0.1/",
help="The path to the torch model weights", help="The path to the Mistral model. The MLX weights will also be saved there.",
)
parser.add_argument(
"--mlx_model",
type=str,
default="mistral-7B-v0.1/mlx_mistral_7b.npz",
help="The path to store the mlx model weights",
) )
args = parser.parse_args() args = parser.parse_args()
state = torch.load(args.torch_model) model_path = Path(args.model_path)
state = torch.load(str(model_path / "consolidated.00.pth"))
np.savez( np.savez(
args.mlx_model, **{k: v.to(torch.float16).numpy() for k, v in state.items()} str(model_path / "weights.npz"),
**{k: v.to(torch.float16).numpy() for k, v in state.items()}
) )

View File

@ -196,7 +196,7 @@ def load_model(folder: str, dtype=mx.float16):
config = json.loads(f.read()) config = json.loads(f.read())
config.pop("sliding_window") config.pop("sliding_window")
model_args = ModelArgs(**config) model_args = ModelArgs(**config)
weights = mx.load(str(model_path / "mlx_mistral_7b.npz")) weights = mx.load(str(model_path / "weights.npz"))
weights = tree_unflatten(list(weights.items())) weights = tree_unflatten(list(weights.items()))
weights = tree_map(lambda p: p.astype(dtype), weights) weights = tree_map(lambda p: p.astype(dtype), weights)
model = Mistral(model_args) model = Mistral(model_args)

View File

@ -33,12 +33,12 @@ Now from `mlx-exmaples/mixtral` conver the weights to NumPy so MLX can read them
python convert.py --model_path mixtral-8x7b-32kseqlen/ python convert.py --model_path mixtral-8x7b-32kseqlen/
``` ```
The conversion script will save the new weights in the same location. The conversion script will save the converted weights in the same location.
After that's done, if you want to clean some stuff up: After that's done, if you want to clean some stuff up:
``` ```
rm mixtral-8x7b-32kseqlen/*.pth rm mixtral-8x7b-32kseqlen/*.pth*
``` ```
### Generate ### Generate

View File

@ -119,7 +119,6 @@ class MOEFeedForward(nn.Module):
self.gate = nn.Linear(args.dim, self.num_experts, bias=False) self.gate = nn.Linear(args.dim, self.num_experts, bias=False)
def __call__(self, x) -> mx.array: def __call__(self, x) -> mx.array:
ne = self.num_experts_per_tok ne = self.num_experts_per_tok
orig_shape = x.shape orig_shape = x.shape
x = x.reshape(-1, x.shape[-1]) x = x.reshape(-1, x.shape[-1])
@ -128,23 +127,12 @@ class MOEFeedForward(nn.Module):
inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne] inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]
scores = mx.softmax(mx.take_along_axis(gates, inds, axis=-1), axis=-1) scores = mx.softmax(mx.take_along_axis(gates, inds, axis=-1), axis=-1)
# For batch: y = []
if x.shape[0] > 1: for xt, st, it in zip(x, scores, inds.tolist()):
mx.eval(inds) yt = mx.concatenate([self.experts[e](xt)[:, None] for e in it], axis=-1)
inds = np.array(inds) yt = (yt * st).sum(axis=-1)
y = mx.zeros((x.shape[0], ne, x.shape[-1])) y.append(yt[None, :])
for e, expert in enumerate(self.experts): y = mx.concatenate(y)
idx1, idx2 = map(mx.array, np.where(inds == e))
if idx1.size == 0:
continue
y[idx1, idx2] = expert(x[idx1])
y = (y * scores[:, :, None]).sum(axis=1)
# For single:
else:
ys = [self.experts[e](x)[:, :, None] for e in inds.squeeze().tolist()]
y = mx.concatenate(ys, axis=-1)
y = (y * scores[:, None, 0]).sum(axis=-1)
return y.reshape(orig_shape) return y.reshape(orig_shape)
@ -300,18 +288,9 @@ if __name__ == "__main__":
print(args.prompt, end="", flush=True) print(args.prompt, end="", flush=True)
prompt = mx.array(tokenizer.encode(args.prompt)) prompt = mx.array(tokenizer.encode(args.prompt))
mx.eval(prompt)
tokens = [] tokens = []
import time
tic = time.time()
p = True
for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)): for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
tokens.append(token) tokens.append(token)
if p:
mx.eval(tokens)
p = False
prompt_time = time.time() - tic
tic = time.time()
if (len(tokens) % 10) == 0: if (len(tokens) % 10) == 0:
mx.eval(tokens) mx.eval(tokens)
@ -320,9 +299,5 @@ if __name__ == "__main__":
tokens = [] tokens = []
mx.eval(tokens) mx.eval(tokens)
tok_time = time.time() - tic
s = tokenizer.decode([t.item() for t in tokens]) s = tokenizer.decode([t.item() for t in tokens])
print(s, flush=True) print(s, flush=True)
print("------")
print(f"Prompt time {prompt_time}")
print(f"Token time {tok_time}")