mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 21:01:32 +08:00
mixtral runs a bit faster
This commit is contained in:
@@ -2,7 +2,8 @@
|
||||
|
||||
An example of generating text with Mistral using MLX.
|
||||
|
||||
Mistral 7B is one of the top large language models in its size class. It is also fully open source with a permissive license[^1].
|
||||
Mistral 7B is one of the top large language models in its size class. It is
|
||||
also fully open source with a permissive license[^1].
|
||||
|
||||
### Setup
|
||||
|
||||
@@ -25,6 +26,8 @@ Then, convert the weights with:
|
||||
python convert.py
|
||||
```
|
||||
|
||||
The conversion script will save the converted weights in the same location.
|
||||
|
||||
### Run
|
||||
|
||||
Once you've converted the weights to MLX format, you can generate text with
|
||||
@@ -36,4 +39,6 @@ python mistral.py --prompt "It is a truth universally acknowledged," --temp 0
|
||||
|
||||
Run `python mistral.py --help` for more details.
|
||||
|
||||
[^1]: Refer to the [blog post](https://mistral.ai/news/announcing-mistral-7b/) and [github repository](https://github.com/mistralai/mistral-src) for more details.
|
||||
[^1]: Refer to the [blog post](https://mistral.ai/news/announcing-mistral-7b/)
|
||||
and [github repository](https://github.com/mistralai/mistral-src) for more
|
||||
details.
|
||||
|
@@ -2,26 +2,23 @@
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import torch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert Mistral weights to MLX.")
|
||||
parser.add_argument(
|
||||
"--torch_model",
|
||||
"--model_path",
|
||||
type=str,
|
||||
default="mistral-7B-v0.1/consolidated.00.pth",
|
||||
help="The path to the torch model weights",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mlx_model",
|
||||
type=str,
|
||||
default="mistral-7B-v0.1/mlx_mistral_7b.npz",
|
||||
help="The path to store the mlx model weights",
|
||||
default="mistral-7B-v0.1/",
|
||||
help="The path to the Mistral model. The MLX weights will also be saved there.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
state = torch.load(args.torch_model)
|
||||
model_path = Path(args.model_path)
|
||||
state = torch.load(str(model_path / "consolidated.00.pth"))
|
||||
np.savez(
|
||||
args.mlx_model, **{k: v.to(torch.float16).numpy() for k, v in state.items()}
|
||||
str(model_path / "weights.npz"),
|
||||
**{k: v.to(torch.float16).numpy() for k, v in state.items()}
|
||||
)
|
||||
|
@@ -196,7 +196,7 @@ def load_model(folder: str, dtype=mx.float16):
|
||||
config = json.loads(f.read())
|
||||
config.pop("sliding_window")
|
||||
model_args = ModelArgs(**config)
|
||||
weights = mx.load(str(model_path / "mlx_mistral_7b.npz"))
|
||||
weights = mx.load(str(model_path / "weights.npz"))
|
||||
weights = tree_unflatten(list(weights.items()))
|
||||
weights = tree_map(lambda p: p.astype(dtype), weights)
|
||||
model = Mistral(model_args)
|
||||
|
Reference in New Issue
Block a user