mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-23 05:58:07 +08:00
feat: add mistral tps (#173)
* feat: add mistral tps * eval params before timing + format --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:

committed by
GitHub

parent
188a91074b
commit
7ae445f6c7
@@ -1,6 +1,6 @@
|
||||
import argparse
|
||||
|
||||
from transformers import AutoTokenizer, T5EncoderModel, AutoModelForSeq2SeqLM
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, T5EncoderModel
|
||||
|
||||
|
||||
def embed(t5_model: str):
|
||||
|
6
t5/t5.py
6
t5/t5.py
@@ -6,7 +6,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from mlx.utils import tree_map, tree_unflatten
|
||||
from transformers import T5Config, AutoTokenizer
|
||||
from transformers import AutoTokenizer, T5Config
|
||||
|
||||
|
||||
def _relative_position_bucket(
|
||||
@@ -252,9 +252,7 @@ class TransformerDecoder(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__()
|
||||
n_layers = getattr(config, "num_decoder_layers", config.num_layers)
|
||||
self.layers = [
|
||||
TransformerDecoderLayer(config) for i in range(n_layers)
|
||||
]
|
||||
self.layers = [TransformerDecoderLayer(config) for i in range(n_layers)]
|
||||
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.relative_attention_bias = RelativePositionBias(config, bidirectional=False)
|
||||
|
||||
|
Reference in New Issue
Block a user