prompt parameter (#291)

This commit is contained in:
Anjor Kanekar 2024-01-11 14:04:57 +00:00 committed by GitHub
parent 7380ebfb0d
commit e74889d0fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -263,9 +263,9 @@ def train(model, train_set, val_set, optimizer, loss, tokenizer, args):
def generate(model, prompt, tokenizer, args):
print(args.prompt, end="", flush=True)
print(prompt, end="", flush=True)
prompt = mx.array(tokenizer.encode(args.prompt))
prompt = mx.array(tokenizer.encode(prompt))
tokens = []
skip = 0