mask dtype

This commit is contained in:
L Lllvvuu
2024-12-27 15:26:43 -08:00
parent 465eb79fff
commit fdd16caf7a
3 changed files with 14 additions and 4 deletions

View File

@@ -28,6 +28,7 @@ def create_causal_mask(
offset: int = 0,
window_size: Optional[int] = None,
lengths: Optional[mx.array] = None,
dtype: mx.Dtype = mx.float32,
):
rinds = mx.arange(offset + N)
linds = mx.arange(offset, offset + N) if offset else rinds
@@ -39,7 +40,9 @@ def create_causal_mask(
if lengths is not None:
lengths = lengths[:, None, None, None]
mask = mask | (rinds >= lengths)
return mask * -1e9
# HACK: sometimes see NaN logprobs if no divide by 2 here
# return mask * (mx.finfo(dtype).min / 2)
return mask.astype(dtype) * (-65504. / 2)
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):

View File

@@ -524,8 +524,15 @@ def batch_generate(
tokenizer._tokenizer.pad_token_id = tokenizer.eos_token_id
res = tokenizer._tokenizer(prompts, padding=True)
input_ids, token_mask = mx.array(res["input_ids"]), mx.array(res["attention_mask"])
causal_mask = create_causal_mask(token_mask.shape[-1])
mask = mx.where(token_mask[:, None, None, :], causal_mask, -1e9)
dtype = None
for module in model.modules():
if isinstance(module, nn.QuantizedEmbedding) or isinstance(module, nn.Embedding):
dtype = module(mx.zeros(1, dtype=input_ids.dtype)).dtype
break
causal_mask = create_causal_mask(token_mask.shape[-1], dtype=dtype)
# HACK: sometimes see NaN logprobs if no divide by 2 here
# mask = mx.where(token_mask[:, None, None, :], causal_mask, mx.finfo(dtype).min / 2)
mask = mx.where(token_mask[:, None, None, :], causal_mask, -65504. / 2)
output_toks = []
prompt_time = None