Support --add_eos_token argument within Lora training (#760)

* Support `--add_eos_token` argument to empower users to control the addition of the eos token during LoRA training, addressing issues like incomplete text generation.

* Support `--add_eos_token`, code format

---------

Co-authored-by: Zhan ChengLong <zhanchenglong@bytedance.com>
This commit is contained in:
Jinwu Zhan 2024-05-14 08:17:42 +08:00 committed by GitHub
parent 10853b57d9
commit 1a86d985d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 16 additions and 4 deletions

View File

@ -47,6 +47,12 @@ def build_parser():
action="store_true",
help="Do training",
)
parser.add_argument(
"--add-eos-token",
type=int,
default=1,
help="Enable add_eos_token for tokenizer",
)
parser.add_argument(
"--data",
type=str,
@ -317,9 +323,13 @@ if __name__ == "__main__":
np.random.seed(args.seed)
print("Loading pretrained model")
model, tokenizer, _ = lora_utils.load(args.model)
# Building tokenizer_config
tokenizer_config = {}
if args.train:
tokenizer_config["add_eos_token"] = bool(args.add_eos_token)
print("Loading pretrained model")
model, tokenizer, _ = lora_utils.load(args.model, tokenizer_config)
# Freeze all layers other than LORA linears
model.freeze()
for l in model.model.layers[len(model.model.layers) - args.lora_layers :]:

View File

@ -122,7 +122,7 @@ def save_model(save_dir: str, weights, tokenizer, config):
)
def load(path_or_hf_repo: str):
def load(path_or_hf_repo: str, tokenizer_config={}):
# If the path exists, it will try to load model form it
# otherwise download and cache from the hf_repo and cache
model_path = Path(path_or_hf_repo)
@ -162,7 +162,9 @@ def load(path_or_hf_repo: str):
model.load_weights(list(weights.items()))
mx.eval(model.parameters())
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_path, **tokenizer_config
)
return model, tokenizer, config