[Lora] Fix generate (#282)

* fix generate

* update readme, fix test, better default

* nits

* typo
This commit is contained in:
Awni Hannun
2024-01-10 16:13:06 -08:00
committed by GitHub
parent a2bc8426f2
commit 80d18671ad
5 changed files with 25 additions and 16 deletions

View File

@@ -265,7 +265,7 @@ def train(model, train_set, val_set, optimizer, loss, tokenizer, args):
def generate(model, prompt, tokenizer, args):
print(args.prompt, end="", flush=True)
prompt = tokenizer.encode(args.prompt)
prompt = mx.array(tokenizer.encode(args.prompt))
tokens = []
skip = 0