From 13794a05da6b4066552abcb25cad44329d96036b Mon Sep 17 00:00:00 2001
From: Anchen
Date: Thu, 29 Feb 2024 02:49:25 +1100
Subject: [PATCH] chore(mlx-lm): add adapter support in generate.py (#494)
* chore(mlx-lm): add adapter support in generate.py
* chore: remove generate from lora.py and raise error to let user use mlx_lm.generate instead
---
llms/mlx_lm/LORA.md | 11 +++++++++++
llms/mlx_lm/generate.py | 9 ++++++++-
llms/mlx_lm/lora.py | 13 +++----------
3 files changed, 22 insertions(+), 11 deletions(-)
diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md
index 445b929a..74ec07f6 100644
--- a/llms/mlx_lm/LORA.md
+++ b/llms/mlx_lm/LORA.md
@@ -72,6 +72,17 @@ python -m mlx_lm.lora \
--test
```
+### Generate
+
+For generation use mlx_lm.generate:
+
+```shell
+python -m mlx_lm.generate \
+ --model \
+ --adapter-file \
+ --prompt ""
+```
+
## Fuse and Upload
You can generate a model fused with the low-rank adapters using the
diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py
index 58ecde05..394c8e15 100644
--- a/llms/mlx_lm/generate.py
+++ b/llms/mlx_lm/generate.py
@@ -23,6 +23,11 @@ def setup_arg_parser():
default="mlx_model",
help="The path to the local model directory or Hugging Face repo.",
)
+ parser.add_argument(
+ "--adapter-file",
+ type=str,
+ help="Optional path for the trained adapter weights.",
+ )
parser.add_argument(
"--trust-remote-code",
action="store_true",
@@ -99,7 +104,9 @@ def main(args):
if args.eos_token is not None:
tokenizer_config["eos_token"] = args.eos_token
- model, tokenizer = load(args.model, tokenizer_config=tokenizer_config)
+ model, tokenizer = load(
+ args.model, adapter_file=args.adapter_file, tokenizer_config=tokenizer_config
+ )
if not args.ignore_chat_template and (
hasattr(tokenizer, "apply_chat_template")
diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py
index e433bbd6..d316efe4 100644
--- a/llms/mlx_lm/lora.py
+++ b/llms/mlx_lm/lora.py
@@ -9,7 +9,7 @@ from mlx.utils import tree_flatten
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.utils import linear_to_lora_layers
-from .utils import generate, load
+from .utils import load
def build_parser():
@@ -234,15 +234,8 @@ def run(args, training_callback: TrainingCallback = None):
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
if args.prompt is not None:
- print("Generating")
- model.eval()
- generate(
- model=model,
- tokenizer=tokenizer,
- temp=args.temp,
- max_tokens=args.max_tokens,
- prompt=args.prompt,
- verbose=True,
+ raise NotImplementedError(
+ "Please use mlx_lm.generate with trained adapter for generation."
)