Added ALiBi implementation (#232)

This commit is contained in:
Hazem Essam 2023-12-22 00:36:38 +02:00 committed by GitHub
parent 794feb83df
commit 0aa65c7a6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 50 additions and 1 deletions

View File

@ -37,7 +37,7 @@ from mlx.nn.layers.dropout import Dropout
from mlx.nn.layers.embedding import Embedding
from mlx.nn.layers.linear import Linear
from mlx.nn.layers.normalization import GroupNorm, LayerNorm, RMSNorm
from mlx.nn.layers.positional_encoding import RoPE, SinusoidalPositionalEncoding
from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
from mlx.nn.layers.quantized import QuantizedLinear
from mlx.nn.layers.transformer import (
MultiHeadAttention,

View File

@ -142,3 +142,40 @@ class SinusoidalPositionalEncoding(Module):
y = y * self.scale
return y
class ALiBi(Module):
@staticmethod
def create_alibi_matrix(
q_sequence_length: int,
k_sequence_length: int,
num_heads: int,
offset: int,
dtype=mx.float32,
):
x1 = mx.arange(offset, q_sequence_length)
x2 = mx.arange(0, k_sequence_length)
distance_matrix = -mx.abs(
mx.expand_dims(x1[:, None] - x2[None, :], axis=(0, 1))
)
alibi_slope = ALiBi.create_alibi_slope(num_heads=num_heads)
alibi_mask = (distance_matrix * alibi_slope).astype(dtype)
return alibi_mask
@staticmethod
def create_alibi_slope(num_heads):
x = (2**8) ** (1 / num_heads)
out = mx.power(x, -mx.arange(1, num_heads + 1))
return mx.expand_dims(out, axis=(-1, -2))
def __call__(self, attention_scores, offset=0, mask=None):
alibi_mask = ALiBi.create_alibi_matrix(
q_sequence_length=attention_scores.shape[-2] + offset,
k_sequence_length=attention_scores.shape[-1],
num_heads=attention_scores.shape[1],
offset=offset,
dtype=attention_scores.dtype,
)
if mask is not None:
alibi_mask = alibi_mask + mask
return attention_scores + alibi_mask

View File

@ -570,6 +570,18 @@ class TestNN(mlx_tests.MLXTestCase):
y = rope(x.astype(mx.float16))
self.assertTrue(y.dtype, mx.float16)
def test_alibi(self):
for kwargs in [{"num_heads": 8}]:
alibi = nn.ALibi(**kwargs)
shape = [1, 8, 20, 20]
x = mx.random.uniform(shape=shape)
y = alibi(x)
self.assertTrue(y.shape, shape)
self.assertTrue(y.dtype, mx.float32)
y = alibi(x.astype(mx.float16))
self.assertTrue(y.dtype, mx.float16)
if __name__ == "__main__":
unittest.main()