This commit is contained in:
Awni Hannun 2023-12-05 11:24:30 -08:00
parent b7840a4721
commit 234a5f5cfe
2 changed files with 12 additions and 6 deletions

View File

@ -12,7 +12,7 @@ Install the dependencies:
pip install -r requirements.txt pip install -r requirements.txt
``` ```
Next, download the model and tokenizer. Next, download the model and tokenizer:
``` ```
curl -O https://files.mistral-7b-v0-1.mistral.ai/mistral-7B-v0.1.tar curl -O https://files.mistral-7b-v0-1.mistral.ai/mistral-7B-v0.1.tar
@ -22,16 +22,16 @@ tar -xf mistral-7B-v0.1.tar
Then, convert the weights with: Then, convert the weights with:
``` ```
python convert.py <path_to_torch_weights> mlx_mistral_weights.npz python convert.py
``` ```
### Run ### Run
Once you've converted the weights to MLX format, you can interact with the Once you've converted the weights to MLX format, you can generate text with
Mistral model: the Mistral model:
``` ```
python mistral.py mlx_mistral.npz tokenizer.model "hello" python mistral.py --prompt "It is a truth universally acknowledged," --temp 0
``` ```
Run `python mistral.py --help` for more details. Run `python mistral.py --help` for more details.

View File

@ -253,6 +253,12 @@ if __name__ == "__main__":
default=100, default=100,
help="Maximum number of tokens to generate", help="Maximum number of tokens to generate",
) )
parser.add_argument(
"--temp",
help="The sampling temperature.",
type=float,
default=1.0,
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
args = parser.parse_args() args = parser.parse_args()
@ -266,7 +272,7 @@ if __name__ == "__main__":
print(args.prompt, end="", flush=True) print(args.prompt, end="", flush=True)
prompt = mx.array(tokenizer.encode(args.prompt)) prompt = mx.array(tokenizer.encode(args.prompt))
tokens = [] tokens = []
for token, _ in zip(generate(prompt, model), range(args.max_tokens)): for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
tokens.append(token) tokens.append(token)
if (len(tokens) % 10) == 0: if (len(tokens) % 10) == 0: