Align CLI args and minor fixes

This commit is contained in:
Alvaro Bartolome
2023-12-21 09:49:45 +01:00
committed by Awni Hannun
parent bdc1f4d1f6
commit d05fa79284
4 changed files with 9 additions and 11 deletions

View File

@@ -60,7 +60,7 @@ def llama(model_path):
def tiny_llama(model_path):
try:
import transformers
except ImportError as e:
except ImportError:
print("The transformers package must be installed for this model conversion:")
print("pip install transformers")
exit(0)
@@ -154,7 +154,7 @@ if __name__ == "__main__":
help="Path to save the MLX model.",
)
parser.add_argument(
"--model-name",
"--model_name",
help=(
"Name of the model to convert. Use 'llama' for models in the "
"Llama family distributed by Meta including Llama 1, Llama 2, "

View File

@@ -5,7 +5,7 @@ import json
import time
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple
from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
@@ -218,7 +218,6 @@ def toc(msg, start):
def generate(args):
input("Press enter to start generation")
print("------")
print(args.prompt)
@@ -347,7 +346,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Llama inference script")
parser.add_argument(
"--model-path",
help="Path to the model directory containing the MLX weights",
help="Path to the model directory containing the MLX weights and tokenizer",
default="mlx_model",
)
parser.add_argument(
@@ -356,14 +355,14 @@ if __name__ == "__main__":
default="In the beginning the Universe was created.",
)
parser.add_argument(
"--few-shot",
"--few_shot",
help="Read a few shot prompt from a file (as in `sample_prompt.txt`).",
)
parser.add_argument(
"--max-tokens", "-m", type=int, default=100, help="How many tokens to generate"
"--max_tokens", "-m", type=int, default=100, help="How many tokens to generate"
)
parser.add_argument(
"--write-every", type=int, default=1, help="After how many tokens to detokenize"
"--write_every", type=int, default=1, help="After how many tokens to detokenize"
)
parser.add_argument(
"--temp", type=float, default=0.0, help="The sampling temperature"

View File

@@ -9,7 +9,6 @@ from typing import List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.utils import tree_map, tree_unflatten
from sentencepiece import SentencePieceProcessor

View File

@@ -176,7 +176,7 @@ def load_model(model_path: str):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Phi-2 inference script")
parser.add_argument(
"--model-path",
"--model_path",
type=str,
default="mlx_model",
help="The path to the model weights",
@@ -187,7 +187,7 @@ if __name__ == "__main__":
default="Write a detailed analogy between mathematics and a lighthouse.",
)
parser.add_argument(
"--max-tokens",
"--max_tokens",
"-m",
type=int,
default=100,