Align CLI args and some smaller fixes (#167)

* Add `.DS_Store` files to `.gitignore`

* Fix variable naming of `config` in `mixtral/convert.py`

* Align CLI args and minor fixes

* standardize

* one more

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Alvaro Bartolome 2023-12-22 23:34:32 +01:00 committed by GitHub
parent 0eaa323c10
commit f4709cb807
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 8 additions and 5 deletions

5
.gitignore vendored
View File

@ -127,5 +127,10 @@ dmypy.json
# Pyre type checker
.pyre/
# IDE files
.idea/
.vscode/
# .DS_Store files
.DS_Store

View File

@ -60,7 +60,7 @@ def llama(model_path):
def tiny_llama(model_path):
try:
import transformers
except ImportError as e:
except ImportError:
print("The transformers package must be installed for this model conversion:")
print("pip install transformers")
exit(0)

View File

@ -5,7 +5,7 @@ import json
import time
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple
from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
@ -218,7 +218,6 @@ def toc(msg, start):
def generate(args):
input("Press enter to start generation")
print("------")
print(args.prompt)
@ -347,7 +346,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Llama inference script")
parser.add_argument(
"--model-path",
help="Path to the model directory containing the MLX weights",
help="Path to the model weights and tokenizer",
default="mlx_model",
)
parser.add_argument(

View File

@ -9,7 +9,6 @@ from typing import List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.utils import tree_map, tree_unflatten
from sentencepiece import SentencePieceProcessor