mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
chore(mlx-lm): add adapter support in generate.py (#494)
* chore(mlx-lm): add adapter support in generate.py * chore: remove generate from lora.py and raise error to let user use mlx_lm.generate instead
This commit is contained in:
parent
ab0f1dd1b6
commit
13794a05da
@ -72,6 +72,17 @@ python -m mlx_lm.lora \
|
||||
--test
|
||||
```
|
||||
|
||||
### Generate
|
||||
|
||||
For generation use mlx_lm.generate:
|
||||
|
||||
```shell
|
||||
python -m mlx_lm.generate \
|
||||
--model <path_to_model> \
|
||||
--adapter-file <path_to_adapters.npz> \
|
||||
--prompt "<your_model_prompt>"
|
||||
```
|
||||
|
||||
## Fuse and Upload
|
||||
|
||||
You can generate a model fused with the low-rank adapters using the
|
||||
|
@ -23,6 +23,11 @@ def setup_arg_parser():
|
||||
default="mlx_model",
|
||||
help="The path to the local model directory or Hugging Face repo.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adapter-file",
|
||||
type=str,
|
||||
help="Optional path for the trained adapter weights.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trust-remote-code",
|
||||
action="store_true",
|
||||
@ -99,7 +104,9 @@ def main(args):
|
||||
if args.eos_token is not None:
|
||||
tokenizer_config["eos_token"] = args.eos_token
|
||||
|
||||
model, tokenizer = load(args.model, tokenizer_config=tokenizer_config)
|
||||
model, tokenizer = load(
|
||||
args.model, adapter_file=args.adapter_file, tokenizer_config=tokenizer_config
|
||||
)
|
||||
|
||||
if not args.ignore_chat_template and (
|
||||
hasattr(tokenizer, "apply_chat_template")
|
||||
|
@ -9,7 +9,7 @@ from mlx.utils import tree_flatten
|
||||
|
||||
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
||||
from .tuner.utils import linear_to_lora_layers
|
||||
from .utils import generate, load
|
||||
from .utils import load
|
||||
|
||||
|
||||
def build_parser():
|
||||
@ -234,15 +234,8 @@ def run(args, training_callback: TrainingCallback = None):
|
||||
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
|
||||
|
||||
if args.prompt is not None:
|
||||
print("Generating")
|
||||
model.eval()
|
||||
generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
temp=args.temp,
|
||||
max_tokens=args.max_tokens,
|
||||
prompt=args.prompt,
|
||||
verbose=True,
|
||||
raise NotImplementedError(
|
||||
"Please use mlx_lm.generate with trained adapter for generation."
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user