This commit is contained in:
Awni Hannun 2024-12-13 18:44:56 -08:00
parent f01bc5fb17
commit e34ecb79b4

View File

@ -260,6 +260,8 @@ def generate_step(
kv_bits=kv_bits, kv_bits=kv_bits,
) )
sampler = sampler or (lambda x: mx.argmax(x, axis=-1))
def _step(y): def _step(y):
with mx.stream(generation_stream): with mx.stream(generation_stream):
logits = model(y[None], cache=prompt_cache) logits = model(y[None], cache=prompt_cache)