Decode (broken after 1st token)

This commit is contained in:
Juarez Bochi 2023-12-16 14:53:50 -05:00
parent 31da1b0dab
commit 203f550ef9
No known key found for this signature in database
GPG Key ID: 34CCBB77DC8BEBB6

128
t5/t5.py
View File

@ -1,5 +1,6 @@
import argparse import argparse
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional
import numpy as np import numpy as np
import mlx.core as mx import mlx.core as mx
@ -10,23 +11,25 @@ from transformers import AutoTokenizer
@dataclass @dataclass
class ModelArgs: class ModelArgs:
d_ff: int = 2048 d_ff: int = 2048
d_kv: int = 64 d_kv: int = 64
d_model: int = 512 d_model: int = 512
dropout_rate: int = 0.1 dropout_rate: int = 0.1
layer_norm_epsilon: float = 1e-06 layer_norm_epsilon: float = 1e-06
n_positions: int = 512 n_positions: int = 512
relative_attention_num_buckets: int = 32 relative_attention_num_buckets: int = 32
relative_attention_max_distance: int = 128 relative_attention_max_distance: int = 128
num_heads: int = 8 num_heads: int = 8
num_layers: int = 6 num_layers: int = 6
decoder_start_token_id: int = 0 decoder_start_token_id: int = 0
eos_token_id: int = 1 eos_token_id: int = 1
pad_token_id: int = 0 pad_token_id: int = 0
vocab_size: int = 32128 vocab_size: int = 32128
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): def _relative_position_bucket(
relative_position, bidirectional=True, num_buckets=32, max_distance=128
):
""" """
Adapted from HF Tensorflow: Adapted from HF Tensorflow:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
@ -66,10 +69,10 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
/ np.log(max_distance / max_exact) / np.log(max_distance / max_exact)
* (num_buckets - max_exact) * (num_buckets - max_exact)
).astype(mx.int16) ).astype(mx.int16)
relative_position_if_large = mx.minimum( relative_position_if_large = mx.minimum(relative_position_if_large, num_buckets - 1)
relative_position_if_large, num_buckets - 1 relative_buckets += mx.where(
is_small, relative_position, relative_position_if_large
) )
relative_buckets += mx.where(is_small, relative_position, relative_position_if_large)
return relative_buckets return relative_buckets
@ -80,22 +83,28 @@ class RelativePositionBias(nn.Module):
self.max_distance = config.relative_attention_max_distance self.max_distance = config.relative_attention_max_distance
self.n_heads = config.num_heads self.n_heads = config.num_heads
self.embeddings = nn.Embedding( self.embeddings = nn.Embedding(
config.relative_attention_num_buckets, config.relative_attention_num_buckets, config.num_heads
config.num_heads) )
def __call__(self, query_length, key_length): def __call__(self, query_length, key_length):
"""Compute binned relative position bias""" """Compute binned relative position bias"""
context_position = mx.arange(query_length, dtype=mx.int32)[:, None] context_position = mx.arange(query_length, dtype=mx.int32)[:, None]
memory_position = mx.arange(key_length, dtype=mx.int32)[None, :] memory_position = mx.arange(key_length, dtype=mx.int32)[None, :]
relative_position = memory_position - context_position # shape (query_length, key_length) relative_position = (
memory_position - context_position
) # shape (query_length, key_length)
relative_position_bucket = _relative_position_bucket( relative_position_bucket = _relative_position_bucket(
relative_position, # shape (query_length, key_length) relative_position, # shape (query_length, key_length)
bidirectional=self.bidirectional, bidirectional=self.bidirectional,
num_buckets=self.num_buckets, num_buckets=self.num_buckets,
max_distance=self.max_distance, max_distance=self.max_distance,
) )
values = self.embeddings(relative_position_bucket) # shape (query_length, key_length, num_heads) values = self.embeddings(
values = mx.expand_dims(values.transpose(2, 0, 1), 0) # shape (1, num_heads, query_length, key_length) relative_position_bucket
) # shape (query_length, key_length, num_heads)
values = mx.expand_dims(
values.transpose(2, 0, 1), 0
) # shape (1, num_heads, query_length, key_length)
return values return values
@ -132,10 +141,8 @@ class MultiHeadAttention(nn.Module):
if self.has_relative_attention_bias: if self.has_relative_attention_bias:
position_bias = self.relative_attention_bias(L, S) position_bias = self.relative_attention_bias(L, S)
scores += position_bias scores += position_bias
scores = mx.softmax(scores, axis=-1) scores = mx.softmax(scores, axis=-1)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat) return self.out_proj(values_hat)
@staticmethod @staticmethod
@ -260,7 +267,7 @@ class TransformerDecoder(nn.Module):
class OutputHead(nn.Module): class OutputHead(nn.Module):
def __init__(self, config: ModelArgs) -> None: def __init__(self, config: ModelArgs) -> None:
self.linear = nn.Linear(config.d_model, config.vocab_size) self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False)
def __call__(self, inputs): def __call__(self, inputs):
return self.linear(inputs) return self.linear(inputs)
@ -281,36 +288,42 @@ class T5(nn.Module):
cache: mx.array = None, cache: mx.array = None,
) -> tuple[mx.array, mx.array]: ) -> tuple[mx.array, mx.array]:
x = self.wte(inputs) x = self.wte(inputs)
y = self.encoder(x, mask=None) #, cache) y = self.encoder(x, mask=None) # , cache)
if x.shape[1] > 1 and mask is None: if x.shape[1] > 1 and mask is None:
mask = MultiHeadAttention.create_additive_causal_mask(x.shape[1]) mask = MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(x.dtype) mask = mask.astype(x.dtype)
decoder_inputs = self.wte(decoder_inputs) decoder_inputs = self.wte(decoder_inputs)
y, cache = self.decoder(x=decoder_inputs, x_mask=mask, memory=y) #, cache) y, cache = self.decoder(
x=decoder_inputs, x_mask=mask, memory=y, memory_mask=None
) # , cache)
return self.lm_head(y), cache return self.lm_head(y), cache
# def generate(prompt: mx.array, model: T5, temp: Optional[float] = 0.0): def generate(
# def sample(logits): inputs: mx.array, decoder_inputs: mx.array, model: T5, temp: Optional[float] = 0.0
# if temp == 0: ):
# return mx.argmax(logits, axis=-1) def sample(logits):
# else: if temp == 0:
# return mx.random.categorical(logits * (1 / temp)) return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
# logits, cache = model(prompt) logits, _ = model(inputs, decoder_inputs)
# y = sample(logits[:, -1, :]) y = sample(logits[:, -1, :])
# yield y yield y
# while True: while True:
# logits, cache = model(y[:, None], cache=cache) # logits, cache = model(y[:, None], cache=cache)
# y = sample(logits.squeeze(1)) logits, _ = model(inputs, decoder_inputs)
# yield y y = sample(logits.squeeze(1))
decoder_inputs = mx.concat(decoder_inputs, y, dim=1)
yield y
def load_model(): def load_model(model_config):
model = T5(ModelArgs()) model = T5(model_config)
weights = mx.load("weights.npz") weights = mx.load("weights.npz")
current_weights = tree_flatten(model.parameters()) current_weights = tree_flatten(model.parameters())
weights_to_load = list(weights.items()) weights_to_load = list(weights.items())
@ -356,7 +369,8 @@ if __name__ == "__main__":
mx.random.seed(args.seed) mx.random.seed(args.seed)
model, tokenizer = load_model() config = ModelArgs()
model, tokenizer = load_model(config)
prompt = tokenizer( prompt = tokenizer(
args.prompt, args.prompt,
@ -369,18 +383,20 @@ if __name__ == "__main__":
print("[INFO] Generating with T5...", flush=True) print("[INFO] Generating with T5...", flush=True)
print(args.prompt, end="", flush=True) print(args.prompt, end="", flush=True)
print(model(prompt)) decoder_inputs = mx.array([[config.decoder_start_token_id]])
# tokens = [] tokens = []
# for token, _ in zip(generate(prompt, model), range(args.max_tokens)): for token, _ in zip(
# tokens.append(token) generate(prompt, decoder_inputs, model), range(args.max_tokens)
):
tokens.append(token)
# if (len(tokens) % 10) == 0: if (len(tokens) % 10) == 0:
# mx.eval(tokens) mx.eval(tokens)
# s = tokenizer.decode([t.item() for t in tokens]) s = tokenizer.decode([t.item() for t in tokens])
# print(s, end="", flush=True) print(s, end="", flush=True)
# tokens = [] tokens = []
# mx.eval(tokens) mx.eval(tokens)
# s = tokenizer.decode([t.item() for t in tokens]) s = tokenizer.decode([t.item() for t in tokens])
# print(s, flush=True) print(s, flush=True)