switch to t5

This commit is contained in:
Awni Hannun 2023-12-28 14:55:25 -08:00
parent 42378e5861
commit 34a62ddc49
7 changed files with 529 additions and 236 deletions

View File

@ -1,11 +1,11 @@
# Speculative Decoding
This example implements [speculative decoding] for text generation.[^1].
Speculative decoding uses a smaller draft model to propose several tokens, and
then a larger model which decides which tokens to accept. The generated text is
identical to what the larger model would produce on its own, but with far fewer
forward passes of the large model since it can evaluate the draft tokens in
parallel.
This example implements speculative decoding with the T5 model for text
generation.[^1] Speculative decoding uses a smaller draft model to propose
several tokens, and a larger model to decide which tokens to accept. The
distribution of the generated text is identical to what the larger model would
produce on its own, but with far fewer forward passes of the large model since
it can evaluate the draft tokens in parallel.
### Setup
@ -16,6 +16,19 @@ cd speculative_decoding
pip install -r requirements.txt
```
Then convert the model and the draft model. For example, you can convert th
T5 11B model with:
```
python convert.py --model t5-11b
```
And for the draft model, convert the T5 small model with:
```
python convert.py --model t5-small
```
### Run
You can run with the default arguments:
@ -24,9 +37,27 @@ You can run with the default arguments:
python main.py
```
To see a full list of options use:
```
python main.py --help
```
### Notes
Speculative decoding works well when most of the tokens from the draft model
are accepted by the larger model. That's more likely to happen if the models
are trained on similar data. The default setting in this example uses TinyLlama
as a draft morel for Llama 7B.
are trained on similar data.
[^1] See the paper [Fast Inference from Transformers via Speculative Decoding](https://arxiv.org/abs/2211.17192)
One way to increase the chance of accepting a draft token is with the parameter
`--delta`. This parameter can be in the range `[0, 1]`. If it is `1` then all
the draft tokens will be accepted by the model. If it is `0`, then only draft
tokens which match the original acceptance criterion kept.[^1] Values closer to
`1` increase the chance that a draft token is accepted.
Conversely, the fewer draft tokens accepted by the model, the more expensive
speculative decoding is. You can use `--draft` to tune the number of draft
tokens per model evaluation in order to reduce the number of discarded draft
tokens.
[^1] See the paper [Fast Inference from Transformers via Speculative
Decoding](https://arxiv.org/abs/2211.17192)

View File

@ -0,0 +1,75 @@
import numpy as np
from transformers import T5ForConditionalGeneration
SHARED_REPLACEMENT_PATTERNS = [
(".block.", ".layers."),
(".k.", ".key_proj."),
(".o.", ".out_proj."),
(".q.", ".query_proj."),
(".v.", ".value_proj."),
("shared.", "wte."),
("lm_head.", "lm_head.linear."),
(".layer.0.layer_norm.", ".ln1."),
(".layer.1.layer_norm.", ".ln2."),
(".layer.2.layer_norm.", ".ln3."),
(".final_layer_norm.", ".ln."),
(
"layers.0.layer.0.SelfAttention.relative_attention_bias.",
"relative_attention_bias.embeddings.",
),
]
ENCODER_REPLACEMENT_PATTERNS = [
(".layer.0.SelfAttention.", ".attention."),
(".layer.1.DenseReluDense.", ".dense."),
]
DECODER_REPLACEMENT_PATTERNS = [
(".layer.0.SelfAttention.", ".self_attention."),
(".layer.1.EncDecAttention.", ".cross_attention."),
(".layer.2.DenseReluDense.", ".dense."),
]
def replace_key(key: str) -> str:
for old, new in SHARED_REPLACEMENT_PATTERNS:
key = key.replace(old, new)
if key.startswith("encoder."):
for old, new in ENCODER_REPLACEMENT_PATTERNS:
key = key.replace(old, new)
elif key.startswith("decoder."):
for old, new in DECODER_REPLACEMENT_PATTERNS:
key = key.replace(old, new)
return key
def convert(model_name, dtype):
dtype = getattr(np, dtype)
model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")
weights = {
replace_key(k): v.numpy().astype(dtype) for k, v in model.state_dict().items()
}
file_name = model_name.replace("/", "-")
print(f"Saving weights to {file_name}.npz")
np.savez(f"{file_name}.npz", **weights)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Convert T5 weights to MLX")
parser.add_argument(
"--model",
type=str,
help="Name of the T5 model.",
default="t5-small",
)
parser.add_argument(
"--dtype",
help="The model data type.",
type=str,
choices=["float16", "float32"],
default="float32",
)
args = parser.parse_args()
convert(args.model, args.dtype)

View File

@ -1,4 +1,3 @@
import time
from dataclasses import dataclass, field
from typing import List, Optional
@ -6,18 +5,26 @@ import mlx.core as mx
import mlx.nn as nn
import numpy as np
import transformers
from model import Llama
from prompts import create_urial_prompt
from model import Model
class Tokenizer:
def __init__(self, model_name: str):
self._tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
self._tokenizer = transformers.AutoTokenizer.from_pretrained(
model_name,
legacy=False,
model_max_length=512,
)
self._decoder_start_id = 0
@property
def eos_id(self) -> int:
return self._tokenizer.eos_token_id
@property
def decoder_start_id(self) -> int:
return self._decoder_start_id
def encode(self, s: str) -> mx.array:
return mx.array(
self._tokenizer(s, return_tensors="np", return_attention_mask=False,)[
@ -25,44 +32,34 @@ class Tokenizer:
].squeeze(0)
)
def decode(self, t: List[int], with_sep: bool = True) -> str:
tokens = self._tokenizer.convert_ids_to_tokens(t)
return "".join(t.replace("", " " if with_sep else "") for t in tokens)
def decode(self, t: List[int]) -> str:
return self._tokenizer.decode(t)
class SpeculativeDecoder:
def __init__(
self,
model: str,
draft_model: str = None,
model: Model,
draft_model: Model,
tokenizer: str,
num_draft: int = 5,
delta: float = 0.0,
):
self.tokenizer = Tokenizer(model)
self.model = Llama.from_hugging_face(model)
if draft_model is not None:
self.draft_model = Llama.from_hugging_face(draft_model)
self.tokenizer = Tokenizer(tokenizer)
self.model = model
self.draft_model = draft_model
self.num_draft = num_draft
self.delta = delta
def tokenize(self, prompt):
# if self.tokenizer.chat_template is not None:
# tokenized = self.tokenizer.apply_chat_template(
# prompt, tokenize=True, add_generation_prompt=True
# )
# else:
# use urial zero-shot template
tokenized = self.tokenizer.encode(create_urial_prompt(prompt["content"]))
return tokenized
def _generate(
self,
x: mx.array,
memory: mx.array,
draft: bool = False,
):
model = self.draft_model if draft else self.model
while True:
logits = model(x[None, :], next_tokens=1).squeeze((0, 1))
logits = model.decode(x[None], memory)[0, -1]
x = mx.argmax(logits, keepdims=True)
lognorm = mx.logsumexp(logits.astype(mx.float32))
logprob = logits[x] - lognorm
@ -72,25 +69,26 @@ class SpeculativeDecoder:
self,
prompt,
max_tokens: int = 100,
draft: bool = False,
):
x = self.tokenize(prompt)
start = time.time()
for (token, _), n in zip(self._generate(x, draft=draft), range(max_tokens)):
token = token.item()
memory = self.model.encode(self.tokenizer.encode(prompt)[None])
x = mx.array([self.tokenizer.decoder_start_id])
skip = 0
outputs = []
for (token, _), n in zip(self._generate(x, memory), range(max_tokens)):
if token == self.tokenizer.eos_id:
break
print(self.tokenizer.decode(token, with_sep=n > 0), end="", flush=True)
run_time = time.time() - start
outputs.append(token.item())
if (n + 1) % 10 == 0:
str_output = self.tokenizer.decode(outputs)
print(str_output[skip:], end="", flush=True)
skip = len(str_output)
print(self.tokenizer.decode(outputs)[skip:], end="", flush=True)
print()
print(f"=== GENERATED {n + 1} TOKENS in {run_time} SECONDS ===")
if draft:
self.draft_model.reset_cache()
else:
self.model.reset_cache()
self.model.reset_cache()
def _get_num_accept(self, draft_tokens, draft_probs, model_logits):
# equal_toks = sampled[:-1] == draft_tokens
# accept_toks = mx.argmax(model_logits, axis=-1) == draft_tokens
model_probs = mx.take_along_axis(
model_logits,
draft_tokens[:, None],
@ -111,14 +109,19 @@ class SpeculativeDecoder:
def sample(logits):
return mx.argmax(logits, axis=-1)
tokens = mx.array(self.tokenize(prompt), mx.uint32)
start = time.time()
prompt = mx.array(self.tokenizer.encode(prompt), mx.uint32)[None]
memory = self.model.encode(prompt)
draft_memory = self.draft_model.encode(prompt)
decoding_steps = 0
tokens = mx.array([self.tokenizer.decoder_start_id])
n_steps = 0
ntoks = 0
accepted_draft_tokens = 0
total_draft_tokens = 0
n_accepted = 0
n_draft = 0
outputs = []
skip = 0
draft_inputs = tokens
inputs = tokens
while True:
@ -127,7 +130,7 @@ class SpeculativeDecoder:
draft_probs = []
for _, (t, p) in zip(
range(ntoks, min(ntoks + self.num_draft, max_tokens)),
self._generate(draft_inputs, draft=True),
self._generate(draft_inputs, draft_memory, draft=True),
):
draft_tokens.append(t)
draft_probs.append(p)
@ -138,10 +141,10 @@ class SpeculativeDecoder:
draft_tokens = mx.concatenate(draft_tokens)
draft_probs = mx.concatenate(draft_probs)
verify_tokens = mx.concatenate([inputs, draft_tokens])
logits = self.model(
verify_tokens[None, :], next_tokens=draft_tokens.size + 1
logits = self.model.decode(
verify_tokens[None, :],
memory,
).squeeze(0)
# sampled = sample(logits).squeeze(0)
# Only keep samples that match the draft:
num_to_accept = self._get_num_accept(
@ -155,38 +158,34 @@ class SpeculativeDecoder:
[new_tokens, mx.argmax(logits[num_to_accept], keepdims=True)]
)
accepted_draft_tokens += num_to_accept
total_draft_tokens += draft_tokens.size
n_accepted += num_to_accept
n_draft += draft_tokens.size
# Rewind the cache for unaccepted tokens:
if (n := draft_tokens.size) > num_to_accept:
self.draft_model.truncate_cache(n - new_tokens.size)
self.model.truncate_cache(n - new_tokens.size + 1)
decoding_steps += 1
n_steps += 1
for t in new_tokens.tolist():
if t == self.tokenizer.eos_id or ntoks >= max_tokens:
break
print(self.tokenizer.decode(t, with_sep=ntoks > 0), end="", flush=True)
outputs.append(t)
ntoks += 1
str_output = self.tokenizer.decode(outputs)
print(str_output[skip:], end="", flush=True)
skip = len(str_output)
if ntoks >= max_tokens or new_tokens[-1] == self.tokenizer.eos_id:
break
draft_inputs = new_tokens[max(new_tokens.size - 2, 0) :]
inputs = draft_inputs[-1:]
end = time.time()
print(self.tokenizer.decode(outputs)[skip:], end="", flush=True)
print()
self.model.reset_cache()
self.draft_model.reset_cache()
print()
print(
"=== GENERATED",
ntoks,
"TOKENS IN",
round(end - start, 2),
"SECONDS ===",
)
print(
f"=== ACCEPTED {accepted_draft_tokens} of {total_draft_tokens} DRAFT TOKENS ==="
)
print("=== DECODING STEPS", decoding_steps, "===")
return {"n_accepted": n_accepted, "n_draft": n_draft, "n_steps": n_steps}

View File

@ -1,31 +1,51 @@
import argparse
import glob
import json
import time
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
from decoder import SpeculativeDecoder
from mlx.utils import tree_unflatten
from model import Model
from transformers import T5Config
def load_model(model_name: str):
config = T5Config.from_pretrained(model_name)
model = Model(config)
weights = mx.load(f"{model_name}.npz")
weights = tree_unflatten(list(weights.items()))
model.update(weights)
mx.eval(model.parameters())
return model
def main(args):
mx.random.seed(args.seed)
spec_decoder = SpeculativeDecoder(
# model="TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T",
model="meta-llama/Llama-2-7b-hf",
draft_model="TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T",
model=load_model(args.model_name),
draft_model=load_model(args.draft_model_name),
tokenizer=args.model_name,
delta=args.delta,
num_draft=args.num_draft,
)
prompt = {"role": "user", "content": "Finish the monologue: To be, or not to be..."}
tic = time.time()
print(args.prompt)
if args.regular_decode:
spec_decoder.generate(args.prompt, max_tokens=args.max_tokens)
else:
stats = spec_decoder.speculative_decode(args.prompt, max_tokens=args.max_tokens)
print("=" * 10)
print(f"Accepted {stats['n_accepted']} / {stats['n_draft']}.")
print(f"Decoding steps {stats['n_steps']}.")
# Do 1 regular generation to get warmed up (the first one is slow)
# engine.generate(messages, max_tokens=1)
# engine.generate(messages, max_tokens=1, draft=True)
# Time regular generation
spec_decoder.generate(prompt, max_tokens=125)
# Time speculative decoding
spec_decoder.speculative_decode(prompt, max_tokens=125)
toc = time.time()
print("=" * 10)
print(f"Full generation time {toc - tic:.3f}")
if __name__ == "__main__":
@ -36,17 +56,44 @@ if __name__ == "__main__":
default=5,
help="Number of draft tokens to use per decoding step.",
)
parser.add_argument(
"--model-name",
help="Name of the model.",
default="t5-small",
)
parser.add_argument(
"--draft-model-name",
help="Name of the draft model.",
default="t5-small",
)
parser.add_argument(
"--seed",
type=int,
default=0,
help="PRNG seed.",
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=100,
help="Maximum number of tokens to generate.",
)
parser.add_argument(
"--prompt",
default="translate English to French: Let's go to the store and buy some groceries including eggs, avocadoes, and bread.",
help="The prompt processed by the model.",
)
parser.add_argument(
"--delta",
type=float,
default=0.1,
help="Lenience for accepting the proposal tokens.",
)
parser.add_argument(
"--regular-decode",
action="store_true",
help="Use regular decoding instead of speculative decoding.",
)
args = parser.parse_args()
main(args)

View File

@ -1,17 +1,135 @@
from typing import Optional, Tuple
from typing import List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.utils import tree_map, tree_unflatten
from transformers import AutoModelForCausalLM, LlamaConfig
from transformers import AutoTokenizer, T5Config
def create_additive_causal_mask(N: int, offset: int = 0, dtype: mx.Dtype = mx.float32):
rinds = mx.arange(offset + N)
linds = mx.arange(offset, offset + N) if offset else rinds
mask = linds[:, None] < rinds[None]
mask = mask.astype(dtype) * -1e9
return mask
def _relative_position_bucket(
relative_position, bidirectional=True, num_buckets=32, max_distance=128
):
"""
Adapted from HF Tensorflow:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
Translate relative position to a bucket number for relative attention. The relative position is defined as
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the model has been trained on
Args:
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
"""
relative_buckets = 0
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0).astype(mx.int16) * num_buckets
relative_position = mx.abs(relative_position)
else:
relative_position = -mx.minimum(
relative_position, mx.zeros_like(relative_position)
)
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
scale = (num_buckets - max_exact) / np.log(max_distance / max_exact)
relative_position_if_large = max_exact + (
mx.log(relative_position.astype(mx.float32) / max_exact) * scale
).astype(mx.int16)
relative_position_if_large = mx.minimum(relative_position_if_large, num_buckets - 1)
relative_buckets += mx.where(
is_small, relative_position, relative_position_if_large
)
return relative_buckets
class RelativePositionBias(nn.Module):
def __init__(self, config: T5Config, bidirectional: bool):
self.bidirectional = bidirectional
self.num_buckets = config.relative_attention_num_buckets
self.max_distance = config.relative_attention_max_distance
self.n_heads = config.num_heads
self.embeddings = nn.Embedding(
config.relative_attention_num_buckets, config.num_heads
)
def __call__(self, query_length: int, key_length: int, offset: int = 0):
"""Compute binned relative position bias"""
context_position = mx.arange(offset, query_length)[:, None]
memory_position = mx.arange(key_length)[None, :]
# shape (query_length, key_length)
relative_position = memory_position - context_position
relative_position_bucket = _relative_position_bucket(
relative_position,
bidirectional=self.bidirectional,
num_buckets=self.num_buckets,
max_distance=self.max_distance,
)
# shape (query_length, key_length, num_heads)
values = self.embeddings(relative_position_bucket)
# shape (num_heads, query_length, key_length)
return values.transpose(2, 0, 1)
class MultiHeadAttention(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
inner_dim = config.d_kv * config.num_heads
self.num_heads = config.num_heads
self.query_proj = nn.Linear(config.d_model, inner_dim, bias=False)
self.key_proj = nn.Linear(config.d_model, inner_dim, bias=False)
self.value_proj = nn.Linear(config.d_model, inner_dim, bias=False)
self.out_proj = nn.Linear(inner_dim, config.d_model, bias=False)
def __call__(
self,
queries: mx.array,
keys: mx.array,
values: mx.array,
mask: Optional[mx.array],
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> [mx.array, Tuple[mx.array, mx.array]]:
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
num_heads = self.num_heads
B, L, _ = queries.shape
_, S, _ = keys.shape
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
key_cache, value_cache = cache
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
# Dimensions are [batch x num heads x sequence x hidden dim]
scores = queries @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores = scores + mask.astype(scores.dtype)
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat), (keys, values)
class RMSNorm(nn.Module):
@ -24,177 +142,200 @@ class RMSNorm(nn.Module):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
t = x.dtype
output = self._norm(x).astype(t)
return self.weight * output
class Attention(nn.Module):
def __init__(self, config: LlamaConfig):
class DenseActivation(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
self.config = config
self.n_heads: int = config.num_attention_heads
self.n_kv_heads: int = config.num_key_value_heads
self.repeats = self.n_heads // self.n_kv_heads
self.head_dim = config.hidden_size // self.n_heads
self.scale = self.head_dim**-0.5
self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.k_proj = nn.Linear(
config.hidden_size, config.hidden_size // self.repeats, bias=False
)
self.v_proj = nn.Linear(
config.hidden_size, config.hidden_size // self.repeats, bias=False
)
self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.rope = nn.RoPE(self.head_dim, traditional=False)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(
0, 2, 1, 3
) # B, n_kv_heads, L, head_dim
def repeat(a):
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
kv_size = a.shape[-1]
return a.reshape([B, self.n_heads, -1, kv_size])
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
mlp_dims = config.d_ff or config.d_model * 4
self.gated = config.feed_forward_proj.startswith("gated")
if self.gated:
self.wi_0 = nn.Linear(config.d_model, mlp_dims, bias=False)
self.wi_1 = nn.Linear(config.d_model, mlp_dims, bias=False)
else:
queries = self.rope(queries)
keys = self.rope(keys)
self.wi = nn.Linear(config.d_model, mlp_dims, bias=False)
self.wo = nn.Linear(mlp_dims, config.d_model, bias=False)
activation = config.feed_forward_proj.removeprefix("gated-")
if activation == "relu":
self.act = nn.relu
elif activation == "gelu":
self.act = nn.gelu
elif activation == "silu":
self.act = nn.silu
else:
raise ValueError(f"Unknown activation: {activation}")
scores = (queries * self.scale) @ repeat(keys).transpose(0, 1, 3, 2)
if mask is not None:
scores += mask
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
output = (scores @ repeat(values)).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values)
def __call__(self, x):
if self.gated:
hidden_act = self.act(self.wi_0(x))
hidden_linear = self.wi_1(x)
x = hidden_act * hidden_linear
else:
x = self.act(self.wi(x))
return self.wo(x)
class FeedForward(nn.Module):
def __init__(self, config: LlamaConfig):
class TransformerEncoderLayer(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
self.gate_proj = nn.Linear(
config.hidden_size, config.intermediate_size, bias=False
)
self.down_proj = nn.Linear(
config.intermediate_size, config.hidden_size, bias=False
)
self.up_proj = nn.Linear(
config.hidden_size, config.intermediate_size, bias=False
)
self.attention = MultiHeadAttention(config)
self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dense = DenseActivation(config)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
def __call__(self, x, mask):
y = self.ln1(x)
y, _ = self.attention(y, y, y, mask=mask)
x = x + y
y = self.ln2(x)
y = self.dense(y)
return x + y
class TransformerBlock(nn.Module):
def __init__(self, config: LlamaConfig):
class TransformerEncoder(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
self.n_heads = config.num_attention_heads
self.dim = config.hidden_size
self.self_attn = Attention(config=config)
self.mlp = FeedForward(config=config)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.layers = [
TransformerEncoderLayer(config) for i in range(config.num_layers)
]
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.relative_attention_bias = RelativePositionBias(config, bidirectional=True)
def __call__(self, x: mx.array):
pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1])
for layer in self.layers:
x = layer(x, mask=pos_bias)
return self.ln(x)
class TransformerDecoderLayer(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
self.self_attention = MultiHeadAttention(config)
self.cross_attention = MultiHeadAttention(config)
self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln3 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dense = DenseActivation(config)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out, cache
memory: mx.array,
mask: mx.array,
memory_mask: mx.array,
cache: Optional[List[Tuple[mx.array, mx.array]]] = None,
):
y = self.ln1(x)
y, cache = self.self_attention(y, y, y, mask, cache)
x = x + y
y = self.ln2(x)
y, _ = self.cross_attention(y, memory, memory, memory_mask)
x = x + y
y = self.ln3(x)
y = self.dense(y)
x = x + y
return x, cache
class Llama(nn.Module):
def __init__(self, config: LlamaConfig):
def create_additive_causal_mask(N: int, offset: int = 0):
rinds = mx.arange(offset + N)
linds = mx.arange(offset, offset + N) if offset else rinds
mask = linds[:, None] < rinds[None]
return mask * -1e9
class TransformerDecoder(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = [
TransformerBlock(config=config) for _ in range(config.num_hidden_layers)
]
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
n_layers = getattr(config, "num_decoder_layers", config.num_layers)
self.layers = [TransformerDecoderLayer(config) for i in range(n_layers)]
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.relative_attention_bias = RelativePositionBias(config, bidirectional=False)
def __call__(self, x, memory, cache=None):
if cache[0] is not None:
offset = cache[0][0].shape[2]
else:
offset = 0
T = x.shape[1]
if T > 1:
mask = create_additive_causal_mask(T, offset)
else:
mask = None
pos_bias = self.relative_attention_bias(T + offset, T + offset, offset=offset)
if mask is not None:
mask += pos_bias
else:
mask = pos_bias
for e, layer in enumerate(self.layers):
x, cache[e] = layer(x, memory, mask, None, cache=cache[e])
x = self.ln(x)
return x, cache
class OutputHead(nn.Module):
def __init__(self, config: T5Config):
self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False)
def __call__(self, inputs):
return self.linear(inputs)
class Model(nn.Module):
def __init__(self, config: T5Config):
self.wte = nn.Embedding(config.vocab_size, config.d_model)
self.encoder = TransformerEncoder(config)
self.decoder = TransformerDecoder(config)
self.tie_word_embeddings = config.tie_word_embeddings
if not self.tie_word_embeddings:
self.lm_head = OutputHead(config)
self.model_dim = config.d_model
self.reset_cache()
def encode(self, inputs: mx.array):
return self.encoder(self.wte(inputs))
def truncate_cache(self, num_to_truncate):
if num_to_truncate <= 0:
return
cache_length = self.kv_cache[0][0].shape[2]
cache_length = self.cache[0][0].shape[2]
if num_to_truncate < cache_length:
self.kv_cache = tree_map(
lambda x: x[:, :, :-num_to_truncate, :], self.kv_cache
)
self.cache = tree_map(lambda x: x[:, :, :-num_to_truncate, :], self.cache)
else:
self.reset_cache()
def reset_cache(self):
self.kv_cache = [None] * len(self.layers)
self.cache = [None] * len(self.decoder.layers)
def decode(
self,
inputs: mx.array,
memory: mx.array,
):
inputs = self.wte(inputs)
y, self.cache = self.decoder(inputs, memory=memory, cache=self.cache)
if not self.tie_word_embeddings:
y *= self.model_dim**-0.5
y = self.lm_head(y)
else:
y = y @ self.wte.weight.T
return y
def __call__(
self,
x: mx.array,
next_tokens: int = -1,
inputs: mx.array,
decoder_inputs: mx.array,
):
if self.kv_cache[0]:
offset = self.kv_cache[0][0].shape[-2]
else:
offset = 0
if x.shape[1] > 1:
mask = create_additive_causal_mask(x.shape[1], offset)
mask = mask.astype(self.embed_tokens.weight.dtype)
else:
mask = None
x = self.embed_tokens(x)
for idx, layer in enumerate(self.layers):
x, self.kv_cache[idx] = layer(x, mask, cache=self.kv_cache[idx])
if next_tokens > 0:
x = x[:, -next_tokens:]
x = self.norm(x)
return self.lm_head(x)
@classmethod
def from_hugging_face(cls, model_path: str, quantized: bool = True):
config = LlamaConfig.from_pretrained(model_path)
torch_weights = AutoModelForCausalLM.from_pretrained(model_path).state_dict()
weights = {
k.replace("model.", ""): mx.array(v.numpy(), mx.float16)
for k, v in torch_weights.items()
}
model = cls(config)
model.update(tree_unflatten(list(weights.items())))
# if quantization is not None:
# nn.QuantizedLinear.quantize_module(model, **quantization)
mx.eval(model.parameters())
return model
return self.decode(decoder_inputs, self.encode(inputs))[0]

View File

@ -1,3 +1,4 @@
mlx>=0.0.5
mlx>=0.0.6
transformers
numpy
numpy
accelerate

View File

@ -125,7 +125,6 @@ class MultiHeadAttention(nn.Module):
values = mx.concatenate([value_cache, values], axis=2)
# Dimensions are [batch x num heads x sequence x hidden dim]
queries = queries
scores = queries @ keys
if mask is not None:
scores = scores + mask.astype(scores.dtype)