chore(mlx-lm): add adapter support in generate.py (#494)

* chore(mlx-lm): add adapter support in generate.py

* chore: remove generate from lora.py and raise error to let user use mlx_lm.generate instead
This commit is contained in:
Anchen 2024-02-29 02:49:25 +11:00 committed by GitHub
parent ab0f1dd1b6
commit 13794a05da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 22 additions and 11 deletions

View File

@ -72,6 +72,17 @@ python -m mlx_lm.lora \
--test
```
### Generate
For generation use mlx_lm.generate:
```shell
python -m mlx_lm.generate \
--model <path_to_model> \
--adapter-file <path_to_adapters.npz> \
--prompt "<your_model_prompt>"
```
## Fuse and Upload
You can generate a model fused with the low-rank adapters using the

View File

@ -23,6 +23,11 @@ def setup_arg_parser():
default="mlx_model",
help="The path to the local model directory or Hugging Face repo.",
)
parser.add_argument(
"--adapter-file",
type=str,
help="Optional path for the trained adapter weights.",
)
parser.add_argument(
"--trust-remote-code",
action="store_true",
@ -99,7 +104,9 @@ def main(args):
if args.eos_token is not None:
tokenizer_config["eos_token"] = args.eos_token
model, tokenizer = load(args.model, tokenizer_config=tokenizer_config)
model, tokenizer = load(
args.model, adapter_file=args.adapter_file, tokenizer_config=tokenizer_config
)
if not args.ignore_chat_template and (
hasattr(tokenizer, "apply_chat_template")

View File

@ -9,7 +9,7 @@ from mlx.utils import tree_flatten
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.utils import linear_to_lora_layers
from .utils import generate, load
from .utils import load
def build_parser():
@ -234,15 +234,8 @@ def run(args, training_callback: TrainingCallback = None):
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
if args.prompt is not None:
print("Generating")
model.eval()
generate(
model=model,
tokenizer=tokenizer,
temp=args.temp,
max_tokens=args.max_tokens,
prompt=args.prompt,
verbose=True,
raise NotImplementedError(
"Please use mlx_lm.generate with trained adapter for generation."
)