diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index def3b6dd..d32bfe6d 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -62,6 +62,7 @@ CONFIG_DEFAULTS = { "grad_checkpoint": False, "lr_schedule": None, "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, + "mask_prompt": False, } @@ -99,7 +100,7 @@ def build_parser(): "--mask-prompt", action="store_true", help="Mask the prompt in the loss when training", - default=False, + default=None, ) parser.add_argument(