dtype fix

This commit is contained in:
L Lllvvuu 2024-12-27 16:01:52 -08:00
parent 30e98c85c1
commit 089480878f
No known key found for this signature in database
GPG Key ID: CFAD5A25056DDD0F

View File

@ -525,7 +525,7 @@ def batch_generate(
tokenizer._tokenizer.pad_token_id = tokenizer.eos_token_id
res = tokenizer._tokenizer(prompts, padding=True)
input_ids, token_mask = mx.array(res["input_ids"]), mx.array(res["attention_mask"])
dtype = None
dtype = mx.float32
for module in model.modules():
if isinstance(module, nn.QuantizedEmbedding) or isinstance(module, nn.Embedding):
dtype = module(mx.zeros(1, dtype=input_ids.dtype)).dtype