mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 19:06:37 +08:00
with cache
This commit is contained in:
parent
29bfb93455
commit
688a6e1e78
@ -14,10 +14,8 @@ SHARED_REPLACEMENT_PATTERNS = [
|
|||||||
(".layer.1.layer_norm.", ".ln2."),
|
(".layer.1.layer_norm.", ".ln2."),
|
||||||
(".layer.2.layer_norm.", ".ln3."),
|
(".layer.2.layer_norm.", ".ln3."),
|
||||||
(".final_layer_norm.", ".ln."),
|
(".final_layer_norm.", ".ln."),
|
||||||
(
|
("layers.0.layer.0.SelfAttention.relative_attention_bias.",
|
||||||
".relative_attention_bias.",
|
"relative_attention_bias.embeddings."),
|
||||||
".relative_attention_bias.embeddings."
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
ENCODER_REPLACEMENT_PATTERNS = [
|
ENCODER_REPLACEMENT_PATTERNS = [
|
||||||
@ -33,6 +31,7 @@ DECODER_REPLACEMENT_PATTERNS = [
|
|||||||
(".layer.2.DenseReluDense.wo.", ".linear2."),
|
(".layer.2.DenseReluDense.wo.", ".linear2."),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def replace_key(key: str) -> str:
|
def replace_key(key: str) -> str:
|
||||||
for old, new in SHARED_REPLACEMENT_PATTERNS:
|
for old, new in SHARED_REPLACEMENT_PATTERNS:
|
||||||
key = key.replace(old, new)
|
key = key.replace(old, new)
|
||||||
@ -45,14 +44,22 @@ def replace_key(key: str) -> str:
|
|||||||
return key
|
return key
|
||||||
|
|
||||||
|
|
||||||
def convert():
|
def convert(model_name):
|
||||||
model = T5ForConditionalGeneration.from_pretrained(
|
model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")
|
||||||
"t5-small", torch_dtype="auto"
|
weights = {replace_key(k): v.numpy() for k, v in model.state_dict().items()}
|
||||||
)
|
|
||||||
state_dict = model.state_dict()
|
|
||||||
weights = {replace_key(k): v.numpy() for k, v in state_dict.items()}
|
|
||||||
np.savez("weights.npz", **weights)
|
np.savez("weights.npz", **weights)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
convert()
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Convert T5 weights to MLX")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_name",
|
||||||
|
type=str,
|
||||||
|
help="Name of the T5 model.",
|
||||||
|
choices=["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"],
|
||||||
|
default="t5-small",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
convert(args.model_name)
|
||||||
|
236
t5/t5.py
236
t5/t5.py
@ -1,5 +1,6 @@
|
|||||||
import argparse
|
import argparse
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Tuple, List
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from time import perf_counter_ns
|
from time import perf_counter_ns
|
||||||
|
|
||||||
@ -110,18 +111,23 @@ class RelativePositionBias(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class MultiHeadAttention(nn.Module):
|
class MultiHeadAttention(nn.Module):
|
||||||
def __init__(self, config: ModelArgs, has_relative_attention_bias: bool = False):
|
def __init__(self, config: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_heads = config.num_heads
|
self.num_heads = config.num_heads
|
||||||
self.query_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
self.query_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
||||||
self.key_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
self.key_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
||||||
self.value_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
self.value_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
||||||
self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
||||||
self.has_relative_attention_bias = has_relative_attention_bias
|
|
||||||
if has_relative_attention_bias:
|
|
||||||
self.relative_attention_bias = RelativePositionBias(config)
|
|
||||||
|
|
||||||
def __call__(self, queries, keys, values, mask=None, position_bias=None):
|
def __call__(
|
||||||
|
self,
|
||||||
|
queries: mx.array,
|
||||||
|
keys: mx.array,
|
||||||
|
values: mx.array,
|
||||||
|
mask: mx.array,
|
||||||
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
|
||||||
queries = self.query_proj(queries)
|
queries = self.query_proj(queries)
|
||||||
keys = self.key_proj(keys)
|
keys = self.key_proj(keys)
|
||||||
values = self.value_proj(values)
|
values = self.value_proj(values)
|
||||||
@ -133,106 +139,95 @@ class MultiHeadAttention(nn.Module):
|
|||||||
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
|
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
|
||||||
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
key_cache, value_cache = cache
|
||||||
|
keys = mx.concatenate([key_cache, keys], axis=3)
|
||||||
|
values = mx.concatenate([value_cache, values], axis=2)
|
||||||
|
|
||||||
# Dimensions are [batch x num heads x sequence x hidden dim]
|
# Dimensions are [batch x num heads x sequence x hidden dim]
|
||||||
scores = queries @ keys
|
scores = queries @ keys
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
scores = scores + mask.astype(scores.dtype)
|
scores = scores + mask.astype(scores.dtype)
|
||||||
|
|
||||||
if self.has_relative_attention_bias:
|
|
||||||
position_bias = self.relative_attention_bias(L, S)
|
|
||||||
if position_bias is not None:
|
|
||||||
scores += position_bias
|
|
||||||
scores = mx.softmax(scores, axis=-1)
|
scores = mx.softmax(scores, axis=-1)
|
||||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
out = self.out_proj(values_hat)
|
return self.out_proj(values_hat), (keys, values)
|
||||||
return out, position_bias
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32):
|
|
||||||
indices = mx.arange(N)
|
|
||||||
mask = indices[:, None] < indices[None]
|
|
||||||
# usually inf but 1e9 is as good and softmax(full(1e9)) != nan
|
|
||||||
# TODO: Should replace this with finfo(dtype).min
|
|
||||||
mask = mask.astype(dtype) * -1e9
|
|
||||||
return mask
|
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm(nn.Module):
|
class RMSNorm(nn.Module):
|
||||||
def __init__(self, dims: int, eps: float = 1e-5):
|
def __init__(self, dims: int, eps: float = 1e-5):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight = mx.ones((dims,))
|
self.weight = mx.ones((dims,))
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.dims = dims
|
|
||||||
|
def _norm(self, x):
|
||||||
|
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
var = x.var(axis=-1, keepdims=True)
|
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
||||||
x = x * mx.rsqrt(var + self.eps)
|
return self.weight * output
|
||||||
return x * self.weight
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerEncoderLayer(nn.Module):
|
class TransformerEncoderLayer(nn.Module):
|
||||||
def __init__(self, config: ModelArgs, has_relative_attention_bias: bool = False):
|
def __init__(self, config: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
mlp_dims = config.d_ff or config.d_model * 4
|
mlp_dims = config.d_ff or config.d_model * 4
|
||||||
self.attention = MultiHeadAttention(
|
self.attention = MultiHeadAttention(config)
|
||||||
config, has_relative_attention_bias=has_relative_attention_bias
|
self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
)
|
self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
self.ln1 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
|
||||||
self.ln2 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
|
||||||
self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||||
self.linear2 = nn.Linear(mlp_dims, config.d_model, bias=False)
|
self.linear2 = nn.Linear(mlp_dims, config.d_model, bias=False)
|
||||||
|
|
||||||
def __call__(self, x, mask, position_bias=None):
|
def __call__(self, x, mask):
|
||||||
y = self.ln1(x)
|
y = self.ln1(x)
|
||||||
y, position_bias = self.attention(
|
y, _ = self.attention(y, y, y, mask=mask)
|
||||||
queries=y, keys=y, values=y, mask=mask, position_bias=position_bias
|
|
||||||
)
|
|
||||||
x = x + y
|
x = x + y
|
||||||
|
|
||||||
y = self.ln2(x)
|
y = self.ln2(x)
|
||||||
y = self.linear1(y)
|
y = self.linear1(y)
|
||||||
y = mx.maximum(y, 0)
|
y = mx.maximum(y, 0)
|
||||||
y = self.linear2(y)
|
y = self.linear2(y)
|
||||||
x = x + y
|
return x + y
|
||||||
|
|
||||||
return x, position_bias
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerEncoder(nn.Module):
|
class TransformerEncoder(nn.Module):
|
||||||
def __init__(self, config: ModelArgs):
|
def __init__(self, config: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layers = [
|
self.layers = [
|
||||||
TransformerEncoderLayer(config, has_relative_attention_bias=i == 0)
|
TransformerEncoderLayer(config) for i in range(config.num_layers)
|
||||||
for i in range(config.num_layers)
|
|
||||||
]
|
]
|
||||||
self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
|
self.relative_attention_bias = RelativePositionBias(config)
|
||||||
|
|
||||||
def __call__(self, x, mask):
|
def __call__(self, x):
|
||||||
position_bias = None
|
pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1])
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
x, position_bias = layer(x, mask, position_bias=position_bias)
|
x = layer(x, mask=pos_bias)
|
||||||
x = self.ln(x)
|
return self.ln(x)
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerDecoderLayer(nn.Module):
|
class TransformerDecoderLayer(nn.Module):
|
||||||
def __init__(self, config: ModelArgs, has_relative_attention_bias: bool = False):
|
def __init__(self, config: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
mlp_dims = config.d_ff or config.d_model * 4
|
mlp_dims = config.d_ff or config.d_model * 4
|
||||||
self.self_attention = MultiHeadAttention(
|
self.self_attention = MultiHeadAttention(config)
|
||||||
config, has_relative_attention_bias=has_relative_attention_bias
|
|
||||||
)
|
|
||||||
self.cross_attention = MultiHeadAttention(config)
|
self.cross_attention = MultiHeadAttention(config)
|
||||||
self.ln1 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
self.ln2 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
self.ln3 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
self.ln3 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||||
self.linear2 = nn.Linear(mlp_dims, config.d_model, bias=False)
|
self.linear2 = nn.Linear(mlp_dims, config.d_model, bias=False)
|
||||||
|
|
||||||
def __call__(self, x, memory, x_mask, memory_mask, position_bias=None):
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
memory: mx.array,
|
||||||
|
mask: mx.array,
|
||||||
|
memory_mask: mx.array,
|
||||||
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
) -> mx.array:
|
||||||
y = self.ln1(x)
|
y = self.ln1(x)
|
||||||
y, position_bias = self.self_attention(y, y, y, x_mask, position_bias=position_bias)
|
y, cache = self.self_attention(y, y, y, mask, cache)
|
||||||
x = x + y
|
x = x + y
|
||||||
|
|
||||||
y = self.ln2(x)
|
y = self.ln2(x)
|
||||||
@ -245,27 +240,39 @@ class TransformerDecoderLayer(nn.Module):
|
|||||||
y = self.linear2(y)
|
y = self.linear2(y)
|
||||||
x = x + y
|
x = x + y
|
||||||
|
|
||||||
return x, position_bias
|
return x, cache
|
||||||
|
|
||||||
|
|
||||||
class TransformerDecoder(nn.Module):
|
class TransformerDecoder(nn.Module):
|
||||||
def __init__(self, config: ModelArgs):
|
def __init__(self, config: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layers = [
|
self.layers = [
|
||||||
TransformerDecoderLayer(config, has_relative_attention_bias=i == 0)
|
TransformerDecoderLayer(config) for i in range(config.num_layers)
|
||||||
for i in range(config.num_layers)
|
|
||||||
]
|
]
|
||||||
self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
|
self.relative_attention_bias = RelativePositionBias(config)
|
||||||
|
|
||||||
def __call__(self, x, memory, x_mask, memory_mask):
|
def __call__(self, x, memory, mask, memory_mask, cache=None):
|
||||||
position_bias = None
|
if cache is not None:
|
||||||
for layer in self.layers:
|
offset = cache[0][0].shape[3]
|
||||||
x, position_bias = layer(
|
else:
|
||||||
x, memory, x_mask, memory_mask, position_bias=position_bias
|
offset = 0
|
||||||
)
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
|
T = offset + x.shape[1]
|
||||||
|
# TODO, add offset to RelativePositionBias class to avoid wasted work
|
||||||
|
pos_bias = self.relative_attention_bias(T, T)
|
||||||
|
pos_bias = pos_bias[:, :, -x.shape[1]:, :]
|
||||||
|
if mask is not None:
|
||||||
|
mask += pos_bias
|
||||||
|
else:
|
||||||
|
mask = pos_bias
|
||||||
|
|
||||||
|
for e, layer in enumerate(self.layers):
|
||||||
|
x, cache[e] = layer(x, memory, mask, memory_mask, cache=cache[e])
|
||||||
x = self.ln(x)
|
x = self.ln(x)
|
||||||
|
|
||||||
return x
|
return x, cache
|
||||||
|
|
||||||
|
|
||||||
class OutputHead(nn.Module):
|
class OutputHead(nn.Module):
|
||||||
@ -283,26 +290,37 @@ class T5(nn.Module):
|
|||||||
self.decoder = TransformerDecoder(config)
|
self.decoder = TransformerDecoder(config)
|
||||||
self.lm_head = OutputHead(config)
|
self.lm_head = OutputHead(config)
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
self,
|
||||||
|
inputs: mx.array
|
||||||
|
) -> mx.array:
|
||||||
|
return self.encoder(self.wte(inputs))
|
||||||
|
|
||||||
|
def decode(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
memory: mx.array,
|
||||||
|
cache = None,
|
||||||
|
):
|
||||||
|
inputs = self.wte(inputs)
|
||||||
|
T = inputs.shape[1]
|
||||||
|
if T > 1:
|
||||||
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
|
||||||
|
mask = mask.astype(inputs.dtype)
|
||||||
|
else:
|
||||||
|
mask = None
|
||||||
|
|
||||||
|
y, cache = self.decoder(
|
||||||
|
inputs, memory=memory, mask=mask, memory_mask=None, cache=cache
|
||||||
|
)
|
||||||
|
return self.lm_head(y), cache
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
decoder_inputs: mx.array,
|
decoder_inputs: mx.array,
|
||||||
mask: mx.array = None,
|
) -> mx.array:
|
||||||
cache: mx.array = None,
|
return decode(decoder_inputs, encode(inputs))[0]
|
||||||
) -> tuple[mx.array, mx.array]:
|
|
||||||
x = self.wte(inputs)
|
|
||||||
y = self.encoder(x, mask=None) # , cache)
|
|
||||||
|
|
||||||
decoder_inputs = self.wte(decoder_inputs)
|
|
||||||
decoder_n_tokens = decoder_inputs.shape[1]
|
|
||||||
if decoder_n_tokens > 1 and mask is None:
|
|
||||||
mask = MultiHeadAttention.create_additive_causal_mask(decoder_n_tokens)
|
|
||||||
mask = mask.astype(x.dtype)
|
|
||||||
|
|
||||||
y = self.decoder(
|
|
||||||
x=decoder_inputs, x_mask=mask, memory=y, memory_mask=None
|
|
||||||
) # , cache)
|
|
||||||
return self.lm_head(y), cache
|
|
||||||
|
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
@ -314,12 +332,15 @@ def generate(
|
|||||||
else:
|
else:
|
||||||
return mx.random.categorical(logits * (1 / temp))
|
return mx.random.categorical(logits * (1 / temp))
|
||||||
|
|
||||||
|
memory = model.encode(inputs)
|
||||||
|
cache = None
|
||||||
|
y = decoder_inputs
|
||||||
while True:
|
while True:
|
||||||
# TODO: add cache
|
logits, cache = model.decode(y[None], memory, cache=cache)
|
||||||
logits, _ = model(inputs, decoder_inputs)
|
# logits, cache = model.decode(decoder_inputs[None], memory, cache=cache)
|
||||||
y = mx.expand_dims(sample(logits[:, -1, :]), 0)
|
y = sample(logits[:, -1, :])
|
||||||
decoder_inputs = mx.concatenate([decoder_inputs, y], axis=1)
|
#decoder_inputs = mx.concatenate([decoder_inputs, y])
|
||||||
yield y
|
yield y.squeeze()
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_config):
|
def load_model(model_config):
|
||||||
@ -340,6 +361,7 @@ def load_model(model_config):
|
|||||||
print("Expected shape: ", current_weights_dict[key].shape)
|
print("Expected shape: ", current_weights_dict[key].shape)
|
||||||
print("Loading shape: ", weights_to_load_dict[key].shape)
|
print("Loading shape: ", weights_to_load_dict[key].shape)
|
||||||
model.update(tree_unflatten(weights_to_load))
|
model.update(tree_unflatten(weights_to_load))
|
||||||
|
mx.eval(model.parameters())
|
||||||
tokenizer = AutoTokenizer.from_pretrained("t5-small", trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained("t5-small", trust_remote_code=True)
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
@ -397,39 +419,23 @@ if __name__ == "__main__":
|
|||||||
print("[INFO] Generating with T5...", flush=True)
|
print("[INFO] Generating with T5...", flush=True)
|
||||||
print("Input: ", args.prompt, flush=True)
|
print("Input: ", args.prompt, flush=True)
|
||||||
|
|
||||||
decoder_inputs = mx.array([[config.decoder_start_token_id]]).astype(mx.uint32)
|
decoder_inputs = mx.array([config.decoder_start_token_id])
|
||||||
|
|
||||||
start = perf_counter_ns()
|
start = perf_counter_ns()
|
||||||
n_tokens = 0
|
|
||||||
tokens = []
|
tokens = []
|
||||||
for token, _ in zip(
|
for token, n_tokens in zip(
|
||||||
generate(prompt, decoder_inputs, model, args.temp),
|
generate(prompt, decoder_inputs, model, args.temp),
|
||||||
range(args.max_tokens)
|
range(args.max_tokens)
|
||||||
):
|
):
|
||||||
tokens.append(token)
|
if token.item() == tokenizer.eos_token_id:
|
||||||
|
break
|
||||||
if (len(tokens) % 10) == 0:
|
tokens.append(token.item())
|
||||||
mx.eval(tokens)
|
# For some reason using the following line doesn't give spaces
|
||||||
eos_index = next(
|
# print(tokenizer.decode(token.item(), clean_up_tokenization_spaces=False), end="", flush=True)
|
||||||
(i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id),
|
print(tokenizer.decode(tokens), end="", flush=True)
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
if eos_index is not None:
|
|
||||||
tokens = tokens[:eos_index]
|
|
||||||
|
|
||||||
s = tokenizer.decode([t.item() for t in tokens])
|
|
||||||
print(s, end="", flush=True)
|
|
||||||
n_tokens += len(tokens)
|
|
||||||
tokens = []
|
|
||||||
if eos_index is not None:
|
|
||||||
break
|
|
||||||
|
|
||||||
end = perf_counter_ns()
|
end = perf_counter_ns()
|
||||||
|
|
||||||
mx.eval(tokens)
|
|
||||||
s = tokenizer.decode([t.item() for t in tokens])
|
|
||||||
n_tokens += len(tokens)
|
|
||||||
elapsed = (end - start) / 1.0e9
|
elapsed = (end - start) / 1.0e9
|
||||||
print(s, flush=True)
|
print()
|
||||||
print(f"Time: {elapsed:.2f} seconds, tokens/s: {n_tokens / elapsed:.2f}")
|
print(f"Time: {elapsed:.2f} seconds, tokens/s: {n_tokens / elapsed:.2f}")
|
||||||
|
Loading…
Reference in New Issue
Block a user