diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index 3b7e3487b..5ac82356a 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -37,7 +37,8 @@ from mlx.nn.layers.dropout import Dropout, Dropout2d from mlx.nn.layers.embedding import Embedding from mlx.nn.layers.linear import Linear from mlx.nn.layers.normalization import BatchNorm, GroupNorm, LayerNorm, RMSNorm -from mlx.nn.layers.positional_encoding import RoPE, SinusoidalPositionalEncoding +from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding +from mlx.nn.layers.quantized import QuantizedLinear from mlx.nn.layers.transformer import ( MultiHeadAttention, TransformerEncoder,