feat: add mistral tps (#173)

* feat: add mistral tps

* eval params before timing + format

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Todsaporn Banjerdkit
2023-12-22 22:55:57 +07:00
committed by GitHub
parent 188a91074b
commit 7ae445f6c7
4 changed files with 22 additions and 9 deletions

View File

@@ -1,6 +1,6 @@
import argparse
from transformers import AutoTokenizer, T5EncoderModel, AutoModelForSeq2SeqLM
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, T5EncoderModel
def embed(t5_model: str):

View File

@@ -6,7 +6,7 @@ import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.utils import tree_map, tree_unflatten
from transformers import T5Config, AutoTokenizer
from transformers import AutoTokenizer, T5Config
def _relative_position_bucket(
@@ -252,9 +252,7 @@ class TransformerDecoder(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
n_layers = getattr(config, "num_decoder_layers", config.num_layers)
self.layers = [
TransformerDecoderLayer(config) for i in range(n_layers)
]
self.layers = [TransformerDecoderLayer(config) for i in range(n_layers)]
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.relative_attention_bias = RelativePositionBias(config, bidirectional=False)