diff --git a/lora/lora.py b/lora/lora.py index 60b698a9..a90eda70 100644 --- a/lora/lora.py +++ b/lora/lora.py @@ -47,6 +47,12 @@ def build_parser(): action="store_true", help="Do training", ) + parser.add_argument( + "--add-eos-token", + type=int, + default=1, + help="Enable add_eos_token for tokenizer", + ) parser.add_argument( "--data", type=str, @@ -317,9 +323,13 @@ if __name__ == "__main__": np.random.seed(args.seed) - print("Loading pretrained model") - model, tokenizer, _ = lora_utils.load(args.model) + # Building tokenizer_config + tokenizer_config = {} + if args.train: + tokenizer_config["add_eos_token"] = bool(args.add_eos_token) + print("Loading pretrained model") + model, tokenizer, _ = lora_utils.load(args.model, tokenizer_config) # Freeze all layers other than LORA linears model.freeze() for l in model.model.layers[len(model.model.layers) - args.lora_layers :]: diff --git a/lora/utils.py b/lora/utils.py index 4026409f..a334723c 100644 --- a/lora/utils.py +++ b/lora/utils.py @@ -122,7 +122,7 @@ def save_model(save_dir: str, weights, tokenizer, config): ) -def load(path_or_hf_repo: str): +def load(path_or_hf_repo: str, tokenizer_config={}): # If the path exists, it will try to load model form it # otherwise download and cache from the hf_repo and cache model_path = Path(path_or_hf_repo) @@ -162,7 +162,9 @@ def load(path_or_hf_repo: str): model.load_weights(list(weights.items())) mx.eval(model.parameters()) - tokenizer = transformers.AutoTokenizer.from_pretrained(model_path) + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_path, **tokenizer_config + ) return model, tokenizer, config