mlx-examples/llms/export/export.py
2025-01-08 17:07:02 -08:00

172 lines
4.6 KiB
Python

import time
from pathlib import Path
import fire
import mlx.core as mx
from mlx_lm import load
class ExportableCache:
def __init__(self, keys=None, values=None, offset=0):
self.offset = offset
self.keys = keys
self.values = values
def update_and_fetch(self, keys, values):
if self.keys is not None:
self.keys = mx.slice_update(self.keys, keys, self.offset, axes=(2,))
self.values = mx.slice_update(self.values, values, self.offset, axes=(2,))
else:
self.keys = keys
self.values = values
return self.keys, self.values
@property
def state(self):
return self.keys, self.values
def expand(cache, mask=None, cache_step_size=256):
cache_size = cache[0].shape[-2]
new_size = cache_step_size * ((cache_size + cache_step_size) // cache_step_size)
def expand_kv(x):
B, n_heads, _, head_dim = x.shape
new_x = mx.zeros((B, n_heads, new_size, head_dim), x.dtype)
new_x[..., : x.shape[2], :] = x
return new_x
cache = [expand_kv(c) for c in cache]
if mask is None:
mask = mx.full(new_size, False)
mask[:cache_size] = True
else:
mask = mx.concatenate([mask, mx.full(cache_step_size, False)])
return cache, mask
def causal_mask(N):
idx = mx.arange(N)
return idx[:, None] >= idx
def step(model, y, *state):
mask = state[-1]
if len(state) > 1:
cache, offset = state[:-2], state[-2]
cache = [
ExportableCache(keys, values, offset)
for keys, values in zip(cache[::2], cache[1::2])
]
else:
cache = [ExportableCache() for i in range(len(model.model.layers))]
logits = model(y, cache=cache, mask=mask)
cache = [y for x in cache for y in x.state]
return logits, *cache
def generate_step(prompt, model, max_tokens):
mx.eval(model)
compiled_step = mx.compile(lambda *args: step(model, *args), shapeless=True)
def _step(*args):
logits, *cache = compiled_step(*args)
return mx.argmax(logits[:, -1], axis=-1), *cache
y, *cache = _step(prompt, causal_mask(prompt.size))
mx.async_eval(y)
offset = mx.array(prompt.size, mx.uint32)
cache, mask = expand(cache)
n = 0
while True:
if n < max_tokens - 1:
if mask.size <= (prompt.size + n):
cache, mask = expand(cache, mask)
mask[prompt.size + n] = True
next_y, *cache = _step(y[None], *cache, offset, mask)
mx.async_eval(next_y)
offset += 1
n += 1
yield y.item()
if n == max_tokens:
break
y = next_y
def export(
model="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
path="llama3.1-instruct-4bit",
):
model, tokenizer = load(model)
mx.eval(model)
tokenizer.save_pretrained(path)
_step = lambda *args: step(model, *args)
# Make example inputs
y_prompt = mx.array([[0, 0]], mx.uint32)
y_gen = mx.array([[0]], mx.uint32)
offset = mx.array([0], mx.uint32)
mask = causal_mask(y_prompt.size)
_, *cache = _step(y_prompt, mask)
model_path = str(Path(path) / "model.mlxfn")
with mx.exporter(model_path, _step, shapeless=True) as exporter:
exporter(y_prompt, mask)
cache, mask = expand(cache)
exporter(y_gen, *cache, offset, mask)
def generate(
prompt,
model="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
max_tokens=128,
):
print("[INFO] Loading model from disk.")
model, tokenizer = load(model)
prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt=True,
return_tensors="mlx",
)
print("[INFO] Starting generation...")
tic = time.time()
tokens = []
detokenizer = tokenizer.detokenizer
detokenizer.reset()
for n, token in enumerate(generate_step(prompt, model, max_tokens)):
if n == 0:
prompt_tps = prompt.size / (time.time() - tic)
tic = time.time()
if token in tokenizer.eos_token_ids:
break
detokenizer.add_token(token)
print(detokenizer.last_segment, end="", flush=True)
detokenizer.finalize()
print(detokenizer.last_segment, flush=True)
gen_tps = (n + 1) / (time.time() - tic)
peak_memory = mx.metal.get_peak_memory() / 1e9
print("=" * 10)
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
print(f"Peak RAM: {peak_memory:.3f} GB")
if __name__ == "__main__":
fire.Fire(
{
"generate": generate,
"export": export,
}
)