mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 20:04:38 +08:00
Align CLI args and minor fixes
This commit is contained in:

committed by
Awni Hannun

parent
bdc1f4d1f6
commit
d05fa79284
@@ -60,7 +60,7 @@ def llama(model_path):
|
||||
def tiny_llama(model_path):
|
||||
try:
|
||||
import transformers
|
||||
except ImportError as e:
|
||||
except ImportError:
|
||||
print("The transformers package must be installed for this model conversion:")
|
||||
print("pip install transformers")
|
||||
exit(0)
|
||||
@@ -154,7 +154,7 @@ if __name__ == "__main__":
|
||||
help="Path to save the MLX model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
"--model_name",
|
||||
help=(
|
||||
"Name of the model to convert. Use 'llama' for models in the "
|
||||
"Llama family distributed by Meta including Llama 1, Llama 2, "
|
||||
|
@@ -5,7 +5,7 @@ import json
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -218,7 +218,6 @@ def toc(msg, start):
|
||||
|
||||
|
||||
def generate(args):
|
||||
|
||||
input("Press enter to start generation")
|
||||
print("------")
|
||||
print(args.prompt)
|
||||
@@ -347,7 +346,7 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Llama inference script")
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
help="Path to the model directory containing the MLX weights",
|
||||
help="Path to the model directory containing the MLX weights and tokenizer",
|
||||
default="mlx_model",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -356,14 +355,14 @@ if __name__ == "__main__":
|
||||
default="In the beginning the Universe was created.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--few-shot",
|
||||
"--few_shot",
|
||||
help="Read a few shot prompt from a file (as in `sample_prompt.txt`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens", "-m", type=int, default=100, help="How many tokens to generate"
|
||||
"--max_tokens", "-m", type=int, default=100, help="How many tokens to generate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--write-every", type=int, default=1, help="After how many tokens to detokenize"
|
||||
"--write_every", type=int, default=1, help="After how many tokens to detokenize"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temp", type=float, default=0.0, help="The sampling temperature"
|
||||
|
@@ -9,7 +9,6 @@ from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from mlx.utils import tree_map, tree_unflatten
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
|
||||
|
@@ -176,7 +176,7 @@ def load_model(model_path: str):
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Phi-2 inference script")
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
"--model_path",
|
||||
type=str,
|
||||
default="mlx_model",
|
||||
help="The path to the model weights",
|
||||
@@ -187,7 +187,7 @@ if __name__ == "__main__":
|
||||
default="Write a detailed analogy between mathematics and a lighthouse.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
"--max_tokens",
|
||||
"-m",
|
||||
type=int,
|
||||
default=100,
|
||||
|
Reference in New Issue
Block a user