mlx.nn.MultiHeadAttention#
- class mlx.nn.MultiHeadAttention(dims: int, num_heads: int, query_input_dims: Optional[int] = None, key_input_dims: Optional[int] = None, value_input_dims: Optional[int] = None, value_dims: Optional[int] = None, value_output_dims: Optional[int] = None, bias: bool = False)#
Implements the scaled dot product attention with multiple heads.
Given inputs for queries, keys and values the
MultiHeadAttention
produces new values by aggregating information from the input values according to the similarities of the input queries and keys.All inputs as well as the output are linearly projected without biases.
MultiHeadAttention also expects an additive attention mask that should be broadcastable with (batch, num_heads, # queries, # keys). The mask should have
-inf
or very negative numbers to the positions that should not be attended to.- Parameters:
dims (int) – The model dimensions. If no other dims are provided then dims is used for queries, keys, values and the output.
num_heads (int) – How many attention heads to use
query_input_dims (int, optional) – The input dimensions of the queries (default: dims).
key_input_dims (int, optional) – The input dimensions of the keys (default: dims).
value_input_dims (int, optional) – The input dimensions of the values (default: key_input_dims).
value_dims (int, optional) – The dimensions of the values after the projection (default: dims).
value_output_dims (int, optional) – The dimensions the new values will be projected to (default: dims).