mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-11 19:56:40 +08:00
Add distributed layers to nn top-level
This commit is contained in:
parent
060e1c9f92
commit
a8b3da7946
@ -60,6 +60,12 @@ from mlx.nn.layers.convolution_transpose import (
|
||||
ConvTranspose2d,
|
||||
ConvTranspose3d,
|
||||
)
|
||||
from mlx.nn.layers.distributed import (
|
||||
AllToShardedLinear,
|
||||
QuantizedAllToShardedLinear,
|
||||
QuantizedShardedToAllLinear,
|
||||
ShardedToAllLinear,
|
||||
)
|
||||
from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d
|
||||
from mlx.nn.layers.embedding import Embedding
|
||||
from mlx.nn.layers.linear import Bilinear, Identity, Linear
|
||||
|
Loading…
Reference in New Issue
Block a user