mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:17:07 +08:00
format
This commit is contained in:
parent
089480878f
commit
2541f13907
@ -42,7 +42,7 @@ def create_causal_mask(
|
|||||||
mask = mask | (rinds >= lengths)
|
mask = mask | (rinds >= lengths)
|
||||||
# HACK: sometimes see NaN logprobs if no divide by 2 here
|
# HACK: sometimes see NaN logprobs if no divide by 2 here
|
||||||
# return mask * (mx.finfo(dtype).min / 2)
|
# return mask * (mx.finfo(dtype).min / 2)
|
||||||
return mask.astype(dtype) * (-65504. / 2)
|
return mask.astype(dtype) * (-65504.0 / 2)
|
||||||
|
|
||||||
|
|
||||||
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
|
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
|
||||||
|
@ -329,9 +329,11 @@ def generate_step(
|
|||||||
model(
|
model(
|
||||||
y[:, :prefill_step_size],
|
y[:, :prefill_step_size],
|
||||||
cache=prompt_cache,
|
cache=prompt_cache,
|
||||||
mask=mask[:, :, :prefill_step_size, : offset + prefill_step_size]
|
mask=(
|
||||||
if mask is not None
|
mask[:, :, :prefill_step_size, : offset + prefill_step_size]
|
||||||
else None,
|
if mask is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
maybe_quantize_kv_cache(
|
maybe_quantize_kv_cache(
|
||||||
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
|
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
|
||||||
@ -509,7 +511,7 @@ def batch_generate(
|
|||||||
kwargs: The remaining options get passed to :func:`generate_step`.
|
kwargs: The remaining options get passed to :func:`generate_step`.
|
||||||
See :func:`generate_step` for more details.
|
See :func:`generate_step` for more details.
|
||||||
"""
|
"""
|
||||||
if 'prompt_cache' in kwargs:
|
if "prompt_cache" in kwargs:
|
||||||
# TODO: Handle `prompt_cache` and `prompt` both left-padded, so that
|
# TODO: Handle `prompt_cache` and `prompt` both left-padded, so that
|
||||||
# we have <pad>text<pad>text. Should involve taking `prompt_cache_lens`
|
# we have <pad>text<pad>text. Should involve taking `prompt_cache_lens`
|
||||||
# to extend `mask` below, and handling position_ids (see TODO below)
|
# to extend `mask` below, and handling position_ids (see TODO below)
|
||||||
@ -527,13 +529,15 @@ def batch_generate(
|
|||||||
input_ids, token_mask = mx.array(res["input_ids"]), mx.array(res["attention_mask"])
|
input_ids, token_mask = mx.array(res["input_ids"]), mx.array(res["attention_mask"])
|
||||||
dtype = mx.float32
|
dtype = mx.float32
|
||||||
for module in model.modules():
|
for module in model.modules():
|
||||||
if isinstance(module, nn.QuantizedEmbedding) or isinstance(module, nn.Embedding):
|
if isinstance(module, nn.QuantizedEmbedding) or isinstance(
|
||||||
|
module, nn.Embedding
|
||||||
|
):
|
||||||
dtype = module(mx.zeros(1, dtype=input_ids.dtype)).dtype
|
dtype = module(mx.zeros(1, dtype=input_ids.dtype)).dtype
|
||||||
break
|
break
|
||||||
causal_mask = create_causal_mask(token_mask.shape[-1], dtype=dtype)
|
causal_mask = create_causal_mask(token_mask.shape[-1], dtype=dtype)
|
||||||
# HACK: sometimes see NaN logprobs if no divide by 2 here
|
# HACK: sometimes see NaN logprobs if no divide by 2 here
|
||||||
# mask = mx.where(token_mask[:, None, None, :], causal_mask, mx.finfo(dtype).min / 2)
|
# mask = mx.where(token_mask[:, None, None, :], causal_mask, mx.finfo(dtype).min / 2)
|
||||||
mask = mx.where(token_mask[:, None, None, :], causal_mask, -65504. / 2)
|
mask = mx.where(token_mask[:, None, None, :], causal_mask, -65504.0 / 2)
|
||||||
|
|
||||||
output_toks = []
|
output_toks = []
|
||||||
prompt_time = None
|
prompt_time = None
|
||||||
|
@ -61,11 +61,13 @@ class TestGenerate(unittest.TestCase):
|
|||||||
],
|
],
|
||||||
max_tokens=5,
|
max_tokens=5,
|
||||||
prefill_step_size=4,
|
prefill_step_size=4,
|
||||||
sampler=make_sampler(temp=1., min_p=0.5),
|
sampler=make_sampler(temp=1.0, min_p=0.5),
|
||||||
logits_processors=make_logits_processors(logit_bias, repetition_penalty=2.0),
|
logits_processors=make_logits_processors(
|
||||||
|
logit_bias, repetition_penalty=2.0
|
||||||
|
),
|
||||||
verbose=False,
|
verbose=False,
|
||||||
)
|
)
|
||||||
self.assertEqual(texts, ['!', '!'])
|
self.assertEqual(texts, ["!", "!"])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -208,7 +208,9 @@ class TestPromptCache(unittest.TestCase):
|
|||||||
prompt_cache = make_prompt_cache(model)
|
prompt_cache = make_prompt_cache(model)
|
||||||
|
|
||||||
# Generate one token so we process the full prompt
|
# Generate one token so we process the full prompt
|
||||||
last_tok, _ = next(generate_step(prompt[None], model, prompt_cache=prompt_cache))
|
last_tok, _ = next(
|
||||||
|
generate_step(prompt[None], model, prompt_cache=prompt_cache)
|
||||||
|
)
|
||||||
|
|
||||||
# Generate two more tokens
|
# Generate two more tokens
|
||||||
results = zip(
|
results = zip(
|
||||||
|
Loading…
Reference in New Issue
Block a user