feat: add mistral tps (#173)

* feat: add mistral tps

* eval params before timing + format

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Todsaporn Banjerdkit 2023-12-22 22:55:57 +07:00 committed by GitHub
parent 188a91074b
commit 7ae445f6c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 22 additions and 9 deletions

View File

@ -2,6 +2,7 @@
import argparse
import json
import time
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple
@ -204,6 +205,7 @@ def load_model(folder: str):
if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization)
model.update(weights)
mx.eval(model.parameters())
return model, tokenizer
@ -265,12 +267,17 @@ if __name__ == "__main__":
model, tokenizer = load_model(args.model_path)
print("[INFO] Starting generation...")
tic = time.time()
print(args.prompt, end="", flush=True)
prompt = mx.array(tokenizer.encode(args.prompt))
tokens = []
for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
for token, ntoks in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
tokens.append(token)
if ntoks == 0:
toc = time.time()
mx.eval(tokens)
prompt_tps = prompt.size / (toc - tic)
tic = time.time()
if (len(tokens) % args.tokens_per_eval) == 0:
mx.eval(tokens)
@ -282,3 +289,8 @@ if __name__ == "__main__":
s = tokenizer.decode([t.item() for t in tokens])
print(s, flush=True)
print("------")
generation_tps = ntoks / (time.time() - tic)
print(
f"Tokens per second: prompt {prompt_tps:.3f}, "
f"generation {generation_tps:.3f}"
)

View File

@ -212,7 +212,10 @@ def iterate_batches(dset, tokenizer, batch_size, train=False):
# Check if any sequence is longer than 2048 tokens
if max(lengths) > 2048:
print("Warning: Some sequences are longer than 2048 tokens. Consider pre-splitting your data to save memory.")
print(
"[WARNING] Some sequences are longer than 2048 tokens. "
"Consider pre-splitting your data to save memory."
)
# Pad to the max length
batch_arr = np.zeros((batch_size, max(lengths)), np.int32)

View File

@ -1,6 +1,6 @@
import argparse
from transformers import AutoTokenizer, T5EncoderModel, AutoModelForSeq2SeqLM
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, T5EncoderModel
def embed(t5_model: str):

View File

@ -6,7 +6,7 @@ import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.utils import tree_map, tree_unflatten
from transformers import T5Config, AutoTokenizer
from transformers import AutoTokenizer, T5Config
def _relative_position_bucket(
@ -252,9 +252,7 @@ class TransformerDecoder(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
n_layers = getattr(config, "num_decoder_layers", config.num_layers)
self.layers = [
TransformerDecoderLayer(config) for i in range(n_layers)
]
self.layers = [TransformerDecoderLayer(config) for i in range(n_layers)]
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.relative_attention_bias = RelativePositionBias(config, bidirectional=False)