mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-24 06:28:07 +08:00
Support --add_eos_token argument within Lora training (#760)
* Support `--add_eos_token` argument to empower users to control the addition of the eos token during LoRA training, addressing issues like incomplete text generation. * Support `--add_eos_token`, code format --------- Co-authored-by: Zhan ChengLong <zhanchenglong@bytedance.com>
This commit is contained in:
14
lora/lora.py
14
lora/lora.py
@@ -47,6 +47,12 @@ def build_parser():
|
||||
action="store_true",
|
||||
help="Do training",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--add-eos-token",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Enable add_eos_token for tokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data",
|
||||
type=str,
|
||||
@@ -317,9 +323,13 @@ if __name__ == "__main__":
|
||||
|
||||
np.random.seed(args.seed)
|
||||
|
||||
print("Loading pretrained model")
|
||||
model, tokenizer, _ = lora_utils.load(args.model)
|
||||
# Building tokenizer_config
|
||||
tokenizer_config = {}
|
||||
if args.train:
|
||||
tokenizer_config["add_eos_token"] = bool(args.add_eos_token)
|
||||
|
||||
print("Loading pretrained model")
|
||||
model, tokenizer, _ = lora_utils.load(args.model, tokenizer_config)
|
||||
# Freeze all layers other than LORA linears
|
||||
model.freeze()
|
||||
for l in model.model.layers[len(model.model.layers) - args.lora_layers :]:
|
||||
|
||||
Reference in New Issue
Block a user