diff --git a/mistral/README.md b/mistral/README.md index eaf07ea5..1bbb385d 100644 --- a/mistral/README.md +++ b/mistral/README.md @@ -12,7 +12,7 @@ Install the dependencies: pip install -r requirements.txt ``` -Next, download the model and tokenizer. +Next, download the model and tokenizer: ``` curl -O https://files.mistral-7b-v0-1.mistral.ai/mistral-7B-v0.1.tar @@ -22,16 +22,16 @@ tar -xf mistral-7B-v0.1.tar Then, convert the weights with: ``` -python convert.py mlx_mistral_weights.npz +python convert.py ``` ### Run -Once you've converted the weights to MLX format, you can interact with the -Mistral model: +Once you've converted the weights to MLX format, you can generate text with +the Mistral model: ``` -python mistral.py mlx_mistral.npz tokenizer.model "hello" +python mistral.py --prompt "It is a truth universally acknowledged," --temp 0 ``` Run `python mistral.py --help` for more details. diff --git a/mistral/mistral.py b/mistral/mistral.py index 57a35585..e846ebfd 100644 --- a/mistral/mistral.py +++ b/mistral/mistral.py @@ -253,6 +253,12 @@ if __name__ == "__main__": default=100, help="Maximum number of tokens to generate", ) + parser.add_argument( + "--temp", + help="The sampling temperature.", + type=float, + default=1.0, + ) parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") args = parser.parse_args() @@ -266,7 +272,7 @@ if __name__ == "__main__": print(args.prompt, end="", flush=True) prompt = mx.array(tokenizer.encode(args.prompt)) tokens = [] - for token, _ in zip(generate(prompt, model), range(args.max_tokens)): + for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)): tokens.append(token) if (len(tokens) % 10) == 0: