feat: add mistral tps (#173)

* feat: add mistral tps

* eval params before timing + format

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Todsaporn Banjerdkit 2023-12-22 22:55:57 +07:00 committed by GitHub
parent 188a91074b
commit 7ae445f6c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 22 additions and 9 deletions

View File

@ -2,6 +2,7 @@
import argparse import argparse
import json import json
import time
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
@ -204,6 +205,7 @@ def load_model(folder: str):
if quantization is not None: if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization) nn.QuantizedLinear.quantize_module(model, **quantization)
model.update(weights) model.update(weights)
mx.eval(model.parameters())
return model, tokenizer return model, tokenizer
@ -265,12 +267,17 @@ if __name__ == "__main__":
model, tokenizer = load_model(args.model_path) model, tokenizer = load_model(args.model_path)
print("[INFO] Starting generation...") print("[INFO] Starting generation...")
tic = time.time()
print(args.prompt, end="", flush=True) print(args.prompt, end="", flush=True)
prompt = mx.array(tokenizer.encode(args.prompt)) prompt = mx.array(tokenizer.encode(args.prompt))
tokens = [] tokens = []
for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)): for token, ntoks in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
tokens.append(token) tokens.append(token)
if ntoks == 0:
toc = time.time()
mx.eval(tokens)
prompt_tps = prompt.size / (toc - tic)
tic = time.time()
if (len(tokens) % args.tokens_per_eval) == 0: if (len(tokens) % args.tokens_per_eval) == 0:
mx.eval(tokens) mx.eval(tokens)
@ -282,3 +289,8 @@ if __name__ == "__main__":
s = tokenizer.decode([t.item() for t in tokens]) s = tokenizer.decode([t.item() for t in tokens])
print(s, flush=True) print(s, flush=True)
print("------") print("------")
generation_tps = ntoks / (time.time() - tic)
print(
f"Tokens per second: prompt {prompt_tps:.3f}, "
f"generation {generation_tps:.3f}"
)

View File

@ -209,10 +209,13 @@ def iterate_batches(dset, tokenizer, batch_size, train=False):
for j in range(batch_size) for j in range(batch_size)
] ]
lengths = [len(x) for x in batch] lengths = [len(x) for x in batch]
# Check if any sequence is longer than 2048 tokens # Check if any sequence is longer than 2048 tokens
if max(lengths) > 2048: if max(lengths) > 2048:
print("Warning: Some sequences are longer than 2048 tokens. Consider pre-splitting your data to save memory.") print(
"[WARNING] Some sequences are longer than 2048 tokens. "
"Consider pre-splitting your data to save memory."
)
# Pad to the max length # Pad to the max length
batch_arr = np.zeros((batch_size, max(lengths)), np.int32) batch_arr = np.zeros((batch_size, max(lengths)), np.int32)

View File

@ -1,6 +1,6 @@
import argparse import argparse
from transformers import AutoTokenizer, T5EncoderModel, AutoModelForSeq2SeqLM from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, T5EncoderModel
def embed(t5_model: str): def embed(t5_model: str):

View File

@ -6,7 +6,7 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
from mlx.utils import tree_map, tree_unflatten from mlx.utils import tree_map, tree_unflatten
from transformers import T5Config, AutoTokenizer from transformers import AutoTokenizer, T5Config
def _relative_position_bucket( def _relative_position_bucket(
@ -252,9 +252,7 @@ class TransformerDecoder(nn.Module):
def __init__(self, config: T5Config): def __init__(self, config: T5Config):
super().__init__() super().__init__()
n_layers = getattr(config, "num_decoder_layers", config.num_layers) n_layers = getattr(config, "num_decoder_layers", config.num_layers)
self.layers = [ self.layers = [TransformerDecoderLayer(config) for i in range(n_layers)]
TransformerDecoderLayer(config) for i in range(n_layers)
]
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.relative_attention_bias = RelativePositionBias(config, bidirectional=False) self.relative_attention_bias = RelativePositionBias(config, bidirectional=False)