mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
update generate_step callsites
This commit is contained in:
parent
3082db0143
commit
a28ca03e04
@ -132,7 +132,7 @@ def main():
|
|||||||
prompt = args.prompt
|
prompt = args.prompt
|
||||||
|
|
||||||
cache = make_prompt_cache(model, args.max_kv_size)
|
cache = make_prompt_cache(model, args.max_kv_size)
|
||||||
y = mx.array(tokenizer.encode(prompt))
|
y = mx.array(tokenizer.encode(prompt))[None]
|
||||||
|
|
||||||
# Process the prompt
|
# Process the prompt
|
||||||
start = time.time()
|
start = time.time()
|
||||||
|
@ -296,7 +296,7 @@ def generate_step(
|
|||||||
|
|
||||||
def _step(y):
|
def _step(y):
|
||||||
with mx.stream(generation_stream):
|
with mx.stream(generation_stream):
|
||||||
if y.ndims == 1:
|
if y.ndim == 1:
|
||||||
y = mx.expand_dims(y, axis=-1)
|
y = mx.expand_dims(y, axis=-1)
|
||||||
logits = model(
|
logits = model(
|
||||||
y,
|
y,
|
||||||
@ -390,12 +390,16 @@ def stream_generate(
|
|||||||
prompt if isinstance(prompt, list) else tokenizer.encode(prompt)
|
prompt if isinstance(prompt, list) else tokenizer.encode(prompt)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if prompt.ndim == 1:
|
||||||
|
prompt = prompt[None]
|
||||||
|
|
||||||
detokenizer = tokenizer.detokenizer
|
detokenizer = tokenizer.detokenizer
|
||||||
|
|
||||||
with wired_limit(model, [generation_stream]):
|
with wired_limit(model, [generation_stream]):
|
||||||
detokenizer.reset()
|
detokenizer.reset()
|
||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
for n, (token, logprobs) in enumerate(generate_step(prompt, model, **kwargs)):
|
for n, (token, logprobs) in enumerate(generate_step(prompt, model, **kwargs)):
|
||||||
|
token, logprobs = token.item(), logprobs.squeeze(0)
|
||||||
if n == 0:
|
if n == 0:
|
||||||
prompt_time = time.perf_counter() - tic
|
prompt_time = time.perf_counter() - tic
|
||||||
prompt_tps = prompt.size / prompt_time
|
prompt_tps = prompt.size / prompt_time
|
||||||
|
@ -121,21 +121,24 @@ class TestPromptCache(unittest.TestCase):
|
|||||||
def test_cache_with_generate(self):
|
def test_cache_with_generate(self):
|
||||||
model, tokenizer = load(HF_MODEL_PATH)
|
model, tokenizer = load(HF_MODEL_PATH)
|
||||||
prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
|
prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
|
||||||
results = list(generate_step(prompt, model, max_tokens=4))
|
results = list(generate_step(prompt[None], model, max_tokens=4))
|
||||||
|
results = [(t.item(), l.squeeze(0)) for t, l in results]
|
||||||
toks, all_logits = zip(*results)
|
toks, all_logits = zip(*results)
|
||||||
|
|
||||||
prompt_cache = make_prompt_cache(model)
|
prompt_cache = make_prompt_cache(model)
|
||||||
i = 0
|
i = 0
|
||||||
for tok, logits in generate_step(
|
for tok, logits in generate_step(
|
||||||
prompt, model, prompt_cache=prompt_cache, max_tokens=2
|
prompt[None], model, prompt_cache=prompt_cache, max_tokens=2
|
||||||
):
|
):
|
||||||
|
tok, logits = tok.item(), logits.squeeze(0)
|
||||||
self.assertEqual(tok, toks[i])
|
self.assertEqual(tok, toks[i])
|
||||||
self.assertTrue(mx.allclose(logits, all_logits[i]))
|
self.assertTrue(mx.allclose(logits, all_logits[i]))
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
for tok, logits in generate_step(
|
for tok, logits in generate_step(
|
||||||
mx.array([toks[i]]), model, prompt_cache=prompt_cache, max_tokens=1
|
mx.array([[toks[i]]]), model, prompt_cache=prompt_cache, max_tokens=1
|
||||||
):
|
):
|
||||||
|
tok, logits = tok.item(), logits.squeeze(0)
|
||||||
i += 1
|
i += 1
|
||||||
self.assertEqual(tok, toks[i])
|
self.assertEqual(tok, toks[i])
|
||||||
self.assertTrue(mx.allclose(logits, all_logits[i]))
|
self.assertTrue(mx.allclose(logits, all_logits[i]))
|
||||||
@ -205,14 +208,14 @@ class TestPromptCache(unittest.TestCase):
|
|||||||
prompt_cache = make_prompt_cache(model)
|
prompt_cache = make_prompt_cache(model)
|
||||||
|
|
||||||
# Generate one token so we process the full prompt
|
# Generate one token so we process the full prompt
|
||||||
last_tok, _ = next(generate_step(prompt, model, prompt_cache=prompt_cache))
|
last_tok, _ = next(generate_step(prompt[None], model, prompt_cache=prompt_cache))
|
||||||
last_tok = mx.array([last_tok])
|
|
||||||
|
|
||||||
# Generate two more tokens
|
# Generate two more tokens
|
||||||
results = zip(
|
results = zip(
|
||||||
range(2), generate_step(last_tok, model, prompt_cache=prompt_cache)
|
range(2), generate_step(last_tok[None], model, prompt_cache=prompt_cache)
|
||||||
)
|
)
|
||||||
toks, all_logits = zip(*(r[1] for r in results))
|
results = [(t.item(), l.squeeze(0)) for _, (t, l) in results]
|
||||||
|
toks, all_logits = zip(*results)
|
||||||
|
|
||||||
# To get back to the cache just after processing the prompt,
|
# To get back to the cache just after processing the prompt,
|
||||||
# trim by 3 tokens
|
# trim by 3 tokens
|
||||||
@ -220,9 +223,10 @@ class TestPromptCache(unittest.TestCase):
|
|||||||
|
|
||||||
# Generate the same thing again
|
# Generate the same thing again
|
||||||
results = zip(
|
results = zip(
|
||||||
range(2), generate_step(last_tok, model, prompt_cache=prompt_cache)
|
range(2), generate_step(last_tok[None], model, prompt_cache=prompt_cache)
|
||||||
)
|
)
|
||||||
second_toks, second_all_logits = zip(*(r[1] for r in results))
|
results = [(t.item(), l.squeeze(0)) for _, (t, l) in results]
|
||||||
|
second_toks, second_all_logits = zip(*results)
|
||||||
self.assertEqual(toks, second_toks)
|
self.assertEqual(toks, second_toks)
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
all(mx.allclose(l, l2) for l, l2 in zip(all_logits, second_all_logits))
|
all(mx.allclose(l, l2) for l, l2 in zip(all_logits, second_all_logits))
|
||||||
@ -278,14 +282,16 @@ class TestPromptCache(unittest.TestCase):
|
|||||||
def test_cache_to_quantized(self):
|
def test_cache_to_quantized(self):
|
||||||
model, tokenizer = load(HF_MODEL_PATH)
|
model, tokenizer = load(HF_MODEL_PATH)
|
||||||
prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
|
prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
|
||||||
results = zip(range(4), generate_step(prompt, model))
|
results = zip(range(4), generate_step(prompt[None], model))
|
||||||
toks, all_logits = zip(*(r[1] for r in results))
|
results = [(t.item(), l.squeeze(0)) for _, (t, l) in results]
|
||||||
|
toks, all_logits = zip(*results)
|
||||||
|
|
||||||
prompt_cache = make_prompt_cache(model)
|
prompt_cache = make_prompt_cache(model)
|
||||||
i = 0
|
i = 0
|
||||||
for _, (tok, logits) in zip(
|
for _, (tok, logits) in zip(
|
||||||
range(2), generate_step(prompt, model, prompt_cache=prompt_cache)
|
range(2), generate_step(prompt[None], model, prompt_cache=prompt_cache)
|
||||||
):
|
):
|
||||||
|
tok, logits = tok.item(), logits.squeeze(0)
|
||||||
self.assertEqual(tok, toks[i])
|
self.assertEqual(tok, toks[i])
|
||||||
self.assertTrue(mx.allclose(logits, all_logits[i]))
|
self.assertTrue(mx.allclose(logits, all_logits[i]))
|
||||||
i += 1
|
i += 1
|
||||||
@ -294,8 +300,9 @@ class TestPromptCache(unittest.TestCase):
|
|||||||
|
|
||||||
for _, (tok, logits) in zip(
|
for _, (tok, logits) in zip(
|
||||||
range(1),
|
range(1),
|
||||||
generate_step(mx.array([toks[i]]), model, prompt_cache=prompt_cache),
|
generate_step(mx.array([[toks[i]]]), model, prompt_cache=prompt_cache),
|
||||||
):
|
):
|
||||||
|
tok, logits = tok.item(), logits.squeeze(0)
|
||||||
i += 1
|
i += 1
|
||||||
self.assertEqual(tok, toks[i])
|
self.assertEqual(tok, toks[i])
|
||||||
self.assertTrue(mx.allclose(logits, all_logits[i], rtol=2e-2))
|
self.assertTrue(mx.allclose(logits, all_logits[i], rtol=2e-2))
|
||||||
|
Loading…
Reference in New Issue
Block a user