mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
mask dtype
This commit is contained in:
parent
465eb79fff
commit
fdd16caf7a
@ -28,6 +28,7 @@ def create_causal_mask(
|
|||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
window_size: Optional[int] = None,
|
window_size: Optional[int] = None,
|
||||||
lengths: Optional[mx.array] = None,
|
lengths: Optional[mx.array] = None,
|
||||||
|
dtype: mx.Dtype = mx.float32,
|
||||||
):
|
):
|
||||||
rinds = mx.arange(offset + N)
|
rinds = mx.arange(offset + N)
|
||||||
linds = mx.arange(offset, offset + N) if offset else rinds
|
linds = mx.arange(offset, offset + N) if offset else rinds
|
||||||
@ -39,7 +40,9 @@ def create_causal_mask(
|
|||||||
if lengths is not None:
|
if lengths is not None:
|
||||||
lengths = lengths[:, None, None, None]
|
lengths = lengths[:, None, None, None]
|
||||||
mask = mask | (rinds >= lengths)
|
mask = mask | (rinds >= lengths)
|
||||||
return mask * -1e9
|
# HACK: sometimes see NaN logprobs if no divide by 2 here
|
||||||
|
# return mask * (mx.finfo(dtype).min / 2)
|
||||||
|
return mask.astype(dtype) * (-65504. / 2)
|
||||||
|
|
||||||
|
|
||||||
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
|
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
|
||||||
|
@ -524,8 +524,15 @@ def batch_generate(
|
|||||||
tokenizer._tokenizer.pad_token_id = tokenizer.eos_token_id
|
tokenizer._tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
res = tokenizer._tokenizer(prompts, padding=True)
|
res = tokenizer._tokenizer(prompts, padding=True)
|
||||||
input_ids, token_mask = mx.array(res["input_ids"]), mx.array(res["attention_mask"])
|
input_ids, token_mask = mx.array(res["input_ids"]), mx.array(res["attention_mask"])
|
||||||
causal_mask = create_causal_mask(token_mask.shape[-1])
|
dtype = None
|
||||||
mask = mx.where(token_mask[:, None, None, :], causal_mask, -1e9)
|
for module in model.modules():
|
||||||
|
if isinstance(module, nn.QuantizedEmbedding) or isinstance(module, nn.Embedding):
|
||||||
|
dtype = module(mx.zeros(1, dtype=input_ids.dtype)).dtype
|
||||||
|
break
|
||||||
|
causal_mask = create_causal_mask(token_mask.shape[-1], dtype=dtype)
|
||||||
|
# HACK: sometimes see NaN logprobs if no divide by 2 here
|
||||||
|
# mask = mx.where(token_mask[:, None, None, :], causal_mask, mx.finfo(dtype).min / 2)
|
||||||
|
mask = mx.where(token_mask[:, None, None, :], causal_mask, -65504. / 2)
|
||||||
|
|
||||||
output_toks = []
|
output_toks = []
|
||||||
prompt_time = None
|
prompt_time = None
|
||||||
|
@ -61,7 +61,7 @@ class TestGenerate(unittest.TestCase):
|
|||||||
],
|
],
|
||||||
max_tokens=5,
|
max_tokens=5,
|
||||||
prefill_step_size=4,
|
prefill_step_size=4,
|
||||||
sampler=make_sampler(temp=1.0, min_p=0.1),
|
sampler=make_sampler(temp=1., min_p=0.5),
|
||||||
logits_processors=make_logits_processors(logit_bias, repetition_penalty=2.0),
|
logits_processors=make_logits_processors(logit_bias, repetition_penalty=2.0),
|
||||||
verbose=False,
|
verbose=False,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user