mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-04 23:24:41 +08:00
changed the returned variable name in create_alibi_matrix for consistency
This commit is contained in:
@@ -162,8 +162,8 @@ class ALiBi(Module):
|
||||
mx.expand_dims(x1[:, None] - x2[None, :], axis=(0, 1))
|
||||
)
|
||||
alibi_slope = ALiBi.create_alibi_slope(num_heads)
|
||||
alibi_matrix = (distance_matrix * alibi_slope).astype(dtype)
|
||||
return alibi_matrix
|
||||
alibi_mask = (distance_matrix * alibi_slope).astype(dtype)
|
||||
return alibi_mask
|
||||
|
||||
@staticmethod
|
||||
def create_alibi_slope(num_heads):
|
||||
|
Reference in New Issue
Block a user