changed the returned variable name in create_alibi_matrix for consistency

This commit is contained in:
Hazem
2023-12-21 21:41:43 +02:00
parent 33de113222
commit a673f620c1

View File

@@ -162,8 +162,8 @@ class ALiBi(Module):
mx.expand_dims(x1[:, None] - x2[None, :], axis=(0, 1))
)
alibi_slope = ALiBi.create_alibi_slope(num_heads)
alibi_matrix = (distance_matrix * alibi_slope).astype(dtype)
return alibi_matrix
alibi_mask = (distance_matrix * alibi_slope).astype(dtype)
return alibi_mask
@staticmethod
def create_alibi_slope(num_heads):