diff --git a/python/mlx/nn/layers/positional_encoding.py b/python/mlx/nn/layers/positional_encoding.py index 39dbdff51..76a9173f9 100644 --- a/python/mlx/nn/layers/positional_encoding.py +++ b/python/mlx/nn/layers/positional_encoding.py @@ -162,8 +162,8 @@ class ALiBi(Module): mx.expand_dims(x1[:, None] - x2[None, :], axis=(0, 1)) ) alibi_slope = ALiBi.create_alibi_slope(num_heads) - alibi_matrix = (distance_matrix * alibi_slope).astype(dtype) - return alibi_matrix + alibi_mask = (distance_matrix * alibi_slope).astype(dtype) + return alibi_mask @staticmethod def create_alibi_slope(num_heads):