diff --git a/python/mlx/nn/layers/transformer.py b/python/mlx/nn/layers/transformer.py index 971a2ad56..9b70221ff 100644 --- a/python/mlx/nn/layers/transformer.py +++ b/python/mlx/nn/layers/transformer.py @@ -9,6 +9,7 @@ from mlx.nn.layers.dropout import Dropout from mlx.nn.layers.linear import Linear from mlx.nn.layers.normalization import LayerNorm + class MultiHeadAttention(Module): """Implements the scaled dot product attention with multiple heads.