From 31273bafbff4b4ef96f490b5bd2ea29e70df6535 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 22 Dec 2023 07:54:36 -0800 Subject: [PATCH] eval params before timing + format --- llms/mistral/mistral.py | 26 ++++++++++++++------------ lora/lora.py | 7 +++++-- t5/hf_t5.py | 2 +- t5/t5.py | 6 ++---- 4 files changed, 22 insertions(+), 19 deletions(-) diff --git a/llms/mistral/mistral.py b/llms/mistral/mistral.py index 5bd67337..688360f2 100644 --- a/llms/mistral/mistral.py +++ b/llms/mistral/mistral.py @@ -2,6 +2,7 @@ import argparse import json +import time from dataclasses import dataclass from pathlib import Path from typing import List, Optional, Tuple @@ -10,7 +11,6 @@ import mlx.core as mx import mlx.nn as nn from mlx.utils import tree_unflatten from sentencepiece import SentencePieceProcessor -import time @dataclass @@ -205,6 +205,7 @@ def load_model(folder: str): if quantization is not None: nn.QuantizedLinear.quantize_module(model, **quantization) model.update(weights) + mx.eval(model.parameters()) return model, tokenizer @@ -266,16 +267,17 @@ if __name__ == "__main__": model, tokenizer = load_model(args.model_path) print("[INFO] Starting generation...") - - start_time = time.time() - + tic = time.time() print(args.prompt, end="", flush=True) prompt = mx.array(tokenizer.encode(args.prompt)) tokens = [] - tokens_len = 0 - for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)): + for token, ntoks in zip(generate(prompt, model, args.temp), range(args.max_tokens)): tokens.append(token) - tokens_len = tokens_len + len(tokens) + if ntoks == 0: + toc = time.time() + mx.eval(tokens) + prompt_tps = prompt.size / (toc - tic) + tic = time.time() if (len(tokens) % args.tokens_per_eval) == 0: mx.eval(tokens) @@ -287,8 +289,8 @@ if __name__ == "__main__": s = tokenizer.decode([t.item() for t in tokens]) print(s, flush=True) print("------") - - # Calculate overall tokens per second - elapsed_time = time.time() - start_time - tokens_per_sec = tokens_len / elapsed_time - print(f"Tokens per second: {tokens_per_sec:.2f}") \ No newline at end of file + generation_tps = ntoks / (time.time() - tic) + print( + f"Tokens per second: prompt {prompt_tps:.3f}, " + f"generation {generation_tps:.3f}" + ) diff --git a/lora/lora.py b/lora/lora.py index 718a27c9..528cf506 100644 --- a/lora/lora.py +++ b/lora/lora.py @@ -209,10 +209,13 @@ def iterate_batches(dset, tokenizer, batch_size, train=False): for j in range(batch_size) ] lengths = [len(x) for x in batch] - + # Check if any sequence is longer than 2048 tokens if max(lengths) > 2048: - print("Warning: Some sequences are longer than 2048 tokens. Consider pre-splitting your data to save memory.") + print( + "[WARNING] Some sequences are longer than 2048 tokens. " + "Consider pre-splitting your data to save memory." + ) # Pad to the max length batch_arr = np.zeros((batch_size, max(lengths)), np.int32) diff --git a/t5/hf_t5.py b/t5/hf_t5.py index 12329e4b..98c6da80 100644 --- a/t5/hf_t5.py +++ b/t5/hf_t5.py @@ -1,6 +1,6 @@ import argparse -from transformers import AutoTokenizer, T5EncoderModel, AutoModelForSeq2SeqLM +from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, T5EncoderModel def embed(t5_model: str): diff --git a/t5/t5.py b/t5/t5.py index 45855136..2acd39b4 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -6,7 +6,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np from mlx.utils import tree_map, tree_unflatten -from transformers import T5Config, AutoTokenizer +from transformers import AutoTokenizer, T5Config def _relative_position_bucket( @@ -252,9 +252,7 @@ class TransformerDecoder(nn.Module): def __init__(self, config: T5Config): super().__init__() n_layers = getattr(config, "num_decoder_layers", config.num_layers) - self.layers = [ - TransformerDecoderLayer(config) for i in range(n_layers) - ] + self.layers = [TransformerDecoderLayer(config) for i in range(n_layers)] self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) self.relative_attention_bias = RelativePositionBias(config, bidirectional=False)