address comments

This commit is contained in:
Awni Hannun
2024-01-03 11:33:41 -08:00
parent 9bb3b4bd77
commit 384ad5792e
4 changed files with 8 additions and 12 deletions

View File

@@ -1,6 +1,7 @@
## Generate Text in MLX
This an example of Llama style large language model text generation.
This an example of Llama style large language model text generation that can
pull models from the Hugging Face Hub.
### Setup
@@ -16,10 +17,10 @@ pip install -r requirements.txt
python generate.py --model <model_path> --prompt "hello"
```
The `<model_path>` should be either a path to a local directory with an MLX
formatted model, or a Hugging Face repo. If the latter, then the model will
be downloaded and cached the first time you use it. See the [#Models] section
for a full list of supported models.
The `<model_path>` should be either a path to a local directory or a Hugging
Face repo with weights stored in `safetensors` format. If you use a repo from
the Hugging Face hub, then the model will be downloaded and cached the first
time you run it. See the [#Models] section for a full list of supported models.
Run `python generate.py --help` to see all the options.

View File

@@ -1,17 +1,13 @@
# Copyright © 2023 Apple Inc.
import argparse
import collections
import copy
import glob
import json
import shutil
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import torch
import transformers
from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten, tree_map
@@ -25,7 +21,7 @@ def fetch_from_hub(hf_path: str):
)
weight_files = glob.glob(f"{model_path}/*.safetensors")
if len(weight_files) == 0:
raise FileNotFoundError("No weights found in {}".format(model_path))
raise FileNotFoundError("No safetensors found in {}".format(model_path))
weights = {}
for wf in weight_files:

View File

@@ -220,7 +220,7 @@ def load(path_or_hf_repo: str):
weight_files = glob.glob(str(model_path / "*.safetensors"))
if len(weight_files) == 0:
raise FileNotFoundError("No weights found in {}".format(model_path))
raise FileNotFoundError("No safetensors found in {}".format(model_path))
weights = {}
for wf in weight_files:

View File

@@ -1,4 +1,3 @@
mlx>=0.0.7
numpy
torch
transformers