mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 07:30:06 +08:00
format
This commit is contained in:
parent
089480878f
commit
2541f13907
@ -42,7 +42,7 @@ def create_causal_mask(
|
||||
mask = mask | (rinds >= lengths)
|
||||
# HACK: sometimes see NaN logprobs if no divide by 2 here
|
||||
# return mask * (mx.finfo(dtype).min / 2)
|
||||
return mask.astype(dtype) * (-65504. / 2)
|
||||
return mask.astype(dtype) * (-65504.0 / 2)
|
||||
|
||||
|
||||
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
|
||||
|
@ -329,9 +329,11 @@ def generate_step(
|
||||
model(
|
||||
y[:, :prefill_step_size],
|
||||
cache=prompt_cache,
|
||||
mask=mask[:, :, :prefill_step_size, : offset + prefill_step_size]
|
||||
if mask is not None
|
||||
else None,
|
||||
mask=(
|
||||
mask[:, :, :prefill_step_size, : offset + prefill_step_size]
|
||||
if mask is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
maybe_quantize_kv_cache(
|
||||
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
|
||||
@ -509,7 +511,7 @@ def batch_generate(
|
||||
kwargs: The remaining options get passed to :func:`generate_step`.
|
||||
See :func:`generate_step` for more details.
|
||||
"""
|
||||
if 'prompt_cache' in kwargs:
|
||||
if "prompt_cache" in kwargs:
|
||||
# TODO: Handle `prompt_cache` and `prompt` both left-padded, so that
|
||||
# we have <pad>text<pad>text. Should involve taking `prompt_cache_lens`
|
||||
# to extend `mask` below, and handling position_ids (see TODO below)
|
||||
@ -527,13 +529,15 @@ def batch_generate(
|
||||
input_ids, token_mask = mx.array(res["input_ids"]), mx.array(res["attention_mask"])
|
||||
dtype = mx.float32
|
||||
for module in model.modules():
|
||||
if isinstance(module, nn.QuantizedEmbedding) or isinstance(module, nn.Embedding):
|
||||
if isinstance(module, nn.QuantizedEmbedding) or isinstance(
|
||||
module, nn.Embedding
|
||||
):
|
||||
dtype = module(mx.zeros(1, dtype=input_ids.dtype)).dtype
|
||||
break
|
||||
causal_mask = create_causal_mask(token_mask.shape[-1], dtype=dtype)
|
||||
# HACK: sometimes see NaN logprobs if no divide by 2 here
|
||||
# mask = mx.where(token_mask[:, None, None, :], causal_mask, mx.finfo(dtype).min / 2)
|
||||
mask = mx.where(token_mask[:, None, None, :], causal_mask, -65504. / 2)
|
||||
mask = mx.where(token_mask[:, None, None, :], causal_mask, -65504.0 / 2)
|
||||
|
||||
output_toks = []
|
||||
prompt_time = None
|
||||
|
@ -61,11 +61,13 @@ class TestGenerate(unittest.TestCase):
|
||||
],
|
||||
max_tokens=5,
|
||||
prefill_step_size=4,
|
||||
sampler=make_sampler(temp=1., min_p=0.5),
|
||||
logits_processors=make_logits_processors(logit_bias, repetition_penalty=2.0),
|
||||
sampler=make_sampler(temp=1.0, min_p=0.5),
|
||||
logits_processors=make_logits_processors(
|
||||
logit_bias, repetition_penalty=2.0
|
||||
),
|
||||
verbose=False,
|
||||
)
|
||||
self.assertEqual(texts, ['!', '!'])
|
||||
self.assertEqual(texts, ["!", "!"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -208,7 +208,9 @@ class TestPromptCache(unittest.TestCase):
|
||||
prompt_cache = make_prompt_cache(model)
|
||||
|
||||
# Generate one token so we process the full prompt
|
||||
last_tok, _ = next(generate_step(prompt[None], model, prompt_cache=prompt_cache))
|
||||
last_tok, _ = next(
|
||||
generate_step(prompt[None], model, prompt_cache=prompt_cache)
|
||||
)
|
||||
|
||||
# Generate two more tokens
|
||||
results = zip(
|
||||
|
Loading…
Reference in New Issue
Block a user