mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
feat: add mistral tps (#173)
* feat: add mistral tps * eval params before timing + format --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
188a91074b
commit
7ae445f6c7
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
@ -204,6 +205,7 @@ def load_model(folder: str):
|
|||||||
if quantization is not None:
|
if quantization is not None:
|
||||||
nn.QuantizedLinear.quantize_module(model, **quantization)
|
nn.QuantizedLinear.quantize_module(model, **quantization)
|
||||||
model.update(weights)
|
model.update(weights)
|
||||||
|
mx.eval(model.parameters())
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
@ -265,12 +267,17 @@ if __name__ == "__main__":
|
|||||||
model, tokenizer = load_model(args.model_path)
|
model, tokenizer = load_model(args.model_path)
|
||||||
|
|
||||||
print("[INFO] Starting generation...")
|
print("[INFO] Starting generation...")
|
||||||
|
tic = time.time()
|
||||||
print(args.prompt, end="", flush=True)
|
print(args.prompt, end="", flush=True)
|
||||||
prompt = mx.array(tokenizer.encode(args.prompt))
|
prompt = mx.array(tokenizer.encode(args.prompt))
|
||||||
tokens = []
|
tokens = []
|
||||||
for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
|
for token, ntoks in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
|
||||||
tokens.append(token)
|
tokens.append(token)
|
||||||
|
if ntoks == 0:
|
||||||
|
toc = time.time()
|
||||||
|
mx.eval(tokens)
|
||||||
|
prompt_tps = prompt.size / (toc - tic)
|
||||||
|
tic = time.time()
|
||||||
|
|
||||||
if (len(tokens) % args.tokens_per_eval) == 0:
|
if (len(tokens) % args.tokens_per_eval) == 0:
|
||||||
mx.eval(tokens)
|
mx.eval(tokens)
|
||||||
@ -282,3 +289,8 @@ if __name__ == "__main__":
|
|||||||
s = tokenizer.decode([t.item() for t in tokens])
|
s = tokenizer.decode([t.item() for t in tokens])
|
||||||
print(s, flush=True)
|
print(s, flush=True)
|
||||||
print("------")
|
print("------")
|
||||||
|
generation_tps = ntoks / (time.time() - tic)
|
||||||
|
print(
|
||||||
|
f"Tokens per second: prompt {prompt_tps:.3f}, "
|
||||||
|
f"generation {generation_tps:.3f}"
|
||||||
|
)
|
||||||
|
@ -209,10 +209,13 @@ def iterate_batches(dset, tokenizer, batch_size, train=False):
|
|||||||
for j in range(batch_size)
|
for j in range(batch_size)
|
||||||
]
|
]
|
||||||
lengths = [len(x) for x in batch]
|
lengths = [len(x) for x in batch]
|
||||||
|
|
||||||
# Check if any sequence is longer than 2048 tokens
|
# Check if any sequence is longer than 2048 tokens
|
||||||
if max(lengths) > 2048:
|
if max(lengths) > 2048:
|
||||||
print("Warning: Some sequences are longer than 2048 tokens. Consider pre-splitting your data to save memory.")
|
print(
|
||||||
|
"[WARNING] Some sequences are longer than 2048 tokens. "
|
||||||
|
"Consider pre-splitting your data to save memory."
|
||||||
|
)
|
||||||
|
|
||||||
# Pad to the max length
|
# Pad to the max length
|
||||||
batch_arr = np.zeros((batch_size, max(lengths)), np.int32)
|
batch_arr = np.zeros((batch_size, max(lengths)), np.int32)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from transformers import AutoTokenizer, T5EncoderModel, AutoModelForSeq2SeqLM
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, T5EncoderModel
|
||||||
|
|
||||||
|
|
||||||
def embed(t5_model: str):
|
def embed(t5_model: str):
|
||||||
|
6
t5/t5.py
6
t5/t5.py
@ -6,7 +6,7 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mlx.utils import tree_map, tree_unflatten
|
from mlx.utils import tree_map, tree_unflatten
|
||||||
from transformers import T5Config, AutoTokenizer
|
from transformers import AutoTokenizer, T5Config
|
||||||
|
|
||||||
|
|
||||||
def _relative_position_bucket(
|
def _relative_position_bucket(
|
||||||
@ -252,9 +252,7 @@ class TransformerDecoder(nn.Module):
|
|||||||
def __init__(self, config: T5Config):
|
def __init__(self, config: T5Config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
n_layers = getattr(config, "num_decoder_layers", config.num_layers)
|
n_layers = getattr(config, "num_decoder_layers", config.num_layers)
|
||||||
self.layers = [
|
self.layers = [TransformerDecoderLayer(config) for i in range(n_layers)]
|
||||||
TransformerDecoderLayer(config) for i in range(n_layers)
|
|
||||||
]
|
|
||||||
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
self.relative_attention_bias = RelativePositionBias(config, bidirectional=False)
|
self.relative_attention_bias = RelativePositionBias(config, bidirectional=False)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user