mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-04 15:54:34 +08:00
address comments
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
## Generate Text in MLX
|
||||
|
||||
This an example of Llama style large language model text generation.
|
||||
This an example of Llama style large language model text generation that can
|
||||
pull models from the Hugging Face Hub.
|
||||
|
||||
### Setup
|
||||
|
||||
@@ -16,10 +17,10 @@ pip install -r requirements.txt
|
||||
python generate.py --model <model_path> --prompt "hello"
|
||||
```
|
||||
|
||||
The `<model_path>` should be either a path to a local directory with an MLX
|
||||
formatted model, or a Hugging Face repo. If the latter, then the model will
|
||||
be downloaded and cached the first time you use it. See the [#Models] section
|
||||
for a full list of supported models.
|
||||
The `<model_path>` should be either a path to a local directory or a Hugging
|
||||
Face repo with weights stored in `safetensors` format. If you use a repo from
|
||||
the Hugging Face hub, then the model will be downloaded and cached the first
|
||||
time you run it. See the [#Models] section for a full list of supported models.
|
||||
|
||||
Run `python generate.py --help` to see all the options.
|
||||
|
||||
|
@@ -1,17 +1,13 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import copy
|
||||
import glob
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
import torch
|
||||
import transformers
|
||||
from huggingface_hub import snapshot_download
|
||||
from mlx.utils import tree_flatten, tree_map
|
||||
@@ -25,7 +21,7 @@ def fetch_from_hub(hf_path: str):
|
||||
)
|
||||
weight_files = glob.glob(f"{model_path}/*.safetensors")
|
||||
if len(weight_files) == 0:
|
||||
raise FileNotFoundError("No weights found in {}".format(model_path))
|
||||
raise FileNotFoundError("No safetensors found in {}".format(model_path))
|
||||
|
||||
weights = {}
|
||||
for wf in weight_files:
|
||||
|
@@ -220,7 +220,7 @@ def load(path_or_hf_repo: str):
|
||||
|
||||
weight_files = glob.glob(str(model_path / "*.safetensors"))
|
||||
if len(weight_files) == 0:
|
||||
raise FileNotFoundError("No weights found in {}".format(model_path))
|
||||
raise FileNotFoundError("No safetensors found in {}".format(model_path))
|
||||
|
||||
weights = {}
|
||||
for wf in weight_files:
|
||||
|
@@ -1,4 +1,3 @@
|
||||
mlx>=0.0.7
|
||||
numpy
|
||||
torch
|
||||
transformers
|
||||
|
Reference in New Issue
Block a user