From 0aa65c7a6b9d9f043358b6cc3faa2e9336a9769e Mon Sep 17 00:00:00 2001 From: Hazem Essam Date: Fri, 22 Dec 2023 00:36:38 +0200 Subject: [PATCH] Added ALiBi implementation (#232) --- python/mlx/nn/layers/__init__.py | 2 +- python/mlx/nn/layers/positional_encoding.py | 37 +++++++++++++++++++++ python/tests/test_nn.py | 12 +++++++ 3 files changed, 50 insertions(+), 1 deletion(-) diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index aa22e495b..3f03064bf 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -37,7 +37,7 @@ from mlx.nn.layers.dropout import Dropout from mlx.nn.layers.embedding import Embedding from mlx.nn.layers.linear import Linear from mlx.nn.layers.normalization import GroupNorm, LayerNorm, RMSNorm -from mlx.nn.layers.positional_encoding import RoPE, SinusoidalPositionalEncoding +from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding from mlx.nn.layers.quantized import QuantizedLinear from mlx.nn.layers.transformer import ( MultiHeadAttention, diff --git a/python/mlx/nn/layers/positional_encoding.py b/python/mlx/nn/layers/positional_encoding.py index b121a5436..db436f407 100644 --- a/python/mlx/nn/layers/positional_encoding.py +++ b/python/mlx/nn/layers/positional_encoding.py @@ -142,3 +142,40 @@ class SinusoidalPositionalEncoding(Module): y = y * self.scale return y + + +class ALiBi(Module): + @staticmethod + def create_alibi_matrix( + q_sequence_length: int, + k_sequence_length: int, + num_heads: int, + offset: int, + dtype=mx.float32, + ): + x1 = mx.arange(offset, q_sequence_length) + x2 = mx.arange(0, k_sequence_length) + distance_matrix = -mx.abs( + mx.expand_dims(x1[:, None] - x2[None, :], axis=(0, 1)) + ) + alibi_slope = ALiBi.create_alibi_slope(num_heads=num_heads) + alibi_mask = (distance_matrix * alibi_slope).astype(dtype) + return alibi_mask + + @staticmethod + def create_alibi_slope(num_heads): + x = (2**8) ** (1 / num_heads) + out = mx.power(x, -mx.arange(1, num_heads + 1)) + return mx.expand_dims(out, axis=(-1, -2)) + + def __call__(self, attention_scores, offset=0, mask=None): + alibi_mask = ALiBi.create_alibi_matrix( + q_sequence_length=attention_scores.shape[-2] + offset, + k_sequence_length=attention_scores.shape[-1], + num_heads=attention_scores.shape[1], + offset=offset, + dtype=attention_scores.dtype, + ) + if mask is not None: + alibi_mask = alibi_mask + mask + return attention_scores + alibi_mask diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 17ec5175c..2c27d3587 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -570,6 +570,18 @@ class TestNN(mlx_tests.MLXTestCase): y = rope(x.astype(mx.float16)) self.assertTrue(y.dtype, mx.float16) + def test_alibi(self): + for kwargs in [{"num_heads": 8}]: + alibi = nn.ALibi(**kwargs) + shape = [1, 8, 20, 20] + x = mx.random.uniform(shape=shape) + y = alibi(x) + self.assertTrue(y.shape, shape) + self.assertTrue(y.dtype, mx.float32) + + y = alibi(x.astype(mx.float16)) + self.assertTrue(y.dtype, mx.float16) + if __name__ == "__main__": unittest.main()