[Lora] Fix generate (#282)

* fix generate

* update readme, fix test, better default

* nits

* typo
This commit is contained in:
Awni Hannun
2024-01-10 16:13:06 -08:00
committed by GitHub
parent a2bc8426f2
commit 80d18671ad
5 changed files with 25 additions and 16 deletions

View File

@@ -110,7 +110,7 @@ For generation use:
```
python lora.py --model <path_to_model> \
--adapter-file <path_to_adapters.npz> \
--num-tokens 50 \
--max-tokens 50 \
--prompt "table: 1-10015132-16
columns: Player, No., Nationality, Position, Years in Toronto, School/Club Team
Q: What is terrence ross' nationality

View File

@@ -265,7 +265,7 @@ def train(model, train_set, val_set, optimizer, loss, tokenizer, args):
def generate(model, prompt, tokenizer, args):
print(args.prompt, end="", flush=True)
prompt = tokenizer.encode(args.prompt)
prompt = mx.array(tokenizer.encode(args.prompt))
tokens = []
skip = 0