mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
nits
This commit is contained in:
parent
b7840a4721
commit
234a5f5cfe
@ -12,7 +12,7 @@ Install the dependencies:
|
|||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
Next, download the model and tokenizer.
|
Next, download the model and tokenizer:
|
||||||
|
|
||||||
```
|
```
|
||||||
curl -O https://files.mistral-7b-v0-1.mistral.ai/mistral-7B-v0.1.tar
|
curl -O https://files.mistral-7b-v0-1.mistral.ai/mistral-7B-v0.1.tar
|
||||||
@ -22,16 +22,16 @@ tar -xf mistral-7B-v0.1.tar
|
|||||||
Then, convert the weights with:
|
Then, convert the weights with:
|
||||||
|
|
||||||
```
|
```
|
||||||
python convert.py <path_to_torch_weights> mlx_mistral_weights.npz
|
python convert.py
|
||||||
```
|
```
|
||||||
|
|
||||||
### Run
|
### Run
|
||||||
|
|
||||||
Once you've converted the weights to MLX format, you can interact with the
|
Once you've converted the weights to MLX format, you can generate text with
|
||||||
Mistral model:
|
the Mistral model:
|
||||||
|
|
||||||
```
|
```
|
||||||
python mistral.py mlx_mistral.npz tokenizer.model "hello"
|
python mistral.py --prompt "It is a truth universally acknowledged," --temp 0
|
||||||
```
|
```
|
||||||
|
|
||||||
Run `python mistral.py --help` for more details.
|
Run `python mistral.py --help` for more details.
|
||||||
|
@ -253,6 +253,12 @@ if __name__ == "__main__":
|
|||||||
default=100,
|
default=100,
|
||||||
help="Maximum number of tokens to generate",
|
help="Maximum number of tokens to generate",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--temp",
|
||||||
|
help="The sampling temperature.",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
)
|
||||||
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -266,7 +272,7 @@ if __name__ == "__main__":
|
|||||||
print(args.prompt, end="", flush=True)
|
print(args.prompt, end="", flush=True)
|
||||||
prompt = mx.array(tokenizer.encode(args.prompt))
|
prompt = mx.array(tokenizer.encode(args.prompt))
|
||||||
tokens = []
|
tokens = []
|
||||||
for token, _ in zip(generate(prompt, model), range(args.max_tokens)):
|
for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
|
||||||
tokens.append(token)
|
tokens.append(token)
|
||||||
|
|
||||||
if (len(tokens) % 10) == 0:
|
if (len(tokens) % 10) == 0:
|
||||||
|
Loading…
Reference in New Issue
Block a user