mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Put prompt processing in same stream (#1122)
* put prompt processing in same stream * patch
This commit is contained in:
parent
a5e173802e
commit
cfc29c29f4
@ -1,3 +1,3 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
__version__ = "0.20.0"
|
__version__ = "0.20.1"
|
||||||
|
@ -274,13 +274,14 @@ def generate_step(
|
|||||||
y = sampler(logprobs)
|
y = sampler(logprobs)
|
||||||
return y, logprobs.squeeze(0)
|
return y, logprobs.squeeze(0)
|
||||||
|
|
||||||
while y.size > prefill_step_size:
|
with mx.stream(generation_stream):
|
||||||
model(y[:prefill_step_size][None], cache=prompt_cache)
|
while y.size > prefill_step_size:
|
||||||
mx.eval([c.state for c in prompt_cache])
|
model(y[:prefill_step_size][None], cache=prompt_cache)
|
||||||
y = y[prefill_step_size:]
|
mx.eval([c.state for c in prompt_cache])
|
||||||
mx.metal.clear_cache()
|
y = y[prefill_step_size:]
|
||||||
|
mx.metal.clear_cache()
|
||||||
|
|
||||||
y, logprobs = _step(y)
|
y, logprobs = _step(y)
|
||||||
|
|
||||||
mx.async_eval(y, logprobs)
|
mx.async_eval(y, logprobs)
|
||||||
n = 0
|
n = 0
|
||||||
|
Loading…
Reference in New Issue
Block a user