mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Decode (broken after 1st token)
This commit is contained in:
parent
31da1b0dab
commit
203f550ef9
100
t5/t5.py
100
t5/t5.py
@ -1,5 +1,6 @@
|
||||
import argparse
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import mlx.core as mx
|
||||
@ -26,7 +27,9 @@ class ModelArgs:
|
||||
vocab_size: int = 32128
|
||||
|
||||
|
||||
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
|
||||
def _relative_position_bucket(
|
||||
relative_position, bidirectional=True, num_buckets=32, max_distance=128
|
||||
):
|
||||
"""
|
||||
Adapted from HF Tensorflow:
|
||||
https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
||||
@ -66,10 +69,10 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
|
||||
/ np.log(max_distance / max_exact)
|
||||
* (num_buckets - max_exact)
|
||||
).astype(mx.int16)
|
||||
relative_position_if_large = mx.minimum(
|
||||
relative_position_if_large, num_buckets - 1
|
||||
relative_position_if_large = mx.minimum(relative_position_if_large, num_buckets - 1)
|
||||
relative_buckets += mx.where(
|
||||
is_small, relative_position, relative_position_if_large
|
||||
)
|
||||
relative_buckets += mx.where(is_small, relative_position, relative_position_if_large)
|
||||
return relative_buckets
|
||||
|
||||
|
||||
@ -80,22 +83,28 @@ class RelativePositionBias(nn.Module):
|
||||
self.max_distance = config.relative_attention_max_distance
|
||||
self.n_heads = config.num_heads
|
||||
self.embeddings = nn.Embedding(
|
||||
config.relative_attention_num_buckets,
|
||||
config.num_heads)
|
||||
config.relative_attention_num_buckets, config.num_heads
|
||||
)
|
||||
|
||||
def __call__(self, query_length, key_length):
|
||||
"""Compute binned relative position bias"""
|
||||
context_position = mx.arange(query_length, dtype=mx.int32)[:, None]
|
||||
memory_position = mx.arange(key_length, dtype=mx.int32)[None, :]
|
||||
relative_position = memory_position - context_position # shape (query_length, key_length)
|
||||
relative_position = (
|
||||
memory_position - context_position
|
||||
) # shape (query_length, key_length)
|
||||
relative_position_bucket = _relative_position_bucket(
|
||||
relative_position, # shape (query_length, key_length)
|
||||
bidirectional=self.bidirectional,
|
||||
num_buckets=self.num_buckets,
|
||||
max_distance=self.max_distance,
|
||||
)
|
||||
values = self.embeddings(relative_position_bucket) # shape (query_length, key_length, num_heads)
|
||||
values = mx.expand_dims(values.transpose(2, 0, 1), 0) # shape (1, num_heads, query_length, key_length)
|
||||
values = self.embeddings(
|
||||
relative_position_bucket
|
||||
) # shape (query_length, key_length, num_heads)
|
||||
values = mx.expand_dims(
|
||||
values.transpose(2, 0, 1), 0
|
||||
) # shape (1, num_heads, query_length, key_length)
|
||||
return values
|
||||
|
||||
|
||||
@ -132,10 +141,8 @@ class MultiHeadAttention(nn.Module):
|
||||
if self.has_relative_attention_bias:
|
||||
position_bias = self.relative_attention_bias(L, S)
|
||||
scores += position_bias
|
||||
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
return self.out_proj(values_hat)
|
||||
|
||||
@staticmethod
|
||||
@ -260,7 +267,7 @@ class TransformerDecoder(nn.Module):
|
||||
|
||||
class OutputHead(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
self.linear = nn.Linear(config.d_model, config.vocab_size)
|
||||
self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
||||
|
||||
def __call__(self, inputs):
|
||||
return self.linear(inputs)
|
||||
@ -281,36 +288,42 @@ class T5(nn.Module):
|
||||
cache: mx.array = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
x = self.wte(inputs)
|
||||
y = self.encoder(x, mask=None) #, cache)
|
||||
y = self.encoder(x, mask=None) # , cache)
|
||||
|
||||
if x.shape[1] > 1 and mask is None:
|
||||
mask = MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||
mask = mask.astype(x.dtype)
|
||||
|
||||
decoder_inputs = self.wte(decoder_inputs)
|
||||
y, cache = self.decoder(x=decoder_inputs, x_mask=mask, memory=y) #, cache)
|
||||
y, cache = self.decoder(
|
||||
x=decoder_inputs, x_mask=mask, memory=y, memory_mask=None
|
||||
) # , cache)
|
||||
return self.lm_head(y), cache
|
||||
|
||||
|
||||
# def generate(prompt: mx.array, model: T5, temp: Optional[float] = 0.0):
|
||||
# def sample(logits):
|
||||
# if temp == 0:
|
||||
# return mx.argmax(logits, axis=-1)
|
||||
# else:
|
||||
# return mx.random.categorical(logits * (1 / temp))
|
||||
def generate(
|
||||
inputs: mx.array, decoder_inputs: mx.array, model: T5, temp: Optional[float] = 0.0
|
||||
):
|
||||
def sample(logits):
|
||||
if temp == 0:
|
||||
return mx.argmax(logits, axis=-1)
|
||||
else:
|
||||
return mx.random.categorical(logits * (1 / temp))
|
||||
|
||||
# logits, cache = model(prompt)
|
||||
# y = sample(logits[:, -1, :])
|
||||
# yield y
|
||||
logits, _ = model(inputs, decoder_inputs)
|
||||
y = sample(logits[:, -1, :])
|
||||
yield y
|
||||
|
||||
# while True:
|
||||
# logits, cache = model(y[:, None], cache=cache)
|
||||
# y = sample(logits.squeeze(1))
|
||||
# yield y
|
||||
while True:
|
||||
# logits, cache = model(y[:, None], cache=cache)
|
||||
logits, _ = model(inputs, decoder_inputs)
|
||||
y = sample(logits.squeeze(1))
|
||||
decoder_inputs = mx.concat(decoder_inputs, y, dim=1)
|
||||
yield y
|
||||
|
||||
|
||||
def load_model():
|
||||
model = T5(ModelArgs())
|
||||
def load_model(model_config):
|
||||
model = T5(model_config)
|
||||
weights = mx.load("weights.npz")
|
||||
current_weights = tree_flatten(model.parameters())
|
||||
weights_to_load = list(weights.items())
|
||||
@ -356,7 +369,8 @@ if __name__ == "__main__":
|
||||
|
||||
mx.random.seed(args.seed)
|
||||
|
||||
model, tokenizer = load_model()
|
||||
config = ModelArgs()
|
||||
model, tokenizer = load_model(config)
|
||||
|
||||
prompt = tokenizer(
|
||||
args.prompt,
|
||||
@ -369,18 +383,20 @@ if __name__ == "__main__":
|
||||
print("[INFO] Generating with T5...", flush=True)
|
||||
print(args.prompt, end="", flush=True)
|
||||
|
||||
print(model(prompt))
|
||||
decoder_inputs = mx.array([[config.decoder_start_token_id]])
|
||||
|
||||
# tokens = []
|
||||
# for token, _ in zip(generate(prompt, model), range(args.max_tokens)):
|
||||
# tokens.append(token)
|
||||
tokens = []
|
||||
for token, _ in zip(
|
||||
generate(prompt, decoder_inputs, model), range(args.max_tokens)
|
||||
):
|
||||
tokens.append(token)
|
||||
|
||||
# if (len(tokens) % 10) == 0:
|
||||
# mx.eval(tokens)
|
||||
# s = tokenizer.decode([t.item() for t in tokens])
|
||||
# print(s, end="", flush=True)
|
||||
# tokens = []
|
||||
if (len(tokens) % 10) == 0:
|
||||
mx.eval(tokens)
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
print(s, end="", flush=True)
|
||||
tokens = []
|
||||
|
||||
# mx.eval(tokens)
|
||||
# s = tokenizer.decode([t.item() for t in tokens])
|
||||
# print(s, flush=True)
|
||||
mx.eval(tokens)
|
||||
s = tokenizer.decode([t.item() for t in tokens])
|
||||
print(s, flush=True)
|
||||
|
Loading…
Reference in New Issue
Block a user