diff --git a/lora/lora.py b/lora/lora.py index eb672996..3f1f085f 100644 --- a/lora/lora.py +++ b/lora/lora.py @@ -263,9 +263,9 @@ def train(model, train_set, val_set, optimizer, loss, tokenizer, args): def generate(model, prompt, tokenizer, args): - print(args.prompt, end="", flush=True) + print(prompt, end="", flush=True) - prompt = mx.array(tokenizer.encode(args.prompt)) + prompt = mx.array(tokenizer.encode(prompt)) tokens = [] skip = 0