Add distributed layers to nn top-level

This commit is contained in:
Angelos Katharopoulos 2024-07-15 13:23:51 -07:00
parent 060e1c9f92
commit a8b3da7946

View File

@ -60,6 +60,12 @@ from mlx.nn.layers.convolution_transpose import (
ConvTranspose2d, ConvTranspose2d,
ConvTranspose3d, ConvTranspose3d,
) )
from mlx.nn.layers.distributed import (
AllToShardedLinear,
QuantizedAllToShardedLinear,
QuantizedShardedToAllLinear,
ShardedToAllLinear,
)
from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d
from mlx.nn.layers.embedding import Embedding from mlx.nn.layers.embedding import Embedding
from mlx.nn.layers.linear import Bilinear, Identity, Linear from mlx.nn.layers.linear import Bilinear, Identity, Linear