From 089480878fabbdd2215816bf3dc303602772b48d Mon Sep 17 00:00:00 2001 From: L Lllvvuu Date: Fri, 27 Dec 2024 16:01:52 -0800 Subject: [PATCH] dtype fix --- llms/mlx_lm/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 0925e469..9e3d3778 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -525,7 +525,7 @@ def batch_generate( tokenizer._tokenizer.pad_token_id = tokenizer.eos_token_id res = tokenizer._tokenizer(prompts, padding=True) input_ids, token_mask = mx.array(res["input_ids"]), mx.array(res["attention_mask"]) - dtype = None + dtype = mx.float32 for module in model.modules(): if isinstance(module, nn.QuantizedEmbedding) or isinstance(module, nn.Embedding): dtype = module(mx.zeros(1, dtype=input_ids.dtype)).dtype