From 1a86d985d96053a8456285ff93b845b962290280 Mon Sep 17 00:00:00 2001 From: Jinwu Zhan Date: Tue, 14 May 2024 08:17:42 +0800 Subject: [PATCH] Support `--add_eos_token` argument within Lora training (#760) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Support `--add_eos_token` argument to empower users to control the addition of the eos token during LoRA training, addressing issues like incomplete text generation. * Support `--add_eos_token`, code format --------- Co-authored-by: Zhan ChengLong --- lora/lora.py | 14 ++++++++++++-- lora/utils.py | 6 ++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/lora/lora.py b/lora/lora.py index 60b698a9..a90eda70 100644 --- a/lora/lora.py +++ b/lora/lora.py @@ -47,6 +47,12 @@ def build_parser(): action="store_true", help="Do training", ) + parser.add_argument( + "--add-eos-token", + type=int, + default=1, + help="Enable add_eos_token for tokenizer", + ) parser.add_argument( "--data", type=str, @@ -317,9 +323,13 @@ if __name__ == "__main__": np.random.seed(args.seed) - print("Loading pretrained model") - model, tokenizer, _ = lora_utils.load(args.model) + # Building tokenizer_config + tokenizer_config = {} + if args.train: + tokenizer_config["add_eos_token"] = bool(args.add_eos_token) + print("Loading pretrained model") + model, tokenizer, _ = lora_utils.load(args.model, tokenizer_config) # Freeze all layers other than LORA linears model.freeze() for l in model.model.layers[len(model.model.layers) - args.lora_layers :]: diff --git a/lora/utils.py b/lora/utils.py index 4026409f..a334723c 100644 --- a/lora/utils.py +++ b/lora/utils.py @@ -122,7 +122,7 @@ def save_model(save_dir: str, weights, tokenizer, config): ) -def load(path_or_hf_repo: str): +def load(path_or_hf_repo: str, tokenizer_config={}): # If the path exists, it will try to load model form it # otherwise download and cache from the hf_repo and cache model_path = Path(path_or_hf_repo) @@ -162,7 +162,9 @@ def load(path_or_hf_repo: str): model.load_weights(list(weights.items())) mx.eval(model.parameters()) - tokenizer = transformers.AutoTokenizer.from_pretrained(model_path) + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_path, **tokenizer_config + ) return model, tokenizer, config