This commit is contained in:
L Lllvvuu 2024-12-27 16:21:20 -08:00
parent 089480878f
commit 2541f13907
No known key found for this signature in database
GPG Key ID: CFAD5A25056DDD0F
4 changed files with 19 additions and 11 deletions

View File

@ -42,7 +42,7 @@ def create_causal_mask(
mask = mask | (rinds >= lengths) mask = mask | (rinds >= lengths)
# HACK: sometimes see NaN logprobs if no divide by 2 here # HACK: sometimes see NaN logprobs if no divide by 2 here
# return mask * (mx.finfo(dtype).min / 2) # return mask * (mx.finfo(dtype).min / 2)
return mask.astype(dtype) * (-65504. / 2) return mask.astype(dtype) * (-65504.0 / 2)
def create_attention_mask(h: mx.array, cache: Optional[Any] = None): def create_attention_mask(h: mx.array, cache: Optional[Any] = None):

View File

@ -329,9 +329,11 @@ def generate_step(
model( model(
y[:, :prefill_step_size], y[:, :prefill_step_size],
cache=prompt_cache, cache=prompt_cache,
mask=mask[:, :, :prefill_step_size, : offset + prefill_step_size] mask=(
if mask is not None mask[:, :, :prefill_step_size, : offset + prefill_step_size]
else None, if mask is not None
else None
),
) )
maybe_quantize_kv_cache( maybe_quantize_kv_cache(
prompt_cache, quantized_kv_start, kv_group_size, kv_bits prompt_cache, quantized_kv_start, kv_group_size, kv_bits
@ -509,7 +511,7 @@ def batch_generate(
kwargs: The remaining options get passed to :func:`generate_step`. kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details. See :func:`generate_step` for more details.
""" """
if 'prompt_cache' in kwargs: if "prompt_cache" in kwargs:
# TODO: Handle `prompt_cache` and `prompt` both left-padded, so that # TODO: Handle `prompt_cache` and `prompt` both left-padded, so that
# we have <pad>text<pad>text. Should involve taking `prompt_cache_lens` # we have <pad>text<pad>text. Should involve taking `prompt_cache_lens`
# to extend `mask` below, and handling position_ids (see TODO below) # to extend `mask` below, and handling position_ids (see TODO below)
@ -527,13 +529,15 @@ def batch_generate(
input_ids, token_mask = mx.array(res["input_ids"]), mx.array(res["attention_mask"]) input_ids, token_mask = mx.array(res["input_ids"]), mx.array(res["attention_mask"])
dtype = mx.float32 dtype = mx.float32
for module in model.modules(): for module in model.modules():
if isinstance(module, nn.QuantizedEmbedding) or isinstance(module, nn.Embedding): if isinstance(module, nn.QuantizedEmbedding) or isinstance(
module, nn.Embedding
):
dtype = module(mx.zeros(1, dtype=input_ids.dtype)).dtype dtype = module(mx.zeros(1, dtype=input_ids.dtype)).dtype
break break
causal_mask = create_causal_mask(token_mask.shape[-1], dtype=dtype) causal_mask = create_causal_mask(token_mask.shape[-1], dtype=dtype)
# HACK: sometimes see NaN logprobs if no divide by 2 here # HACK: sometimes see NaN logprobs if no divide by 2 here
# mask = mx.where(token_mask[:, None, None, :], causal_mask, mx.finfo(dtype).min / 2) # mask = mx.where(token_mask[:, None, None, :], causal_mask, mx.finfo(dtype).min / 2)
mask = mx.where(token_mask[:, None, None, :], causal_mask, -65504. / 2) mask = mx.where(token_mask[:, None, None, :], causal_mask, -65504.0 / 2)
output_toks = [] output_toks = []
prompt_time = None prompt_time = None

View File

@ -61,11 +61,13 @@ class TestGenerate(unittest.TestCase):
], ],
max_tokens=5, max_tokens=5,
prefill_step_size=4, prefill_step_size=4,
sampler=make_sampler(temp=1., min_p=0.5), sampler=make_sampler(temp=1.0, min_p=0.5),
logits_processors=make_logits_processors(logit_bias, repetition_penalty=2.0), logits_processors=make_logits_processors(
logit_bias, repetition_penalty=2.0
),
verbose=False, verbose=False,
) )
self.assertEqual(texts, ['!', '!']) self.assertEqual(texts, ["!", "!"])
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -208,7 +208,9 @@ class TestPromptCache(unittest.TestCase):
prompt_cache = make_prompt_cache(model) prompt_cache = make_prompt_cache(model)
# Generate one token so we process the full prompt # Generate one token so we process the full prompt
last_tok, _ = next(generate_step(prompt[None], model, prompt_cache=prompt_cache)) last_tok, _ = next(
generate_step(prompt[None], model, prompt_cache=prompt_cache)
)
# Generate two more tokens # Generate two more tokens
results = zip( results = zip(