mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
chore(mlx-lm): add adapter support in generate.py (#494)
* chore(mlx-lm): add adapter support in generate.py * chore: remove generate from lora.py and raise error to let user use mlx_lm.generate instead
This commit is contained in:
parent
ab0f1dd1b6
commit
13794a05da
@ -72,6 +72,17 @@ python -m mlx_lm.lora \
|
|||||||
--test
|
--test
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Generate
|
||||||
|
|
||||||
|
For generation use mlx_lm.generate:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python -m mlx_lm.generate \
|
||||||
|
--model <path_to_model> \
|
||||||
|
--adapter-file <path_to_adapters.npz> \
|
||||||
|
--prompt "<your_model_prompt>"
|
||||||
|
```
|
||||||
|
|
||||||
## Fuse and Upload
|
## Fuse and Upload
|
||||||
|
|
||||||
You can generate a model fused with the low-rank adapters using the
|
You can generate a model fused with the low-rank adapters using the
|
||||||
|
@ -23,6 +23,11 @@ def setup_arg_parser():
|
|||||||
default="mlx_model",
|
default="mlx_model",
|
||||||
help="The path to the local model directory or Hugging Face repo.",
|
help="The path to the local model directory or Hugging Face repo.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--adapter-file",
|
||||||
|
type=str,
|
||||||
|
help="Optional path for the trained adapter weights.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--trust-remote-code",
|
"--trust-remote-code",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@ -99,7 +104,9 @@ def main(args):
|
|||||||
if args.eos_token is not None:
|
if args.eos_token is not None:
|
||||||
tokenizer_config["eos_token"] = args.eos_token
|
tokenizer_config["eos_token"] = args.eos_token
|
||||||
|
|
||||||
model, tokenizer = load(args.model, tokenizer_config=tokenizer_config)
|
model, tokenizer = load(
|
||||||
|
args.model, adapter_file=args.adapter_file, tokenizer_config=tokenizer_config
|
||||||
|
)
|
||||||
|
|
||||||
if not args.ignore_chat_template and (
|
if not args.ignore_chat_template and (
|
||||||
hasattr(tokenizer, "apply_chat_template")
|
hasattr(tokenizer, "apply_chat_template")
|
||||||
|
@ -9,7 +9,7 @@ from mlx.utils import tree_flatten
|
|||||||
|
|
||||||
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
||||||
from .tuner.utils import linear_to_lora_layers
|
from .tuner.utils import linear_to_lora_layers
|
||||||
from .utils import generate, load
|
from .utils import load
|
||||||
|
|
||||||
|
|
||||||
def build_parser():
|
def build_parser():
|
||||||
@ -234,15 +234,8 @@ def run(args, training_callback: TrainingCallback = None):
|
|||||||
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
|
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
|
||||||
|
|
||||||
if args.prompt is not None:
|
if args.prompt is not None:
|
||||||
print("Generating")
|
raise NotImplementedError(
|
||||||
model.eval()
|
"Please use mlx_lm.generate with trained adapter for generation."
|
||||||
generate(
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
temp=args.temp,
|
|
||||||
max_tokens=args.max_tokens,
|
|
||||||
prompt=args.prompt,
|
|
||||||
verbose=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user