Decode (broken after 1st token)

This commit is contained in:
Juarez Bochi 2023-12-16 14:53:50 -05:00
parent 31da1b0dab
commit 203f550ef9
No known key found for this signature in database
GPG Key ID: 34CCBB77DC8BEBB6

View File

@ -1,5 +1,6 @@
import argparse
from dataclasses import dataclass
from typing import Optional
import numpy as np
import mlx.core as mx
@ -26,7 +27,9 @@ class ModelArgs:
vocab_size: int = 32128
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
def _relative_position_bucket(
relative_position, bidirectional=True, num_buckets=32, max_distance=128
):
"""
Adapted from HF Tensorflow:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
@ -66,10 +69,10 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
/ np.log(max_distance / max_exact)
* (num_buckets - max_exact)
).astype(mx.int16)
relative_position_if_large = mx.minimum(
relative_position_if_large, num_buckets - 1
relative_position_if_large = mx.minimum(relative_position_if_large, num_buckets - 1)
relative_buckets += mx.where(
is_small, relative_position, relative_position_if_large
)
relative_buckets += mx.where(is_small, relative_position, relative_position_if_large)
return relative_buckets
@ -80,22 +83,28 @@ class RelativePositionBias(nn.Module):
self.max_distance = config.relative_attention_max_distance
self.n_heads = config.num_heads
self.embeddings = nn.Embedding(
config.relative_attention_num_buckets,
config.num_heads)
config.relative_attention_num_buckets, config.num_heads
)
def __call__(self, query_length, key_length):
"""Compute binned relative position bias"""
context_position = mx.arange(query_length, dtype=mx.int32)[:, None]
memory_position = mx.arange(key_length, dtype=mx.int32)[None, :]
relative_position = memory_position - context_position # shape (query_length, key_length)
relative_position = (
memory_position - context_position
) # shape (query_length, key_length)
relative_position_bucket = _relative_position_bucket(
relative_position, # shape (query_length, key_length)
bidirectional=self.bidirectional,
num_buckets=self.num_buckets,
max_distance=self.max_distance,
)
values = self.embeddings(relative_position_bucket) # shape (query_length, key_length, num_heads)
values = mx.expand_dims(values.transpose(2, 0, 1), 0) # shape (1, num_heads, query_length, key_length)
values = self.embeddings(
relative_position_bucket
) # shape (query_length, key_length, num_heads)
values = mx.expand_dims(
values.transpose(2, 0, 1), 0
) # shape (1, num_heads, query_length, key_length)
return values
@ -132,10 +141,8 @@ class MultiHeadAttention(nn.Module):
if self.has_relative_attention_bias:
position_bias = self.relative_attention_bias(L, S)
scores += position_bias
scores = mx.softmax(scores, axis=-1)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat)
@staticmethod
@ -260,7 +267,7 @@ class TransformerDecoder(nn.Module):
class OutputHead(nn.Module):
def __init__(self, config: ModelArgs) -> None:
self.linear = nn.Linear(config.d_model, config.vocab_size)
self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False)
def __call__(self, inputs):
return self.linear(inputs)
@ -288,29 +295,35 @@ class T5(nn.Module):
mask = mask.astype(x.dtype)
decoder_inputs = self.wte(decoder_inputs)
y, cache = self.decoder(x=decoder_inputs, x_mask=mask, memory=y) #, cache)
y, cache = self.decoder(
x=decoder_inputs, x_mask=mask, memory=y, memory_mask=None
) # , cache)
return self.lm_head(y), cache
# def generate(prompt: mx.array, model: T5, temp: Optional[float] = 0.0):
# def sample(logits):
# if temp == 0:
# return mx.argmax(logits, axis=-1)
# else:
# return mx.random.categorical(logits * (1 / temp))
def generate(
inputs: mx.array, decoder_inputs: mx.array, model: T5, temp: Optional[float] = 0.0
):
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
# logits, cache = model(prompt)
# y = sample(logits[:, -1, :])
# yield y
logits, _ = model(inputs, decoder_inputs)
y = sample(logits[:, -1, :])
yield y
# while True:
while True:
# logits, cache = model(y[:, None], cache=cache)
# y = sample(logits.squeeze(1))
# yield y
logits, _ = model(inputs, decoder_inputs)
y = sample(logits.squeeze(1))
decoder_inputs = mx.concat(decoder_inputs, y, dim=1)
yield y
def load_model():
model = T5(ModelArgs())
def load_model(model_config):
model = T5(model_config)
weights = mx.load("weights.npz")
current_weights = tree_flatten(model.parameters())
weights_to_load = list(weights.items())
@ -356,7 +369,8 @@ if __name__ == "__main__":
mx.random.seed(args.seed)
model, tokenizer = load_model()
config = ModelArgs()
model, tokenizer = load_model(config)
prompt = tokenizer(
args.prompt,
@ -369,18 +383,20 @@ if __name__ == "__main__":
print("[INFO] Generating with T5...", flush=True)
print(args.prompt, end="", flush=True)
print(model(prompt))
decoder_inputs = mx.array([[config.decoder_start_token_id]])
# tokens = []
# for token, _ in zip(generate(prompt, model), range(args.max_tokens)):
# tokens.append(token)
tokens = []
for token, _ in zip(
generate(prompt, decoder_inputs, model), range(args.max_tokens)
):
tokens.append(token)
# if (len(tokens) % 10) == 0:
# mx.eval(tokens)
# s = tokenizer.decode([t.item() for t in tokens])
# print(s, end="", flush=True)
# tokens = []
if (len(tokens) % 10) == 0:
mx.eval(tokens)
s = tokenizer.decode([t.item() for t in tokens])
print(s, end="", flush=True)
tokens = []
# mx.eval(tokens)
# s = tokenizer.decode([t.item() for t in tokens])
# print(s, flush=True)
mx.eval(tokens)
s = tokenizer.decode([t.item() for t in tokens])
print(s, flush=True)