diff --git a/llms/mistral/mistral.py b/llms/mistral/mistral.py index f023ae02..688360f2 100644 --- a/llms/mistral/mistral.py +++ b/llms/mistral/mistral.py @@ -2,6 +2,7 @@ import argparse import json +import time from dataclasses import dataclass from pathlib import Path from typing import List, Optional, Tuple @@ -204,6 +205,7 @@ def load_model(folder: str): if quantization is not None: nn.QuantizedLinear.quantize_module(model, **quantization) model.update(weights) + mx.eval(model.parameters()) return model, tokenizer @@ -265,12 +267,17 @@ if __name__ == "__main__": model, tokenizer = load_model(args.model_path) print("[INFO] Starting generation...") - + tic = time.time() print(args.prompt, end="", flush=True) prompt = mx.array(tokenizer.encode(args.prompt)) tokens = [] - for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)): + for token, ntoks in zip(generate(prompt, model, args.temp), range(args.max_tokens)): tokens.append(token) + if ntoks == 0: + toc = time.time() + mx.eval(tokens) + prompt_tps = prompt.size / (toc - tic) + tic = time.time() if (len(tokens) % args.tokens_per_eval) == 0: mx.eval(tokens) @@ -282,3 +289,8 @@ if __name__ == "__main__": s = tokenizer.decode([t.item() for t in tokens]) print(s, flush=True) print("------") + generation_tps = ntoks / (time.time() - tic) + print( + f"Tokens per second: prompt {prompt_tps:.3f}, " + f"generation {generation_tps:.3f}" + ) diff --git a/lora/lora.py b/lora/lora.py index 718a27c9..528cf506 100644 --- a/lora/lora.py +++ b/lora/lora.py @@ -209,10 +209,13 @@ def iterate_batches(dset, tokenizer, batch_size, train=False): for j in range(batch_size) ] lengths = [len(x) for x in batch] - + # Check if any sequence is longer than 2048 tokens if max(lengths) > 2048: - print("Warning: Some sequences are longer than 2048 tokens. Consider pre-splitting your data to save memory.") + print( + "[WARNING] Some sequences are longer than 2048 tokens. " + "Consider pre-splitting your data to save memory." + ) # Pad to the max length batch_arr = np.zeros((batch_size, max(lengths)), np.int32) diff --git a/t5/hf_t5.py b/t5/hf_t5.py index 12329e4b..98c6da80 100644 --- a/t5/hf_t5.py +++ b/t5/hf_t5.py @@ -1,6 +1,6 @@ import argparse -from transformers import AutoTokenizer, T5EncoderModel, AutoModelForSeq2SeqLM +from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, T5EncoderModel def embed(t5_model: str): diff --git a/t5/t5.py b/t5/t5.py index 45855136..2acd39b4 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -6,7 +6,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np from mlx.utils import tree_map, tree_unflatten -from transformers import T5Config, AutoTokenizer +from transformers import AutoTokenizer, T5Config def _relative_position_bucket( @@ -252,9 +252,7 @@ class TransformerDecoder(nn.Module): def __init__(self, config: T5Config): super().__init__() n_layers = getattr(config, "num_decoder_layers", config.num_layers) - self.layers = [ - TransformerDecoderLayer(config) for i in range(n_layers) - ] + self.layers = [TransformerDecoderLayer(config) for i in range(n_layers)] self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) self.relative_attention_bias = RelativePositionBias(config, bidirectional=False)