This commit is contained in:
L Lllvvuu 2024-12-27 16:21:20 -08:00
parent 089480878f
commit 2541f13907
No known key found for this signature in database
GPG Key ID: CFAD5A25056DDD0F
4 changed files with 19 additions and 11 deletions

View File

@ -42,7 +42,7 @@ def create_causal_mask(
mask = mask | (rinds >= lengths)
# HACK: sometimes see NaN logprobs if no divide by 2 here
# return mask * (mx.finfo(dtype).min / 2)
return mask.astype(dtype) * (-65504. / 2)
return mask.astype(dtype) * (-65504.0 / 2)
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):

View File

@ -329,9 +329,11 @@ def generate_step(
model(
y[:, :prefill_step_size],
cache=prompt_cache,
mask=mask[:, :, :prefill_step_size, : offset + prefill_step_size]
if mask is not None
else None,
mask=(
mask[:, :, :prefill_step_size, : offset + prefill_step_size]
if mask is not None
else None
),
)
maybe_quantize_kv_cache(
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
@ -509,7 +511,7 @@ def batch_generate(
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.
"""
if 'prompt_cache' in kwargs:
if "prompt_cache" in kwargs:
# TODO: Handle `prompt_cache` and `prompt` both left-padded, so that
# we have <pad>text<pad>text. Should involve taking `prompt_cache_lens`
# to extend `mask` below, and handling position_ids (see TODO below)
@ -527,13 +529,15 @@ def batch_generate(
input_ids, token_mask = mx.array(res["input_ids"]), mx.array(res["attention_mask"])
dtype = mx.float32
for module in model.modules():
if isinstance(module, nn.QuantizedEmbedding) or isinstance(module, nn.Embedding):
if isinstance(module, nn.QuantizedEmbedding) or isinstance(
module, nn.Embedding
):
dtype = module(mx.zeros(1, dtype=input_ids.dtype)).dtype
break
causal_mask = create_causal_mask(token_mask.shape[-1], dtype=dtype)
# HACK: sometimes see NaN logprobs if no divide by 2 here
# mask = mx.where(token_mask[:, None, None, :], causal_mask, mx.finfo(dtype).min / 2)
mask = mx.where(token_mask[:, None, None, :], causal_mask, -65504. / 2)
mask = mx.where(token_mask[:, None, None, :], causal_mask, -65504.0 / 2)
output_toks = []
prompt_time = None

View File

@ -61,11 +61,13 @@ class TestGenerate(unittest.TestCase):
],
max_tokens=5,
prefill_step_size=4,
sampler=make_sampler(temp=1., min_p=0.5),
logits_processors=make_logits_processors(logit_bias, repetition_penalty=2.0),
sampler=make_sampler(temp=1.0, min_p=0.5),
logits_processors=make_logits_processors(
logit_bias, repetition_penalty=2.0
),
verbose=False,
)
self.assertEqual(texts, ['!', '!'])
self.assertEqual(texts, ["!", "!"])
if __name__ == "__main__":

View File

@ -208,7 +208,9 @@ class TestPromptCache(unittest.TestCase):
prompt_cache = make_prompt_cache(model)
# Generate one token so we process the full prompt
last_tok, _ = next(generate_step(prompt[None], model, prompt_cache=prompt_cache))
last_tok, _ = next(
generate_step(prompt[None], model, prompt_cache=prompt_cache)
)
# Generate two more tokens
results = zip(