mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Align CLI args and some smaller fixes (#167)
* Add `.DS_Store` files to `.gitignore` * Fix variable naming of `config` in `mixtral/convert.py` * Align CLI args and minor fixes * standardize * one more --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
0eaa323c10
commit
f4709cb807
5
.gitignore
vendored
5
.gitignore
vendored
@ -127,5 +127,10 @@ dmypy.json
|
|||||||
|
|
||||||
# Pyre type checker
|
# Pyre type checker
|
||||||
.pyre/
|
.pyre/
|
||||||
|
|
||||||
|
# IDE files
|
||||||
.idea/
|
.idea/
|
||||||
.vscode/
|
.vscode/
|
||||||
|
|
||||||
|
# .DS_Store files
|
||||||
|
.DS_Store
|
||||||
|
@ -60,7 +60,7 @@ def llama(model_path):
|
|||||||
def tiny_llama(model_path):
|
def tiny_llama(model_path):
|
||||||
try:
|
try:
|
||||||
import transformers
|
import transformers
|
||||||
except ImportError as e:
|
except ImportError:
|
||||||
print("The transformers package must be installed for this model conversion:")
|
print("The transformers package must be installed for this model conversion:")
|
||||||
print("pip install transformers")
|
print("pip install transformers")
|
||||||
exit(0)
|
exit(0)
|
||||||
|
@ -5,7 +5,7 @@ import json
|
|||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -218,7 +218,6 @@ def toc(msg, start):
|
|||||||
|
|
||||||
|
|
||||||
def generate(args):
|
def generate(args):
|
||||||
|
|
||||||
input("Press enter to start generation")
|
input("Press enter to start generation")
|
||||||
print("------")
|
print("------")
|
||||||
print(args.prompt)
|
print(args.prompt)
|
||||||
@ -347,7 +346,7 @@ if __name__ == "__main__":
|
|||||||
parser = argparse.ArgumentParser(description="Llama inference script")
|
parser = argparse.ArgumentParser(description="Llama inference script")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model-path",
|
"--model-path",
|
||||||
help="Path to the model directory containing the MLX weights",
|
help="Path to the model weights and tokenizer",
|
||||||
default="mlx_model",
|
default="mlx_model",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -9,7 +9,6 @@ from typing import List, Optional, Tuple
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import numpy as np
|
|
||||||
from mlx.utils import tree_map, tree_unflatten
|
from mlx.utils import tree_map, tree_unflatten
|
||||||
from sentencepiece import SentencePieceProcessor
|
from sentencepiece import SentencePieceProcessor
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user