mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-12 12:16:43 +08:00
Add distributed layers to nn top-level
This commit is contained in:
parent
060e1c9f92
commit
a8b3da7946
@ -60,6 +60,12 @@ from mlx.nn.layers.convolution_transpose import (
|
|||||||
ConvTranspose2d,
|
ConvTranspose2d,
|
||||||
ConvTranspose3d,
|
ConvTranspose3d,
|
||||||
)
|
)
|
||||||
|
from mlx.nn.layers.distributed import (
|
||||||
|
AllToShardedLinear,
|
||||||
|
QuantizedAllToShardedLinear,
|
||||||
|
QuantizedShardedToAllLinear,
|
||||||
|
ShardedToAllLinear,
|
||||||
|
)
|
||||||
from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d
|
from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d
|
||||||
from mlx.nn.layers.embedding import Embedding
|
from mlx.nn.layers.embedding import Embedding
|
||||||
from mlx.nn.layers.linear import Bilinear, Identity, Linear
|
from mlx.nn.layers.linear import Bilinear, Identity, Linear
|
||||||
|
Loading…
Reference in New Issue
Block a user