Decode (broken after 1st token)

This commit is contained in:
Juarez Bochi 2023-12-16 14:53:50 -05:00
parent 31da1b0dab
commit 203f550ef9
No known key found for this signature in database
GPG Key ID: 34CCBB77DC8BEBB6

View File

@ -1,5 +1,6 @@
import argparse import argparse
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional
import numpy as np import numpy as np
import mlx.core as mx import mlx.core as mx
@ -26,7 +27,9 @@ class ModelArgs:
vocab_size: int = 32128 vocab_size: int = 32128
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): def _relative_position_bucket(
relative_position, bidirectional=True, num_buckets=32, max_distance=128
):
""" """
Adapted from HF Tensorflow: Adapted from HF Tensorflow:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
@ -66,10 +69,10 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
/ np.log(max_distance / max_exact) / np.log(max_distance / max_exact)
* (num_buckets - max_exact) * (num_buckets - max_exact)
).astype(mx.int16) ).astype(mx.int16)
relative_position_if_large = mx.minimum( relative_position_if_large = mx.minimum(relative_position_if_large, num_buckets - 1)
relative_position_if_large, num_buckets - 1 relative_buckets += mx.where(
is_small, relative_position, relative_position_if_large
) )
relative_buckets += mx.where(is_small, relative_position, relative_position_if_large)
return relative_buckets return relative_buckets
@ -80,22 +83,28 @@ class RelativePositionBias(nn.Module):
self.max_distance = config.relative_attention_max_distance self.max_distance = config.relative_attention_max_distance
self.n_heads = config.num_heads self.n_heads = config.num_heads
self.embeddings = nn.Embedding( self.embeddings = nn.Embedding(
config.relative_attention_num_buckets, config.relative_attention_num_buckets, config.num_heads
config.num_heads) )
def __call__(self, query_length, key_length): def __call__(self, query_length, key_length):
"""Compute binned relative position bias""" """Compute binned relative position bias"""
context_position = mx.arange(query_length, dtype=mx.int32)[:, None] context_position = mx.arange(query_length, dtype=mx.int32)[:, None]
memory_position = mx.arange(key_length, dtype=mx.int32)[None, :] memory_position = mx.arange(key_length, dtype=mx.int32)[None, :]
relative_position = memory_position - context_position # shape (query_length, key_length) relative_position = (
memory_position - context_position
) # shape (query_length, key_length)
relative_position_bucket = _relative_position_bucket( relative_position_bucket = _relative_position_bucket(
relative_position, # shape (query_length, key_length) relative_position, # shape (query_length, key_length)
bidirectional=self.bidirectional, bidirectional=self.bidirectional,
num_buckets=self.num_buckets, num_buckets=self.num_buckets,
max_distance=self.max_distance, max_distance=self.max_distance,
) )
values = self.embeddings(relative_position_bucket) # shape (query_length, key_length, num_heads) values = self.embeddings(
values = mx.expand_dims(values.transpose(2, 0, 1), 0) # shape (1, num_heads, query_length, key_length) relative_position_bucket
) # shape (query_length, key_length, num_heads)
values = mx.expand_dims(
values.transpose(2, 0, 1), 0
) # shape (1, num_heads, query_length, key_length)
return values return values
@ -132,10 +141,8 @@ class MultiHeadAttention(nn.Module):
if self.has_relative_attention_bias: if self.has_relative_attention_bias:
position_bias = self.relative_attention_bias(L, S) position_bias = self.relative_attention_bias(L, S)
scores += position_bias scores += position_bias
scores = mx.softmax(scores, axis=-1) scores = mx.softmax(scores, axis=-1)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat) return self.out_proj(values_hat)
@staticmethod @staticmethod
@ -260,7 +267,7 @@ class TransformerDecoder(nn.Module):
class OutputHead(nn.Module): class OutputHead(nn.Module):
def __init__(self, config: ModelArgs) -> None: def __init__(self, config: ModelArgs) -> None:
self.linear = nn.Linear(config.d_model, config.vocab_size) self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False)
def __call__(self, inputs): def __call__(self, inputs):
return self.linear(inputs) return self.linear(inputs)
@ -288,29 +295,35 @@ class T5(nn.Module):
mask = mask.astype(x.dtype) mask = mask.astype(x.dtype)
decoder_inputs = self.wte(decoder_inputs) decoder_inputs = self.wte(decoder_inputs)
y, cache = self.decoder(x=decoder_inputs, x_mask=mask, memory=y) #, cache) y, cache = self.decoder(
x=decoder_inputs, x_mask=mask, memory=y, memory_mask=None
) # , cache)
return self.lm_head(y), cache return self.lm_head(y), cache
# def generate(prompt: mx.array, model: T5, temp: Optional[float] = 0.0): def generate(
# def sample(logits): inputs: mx.array, decoder_inputs: mx.array, model: T5, temp: Optional[float] = 0.0
# if temp == 0: ):
# return mx.argmax(logits, axis=-1) def sample(logits):
# else: if temp == 0:
# return mx.random.categorical(logits * (1 / temp)) return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
# logits, cache = model(prompt) logits, _ = model(inputs, decoder_inputs)
# y = sample(logits[:, -1, :]) y = sample(logits[:, -1, :])
# yield y yield y
# while True: while True:
# logits, cache = model(y[:, None], cache=cache) # logits, cache = model(y[:, None], cache=cache)
# y = sample(logits.squeeze(1)) logits, _ = model(inputs, decoder_inputs)
# yield y y = sample(logits.squeeze(1))
decoder_inputs = mx.concat(decoder_inputs, y, dim=1)
yield y
def load_model(): def load_model(model_config):
model = T5(ModelArgs()) model = T5(model_config)
weights = mx.load("weights.npz") weights = mx.load("weights.npz")
current_weights = tree_flatten(model.parameters()) current_weights = tree_flatten(model.parameters())
weights_to_load = list(weights.items()) weights_to_load = list(weights.items())
@ -356,7 +369,8 @@ if __name__ == "__main__":
mx.random.seed(args.seed) mx.random.seed(args.seed)
model, tokenizer = load_model() config = ModelArgs()
model, tokenizer = load_model(config)
prompt = tokenizer( prompt = tokenizer(
args.prompt, args.prompt,
@ -369,18 +383,20 @@ if __name__ == "__main__":
print("[INFO] Generating with T5...", flush=True) print("[INFO] Generating with T5...", flush=True)
print(args.prompt, end="", flush=True) print(args.prompt, end="", flush=True)
print(model(prompt)) decoder_inputs = mx.array([[config.decoder_start_token_id]])
# tokens = [] tokens = []
# for token, _ in zip(generate(prompt, model), range(args.max_tokens)): for token, _ in zip(
# tokens.append(token) generate(prompt, decoder_inputs, model), range(args.max_tokens)
):
tokens.append(token)
# if (len(tokens) % 10) == 0: if (len(tokens) % 10) == 0:
# mx.eval(tokens) mx.eval(tokens)
# s = tokenizer.decode([t.item() for t in tokens]) s = tokenizer.decode([t.item() for t in tokens])
# print(s, end="", flush=True) print(s, end="", flush=True)
# tokens = [] tokens = []
# mx.eval(tokens) mx.eval(tokens)
# s = tokenizer.decode([t.item() for t in tokens]) s = tokenizer.decode([t.item() for t in tokens])
# print(s, flush=True) print(s, flush=True)