From fdd16caf7a021360ded434fa65060e4691d79441 Mon Sep 17 00:00:00 2001 From: L Lllvvuu Date: Fri, 27 Dec 2024 15:26:43 -0800 Subject: [PATCH] mask dtype --- llms/mlx_lm/models/base.py | 5 ++++- llms/mlx_lm/utils.py | 11 +++++++++-- llms/tests/test_generate.py | 2 +- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index ad7a4a65..3d402aa1 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -28,6 +28,7 @@ def create_causal_mask( offset: int = 0, window_size: Optional[int] = None, lengths: Optional[mx.array] = None, + dtype: mx.Dtype = mx.float32, ): rinds = mx.arange(offset + N) linds = mx.arange(offset, offset + N) if offset else rinds @@ -39,7 +40,9 @@ def create_causal_mask( if lengths is not None: lengths = lengths[:, None, None, None] mask = mask | (rinds >= lengths) - return mask * -1e9 + # HACK: sometimes see NaN logprobs if no divide by 2 here + # return mask * (mx.finfo(dtype).min / 2) + return mask.astype(dtype) * (-65504. / 2) def create_attention_mask(h: mx.array, cache: Optional[Any] = None): diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index f28fd830..185a0698 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -524,8 +524,15 @@ def batch_generate( tokenizer._tokenizer.pad_token_id = tokenizer.eos_token_id res = tokenizer._tokenizer(prompts, padding=True) input_ids, token_mask = mx.array(res["input_ids"]), mx.array(res["attention_mask"]) - causal_mask = create_causal_mask(token_mask.shape[-1]) - mask = mx.where(token_mask[:, None, None, :], causal_mask, -1e9) + dtype = None + for module in model.modules(): + if isinstance(module, nn.QuantizedEmbedding) or isinstance(module, nn.Embedding): + dtype = module(mx.zeros(1, dtype=input_ids.dtype)).dtype + break + causal_mask = create_causal_mask(token_mask.shape[-1], dtype=dtype) + # HACK: sometimes see NaN logprobs if no divide by 2 here + # mask = mx.where(token_mask[:, None, None, :], causal_mask, mx.finfo(dtype).min / 2) + mask = mx.where(token_mask[:, None, None, :], causal_mask, -65504. / 2) output_toks = [] prompt_time = None diff --git a/llms/tests/test_generate.py b/llms/tests/test_generate.py index 14fa75e9..41f9704f 100644 --- a/llms/tests/test_generate.py +++ b/llms/tests/test_generate.py @@ -61,7 +61,7 @@ class TestGenerate(unittest.TestCase): ], max_tokens=5, prefill_step_size=4, - sampler=make_sampler(temp=1.0, min_p=0.1), + sampler=make_sampler(temp=1., min_p=0.5), logits_processors=make_logits_processors(logit_bias, repetition_penalty=2.0), verbose=False, )