diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md index 445b929a..74ec07f6 100644 --- a/llms/mlx_lm/LORA.md +++ b/llms/mlx_lm/LORA.md @@ -72,6 +72,17 @@ python -m mlx_lm.lora \ --test ``` +### Generate + +For generation use mlx_lm.generate: + +```shell +python -m mlx_lm.generate \ + --model \ + --adapter-file \ + --prompt "" +``` + ## Fuse and Upload You can generate a model fused with the low-rank adapters using the diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 58ecde05..394c8e15 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -23,6 +23,11 @@ def setup_arg_parser(): default="mlx_model", help="The path to the local model directory or Hugging Face repo.", ) + parser.add_argument( + "--adapter-file", + type=str, + help="Optional path for the trained adapter weights.", + ) parser.add_argument( "--trust-remote-code", action="store_true", @@ -99,7 +104,9 @@ def main(args): if args.eos_token is not None: tokenizer_config["eos_token"] = args.eos_token - model, tokenizer = load(args.model, tokenizer_config=tokenizer_config) + model, tokenizer = load( + args.model, adapter_file=args.adapter_file, tokenizer_config=tokenizer_config + ) if not args.ignore_chat_template and ( hasattr(tokenizer, "apply_chat_template") diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index e433bbd6..d316efe4 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -9,7 +9,7 @@ from mlx.utils import tree_flatten from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train from .tuner.utils import linear_to_lora_layers -from .utils import generate, load +from .utils import load def build_parser(): @@ -234,15 +234,8 @@ def run(args, training_callback: TrainingCallback = None): print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.") if args.prompt is not None: - print("Generating") - model.eval() - generate( - model=model, - tokenizer=tokenizer, - temp=args.temp, - max_tokens=args.max_tokens, - prompt=args.prompt, - verbose=True, + raise NotImplementedError( + "Please use mlx_lm.generate with trained adapter for generation." )