Added ALiBi implementation (#232)

This commit is contained in:
Hazem Essam 2023-12-22 00:36:38 +02:00 committed by GitHub
parent 794feb83df
commit 0aa65c7a6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 50 additions and 1 deletions

View File

@ -37,7 +37,7 @@ from mlx.nn.layers.dropout import Dropout
from mlx.nn.layers.embedding import Embedding from mlx.nn.layers.embedding import Embedding
from mlx.nn.layers.linear import Linear from mlx.nn.layers.linear import Linear
from mlx.nn.layers.normalization import GroupNorm, LayerNorm, RMSNorm from mlx.nn.layers.normalization import GroupNorm, LayerNorm, RMSNorm
from mlx.nn.layers.positional_encoding import RoPE, SinusoidalPositionalEncoding from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
from mlx.nn.layers.quantized import QuantizedLinear from mlx.nn.layers.quantized import QuantizedLinear
from mlx.nn.layers.transformer import ( from mlx.nn.layers.transformer import (
MultiHeadAttention, MultiHeadAttention,

View File

@ -142,3 +142,40 @@ class SinusoidalPositionalEncoding(Module):
y = y * self.scale y = y * self.scale
return y return y
class ALiBi(Module):
@staticmethod
def create_alibi_matrix(
q_sequence_length: int,
k_sequence_length: int,
num_heads: int,
offset: int,
dtype=mx.float32,
):
x1 = mx.arange(offset, q_sequence_length)
x2 = mx.arange(0, k_sequence_length)
distance_matrix = -mx.abs(
mx.expand_dims(x1[:, None] - x2[None, :], axis=(0, 1))
)
alibi_slope = ALiBi.create_alibi_slope(num_heads=num_heads)
alibi_mask = (distance_matrix * alibi_slope).astype(dtype)
return alibi_mask
@staticmethod
def create_alibi_slope(num_heads):
x = (2**8) ** (1 / num_heads)
out = mx.power(x, -mx.arange(1, num_heads + 1))
return mx.expand_dims(out, axis=(-1, -2))
def __call__(self, attention_scores, offset=0, mask=None):
alibi_mask = ALiBi.create_alibi_matrix(
q_sequence_length=attention_scores.shape[-2] + offset,
k_sequence_length=attention_scores.shape[-1],
num_heads=attention_scores.shape[1],
offset=offset,
dtype=attention_scores.dtype,
)
if mask is not None:
alibi_mask = alibi_mask + mask
return attention_scores + alibi_mask

View File

@ -570,6 +570,18 @@ class TestNN(mlx_tests.MLXTestCase):
y = rope(x.astype(mx.float16)) y = rope(x.astype(mx.float16))
self.assertTrue(y.dtype, mx.float16) self.assertTrue(y.dtype, mx.float16)
def test_alibi(self):
for kwargs in [{"num_heads": 8}]:
alibi = nn.ALibi(**kwargs)
shape = [1, 8, 20, 20]
x = mx.random.uniform(shape=shape)
y = alibi(x)
self.assertTrue(y.shape, shape)
self.assertTrue(y.dtype, mx.float32)
y = alibi(x.astype(mx.float16))
self.assertTrue(y.dtype, mx.float16)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()