mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Decode (broken after 1st token)
This commit is contained in:
parent
31da1b0dab
commit
203f550ef9
96
t5/t5.py
96
t5/t5.py
@ -1,5 +1,6 @@
|
|||||||
import argparse
|
import argparse
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
@ -26,7 +27,9 @@ class ModelArgs:
|
|||||||
vocab_size: int = 32128
|
vocab_size: int = 32128
|
||||||
|
|
||||||
|
|
||||||
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
|
def _relative_position_bucket(
|
||||||
|
relative_position, bidirectional=True, num_buckets=32, max_distance=128
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Adapted from HF Tensorflow:
|
Adapted from HF Tensorflow:
|
||||||
https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
||||||
@ -66,10 +69,10 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
|
|||||||
/ np.log(max_distance / max_exact)
|
/ np.log(max_distance / max_exact)
|
||||||
* (num_buckets - max_exact)
|
* (num_buckets - max_exact)
|
||||||
).astype(mx.int16)
|
).astype(mx.int16)
|
||||||
relative_position_if_large = mx.minimum(
|
relative_position_if_large = mx.minimum(relative_position_if_large, num_buckets - 1)
|
||||||
relative_position_if_large, num_buckets - 1
|
relative_buckets += mx.where(
|
||||||
|
is_small, relative_position, relative_position_if_large
|
||||||
)
|
)
|
||||||
relative_buckets += mx.where(is_small, relative_position, relative_position_if_large)
|
|
||||||
return relative_buckets
|
return relative_buckets
|
||||||
|
|
||||||
|
|
||||||
@ -80,22 +83,28 @@ class RelativePositionBias(nn.Module):
|
|||||||
self.max_distance = config.relative_attention_max_distance
|
self.max_distance = config.relative_attention_max_distance
|
||||||
self.n_heads = config.num_heads
|
self.n_heads = config.num_heads
|
||||||
self.embeddings = nn.Embedding(
|
self.embeddings = nn.Embedding(
|
||||||
config.relative_attention_num_buckets,
|
config.relative_attention_num_buckets, config.num_heads
|
||||||
config.num_heads)
|
)
|
||||||
|
|
||||||
def __call__(self, query_length, key_length):
|
def __call__(self, query_length, key_length):
|
||||||
"""Compute binned relative position bias"""
|
"""Compute binned relative position bias"""
|
||||||
context_position = mx.arange(query_length, dtype=mx.int32)[:, None]
|
context_position = mx.arange(query_length, dtype=mx.int32)[:, None]
|
||||||
memory_position = mx.arange(key_length, dtype=mx.int32)[None, :]
|
memory_position = mx.arange(key_length, dtype=mx.int32)[None, :]
|
||||||
relative_position = memory_position - context_position # shape (query_length, key_length)
|
relative_position = (
|
||||||
|
memory_position - context_position
|
||||||
|
) # shape (query_length, key_length)
|
||||||
relative_position_bucket = _relative_position_bucket(
|
relative_position_bucket = _relative_position_bucket(
|
||||||
relative_position, # shape (query_length, key_length)
|
relative_position, # shape (query_length, key_length)
|
||||||
bidirectional=self.bidirectional,
|
bidirectional=self.bidirectional,
|
||||||
num_buckets=self.num_buckets,
|
num_buckets=self.num_buckets,
|
||||||
max_distance=self.max_distance,
|
max_distance=self.max_distance,
|
||||||
)
|
)
|
||||||
values = self.embeddings(relative_position_bucket) # shape (query_length, key_length, num_heads)
|
values = self.embeddings(
|
||||||
values = mx.expand_dims(values.transpose(2, 0, 1), 0) # shape (1, num_heads, query_length, key_length)
|
relative_position_bucket
|
||||||
|
) # shape (query_length, key_length, num_heads)
|
||||||
|
values = mx.expand_dims(
|
||||||
|
values.transpose(2, 0, 1), 0
|
||||||
|
) # shape (1, num_heads, query_length, key_length)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
@ -132,10 +141,8 @@ class MultiHeadAttention(nn.Module):
|
|||||||
if self.has_relative_attention_bias:
|
if self.has_relative_attention_bias:
|
||||||
position_bias = self.relative_attention_bias(L, S)
|
position_bias = self.relative_attention_bias(L, S)
|
||||||
scores += position_bias
|
scores += position_bias
|
||||||
|
|
||||||
scores = mx.softmax(scores, axis=-1)
|
scores = mx.softmax(scores, axis=-1)
|
||||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
|
||||||
return self.out_proj(values_hat)
|
return self.out_proj(values_hat)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -260,7 +267,7 @@ class TransformerDecoder(nn.Module):
|
|||||||
|
|
||||||
class OutputHead(nn.Module):
|
class OutputHead(nn.Module):
|
||||||
def __init__(self, config: ModelArgs) -> None:
|
def __init__(self, config: ModelArgs) -> None:
|
||||||
self.linear = nn.Linear(config.d_model, config.vocab_size)
|
self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
||||||
|
|
||||||
def __call__(self, inputs):
|
def __call__(self, inputs):
|
||||||
return self.linear(inputs)
|
return self.linear(inputs)
|
||||||
@ -288,29 +295,35 @@ class T5(nn.Module):
|
|||||||
mask = mask.astype(x.dtype)
|
mask = mask.astype(x.dtype)
|
||||||
|
|
||||||
decoder_inputs = self.wte(decoder_inputs)
|
decoder_inputs = self.wte(decoder_inputs)
|
||||||
y, cache = self.decoder(x=decoder_inputs, x_mask=mask, memory=y) #, cache)
|
y, cache = self.decoder(
|
||||||
|
x=decoder_inputs, x_mask=mask, memory=y, memory_mask=None
|
||||||
|
) # , cache)
|
||||||
return self.lm_head(y), cache
|
return self.lm_head(y), cache
|
||||||
|
|
||||||
|
|
||||||
# def generate(prompt: mx.array, model: T5, temp: Optional[float] = 0.0):
|
def generate(
|
||||||
# def sample(logits):
|
inputs: mx.array, decoder_inputs: mx.array, model: T5, temp: Optional[float] = 0.0
|
||||||
# if temp == 0:
|
):
|
||||||
# return mx.argmax(logits, axis=-1)
|
def sample(logits):
|
||||||
# else:
|
if temp == 0:
|
||||||
# return mx.random.categorical(logits * (1 / temp))
|
return mx.argmax(logits, axis=-1)
|
||||||
|
else:
|
||||||
|
return mx.random.categorical(logits * (1 / temp))
|
||||||
|
|
||||||
# logits, cache = model(prompt)
|
logits, _ = model(inputs, decoder_inputs)
|
||||||
# y = sample(logits[:, -1, :])
|
y = sample(logits[:, -1, :])
|
||||||
# yield y
|
yield y
|
||||||
|
|
||||||
# while True:
|
while True:
|
||||||
# logits, cache = model(y[:, None], cache=cache)
|
# logits, cache = model(y[:, None], cache=cache)
|
||||||
# y = sample(logits.squeeze(1))
|
logits, _ = model(inputs, decoder_inputs)
|
||||||
# yield y
|
y = sample(logits.squeeze(1))
|
||||||
|
decoder_inputs = mx.concat(decoder_inputs, y, dim=1)
|
||||||
|
yield y
|
||||||
|
|
||||||
|
|
||||||
def load_model():
|
def load_model(model_config):
|
||||||
model = T5(ModelArgs())
|
model = T5(model_config)
|
||||||
weights = mx.load("weights.npz")
|
weights = mx.load("weights.npz")
|
||||||
current_weights = tree_flatten(model.parameters())
|
current_weights = tree_flatten(model.parameters())
|
||||||
weights_to_load = list(weights.items())
|
weights_to_load = list(weights.items())
|
||||||
@ -356,7 +369,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
mx.random.seed(args.seed)
|
mx.random.seed(args.seed)
|
||||||
|
|
||||||
model, tokenizer = load_model()
|
config = ModelArgs()
|
||||||
|
model, tokenizer = load_model(config)
|
||||||
|
|
||||||
prompt = tokenizer(
|
prompt = tokenizer(
|
||||||
args.prompt,
|
args.prompt,
|
||||||
@ -369,18 +383,20 @@ if __name__ == "__main__":
|
|||||||
print("[INFO] Generating with T5...", flush=True)
|
print("[INFO] Generating with T5...", flush=True)
|
||||||
print(args.prompt, end="", flush=True)
|
print(args.prompt, end="", flush=True)
|
||||||
|
|
||||||
print(model(prompt))
|
decoder_inputs = mx.array([[config.decoder_start_token_id]])
|
||||||
|
|
||||||
# tokens = []
|
tokens = []
|
||||||
# for token, _ in zip(generate(prompt, model), range(args.max_tokens)):
|
for token, _ in zip(
|
||||||
# tokens.append(token)
|
generate(prompt, decoder_inputs, model), range(args.max_tokens)
|
||||||
|
):
|
||||||
|
tokens.append(token)
|
||||||
|
|
||||||
# if (len(tokens) % 10) == 0:
|
if (len(tokens) % 10) == 0:
|
||||||
# mx.eval(tokens)
|
mx.eval(tokens)
|
||||||
# s = tokenizer.decode([t.item() for t in tokens])
|
s = tokenizer.decode([t.item() for t in tokens])
|
||||||
# print(s, end="", flush=True)
|
print(s, end="", flush=True)
|
||||||
# tokens = []
|
tokens = []
|
||||||
|
|
||||||
# mx.eval(tokens)
|
mx.eval(tokens)
|
||||||
# s = tokenizer.decode([t.item() for t in tokens])
|
s = tokenizer.decode([t.item() for t in tokens])
|
||||||
# print(s, flush=True)
|
print(s, flush=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user