diff --git a/mistral/README.md b/mistral/README.md index 1bbb385d..c2406b6d 100644 --- a/mistral/README.md +++ b/mistral/README.md @@ -2,7 +2,8 @@ An example of generating text with Mistral using MLX. -Mistral 7B is one of the top large language models in its size class. It is also fully open source with a permissive license[^1]. +Mistral 7B is one of the top large language models in its size class. It is +also fully open source with a permissive license[^1]. ### Setup @@ -25,6 +26,8 @@ Then, convert the weights with: python convert.py ``` +The conversion script will save the converted weights in the same location. + ### Run Once you've converted the weights to MLX format, you can generate text with @@ -36,4 +39,6 @@ python mistral.py --prompt "It is a truth universally acknowledged," --temp 0 Run `python mistral.py --help` for more details. -[^1]: Refer to the [blog post](https://mistral.ai/news/announcing-mistral-7b/) and [github repository](https://github.com/mistralai/mistral-src) for more details. +[^1]: Refer to the [blog post](https://mistral.ai/news/announcing-mistral-7b/) +and [github repository](https://github.com/mistralai/mistral-src) for more +details. diff --git a/mistral/convert.py b/mistral/convert.py index e170aed7..0efaf489 100644 --- a/mistral/convert.py +++ b/mistral/convert.py @@ -2,26 +2,23 @@ import argparse import numpy as np +from pathlib import Path import torch if __name__ == "__main__": parser = argparse.ArgumentParser(description="Convert Mistral weights to MLX.") parser.add_argument( - "--torch_model", + "--model_path", type=str, - default="mistral-7B-v0.1/consolidated.00.pth", - help="The path to the torch model weights", - ) - parser.add_argument( - "--mlx_model", - type=str, - default="mistral-7B-v0.1/mlx_mistral_7b.npz", - help="The path to store the mlx model weights", + default="mistral-7B-v0.1/", + help="The path to the Mistral model. The MLX weights will also be saved there.", ) args = parser.parse_args() - state = torch.load(args.torch_model) + model_path = Path(args.model_path) + state = torch.load(str(model_path / "consolidated.00.pth")) np.savez( - args.mlx_model, **{k: v.to(torch.float16).numpy() for k, v in state.items()} + str(model_path / "weights.npz"), + **{k: v.to(torch.float16).numpy() for k, v in state.items()} ) diff --git a/mistral/mistral.py b/mistral/mistral.py index 767b5936..0c3976c1 100644 --- a/mistral/mistral.py +++ b/mistral/mistral.py @@ -196,7 +196,7 @@ def load_model(folder: str, dtype=mx.float16): config = json.loads(f.read()) config.pop("sliding_window") model_args = ModelArgs(**config) - weights = mx.load(str(model_path / "mlx_mistral_7b.npz")) + weights = mx.load(str(model_path / "weights.npz")) weights = tree_unflatten(list(weights.items())) weights = tree_map(lambda p: p.astype(dtype), weights) model = Mistral(model_args) diff --git a/mixtral/README.md b/mixtral/README.md index 6891f7b9..a3af2b66 100644 --- a/mixtral/README.md +++ b/mixtral/README.md @@ -33,12 +33,12 @@ Now from `mlx-exmaples/mixtral` conver the weights to NumPy so MLX can read them python convert.py --model_path mixtral-8x7b-32kseqlen/ ``` -The conversion script will save the new weights in the same location. +The conversion script will save the converted weights in the same location. After that's done, if you want to clean some stuff up: ``` -rm mixtral-8x7b-32kseqlen/*.pth +rm mixtral-8x7b-32kseqlen/*.pth* ``` ### Generate diff --git a/mixtral/mixtral.py b/mixtral/mixtral.py index a96d5f7d..1a9be600 100644 --- a/mixtral/mixtral.py +++ b/mixtral/mixtral.py @@ -119,7 +119,6 @@ class MOEFeedForward(nn.Module): self.gate = nn.Linear(args.dim, self.num_experts, bias=False) def __call__(self, x) -> mx.array: - ne = self.num_experts_per_tok orig_shape = x.shape x = x.reshape(-1, x.shape[-1]) @@ -128,23 +127,12 @@ class MOEFeedForward(nn.Module): inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne] scores = mx.softmax(mx.take_along_axis(gates, inds, axis=-1), axis=-1) - # For batch: - if x.shape[0] > 1: - mx.eval(inds) - inds = np.array(inds) - y = mx.zeros((x.shape[0], ne, x.shape[-1])) - for e, expert in enumerate(self.experts): - idx1, idx2 = map(mx.array, np.where(inds == e)) - if idx1.size == 0: - continue - y[idx1, idx2] = expert(x[idx1]) - y = (y * scores[:, :, None]).sum(axis=1) - - # For single: - else: - ys = [self.experts[e](x)[:, :, None] for e in inds.squeeze().tolist()] - y = mx.concatenate(ys, axis=-1) - y = (y * scores[:, None, 0]).sum(axis=-1) + y = [] + for xt, st, it in zip(x, scores, inds.tolist()): + yt = mx.concatenate([self.experts[e](xt)[:, None] for e in it], axis=-1) + yt = (yt * st).sum(axis=-1) + y.append(yt[None, :]) + y = mx.concatenate(y) return y.reshape(orig_shape) @@ -300,18 +288,9 @@ if __name__ == "__main__": print(args.prompt, end="", flush=True) prompt = mx.array(tokenizer.encode(args.prompt)) - mx.eval(prompt) tokens = [] - import time - tic = time.time() - p = True for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)): tokens.append(token) - if p: - mx.eval(tokens) - p = False - prompt_time = time.time() - tic - tic = time.time() if (len(tokens) % 10) == 0: mx.eval(tokens) @@ -320,9 +299,5 @@ if __name__ == "__main__": tokens = [] mx.eval(tokens) - tok_time = time.time() - tic s = tokenizer.decode([t.item() for t in tokens]) print(s, flush=True) - print("------") - print(f"Prompt time {prompt_time}") - print(f"Token time {tok_time}")