mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Support --add_eos_token
argument within Lora training (#760)
* Support `--add_eos_token` argument to empower users to control the addition of the eos token during LoRA training, addressing issues like incomplete text generation. * Support `--add_eos_token`, code format --------- Co-authored-by: Zhan ChengLong <zhanchenglong@bytedance.com>
This commit is contained in:
parent
10853b57d9
commit
1a86d985d9
14
lora/lora.py
14
lora/lora.py
@ -47,6 +47,12 @@ def build_parser():
|
||||
action="store_true",
|
||||
help="Do training",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--add-eos-token",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Enable add_eos_token for tokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data",
|
||||
type=str,
|
||||
@ -317,9 +323,13 @@ if __name__ == "__main__":
|
||||
|
||||
np.random.seed(args.seed)
|
||||
|
||||
print("Loading pretrained model")
|
||||
model, tokenizer, _ = lora_utils.load(args.model)
|
||||
# Building tokenizer_config
|
||||
tokenizer_config = {}
|
||||
if args.train:
|
||||
tokenizer_config["add_eos_token"] = bool(args.add_eos_token)
|
||||
|
||||
print("Loading pretrained model")
|
||||
model, tokenizer, _ = lora_utils.load(args.model, tokenizer_config)
|
||||
# Freeze all layers other than LORA linears
|
||||
model.freeze()
|
||||
for l in model.model.layers[len(model.model.layers) - args.lora_layers :]:
|
||||
|
@ -122,7 +122,7 @@ def save_model(save_dir: str, weights, tokenizer, config):
|
||||
)
|
||||
|
||||
|
||||
def load(path_or_hf_repo: str):
|
||||
def load(path_or_hf_repo: str, tokenizer_config={}):
|
||||
# If the path exists, it will try to load model form it
|
||||
# otherwise download and cache from the hf_repo and cache
|
||||
model_path = Path(path_or_hf_repo)
|
||||
@ -162,7 +162,9 @@ def load(path_or_hf_repo: str):
|
||||
model.load_weights(list(weights.items()))
|
||||
|
||||
mx.eval(model.parameters())
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
model_path, **tokenizer_config
|
||||
)
|
||||
return model, tokenizer, config
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user