Add distributed layers to nn top-level

This commit is contained in:
Angelos Katharopoulos 2024-07-15 13:23:51 -07:00
parent 060e1c9f92
commit a8b3da7946

View File

@ -60,6 +60,12 @@ from mlx.nn.layers.convolution_transpose import (
ConvTranspose2d,
ConvTranspose3d,
)
from mlx.nn.layers.distributed import (
AllToShardedLinear,
QuantizedAllToShardedLinear,
QuantizedShardedToAllLinear,
ShardedToAllLinear,
)
from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d
from mlx.nn.layers.embedding import Embedding
from mlx.nn.layers.linear import Bilinear, Identity, Linear