diff --git a/python/mlx/nn/layers/transformer.py b/python/mlx/nn/layers/transformer.py index 2c586cd3e..68d9303ac 100644 --- a/python/mlx/nn/layers/transformer.py +++ b/python/mlx/nn/layers/transformer.py @@ -16,7 +16,7 @@ class MultiHeadAttention(Module): new values by aggregating information from the input values according to the similarities of the input queries and keys. - All inputs as well as the output are lineary projected without biases. + All inputs as well as the output are linearly projected without biases. MultiHeadAttention also expects an additive attention mask that should be broadcastable with (batch, num_heads, # queries, # keys). The mask should @@ -48,7 +48,7 @@ class MultiHeadAttention(Module): if (dims % num_heads) != 0: raise ValueError( - f"The input feature dimensions should be divisble by the number of heads ({dims} % {num_heads}) != 0" + f"The input feature dimensions should be divisible by the number of heads ({dims} % {num_heads}) != 0" ) query_input_dims = query_input_dims or dims