with cache

This commit is contained in:
Awni Hannun 2023-12-17 17:35:53 -08:00
parent 29bfb93455
commit 688a6e1e78
2 changed files with 139 additions and 126 deletions

View File

@ -14,10 +14,8 @@ SHARED_REPLACEMENT_PATTERNS = [
(".layer.1.layer_norm.", ".ln2."),
(".layer.2.layer_norm.", ".ln3."),
(".final_layer_norm.", ".ln."),
(
".relative_attention_bias.",
".relative_attention_bias.embeddings."
),
("layers.0.layer.0.SelfAttention.relative_attention_bias.",
"relative_attention_bias.embeddings."),
]
ENCODER_REPLACEMENT_PATTERNS = [
@ -33,6 +31,7 @@ DECODER_REPLACEMENT_PATTERNS = [
(".layer.2.DenseReluDense.wo.", ".linear2."),
]
def replace_key(key: str) -> str:
for old, new in SHARED_REPLACEMENT_PATTERNS:
key = key.replace(old, new)
@ -45,14 +44,22 @@ def replace_key(key: str) -> str:
return key
def convert():
model = T5ForConditionalGeneration.from_pretrained(
"t5-small", torch_dtype="auto"
)
state_dict = model.state_dict()
weights = {replace_key(k): v.numpy() for k, v in state_dict.items()}
def convert(model_name):
model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")
weights = {replace_key(k): v.numpy() for k, v in model.state_dict().items()}
np.savez("weights.npz", **weights)
if __name__ == "__main__":
convert()
import argparse
parser = argparse.ArgumentParser(description="Convert T5 weights to MLX")
parser.add_argument(
"--model_name",
type=str,
help="Name of the T5 model.",
choices=["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"],
default="t5-small",
)
args = parser.parse_args()
convert(args.model_name)

232
t5/t5.py
View File

@ -1,5 +1,6 @@
import argparse
from dataclasses import dataclass
from typing import Optional, Tuple, List
from typing import Optional
from time import perf_counter_ns
@ -110,18 +111,23 @@ class RelativePositionBias(nn.Module):
class MultiHeadAttention(nn.Module):
def __init__(self, config: ModelArgs, has_relative_attention_bias: bool = False):
def __init__(self, config: ModelArgs):
super().__init__()
self.num_heads = config.num_heads
self.query_proj = nn.Linear(config.d_model, config.d_model, bias=False)
self.key_proj = nn.Linear(config.d_model, config.d_model, bias=False)
self.value_proj = nn.Linear(config.d_model, config.d_model, bias=False)
self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False)
self.has_relative_attention_bias = has_relative_attention_bias
if has_relative_attention_bias:
self.relative_attention_bias = RelativePositionBias(config)
def __call__(self, queries, keys, values, mask=None, position_bias=None):
def __call__(
self,
queries: mx.array,
keys: mx.array,
values: mx.array,
mask: mx.array,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
@ -133,106 +139,95 @@ class MultiHeadAttention(nn.Module):
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
key_cache, value_cache = cache
keys = mx.concatenate([key_cache, keys], axis=3)
values = mx.concatenate([value_cache, values], axis=2)
# Dimensions are [batch x num heads x sequence x hidden dim]
scores = queries @ keys
if mask is not None:
scores = scores + mask.astype(scores.dtype)
if self.has_relative_attention_bias:
position_bias = self.relative_attention_bias(L, S)
if position_bias is not None:
scores += position_bias
scores = mx.softmax(scores, axis=-1)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
out = self.out_proj(values_hat)
return out, position_bias
@staticmethod
def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32):
indices = mx.arange(N)
mask = indices[:, None] < indices[None]
# usually inf but 1e9 is as good and softmax(full(1e9)) != nan
# TODO: Should replace this with finfo(dtype).min
mask = mask.astype(dtype) * -1e9
return mask
return self.out_proj(values_hat), (keys, values)
class LayerNorm(nn.Module):
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
self.dims = dims
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
var = x.var(axis=-1, keepdims=True)
x = x * mx.rsqrt(var + self.eps)
return x * self.weight
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output
class TransformerEncoderLayer(nn.Module):
def __init__(self, config: ModelArgs, has_relative_attention_bias: bool = False):
def __init__(self, config: ModelArgs):
super().__init__()
mlp_dims = config.d_ff or config.d_model * 4
self.attention = MultiHeadAttention(
config, has_relative_attention_bias=has_relative_attention_bias
)
self.ln1 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln2 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.attention = MultiHeadAttention(config)
self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False)
self.linear2 = nn.Linear(mlp_dims, config.d_model, bias=False)
def __call__(self, x, mask, position_bias=None):
def __call__(self, x, mask):
y = self.ln1(x)
y, position_bias = self.attention(
queries=y, keys=y, values=y, mask=mask, position_bias=position_bias
)
y, _ = self.attention(y, y, y, mask=mask)
x = x + y
y = self.ln2(x)
y = self.linear1(y)
y = mx.maximum(y, 0)
y = self.linear2(y)
x = x + y
return x, position_bias
return x + y
class TransformerEncoder(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.layers = [
TransformerEncoderLayer(config, has_relative_attention_bias=i == 0)
for i in range(config.num_layers)
TransformerEncoderLayer(config) for i in range(config.num_layers)
]
self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.relative_attention_bias = RelativePositionBias(config)
def __call__(self, x, mask):
position_bias = None
def __call__(self, x):
pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1])
for layer in self.layers:
x, position_bias = layer(x, mask, position_bias=position_bias)
x = self.ln(x)
return x
x = layer(x, mask=pos_bias)
return self.ln(x)
class TransformerDecoderLayer(nn.Module):
def __init__(self, config: ModelArgs, has_relative_attention_bias: bool = False):
def __init__(self, config: ModelArgs):
super().__init__()
mlp_dims = config.d_ff or config.d_model * 4
self.self_attention = MultiHeadAttention(
config, has_relative_attention_bias=has_relative_attention_bias
)
self.self_attention = MultiHeadAttention(config)
self.cross_attention = MultiHeadAttention(config)
self.ln1 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln2 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln3 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln3 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False)
self.linear2 = nn.Linear(mlp_dims, config.d_model, bias=False)
def __call__(self, x, memory, x_mask, memory_mask, position_bias=None):
def __call__(
self,
x: mx.array,
memory: mx.array,
mask: mx.array,
memory_mask: mx.array,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
y = self.ln1(x)
y, position_bias = self.self_attention(y, y, y, x_mask, position_bias=position_bias)
y, cache = self.self_attention(y, y, y, mask, cache)
x = x + y
y = self.ln2(x)
@ -245,27 +240,39 @@ class TransformerDecoderLayer(nn.Module):
y = self.linear2(y)
x = x + y
return x, position_bias
return x, cache
class TransformerDecoder(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.layers = [
TransformerDecoderLayer(config, has_relative_attention_bias=i == 0)
for i in range(config.num_layers)
TransformerDecoderLayer(config) for i in range(config.num_layers)
]
self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.relative_attention_bias = RelativePositionBias(config)
def __call__(self, x, memory, x_mask, memory_mask):
position_bias = None
for layer in self.layers:
x, position_bias = layer(
x, memory, x_mask, memory_mask, position_bias=position_bias
)
def __call__(self, x, memory, mask, memory_mask, cache=None):
if cache is not None:
offset = cache[0][0].shape[3]
else:
offset = 0
cache = [None] * len(self.layers)
T = offset + x.shape[1]
# TODO, add offset to RelativePositionBias class to avoid wasted work
pos_bias = self.relative_attention_bias(T, T)
pos_bias = pos_bias[:, :, -x.shape[1]:, :]
if mask is not None:
mask += pos_bias
else:
mask = pos_bias
for e, layer in enumerate(self.layers):
x, cache[e] = layer(x, memory, mask, memory_mask, cache=cache[e])
x = self.ln(x)
return x
return x, cache
class OutputHead(nn.Module):
@ -283,26 +290,37 @@ class T5(nn.Module):
self.decoder = TransformerDecoder(config)
self.lm_head = OutputHead(config)
def encode(
self,
inputs: mx.array
) -> mx.array:
return self.encoder(self.wte(inputs))
def decode(
self,
inputs: mx.array,
memory: mx.array,
cache = None,
):
inputs = self.wte(inputs)
T = inputs.shape[1]
if T > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
mask = mask.astype(inputs.dtype)
else:
mask = None
y, cache = self.decoder(
inputs, memory=memory, mask=mask, memory_mask=None, cache=cache
)
return self.lm_head(y), cache
def __call__(
self,
inputs: mx.array,
decoder_inputs: mx.array,
mask: mx.array = None,
cache: mx.array = None,
) -> tuple[mx.array, mx.array]:
x = self.wte(inputs)
y = self.encoder(x, mask=None) # , cache)
decoder_inputs = self.wte(decoder_inputs)
decoder_n_tokens = decoder_inputs.shape[1]
if decoder_n_tokens > 1 and mask is None:
mask = MultiHeadAttention.create_additive_causal_mask(decoder_n_tokens)
mask = mask.astype(x.dtype)
y = self.decoder(
x=decoder_inputs, x_mask=mask, memory=y, memory_mask=None
) # , cache)
return self.lm_head(y), cache
) -> mx.array:
return decode(decoder_inputs, encode(inputs))[0]
def generate(
@ -314,12 +332,15 @@ def generate(
else:
return mx.random.categorical(logits * (1 / temp))
memory = model.encode(inputs)
cache = None
y = decoder_inputs
while True:
# TODO: add cache
logits, _ = model(inputs, decoder_inputs)
y = mx.expand_dims(sample(logits[:, -1, :]), 0)
decoder_inputs = mx.concatenate([decoder_inputs, y], axis=1)
yield y
logits, cache = model.decode(y[None], memory, cache=cache)
# logits, cache = model.decode(decoder_inputs[None], memory, cache=cache)
y = sample(logits[:, -1, :])
#decoder_inputs = mx.concatenate([decoder_inputs, y])
yield y.squeeze()
def load_model(model_config):
@ -340,6 +361,7 @@ def load_model(model_config):
print("Expected shape: ", current_weights_dict[key].shape)
print("Loading shape: ", weights_to_load_dict[key].shape)
model.update(tree_unflatten(weights_to_load))
mx.eval(model.parameters())
tokenizer = AutoTokenizer.from_pretrained("t5-small", trust_remote_code=True)
return model, tokenizer
@ -397,39 +419,23 @@ if __name__ == "__main__":
print("[INFO] Generating with T5...", flush=True)
print("Input: ", args.prompt, flush=True)
decoder_inputs = mx.array([[config.decoder_start_token_id]]).astype(mx.uint32)
decoder_inputs = mx.array([config.decoder_start_token_id])
start = perf_counter_ns()
n_tokens = 0
tokens = []
for token, _ in zip(
for token, n_tokens in zip(
generate(prompt, decoder_inputs, model, args.temp),
range(args.max_tokens)
):
tokens.append(token)
if (len(tokens) % 10) == 0:
mx.eval(tokens)
eos_index = next(
(i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id),
None,
)
if eos_index is not None:
tokens = tokens[:eos_index]
s = tokenizer.decode([t.item() for t in tokens])
print(s, end="", flush=True)
n_tokens += len(tokens)
tokens = []
if eos_index is not None:
if token.item() == tokenizer.eos_token_id:
break
tokens.append(token.item())
# For some reason using the following line doesn't give spaces
# print(tokenizer.decode(token.item(), clean_up_tokenization_spaces=False), end="", flush=True)
print(tokenizer.decode(tokens), end="", flush=True)
end = perf_counter_ns()
mx.eval(tokens)
s = tokenizer.decode([t.item() for t in tokens])
n_tokens += len(tokens)
elapsed = (end - start) / 1.0e9
print(s, flush=True)
print()
print(f"Time: {elapsed:.2f} seconds, tokens/s: {n_tokens / elapsed:.2f}")