mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
switch to t5
This commit is contained in:
@@ -1,11 +1,11 @@
|
|||||||
# Speculative Decoding
|
# Speculative Decoding
|
||||||
|
|
||||||
This example implements [speculative decoding] for text generation.[^1].
|
This example implements speculative decoding with the T5 model for text
|
||||||
Speculative decoding uses a smaller draft model to propose several tokens, and
|
generation.[^1] Speculative decoding uses a smaller draft model to propose
|
||||||
then a larger model which decides which tokens to accept. The generated text is
|
several tokens, and a larger model to decide which tokens to accept. The
|
||||||
identical to what the larger model would produce on its own, but with far fewer
|
distribution of the generated text is identical to what the larger model would
|
||||||
forward passes of the large model since it can evaluate the draft tokens in
|
produce on its own, but with far fewer forward passes of the large model since
|
||||||
parallel.
|
it can evaluate the draft tokens in parallel.
|
||||||
|
|
||||||
### Setup
|
### Setup
|
||||||
|
|
||||||
@@ -16,6 +16,19 @@ cd speculative_decoding
|
|||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Then convert the model and the draft model. For example, you can convert th
|
||||||
|
T5 11B model with:
|
||||||
|
|
||||||
|
```
|
||||||
|
python convert.py --model t5-11b
|
||||||
|
```
|
||||||
|
|
||||||
|
And for the draft model, convert the T5 small model with:
|
||||||
|
|
||||||
|
```
|
||||||
|
python convert.py --model t5-small
|
||||||
|
```
|
||||||
|
|
||||||
### Run
|
### Run
|
||||||
|
|
||||||
You can run with the default arguments:
|
You can run with the default arguments:
|
||||||
@@ -24,9 +37,27 @@ You can run with the default arguments:
|
|||||||
python main.py
|
python main.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
To see a full list of options use:
|
||||||
|
```
|
||||||
|
python main.py --help
|
||||||
|
```
|
||||||
|
|
||||||
|
### Notes
|
||||||
|
|
||||||
Speculative decoding works well when most of the tokens from the draft model
|
Speculative decoding works well when most of the tokens from the draft model
|
||||||
are accepted by the larger model. That's more likely to happen if the models
|
are accepted by the larger model. That's more likely to happen if the models
|
||||||
are trained on similar data. The default setting in this example uses TinyLlama
|
are trained on similar data.
|
||||||
as a draft morel for Llama 7B.
|
|
||||||
|
|
||||||
[^1] See the paper [Fast Inference from Transformers via Speculative Decoding](https://arxiv.org/abs/2211.17192)
|
One way to increase the chance of accepting a draft token is with the parameter
|
||||||
|
`--delta`. This parameter can be in the range `[0, 1]`. If it is `1` then all
|
||||||
|
the draft tokens will be accepted by the model. If it is `0`, then only draft
|
||||||
|
tokens which match the original acceptance criterion kept.[^1] Values closer to
|
||||||
|
`1` increase the chance that a draft token is accepted.
|
||||||
|
|
||||||
|
Conversely, the fewer draft tokens accepted by the model, the more expensive
|
||||||
|
speculative decoding is. You can use `--draft` to tune the number of draft
|
||||||
|
tokens per model evaluation in order to reduce the number of discarded draft
|
||||||
|
tokens.
|
||||||
|
|
||||||
|
[^1] See the paper [Fast Inference from Transformers via Speculative
|
||||||
|
Decoding](https://arxiv.org/abs/2211.17192)
|
||||||
|
75
llms/speculative_decoding/convert.py
Normal file
75
llms/speculative_decoding/convert.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
import numpy as np
|
||||||
|
from transformers import T5ForConditionalGeneration
|
||||||
|
|
||||||
|
SHARED_REPLACEMENT_PATTERNS = [
|
||||||
|
(".block.", ".layers."),
|
||||||
|
(".k.", ".key_proj."),
|
||||||
|
(".o.", ".out_proj."),
|
||||||
|
(".q.", ".query_proj."),
|
||||||
|
(".v.", ".value_proj."),
|
||||||
|
("shared.", "wte."),
|
||||||
|
("lm_head.", "lm_head.linear."),
|
||||||
|
(".layer.0.layer_norm.", ".ln1."),
|
||||||
|
(".layer.1.layer_norm.", ".ln2."),
|
||||||
|
(".layer.2.layer_norm.", ".ln3."),
|
||||||
|
(".final_layer_norm.", ".ln."),
|
||||||
|
(
|
||||||
|
"layers.0.layer.0.SelfAttention.relative_attention_bias.",
|
||||||
|
"relative_attention_bias.embeddings.",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
ENCODER_REPLACEMENT_PATTERNS = [
|
||||||
|
(".layer.0.SelfAttention.", ".attention."),
|
||||||
|
(".layer.1.DenseReluDense.", ".dense."),
|
||||||
|
]
|
||||||
|
|
||||||
|
DECODER_REPLACEMENT_PATTERNS = [
|
||||||
|
(".layer.0.SelfAttention.", ".self_attention."),
|
||||||
|
(".layer.1.EncDecAttention.", ".cross_attention."),
|
||||||
|
(".layer.2.DenseReluDense.", ".dense."),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def replace_key(key: str) -> str:
|
||||||
|
for old, new in SHARED_REPLACEMENT_PATTERNS:
|
||||||
|
key = key.replace(old, new)
|
||||||
|
if key.startswith("encoder."):
|
||||||
|
for old, new in ENCODER_REPLACEMENT_PATTERNS:
|
||||||
|
key = key.replace(old, new)
|
||||||
|
elif key.startswith("decoder."):
|
||||||
|
for old, new in DECODER_REPLACEMENT_PATTERNS:
|
||||||
|
key = key.replace(old, new)
|
||||||
|
return key
|
||||||
|
|
||||||
|
|
||||||
|
def convert(model_name, dtype):
|
||||||
|
dtype = getattr(np, dtype)
|
||||||
|
model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")
|
||||||
|
weights = {
|
||||||
|
replace_key(k): v.numpy().astype(dtype) for k, v in model.state_dict().items()
|
||||||
|
}
|
||||||
|
file_name = model_name.replace("/", "-")
|
||||||
|
print(f"Saving weights to {file_name}.npz")
|
||||||
|
np.savez(f"{file_name}.npz", **weights)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Convert T5 weights to MLX")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
type=str,
|
||||||
|
help="Name of the T5 model.",
|
||||||
|
default="t5-small",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype",
|
||||||
|
help="The model data type.",
|
||||||
|
type=str,
|
||||||
|
choices=["float16", "float32"],
|
||||||
|
default="float32",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
convert(args.model, args.dtype)
|
@@ -1,4 +1,3 @@
|
|||||||
import time
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
@@ -6,18 +5,26 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import transformers
|
import transformers
|
||||||
from model import Llama
|
from model import Model
|
||||||
from prompts import create_urial_prompt
|
|
||||||
|
|
||||||
|
|
||||||
class Tokenizer:
|
class Tokenizer:
|
||||||
def __init__(self, model_name: str):
|
def __init__(self, model_name: str):
|
||||||
self._tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
|
self._tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
legacy=False,
|
||||||
|
model_max_length=512,
|
||||||
|
)
|
||||||
|
self._decoder_start_id = 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def eos_id(self) -> int:
|
def eos_id(self) -> int:
|
||||||
return self._tokenizer.eos_token_id
|
return self._tokenizer.eos_token_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def decoder_start_id(self) -> int:
|
||||||
|
return self._decoder_start_id
|
||||||
|
|
||||||
def encode(self, s: str) -> mx.array:
|
def encode(self, s: str) -> mx.array:
|
||||||
return mx.array(
|
return mx.array(
|
||||||
self._tokenizer(s, return_tensors="np", return_attention_mask=False,)[
|
self._tokenizer(s, return_tensors="np", return_attention_mask=False,)[
|
||||||
@@ -25,44 +32,34 @@ class Tokenizer:
|
|||||||
].squeeze(0)
|
].squeeze(0)
|
||||||
)
|
)
|
||||||
|
|
||||||
def decode(self, t: List[int], with_sep: bool = True) -> str:
|
def decode(self, t: List[int]) -> str:
|
||||||
tokens = self._tokenizer.convert_ids_to_tokens(t)
|
return self._tokenizer.decode(t)
|
||||||
return "".join(t.replace("▁", " " if with_sep else "") for t in tokens)
|
|
||||||
|
|
||||||
|
|
||||||
class SpeculativeDecoder:
|
class SpeculativeDecoder:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: Model,
|
||||||
draft_model: str = None,
|
draft_model: Model,
|
||||||
|
tokenizer: str,
|
||||||
num_draft: int = 5,
|
num_draft: int = 5,
|
||||||
delta: float = 0.0,
|
delta: float = 0.0,
|
||||||
):
|
):
|
||||||
self.tokenizer = Tokenizer(model)
|
self.tokenizer = Tokenizer(tokenizer)
|
||||||
self.model = Llama.from_hugging_face(model)
|
self.model = model
|
||||||
if draft_model is not None:
|
self.draft_model = draft_model
|
||||||
self.draft_model = Llama.from_hugging_face(draft_model)
|
|
||||||
self.num_draft = num_draft
|
self.num_draft = num_draft
|
||||||
self.delta = delta
|
self.delta = delta
|
||||||
|
|
||||||
def tokenize(self, prompt):
|
|
||||||
# if self.tokenizer.chat_template is not None:
|
|
||||||
# tokenized = self.tokenizer.apply_chat_template(
|
|
||||||
# prompt, tokenize=True, add_generation_prompt=True
|
|
||||||
# )
|
|
||||||
# else:
|
|
||||||
# use urial zero-shot template
|
|
||||||
tokenized = self.tokenizer.encode(create_urial_prompt(prompt["content"]))
|
|
||||||
return tokenized
|
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
|
memory: mx.array,
|
||||||
draft: bool = False,
|
draft: bool = False,
|
||||||
):
|
):
|
||||||
model = self.draft_model if draft else self.model
|
model = self.draft_model if draft else self.model
|
||||||
while True:
|
while True:
|
||||||
logits = model(x[None, :], next_tokens=1).squeeze((0, 1))
|
logits = model.decode(x[None], memory)[0, -1]
|
||||||
x = mx.argmax(logits, keepdims=True)
|
x = mx.argmax(logits, keepdims=True)
|
||||||
lognorm = mx.logsumexp(logits.astype(mx.float32))
|
lognorm = mx.logsumexp(logits.astype(mx.float32))
|
||||||
logprob = logits[x] - lognorm
|
logprob = logits[x] - lognorm
|
||||||
@@ -72,25 +69,26 @@ class SpeculativeDecoder:
|
|||||||
self,
|
self,
|
||||||
prompt,
|
prompt,
|
||||||
max_tokens: int = 100,
|
max_tokens: int = 100,
|
||||||
draft: bool = False,
|
|
||||||
):
|
):
|
||||||
x = self.tokenize(prompt)
|
memory = self.model.encode(self.tokenizer.encode(prompt)[None])
|
||||||
start = time.time()
|
x = mx.array([self.tokenizer.decoder_start_id])
|
||||||
for (token, _), n in zip(self._generate(x, draft=draft), range(max_tokens)):
|
skip = 0
|
||||||
token = token.item()
|
outputs = []
|
||||||
|
for (token, _), n in zip(self._generate(x, memory), range(max_tokens)):
|
||||||
if token == self.tokenizer.eos_id:
|
if token == self.tokenizer.eos_id:
|
||||||
break
|
break
|
||||||
print(self.tokenizer.decode(token, with_sep=n > 0), end="", flush=True)
|
outputs.append(token.item())
|
||||||
run_time = time.time() - start
|
if (n + 1) % 10 == 0:
|
||||||
|
str_output = self.tokenizer.decode(outputs)
|
||||||
|
print(str_output[skip:], end="", flush=True)
|
||||||
|
skip = len(str_output)
|
||||||
|
|
||||||
|
print(self.tokenizer.decode(outputs)[skip:], end="", flush=True)
|
||||||
print()
|
print()
|
||||||
print(f"=== GENERATED {n + 1} TOKENS in {run_time} SECONDS ===")
|
self.model.reset_cache()
|
||||||
if draft:
|
|
||||||
self.draft_model.reset_cache()
|
|
||||||
else:
|
|
||||||
self.model.reset_cache()
|
|
||||||
|
|
||||||
def _get_num_accept(self, draft_tokens, draft_probs, model_logits):
|
def _get_num_accept(self, draft_tokens, draft_probs, model_logits):
|
||||||
# equal_toks = sampled[:-1] == draft_tokens
|
# accept_toks = mx.argmax(model_logits, axis=-1) == draft_tokens
|
||||||
model_probs = mx.take_along_axis(
|
model_probs = mx.take_along_axis(
|
||||||
model_logits,
|
model_logits,
|
||||||
draft_tokens[:, None],
|
draft_tokens[:, None],
|
||||||
@@ -111,14 +109,19 @@ class SpeculativeDecoder:
|
|||||||
def sample(logits):
|
def sample(logits):
|
||||||
return mx.argmax(logits, axis=-1)
|
return mx.argmax(logits, axis=-1)
|
||||||
|
|
||||||
tokens = mx.array(self.tokenize(prompt), mx.uint32)
|
prompt = mx.array(self.tokenizer.encode(prompt), mx.uint32)[None]
|
||||||
start = time.time()
|
memory = self.model.encode(prompt)
|
||||||
|
draft_memory = self.draft_model.encode(prompt)
|
||||||
|
|
||||||
decoding_steps = 0
|
tokens = mx.array([self.tokenizer.decoder_start_id])
|
||||||
|
|
||||||
|
n_steps = 0
|
||||||
ntoks = 0
|
ntoks = 0
|
||||||
accepted_draft_tokens = 0
|
n_accepted = 0
|
||||||
total_draft_tokens = 0
|
n_draft = 0
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
skip = 0
|
||||||
draft_inputs = tokens
|
draft_inputs = tokens
|
||||||
inputs = tokens
|
inputs = tokens
|
||||||
while True:
|
while True:
|
||||||
@@ -127,7 +130,7 @@ class SpeculativeDecoder:
|
|||||||
draft_probs = []
|
draft_probs = []
|
||||||
for _, (t, p) in zip(
|
for _, (t, p) in zip(
|
||||||
range(ntoks, min(ntoks + self.num_draft, max_tokens)),
|
range(ntoks, min(ntoks + self.num_draft, max_tokens)),
|
||||||
self._generate(draft_inputs, draft=True),
|
self._generate(draft_inputs, draft_memory, draft=True),
|
||||||
):
|
):
|
||||||
draft_tokens.append(t)
|
draft_tokens.append(t)
|
||||||
draft_probs.append(p)
|
draft_probs.append(p)
|
||||||
@@ -138,10 +141,10 @@ class SpeculativeDecoder:
|
|||||||
draft_tokens = mx.concatenate(draft_tokens)
|
draft_tokens = mx.concatenate(draft_tokens)
|
||||||
draft_probs = mx.concatenate(draft_probs)
|
draft_probs = mx.concatenate(draft_probs)
|
||||||
verify_tokens = mx.concatenate([inputs, draft_tokens])
|
verify_tokens = mx.concatenate([inputs, draft_tokens])
|
||||||
logits = self.model(
|
logits = self.model.decode(
|
||||||
verify_tokens[None, :], next_tokens=draft_tokens.size + 1
|
verify_tokens[None, :],
|
||||||
|
memory,
|
||||||
).squeeze(0)
|
).squeeze(0)
|
||||||
# sampled = sample(logits).squeeze(0)
|
|
||||||
|
|
||||||
# Only keep samples that match the draft:
|
# Only keep samples that match the draft:
|
||||||
num_to_accept = self._get_num_accept(
|
num_to_accept = self._get_num_accept(
|
||||||
@@ -155,38 +158,34 @@ class SpeculativeDecoder:
|
|||||||
[new_tokens, mx.argmax(logits[num_to_accept], keepdims=True)]
|
[new_tokens, mx.argmax(logits[num_to_accept], keepdims=True)]
|
||||||
)
|
)
|
||||||
|
|
||||||
accepted_draft_tokens += num_to_accept
|
n_accepted += num_to_accept
|
||||||
total_draft_tokens += draft_tokens.size
|
n_draft += draft_tokens.size
|
||||||
|
|
||||||
# Rewind the cache for unaccepted tokens:
|
# Rewind the cache for unaccepted tokens:
|
||||||
if (n := draft_tokens.size) > num_to_accept:
|
if (n := draft_tokens.size) > num_to_accept:
|
||||||
self.draft_model.truncate_cache(n - new_tokens.size)
|
self.draft_model.truncate_cache(n - new_tokens.size)
|
||||||
self.model.truncate_cache(n - new_tokens.size + 1)
|
self.model.truncate_cache(n - new_tokens.size + 1)
|
||||||
|
|
||||||
decoding_steps += 1
|
n_steps += 1
|
||||||
|
|
||||||
for t in new_tokens.tolist():
|
for t in new_tokens.tolist():
|
||||||
if t == self.tokenizer.eos_id or ntoks >= max_tokens:
|
if t == self.tokenizer.eos_id or ntoks >= max_tokens:
|
||||||
break
|
break
|
||||||
print(self.tokenizer.decode(t, with_sep=ntoks > 0), end="", flush=True)
|
outputs.append(t)
|
||||||
ntoks += 1
|
ntoks += 1
|
||||||
|
|
||||||
|
str_output = self.tokenizer.decode(outputs)
|
||||||
|
print(str_output[skip:], end="", flush=True)
|
||||||
|
skip = len(str_output)
|
||||||
|
|
||||||
if ntoks >= max_tokens or new_tokens[-1] == self.tokenizer.eos_id:
|
if ntoks >= max_tokens or new_tokens[-1] == self.tokenizer.eos_id:
|
||||||
break
|
break
|
||||||
draft_inputs = new_tokens[max(new_tokens.size - 2, 0) :]
|
draft_inputs = new_tokens[max(new_tokens.size - 2, 0) :]
|
||||||
inputs = draft_inputs[-1:]
|
inputs = draft_inputs[-1:]
|
||||||
|
|
||||||
end = time.time()
|
print(self.tokenizer.decode(outputs)[skip:], end="", flush=True)
|
||||||
|
print()
|
||||||
|
|
||||||
self.model.reset_cache()
|
self.model.reset_cache()
|
||||||
self.draft_model.reset_cache()
|
self.draft_model.reset_cache()
|
||||||
print()
|
return {"n_accepted": n_accepted, "n_draft": n_draft, "n_steps": n_steps}
|
||||||
print(
|
|
||||||
"=== GENERATED",
|
|
||||||
ntoks,
|
|
||||||
"TOKENS IN",
|
|
||||||
round(end - start, 2),
|
|
||||||
"SECONDS ===",
|
|
||||||
)
|
|
||||||
print(
|
|
||||||
f"=== ACCEPTED {accepted_draft_tokens} of {total_draft_tokens} DRAFT TOKENS ==="
|
|
||||||
)
|
|
||||||
print("=== DECODING STEPS", decoding_steps, "===")
|
|
||||||
|
@@ -1,31 +1,51 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
import glob
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
from decoder import SpeculativeDecoder
|
from decoder import SpeculativeDecoder
|
||||||
|
from mlx.utils import tree_unflatten
|
||||||
|
from model import Model
|
||||||
|
from transformers import T5Config
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(model_name: str):
|
||||||
|
config = T5Config.from_pretrained(model_name)
|
||||||
|
model = Model(config)
|
||||||
|
weights = mx.load(f"{model_name}.npz")
|
||||||
|
weights = tree_unflatten(list(weights.items()))
|
||||||
|
model.update(weights)
|
||||||
|
mx.eval(model.parameters())
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
mx.random.seed(args.seed)
|
mx.random.seed(args.seed)
|
||||||
|
|
||||||
spec_decoder = SpeculativeDecoder(
|
spec_decoder = SpeculativeDecoder(
|
||||||
# model="TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T",
|
model=load_model(args.model_name),
|
||||||
model="meta-llama/Llama-2-7b-hf",
|
draft_model=load_model(args.draft_model_name),
|
||||||
draft_model="TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T",
|
tokenizer=args.model_name,
|
||||||
delta=args.delta,
|
delta=args.delta,
|
||||||
num_draft=args.num_draft,
|
num_draft=args.num_draft,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = {"role": "user", "content": "Finish the monologue: To be, or not to be..."}
|
tic = time.time()
|
||||||
|
print(args.prompt)
|
||||||
|
if args.regular_decode:
|
||||||
|
spec_decoder.generate(args.prompt, max_tokens=args.max_tokens)
|
||||||
|
else:
|
||||||
|
stats = spec_decoder.speculative_decode(args.prompt, max_tokens=args.max_tokens)
|
||||||
|
print("=" * 10)
|
||||||
|
print(f"Accepted {stats['n_accepted']} / {stats['n_draft']}.")
|
||||||
|
print(f"Decoding steps {stats['n_steps']}.")
|
||||||
|
|
||||||
# Do 1 regular generation to get warmed up (the first one is slow)
|
toc = time.time()
|
||||||
# engine.generate(messages, max_tokens=1)
|
print("=" * 10)
|
||||||
# engine.generate(messages, max_tokens=1, draft=True)
|
print(f"Full generation time {toc - tic:.3f}")
|
||||||
|
|
||||||
# Time regular generation
|
|
||||||
spec_decoder.generate(prompt, max_tokens=125)
|
|
||||||
|
|
||||||
# Time speculative decoding
|
|
||||||
spec_decoder.speculative_decode(prompt, max_tokens=125)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@@ -36,17 +56,44 @@ if __name__ == "__main__":
|
|||||||
default=5,
|
default=5,
|
||||||
help="Number of draft tokens to use per decoding step.",
|
help="Number of draft tokens to use per decoding step.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-name",
|
||||||
|
help="Name of the model.",
|
||||||
|
default="t5-small",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--draft-model-name",
|
||||||
|
help="Name of the draft model.",
|
||||||
|
default="t5-small",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--seed",
|
"--seed",
|
||||||
type=int,
|
type=int,
|
||||||
default=0,
|
default=0,
|
||||||
help="PRNG seed.",
|
help="PRNG seed.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-tokens",
|
||||||
|
"-m",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="Maximum number of tokens to generate.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt",
|
||||||
|
default="translate English to French: Let's go to the store and buy some groceries including eggs, avocadoes, and bread.",
|
||||||
|
help="The prompt processed by the model.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--delta",
|
"--delta",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.1,
|
default=0.1,
|
||||||
help="Lenience for accepting the proposal tokens.",
|
help="Lenience for accepting the proposal tokens.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--regular-decode",
|
||||||
|
action="store_true",
|
||||||
|
help="Use regular decoding instead of speculative decoding.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
@@ -1,17 +1,135 @@
|
|||||||
from typing import Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
import numpy as np
|
||||||
from mlx.utils import tree_map, tree_unflatten
|
from mlx.utils import tree_map, tree_unflatten
|
||||||
from transformers import AutoModelForCausalLM, LlamaConfig
|
from transformers import AutoTokenizer, T5Config
|
||||||
|
|
||||||
|
|
||||||
def create_additive_causal_mask(N: int, offset: int = 0, dtype: mx.Dtype = mx.float32):
|
def _relative_position_bucket(
|
||||||
rinds = mx.arange(offset + N)
|
relative_position, bidirectional=True, num_buckets=32, max_distance=128
|
||||||
linds = mx.arange(offset, offset + N) if offset else rinds
|
):
|
||||||
mask = linds[:, None] < rinds[None]
|
"""
|
||||||
mask = mask.astype(dtype) * -1e9
|
Adapted from HF Tensorflow:
|
||||||
return mask
|
https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
||||||
|
|
||||||
|
Translate relative position to a bucket number for relative attention. The relative position is defined as
|
||||||
|
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
|
||||||
|
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
|
||||||
|
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
|
||||||
|
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
|
||||||
|
This should allow for more graceful generalization to longer sequences than the model has been trained on
|
||||||
|
|
||||||
|
Args:
|
||||||
|
relative_position: an int32 Tensor
|
||||||
|
bidirectional: a boolean - whether the attention is bidirectional
|
||||||
|
num_buckets: an integer
|
||||||
|
max_distance: an integer
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
|
||||||
|
"""
|
||||||
|
relative_buckets = 0
|
||||||
|
if bidirectional:
|
||||||
|
num_buckets //= 2
|
||||||
|
relative_buckets += (relative_position > 0).astype(mx.int16) * num_buckets
|
||||||
|
relative_position = mx.abs(relative_position)
|
||||||
|
else:
|
||||||
|
relative_position = -mx.minimum(
|
||||||
|
relative_position, mx.zeros_like(relative_position)
|
||||||
|
)
|
||||||
|
# now relative_position is in the range [0, inf)
|
||||||
|
|
||||||
|
# half of the buckets are for exact increments in positions
|
||||||
|
max_exact = num_buckets // 2
|
||||||
|
is_small = relative_position < max_exact
|
||||||
|
|
||||||
|
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
|
||||||
|
scale = (num_buckets - max_exact) / np.log(max_distance / max_exact)
|
||||||
|
relative_position_if_large = max_exact + (
|
||||||
|
mx.log(relative_position.astype(mx.float32) / max_exact) * scale
|
||||||
|
).astype(mx.int16)
|
||||||
|
relative_position_if_large = mx.minimum(relative_position_if_large, num_buckets - 1)
|
||||||
|
relative_buckets += mx.where(
|
||||||
|
is_small, relative_position, relative_position_if_large
|
||||||
|
)
|
||||||
|
return relative_buckets
|
||||||
|
|
||||||
|
|
||||||
|
class RelativePositionBias(nn.Module):
|
||||||
|
def __init__(self, config: T5Config, bidirectional: bool):
|
||||||
|
self.bidirectional = bidirectional
|
||||||
|
self.num_buckets = config.relative_attention_num_buckets
|
||||||
|
self.max_distance = config.relative_attention_max_distance
|
||||||
|
self.n_heads = config.num_heads
|
||||||
|
self.embeddings = nn.Embedding(
|
||||||
|
config.relative_attention_num_buckets, config.num_heads
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, query_length: int, key_length: int, offset: int = 0):
|
||||||
|
"""Compute binned relative position bias"""
|
||||||
|
context_position = mx.arange(offset, query_length)[:, None]
|
||||||
|
memory_position = mx.arange(key_length)[None, :]
|
||||||
|
|
||||||
|
# shape (query_length, key_length)
|
||||||
|
relative_position = memory_position - context_position
|
||||||
|
relative_position_bucket = _relative_position_bucket(
|
||||||
|
relative_position,
|
||||||
|
bidirectional=self.bidirectional,
|
||||||
|
num_buckets=self.num_buckets,
|
||||||
|
max_distance=self.max_distance,
|
||||||
|
)
|
||||||
|
|
||||||
|
# shape (query_length, key_length, num_heads)
|
||||||
|
values = self.embeddings(relative_position_bucket)
|
||||||
|
|
||||||
|
# shape (num_heads, query_length, key_length)
|
||||||
|
return values.transpose(2, 0, 1)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadAttention(nn.Module):
|
||||||
|
def __init__(self, config: T5Config):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = config.d_kv * config.num_heads
|
||||||
|
self.num_heads = config.num_heads
|
||||||
|
self.query_proj = nn.Linear(config.d_model, inner_dim, bias=False)
|
||||||
|
self.key_proj = nn.Linear(config.d_model, inner_dim, bias=False)
|
||||||
|
self.value_proj = nn.Linear(config.d_model, inner_dim, bias=False)
|
||||||
|
self.out_proj = nn.Linear(inner_dim, config.d_model, bias=False)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
queries: mx.array,
|
||||||
|
keys: mx.array,
|
||||||
|
values: mx.array,
|
||||||
|
mask: Optional[mx.array],
|
||||||
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
) -> [mx.array, Tuple[mx.array, mx.array]]:
|
||||||
|
queries = self.query_proj(queries)
|
||||||
|
keys = self.key_proj(keys)
|
||||||
|
values = self.value_proj(values)
|
||||||
|
|
||||||
|
num_heads = self.num_heads
|
||||||
|
B, L, _ = queries.shape
|
||||||
|
_, S, _ = keys.shape
|
||||||
|
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
key_cache, value_cache = cache
|
||||||
|
keys = mx.concatenate([key_cache, keys], axis=2)
|
||||||
|
values = mx.concatenate([value_cache, values], axis=2)
|
||||||
|
|
||||||
|
# Dimensions are [batch x num heads x sequence x hidden dim]
|
||||||
|
scores = queries @ keys.transpose(0, 1, 3, 2)
|
||||||
|
if mask is not None:
|
||||||
|
scores = scores + mask.astype(scores.dtype)
|
||||||
|
|
||||||
|
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||||
|
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
return self.out_proj(values_hat), (keys, values)
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
class RMSNorm(nn.Module):
|
||||||
@@ -24,177 +142,200 @@ class RMSNorm(nn.Module):
|
|||||||
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
t = x.dtype
|
||||||
|
output = self._norm(x).astype(t)
|
||||||
return self.weight * output
|
return self.weight * output
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class DenseActivation(nn.Module):
|
||||||
def __init__(self, config: LlamaConfig):
|
def __init__(self, config: T5Config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
mlp_dims = config.d_ff or config.d_model * 4
|
||||||
|
self.gated = config.feed_forward_proj.startswith("gated")
|
||||||
self.n_heads: int = config.num_attention_heads
|
if self.gated:
|
||||||
self.n_kv_heads: int = config.num_key_value_heads
|
self.wi_0 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||||
self.repeats = self.n_heads // self.n_kv_heads
|
self.wi_1 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||||
self.head_dim = config.hidden_size // self.n_heads
|
|
||||||
self.scale = self.head_dim**-0.5
|
|
||||||
|
|
||||||
self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
|
|
||||||
self.k_proj = nn.Linear(
|
|
||||||
config.hidden_size, config.hidden_size // self.repeats, bias=False
|
|
||||||
)
|
|
||||||
self.v_proj = nn.Linear(
|
|
||||||
config.hidden_size, config.hidden_size // self.repeats, bias=False
|
|
||||||
)
|
|
||||||
self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
|
|
||||||
self.rope = nn.RoPE(self.head_dim, traditional=False)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
x: mx.array,
|
|
||||||
mask: Optional[mx.array] = None,
|
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
|
||||||
) -> mx.array:
|
|
||||||
B, L, D = x.shape
|
|
||||||
|
|
||||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
|
||||||
|
|
||||||
# Prepare the queries, keys and values for the attention computation
|
|
||||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
|
||||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
|
||||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(
|
|
||||||
0, 2, 1, 3
|
|
||||||
) # B, n_kv_heads, L, head_dim
|
|
||||||
|
|
||||||
def repeat(a):
|
|
||||||
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
|
|
||||||
kv_size = a.shape[-1]
|
|
||||||
return a.reshape([B, self.n_heads, -1, kv_size])
|
|
||||||
|
|
||||||
if cache is not None:
|
|
||||||
key_cache, value_cache = cache
|
|
||||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
|
||||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
|
||||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
|
||||||
values = mx.concatenate([value_cache, values], axis=2)
|
|
||||||
else:
|
else:
|
||||||
queries = self.rope(queries)
|
self.wi = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||||
keys = self.rope(keys)
|
self.wo = nn.Linear(mlp_dims, config.d_model, bias=False)
|
||||||
|
activation = config.feed_forward_proj.removeprefix("gated-")
|
||||||
|
if activation == "relu":
|
||||||
|
self.act = nn.relu
|
||||||
|
elif activation == "gelu":
|
||||||
|
self.act = nn.gelu
|
||||||
|
elif activation == "silu":
|
||||||
|
self.act = nn.silu
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown activation: {activation}")
|
||||||
|
|
||||||
scores = (queries * self.scale) @ repeat(keys).transpose(0, 1, 3, 2)
|
def __call__(self, x):
|
||||||
if mask is not None:
|
if self.gated:
|
||||||
scores += mask
|
hidden_act = self.act(self.wi_0(x))
|
||||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
hidden_linear = self.wi_1(x)
|
||||||
output = (scores @ repeat(values)).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
x = hidden_act * hidden_linear
|
||||||
return self.o_proj(output), (keys, values)
|
else:
|
||||||
|
x = self.act(self.wi(x))
|
||||||
|
return self.wo(x)
|
||||||
|
|
||||||
|
|
||||||
class FeedForward(nn.Module):
|
class TransformerEncoderLayer(nn.Module):
|
||||||
def __init__(self, config: LlamaConfig):
|
def __init__(self, config: T5Config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.gate_proj = nn.Linear(
|
self.attention = MultiHeadAttention(config)
|
||||||
config.hidden_size, config.intermediate_size, bias=False
|
self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
)
|
self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
self.down_proj = nn.Linear(
|
self.dense = DenseActivation(config)
|
||||||
config.intermediate_size, config.hidden_size, bias=False
|
|
||||||
)
|
|
||||||
self.up_proj = nn.Linear(
|
|
||||||
config.hidden_size, config.intermediate_size, bias=False
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(self, x) -> mx.array:
|
def __call__(self, x, mask):
|
||||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
y = self.ln1(x)
|
||||||
|
y, _ = self.attention(y, y, y, mask=mask)
|
||||||
|
x = x + y
|
||||||
|
|
||||||
|
y = self.ln2(x)
|
||||||
|
y = self.dense(y)
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
|
||||||
class TransformerBlock(nn.Module):
|
class TransformerEncoder(nn.Module):
|
||||||
def __init__(self, config: LlamaConfig):
|
def __init__(self, config: T5Config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.n_heads = config.num_attention_heads
|
self.layers = [
|
||||||
self.dim = config.hidden_size
|
TransformerEncoderLayer(config) for i in range(config.num_layers)
|
||||||
self.self_attn = Attention(config=config)
|
]
|
||||||
self.mlp = FeedForward(config=config)
|
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.relative_attention_bias = RelativePositionBias(config, bidirectional=True)
|
||||||
self.post_attention_layernorm = RMSNorm(
|
|
||||||
config.hidden_size, eps=config.rms_norm_eps
|
def __call__(self, x: mx.array):
|
||||||
)
|
pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1])
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x, mask=pos_bias)
|
||||||
|
return self.ln(x)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerDecoderLayer(nn.Module):
|
||||||
|
def __init__(self, config: T5Config):
|
||||||
|
super().__init__()
|
||||||
|
self.self_attention = MultiHeadAttention(config)
|
||||||
|
self.cross_attention = MultiHeadAttention(config)
|
||||||
|
self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
|
self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
|
self.ln3 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
|
self.dense = DenseActivation(config)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
memory: mx.array,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
mask: mx.array,
|
||||||
) -> mx.array:
|
memory_mask: mx.array,
|
||||||
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
|
cache: Optional[List[Tuple[mx.array, mx.array]]] = None,
|
||||||
h = x + r
|
):
|
||||||
r = self.mlp(self.post_attention_layernorm(h))
|
y = self.ln1(x)
|
||||||
out = h + r
|
y, cache = self.self_attention(y, y, y, mask, cache)
|
||||||
return out, cache
|
x = x + y
|
||||||
|
|
||||||
|
y = self.ln2(x)
|
||||||
|
y, _ = self.cross_attention(y, memory, memory, memory_mask)
|
||||||
|
x = x + y
|
||||||
|
|
||||||
|
y = self.ln3(x)
|
||||||
|
y = self.dense(y)
|
||||||
|
x = x + y
|
||||||
|
|
||||||
|
return x, cache
|
||||||
|
|
||||||
|
|
||||||
class Llama(nn.Module):
|
def create_additive_causal_mask(N: int, offset: int = 0):
|
||||||
def __init__(self, config: LlamaConfig):
|
rinds = mx.arange(offset + N)
|
||||||
|
linds = mx.arange(offset, offset + N) if offset else rinds
|
||||||
|
mask = linds[:, None] < rinds[None]
|
||||||
|
return mask * -1e9
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerDecoder(nn.Module):
|
||||||
|
def __init__(self, config: T5Config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
n_layers = getattr(config, "num_decoder_layers", config.num_layers)
|
||||||
self.vocab_size = config.vocab_size
|
self.layers = [TransformerDecoderLayer(config) for i in range(n_layers)]
|
||||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
self.layers = [
|
self.relative_attention_bias = RelativePositionBias(config, bidirectional=False)
|
||||||
TransformerBlock(config=config) for _ in range(config.num_hidden_layers)
|
|
||||||
]
|
def __call__(self, x, memory, cache=None):
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
if cache[0] is not None:
|
||||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
offset = cache[0][0].shape[2]
|
||||||
|
else:
|
||||||
|
offset = 0
|
||||||
|
|
||||||
|
T = x.shape[1]
|
||||||
|
if T > 1:
|
||||||
|
mask = create_additive_causal_mask(T, offset)
|
||||||
|
else:
|
||||||
|
mask = None
|
||||||
|
|
||||||
|
pos_bias = self.relative_attention_bias(T + offset, T + offset, offset=offset)
|
||||||
|
if mask is not None:
|
||||||
|
mask += pos_bias
|
||||||
|
else:
|
||||||
|
mask = pos_bias
|
||||||
|
|
||||||
|
for e, layer in enumerate(self.layers):
|
||||||
|
x, cache[e] = layer(x, memory, mask, None, cache=cache[e])
|
||||||
|
x = self.ln(x)
|
||||||
|
|
||||||
|
return x, cache
|
||||||
|
|
||||||
|
|
||||||
|
class OutputHead(nn.Module):
|
||||||
|
def __init__(self, config: T5Config):
|
||||||
|
self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
||||||
|
|
||||||
|
def __call__(self, inputs):
|
||||||
|
return self.linear(inputs)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, config: T5Config):
|
||||||
|
self.wte = nn.Embedding(config.vocab_size, config.d_model)
|
||||||
|
self.encoder = TransformerEncoder(config)
|
||||||
|
self.decoder = TransformerDecoder(config)
|
||||||
|
self.tie_word_embeddings = config.tie_word_embeddings
|
||||||
|
if not self.tie_word_embeddings:
|
||||||
|
self.lm_head = OutputHead(config)
|
||||||
|
self.model_dim = config.d_model
|
||||||
self.reset_cache()
|
self.reset_cache()
|
||||||
|
|
||||||
|
def encode(self, inputs: mx.array):
|
||||||
|
return self.encoder(self.wte(inputs))
|
||||||
|
|
||||||
def truncate_cache(self, num_to_truncate):
|
def truncate_cache(self, num_to_truncate):
|
||||||
if num_to_truncate <= 0:
|
if num_to_truncate <= 0:
|
||||||
return
|
return
|
||||||
cache_length = self.kv_cache[0][0].shape[2]
|
cache_length = self.cache[0][0].shape[2]
|
||||||
if num_to_truncate < cache_length:
|
if num_to_truncate < cache_length:
|
||||||
self.kv_cache = tree_map(
|
self.cache = tree_map(lambda x: x[:, :, :-num_to_truncate, :], self.cache)
|
||||||
lambda x: x[:, :, :-num_to_truncate, :], self.kv_cache
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.reset_cache()
|
self.reset_cache()
|
||||||
|
|
||||||
def reset_cache(self):
|
def reset_cache(self):
|
||||||
self.kv_cache = [None] * len(self.layers)
|
self.cache = [None] * len(self.decoder.layers)
|
||||||
|
|
||||||
|
def decode(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
memory: mx.array,
|
||||||
|
):
|
||||||
|
inputs = self.wte(inputs)
|
||||||
|
y, self.cache = self.decoder(inputs, memory=memory, cache=self.cache)
|
||||||
|
if not self.tie_word_embeddings:
|
||||||
|
y *= self.model_dim**-0.5
|
||||||
|
y = self.lm_head(y)
|
||||||
|
else:
|
||||||
|
y = y @ self.wte.weight.T
|
||||||
|
return y
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
x: mx.array,
|
inputs: mx.array,
|
||||||
next_tokens: int = -1,
|
decoder_inputs: mx.array,
|
||||||
):
|
):
|
||||||
if self.kv_cache[0]:
|
return self.decode(decoder_inputs, self.encode(inputs))[0]
|
||||||
offset = self.kv_cache[0][0].shape[-2]
|
|
||||||
else:
|
|
||||||
offset = 0
|
|
||||||
|
|
||||||
if x.shape[1] > 1:
|
|
||||||
mask = create_additive_causal_mask(x.shape[1], offset)
|
|
||||||
mask = mask.astype(self.embed_tokens.weight.dtype)
|
|
||||||
else:
|
|
||||||
mask = None
|
|
||||||
|
|
||||||
x = self.embed_tokens(x)
|
|
||||||
for idx, layer in enumerate(self.layers):
|
|
||||||
x, self.kv_cache[idx] = layer(x, mask, cache=self.kv_cache[idx])
|
|
||||||
|
|
||||||
if next_tokens > 0:
|
|
||||||
x = x[:, -next_tokens:]
|
|
||||||
|
|
||||||
x = self.norm(x)
|
|
||||||
return self.lm_head(x)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_hugging_face(cls, model_path: str, quantized: bool = True):
|
|
||||||
config = LlamaConfig.from_pretrained(model_path)
|
|
||||||
torch_weights = AutoModelForCausalLM.from_pretrained(model_path).state_dict()
|
|
||||||
weights = {
|
|
||||||
k.replace("model.", ""): mx.array(v.numpy(), mx.float16)
|
|
||||||
for k, v in torch_weights.items()
|
|
||||||
}
|
|
||||||
model = cls(config)
|
|
||||||
model.update(tree_unflatten(list(weights.items())))
|
|
||||||
# if quantization is not None:
|
|
||||||
# nn.QuantizedLinear.quantize_module(model, **quantization)
|
|
||||||
mx.eval(model.parameters())
|
|
||||||
return model
|
|
||||||
|
@@ -1,3 +1,4 @@
|
|||||||
mlx>=0.0.5
|
mlx>=0.0.6
|
||||||
transformers
|
transformers
|
||||||
numpy
|
numpy
|
||||||
|
accelerate
|
||||||
|
1
t5/t5.py
1
t5/t5.py
@@ -125,7 +125,6 @@ class MultiHeadAttention(nn.Module):
|
|||||||
values = mx.concatenate([value_cache, values], axis=2)
|
values = mx.concatenate([value_cache, values], axis=2)
|
||||||
|
|
||||||
# Dimensions are [batch x num heads x sequence x hidden dim]
|
# Dimensions are [batch x num heads x sequence x hidden dim]
|
||||||
queries = queries
|
|
||||||
scores = queries @ keys
|
scores = queries @ keys
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
scores = scores + mask.astype(scores.dtype)
|
scores = scores + mask.astype(scores.dtype)
|
||||||
|
Reference in New Issue
Block a user