mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
dtype fix
This commit is contained in:
parent
30e98c85c1
commit
089480878f
@ -525,7 +525,7 @@ def batch_generate(
|
|||||||
tokenizer._tokenizer.pad_token_id = tokenizer.eos_token_id
|
tokenizer._tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
res = tokenizer._tokenizer(prompts, padding=True)
|
res = tokenizer._tokenizer(prompts, padding=True)
|
||||||
input_ids, token_mask = mx.array(res["input_ids"]), mx.array(res["attention_mask"])
|
input_ids, token_mask = mx.array(res["input_ids"]), mx.array(res["attention_mask"])
|
||||||
dtype = None
|
dtype = mx.float32
|
||||||
for module in model.modules():
|
for module in model.modules():
|
||||||
if isinstance(module, nn.QuantizedEmbedding) or isinstance(module, nn.Embedding):
|
if isinstance(module, nn.QuantizedEmbedding) or isinstance(module, nn.Embedding):
|
||||||
dtype = module(mx.zeros(1, dtype=input_ids.dtype)).dtype
|
dtype = module(mx.zeros(1, dtype=input_ids.dtype)).dtype
|
||||||
|
Loading…
Reference in New Issue
Block a user