with cache

This commit is contained in:
Awni Hannun 2023-12-17 17:35:53 -08:00
parent 29bfb93455
commit 688a6e1e78
2 changed files with 139 additions and 126 deletions

View File

@ -14,10 +14,8 @@ SHARED_REPLACEMENT_PATTERNS = [
(".layer.1.layer_norm.", ".ln2."), (".layer.1.layer_norm.", ".ln2."),
(".layer.2.layer_norm.", ".ln3."), (".layer.2.layer_norm.", ".ln3."),
(".final_layer_norm.", ".ln."), (".final_layer_norm.", ".ln."),
( ("layers.0.layer.0.SelfAttention.relative_attention_bias.",
".relative_attention_bias.", "relative_attention_bias.embeddings."),
".relative_attention_bias.embeddings."
),
] ]
ENCODER_REPLACEMENT_PATTERNS = [ ENCODER_REPLACEMENT_PATTERNS = [
@ -33,6 +31,7 @@ DECODER_REPLACEMENT_PATTERNS = [
(".layer.2.DenseReluDense.wo.", ".linear2."), (".layer.2.DenseReluDense.wo.", ".linear2."),
] ]
def replace_key(key: str) -> str: def replace_key(key: str) -> str:
for old, new in SHARED_REPLACEMENT_PATTERNS: for old, new in SHARED_REPLACEMENT_PATTERNS:
key = key.replace(old, new) key = key.replace(old, new)
@ -45,14 +44,22 @@ def replace_key(key: str) -> str:
return key return key
def convert(): def convert(model_name):
model = T5ForConditionalGeneration.from_pretrained( model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")
"t5-small", torch_dtype="auto" weights = {replace_key(k): v.numpy() for k, v in model.state_dict().items()}
)
state_dict = model.state_dict()
weights = {replace_key(k): v.numpy() for k, v in state_dict.items()}
np.savez("weights.npz", **weights) np.savez("weights.npz", **weights)
if __name__ == "__main__": if __name__ == "__main__":
convert() import argparse
parser = argparse.ArgumentParser(description="Convert T5 weights to MLX")
parser.add_argument(
"--model_name",
type=str,
help="Name of the T5 model.",
choices=["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"],
default="t5-small",
)
args = parser.parse_args()
convert(args.model_name)

236
t5/t5.py
View File

@ -1,5 +1,6 @@
import argparse import argparse
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple, List
from typing import Optional from typing import Optional
from time import perf_counter_ns from time import perf_counter_ns
@ -110,18 +111,23 @@ class RelativePositionBias(nn.Module):
class MultiHeadAttention(nn.Module): class MultiHeadAttention(nn.Module):
def __init__(self, config: ModelArgs, has_relative_attention_bias: bool = False): def __init__(self, config: ModelArgs):
super().__init__() super().__init__()
self.num_heads = config.num_heads self.num_heads = config.num_heads
self.query_proj = nn.Linear(config.d_model, config.d_model, bias=False) self.query_proj = nn.Linear(config.d_model, config.d_model, bias=False)
self.key_proj = nn.Linear(config.d_model, config.d_model, bias=False) self.key_proj = nn.Linear(config.d_model, config.d_model, bias=False)
self.value_proj = nn.Linear(config.d_model, config.d_model, bias=False) self.value_proj = nn.Linear(config.d_model, config.d_model, bias=False)
self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False) self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False)
self.has_relative_attention_bias = has_relative_attention_bias
if has_relative_attention_bias:
self.relative_attention_bias = RelativePositionBias(config)
def __call__(self, queries, keys, values, mask=None, position_bias=None): def __call__(
self,
queries: mx.array,
keys: mx.array,
values: mx.array,
mask: mx.array,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
queries = self.query_proj(queries) queries = self.query_proj(queries)
keys = self.key_proj(keys) keys = self.key_proj(keys)
values = self.value_proj(values) values = self.value_proj(values)
@ -133,106 +139,95 @@ class MultiHeadAttention(nn.Module):
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1) keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
key_cache, value_cache = cache
keys = mx.concatenate([key_cache, keys], axis=3)
values = mx.concatenate([value_cache, values], axis=2)
# Dimensions are [batch x num heads x sequence x hidden dim] # Dimensions are [batch x num heads x sequence x hidden dim]
scores = queries @ keys scores = queries @ keys
if mask is not None: if mask is not None:
scores = scores + mask.astype(scores.dtype) scores = scores + mask.astype(scores.dtype)
if self.has_relative_attention_bias:
position_bias = self.relative_attention_bias(L, S)
if position_bias is not None:
scores += position_bias
scores = mx.softmax(scores, axis=-1) scores = mx.softmax(scores, axis=-1)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
out = self.out_proj(values_hat) return self.out_proj(values_hat), (keys, values)
return out, position_bias
@staticmethod
def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32):
indices = mx.arange(N)
mask = indices[:, None] < indices[None]
# usually inf but 1e9 is as good and softmax(full(1e9)) != nan
# TODO: Should replace this with finfo(dtype).min
mask = mask.astype(dtype) * -1e9
return mask
class LayerNorm(nn.Module): class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5): def __init__(self, dims: int, eps: float = 1e-5):
super().__init__() super().__init__()
self.weight = mx.ones((dims,)) self.weight = mx.ones((dims,))
self.eps = eps self.eps = eps
self.dims = dims
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x): def __call__(self, x):
var = x.var(axis=-1, keepdims=True) output = self._norm(x.astype(mx.float32)).astype(x.dtype)
x = x * mx.rsqrt(var + self.eps) return self.weight * output
return x * self.weight
class TransformerEncoderLayer(nn.Module): class TransformerEncoderLayer(nn.Module):
def __init__(self, config: ModelArgs, has_relative_attention_bias: bool = False): def __init__(self, config: ModelArgs):
super().__init__() super().__init__()
mlp_dims = config.d_ff or config.d_model * 4 mlp_dims = config.d_ff or config.d_model * 4
self.attention = MultiHeadAttention( self.attention = MultiHeadAttention(config)
config, has_relative_attention_bias=has_relative_attention_bias self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
) self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln1 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln2 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False) self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False)
self.linear2 = nn.Linear(mlp_dims, config.d_model, bias=False) self.linear2 = nn.Linear(mlp_dims, config.d_model, bias=False)
def __call__(self, x, mask, position_bias=None): def __call__(self, x, mask):
y = self.ln1(x) y = self.ln1(x)
y, position_bias = self.attention( y, _ = self.attention(y, y, y, mask=mask)
queries=y, keys=y, values=y, mask=mask, position_bias=position_bias
)
x = x + y x = x + y
y = self.ln2(x) y = self.ln2(x)
y = self.linear1(y) y = self.linear1(y)
y = mx.maximum(y, 0) y = mx.maximum(y, 0)
y = self.linear2(y) y = self.linear2(y)
x = x + y return x + y
return x, position_bias
class TransformerEncoder(nn.Module): class TransformerEncoder(nn.Module):
def __init__(self, config: ModelArgs): def __init__(self, config: ModelArgs):
super().__init__() super().__init__()
self.layers = [ self.layers = [
TransformerEncoderLayer(config, has_relative_attention_bias=i == 0) TransformerEncoderLayer(config) for i in range(config.num_layers)
for i in range(config.num_layers)
] ]
self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.relative_attention_bias = RelativePositionBias(config)
def __call__(self, x, mask): def __call__(self, x):
position_bias = None pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1])
for layer in self.layers: for layer in self.layers:
x, position_bias = layer(x, mask, position_bias=position_bias) x = layer(x, mask=pos_bias)
x = self.ln(x) return self.ln(x)
return x
class TransformerDecoderLayer(nn.Module): class TransformerDecoderLayer(nn.Module):
def __init__(self, config: ModelArgs, has_relative_attention_bias: bool = False): def __init__(self, config: ModelArgs):
super().__init__() super().__init__()
mlp_dims = config.d_ff or config.d_model * 4 mlp_dims = config.d_ff or config.d_model * 4
self.self_attention = MultiHeadAttention( self.self_attention = MultiHeadAttention(config)
config, has_relative_attention_bias=has_relative_attention_bias
)
self.cross_attention = MultiHeadAttention(config) self.cross_attention = MultiHeadAttention(config)
self.ln1 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln2 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln3 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.ln3 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False) self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False)
self.linear2 = nn.Linear(mlp_dims, config.d_model, bias=False) self.linear2 = nn.Linear(mlp_dims, config.d_model, bias=False)
def __call__(self, x, memory, x_mask, memory_mask, position_bias=None): def __call__(
self,
x: mx.array,
memory: mx.array,
mask: mx.array,
memory_mask: mx.array,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
y = self.ln1(x) y = self.ln1(x)
y, position_bias = self.self_attention(y, y, y, x_mask, position_bias=position_bias) y, cache = self.self_attention(y, y, y, mask, cache)
x = x + y x = x + y
y = self.ln2(x) y = self.ln2(x)
@ -245,27 +240,39 @@ class TransformerDecoderLayer(nn.Module):
y = self.linear2(y) y = self.linear2(y)
x = x + y x = x + y
return x, position_bias return x, cache
class TransformerDecoder(nn.Module): class TransformerDecoder(nn.Module):
def __init__(self, config: ModelArgs): def __init__(self, config: ModelArgs):
super().__init__() super().__init__()
self.layers = [ self.layers = [
TransformerDecoderLayer(config, has_relative_attention_bias=i == 0) TransformerDecoderLayer(config) for i in range(config.num_layers)
for i in range(config.num_layers)
] ]
self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.relative_attention_bias = RelativePositionBias(config)
def __call__(self, x, memory, x_mask, memory_mask): def __call__(self, x, memory, mask, memory_mask, cache=None):
position_bias = None if cache is not None:
for layer in self.layers: offset = cache[0][0].shape[3]
x, position_bias = layer( else:
x, memory, x_mask, memory_mask, position_bias=position_bias offset = 0
) cache = [None] * len(self.layers)
T = offset + x.shape[1]
# TODO, add offset to RelativePositionBias class to avoid wasted work
pos_bias = self.relative_attention_bias(T, T)
pos_bias = pos_bias[:, :, -x.shape[1]:, :]
if mask is not None:
mask += pos_bias
else:
mask = pos_bias
for e, layer in enumerate(self.layers):
x, cache[e] = layer(x, memory, mask, memory_mask, cache=cache[e])
x = self.ln(x) x = self.ln(x)
return x return x, cache
class OutputHead(nn.Module): class OutputHead(nn.Module):
@ -283,26 +290,37 @@ class T5(nn.Module):
self.decoder = TransformerDecoder(config) self.decoder = TransformerDecoder(config)
self.lm_head = OutputHead(config) self.lm_head = OutputHead(config)
def encode(
self,
inputs: mx.array
) -> mx.array:
return self.encoder(self.wte(inputs))
def decode(
self,
inputs: mx.array,
memory: mx.array,
cache = None,
):
inputs = self.wte(inputs)
T = inputs.shape[1]
if T > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
mask = mask.astype(inputs.dtype)
else:
mask = None
y, cache = self.decoder(
inputs, memory=memory, mask=mask, memory_mask=None, cache=cache
)
return self.lm_head(y), cache
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
decoder_inputs: mx.array, decoder_inputs: mx.array,
mask: mx.array = None, ) -> mx.array:
cache: mx.array = None, return decode(decoder_inputs, encode(inputs))[0]
) -> tuple[mx.array, mx.array]:
x = self.wte(inputs)
y = self.encoder(x, mask=None) # , cache)
decoder_inputs = self.wte(decoder_inputs)
decoder_n_tokens = decoder_inputs.shape[1]
if decoder_n_tokens > 1 and mask is None:
mask = MultiHeadAttention.create_additive_causal_mask(decoder_n_tokens)
mask = mask.astype(x.dtype)
y = self.decoder(
x=decoder_inputs, x_mask=mask, memory=y, memory_mask=None
) # , cache)
return self.lm_head(y), cache
def generate( def generate(
@ -314,12 +332,15 @@ def generate(
else: else:
return mx.random.categorical(logits * (1 / temp)) return mx.random.categorical(logits * (1 / temp))
memory = model.encode(inputs)
cache = None
y = decoder_inputs
while True: while True:
# TODO: add cache logits, cache = model.decode(y[None], memory, cache=cache)
logits, _ = model(inputs, decoder_inputs) # logits, cache = model.decode(decoder_inputs[None], memory, cache=cache)
y = mx.expand_dims(sample(logits[:, -1, :]), 0) y = sample(logits[:, -1, :])
decoder_inputs = mx.concatenate([decoder_inputs, y], axis=1) #decoder_inputs = mx.concatenate([decoder_inputs, y])
yield y yield y.squeeze()
def load_model(model_config): def load_model(model_config):
@ -340,6 +361,7 @@ def load_model(model_config):
print("Expected shape: ", current_weights_dict[key].shape) print("Expected shape: ", current_weights_dict[key].shape)
print("Loading shape: ", weights_to_load_dict[key].shape) print("Loading shape: ", weights_to_load_dict[key].shape)
model.update(tree_unflatten(weights_to_load)) model.update(tree_unflatten(weights_to_load))
mx.eval(model.parameters())
tokenizer = AutoTokenizer.from_pretrained("t5-small", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained("t5-small", trust_remote_code=True)
return model, tokenizer return model, tokenizer
@ -397,39 +419,23 @@ if __name__ == "__main__":
print("[INFO] Generating with T5...", flush=True) print("[INFO] Generating with T5...", flush=True)
print("Input: ", args.prompt, flush=True) print("Input: ", args.prompt, flush=True)
decoder_inputs = mx.array([[config.decoder_start_token_id]]).astype(mx.uint32) decoder_inputs = mx.array([config.decoder_start_token_id])
start = perf_counter_ns() start = perf_counter_ns()
n_tokens = 0
tokens = [] tokens = []
for token, _ in zip( for token, n_tokens in zip(
generate(prompt, decoder_inputs, model, args.temp), generate(prompt, decoder_inputs, model, args.temp),
range(args.max_tokens) range(args.max_tokens)
): ):
tokens.append(token) if token.item() == tokenizer.eos_token_id:
break
if (len(tokens) % 10) == 0: tokens.append(token.item())
mx.eval(tokens) # For some reason using the following line doesn't give spaces
eos_index = next( # print(tokenizer.decode(token.item(), clean_up_tokenization_spaces=False), end="", flush=True)
(i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id), print(tokenizer.decode(tokens), end="", flush=True)
None,
)
if eos_index is not None:
tokens = tokens[:eos_index]
s = tokenizer.decode([t.item() for t in tokens])
print(s, end="", flush=True)
n_tokens += len(tokens)
tokens = []
if eos_index is not None:
break
end = perf_counter_ns() end = perf_counter_ns()
mx.eval(tokens)
s = tokenizer.decode([t.item() for t in tokens])
n_tokens += len(tokens)
elapsed = (end - start) / 1.0e9 elapsed = (end - start) / 1.0e9
print(s, flush=True) print()
print(f"Time: {elapsed:.2f} seconds, tokens/s: {n_tokens / elapsed:.2f}") print(f"Time: {elapsed:.2f} seconds, tokens/s: {n_tokens / elapsed:.2f}")