feat: add mistral tps (#173)

* feat: add mistral tps

* eval params before timing + format

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Todsaporn Banjerdkit
2023-12-22 22:55:57 +07:00
committed by GitHub
parent 188a91074b
commit 7ae445f6c7
4 changed files with 22 additions and 9 deletions

View File

@@ -209,10 +209,13 @@ def iterate_batches(dset, tokenizer, batch_size, train=False):
for j in range(batch_size)
]
lengths = [len(x) for x in batch]
# Check if any sequence is longer than 2048 tokens
if max(lengths) > 2048:
print("Warning: Some sequences are longer than 2048 tokens. Consider pre-splitting your data to save memory.")
print(
"[WARNING] Some sequences are longer than 2048 tokens. "
"Consider pre-splitting your data to save memory."
)
# Pad to the max length
batch_arr = np.zeros((batch_size, max(lengths)), np.int32)