From 2541f13907c4908335baf05bcdda5b2a04f94105 Mon Sep 17 00:00:00 2001 From: L Lllvvuu Date: Fri, 27 Dec 2024 16:21:20 -0800 Subject: [PATCH] format --- llms/mlx_lm/models/base.py | 2 +- llms/mlx_lm/utils.py | 16 ++++++++++------ llms/tests/test_generate.py | 8 +++++--- llms/tests/test_prompt_cache.py | 4 +++- 4 files changed, 19 insertions(+), 11 deletions(-) diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index 3d402aa1..568b85ab 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -42,7 +42,7 @@ def create_causal_mask( mask = mask | (rinds >= lengths) # HACK: sometimes see NaN logprobs if no divide by 2 here # return mask * (mx.finfo(dtype).min / 2) - return mask.astype(dtype) * (-65504. / 2) + return mask.astype(dtype) * (-65504.0 / 2) def create_attention_mask(h: mx.array, cache: Optional[Any] = None): diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 9e3d3778..714609f5 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -329,9 +329,11 @@ def generate_step( model( y[:, :prefill_step_size], cache=prompt_cache, - mask=mask[:, :, :prefill_step_size, : offset + prefill_step_size] - if mask is not None - else None, + mask=( + mask[:, :, :prefill_step_size, : offset + prefill_step_size] + if mask is not None + else None + ), ) maybe_quantize_kv_cache( prompt_cache, quantized_kv_start, kv_group_size, kv_bits @@ -509,7 +511,7 @@ def batch_generate( kwargs: The remaining options get passed to :func:`generate_step`. See :func:`generate_step` for more details. """ - if 'prompt_cache' in kwargs: + if "prompt_cache" in kwargs: # TODO: Handle `prompt_cache` and `prompt` both left-padded, so that # we have texttext. Should involve taking `prompt_cache_lens` # to extend `mask` below, and handling position_ids (see TODO below) @@ -527,13 +529,15 @@ def batch_generate( input_ids, token_mask = mx.array(res["input_ids"]), mx.array(res["attention_mask"]) dtype = mx.float32 for module in model.modules(): - if isinstance(module, nn.QuantizedEmbedding) or isinstance(module, nn.Embedding): + if isinstance(module, nn.QuantizedEmbedding) or isinstance( + module, nn.Embedding + ): dtype = module(mx.zeros(1, dtype=input_ids.dtype)).dtype break causal_mask = create_causal_mask(token_mask.shape[-1], dtype=dtype) # HACK: sometimes see NaN logprobs if no divide by 2 here # mask = mx.where(token_mask[:, None, None, :], causal_mask, mx.finfo(dtype).min / 2) - mask = mx.where(token_mask[:, None, None, :], causal_mask, -65504. / 2) + mask = mx.where(token_mask[:, None, None, :], causal_mask, -65504.0 / 2) output_toks = [] prompt_time = None diff --git a/llms/tests/test_generate.py b/llms/tests/test_generate.py index 41f9704f..bcdb0d9f 100644 --- a/llms/tests/test_generate.py +++ b/llms/tests/test_generate.py @@ -61,11 +61,13 @@ class TestGenerate(unittest.TestCase): ], max_tokens=5, prefill_step_size=4, - sampler=make_sampler(temp=1., min_p=0.5), - logits_processors=make_logits_processors(logit_bias, repetition_penalty=2.0), + sampler=make_sampler(temp=1.0, min_p=0.5), + logits_processors=make_logits_processors( + logit_bias, repetition_penalty=2.0 + ), verbose=False, ) - self.assertEqual(texts, ['!', '!']) + self.assertEqual(texts, ["!", "!"]) if __name__ == "__main__": diff --git a/llms/tests/test_prompt_cache.py b/llms/tests/test_prompt_cache.py index 6acab5a7..5fcc6834 100644 --- a/llms/tests/test_prompt_cache.py +++ b/llms/tests/test_prompt_cache.py @@ -208,7 +208,9 @@ class TestPromptCache(unittest.TestCase): prompt_cache = make_prompt_cache(model) # Generate one token so we process the full prompt - last_tok, _ = next(generate_step(prompt[None], model, prompt_cache=prompt_cache)) + last_tok, _ = next( + generate_step(prompt[None], model, prompt_cache=prompt_cache) + ) # Generate two more tokens results = zip(