mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Make shape a tuple (#591)
* shape tuple * also remove simplify from docs * rebase
This commit is contained in:
@@ -81,7 +81,7 @@ class Dropout2d(Module):
|
||||
# Dropout is applied on the whole channel
|
||||
# 3D input: (1, 1, C)
|
||||
# 4D input: (B, 1, 1, C)
|
||||
mask_shape = x.shape
|
||||
mask_shape = list(x.shape)
|
||||
mask_shape[-2] = mask_shape[-3] = 1
|
||||
|
||||
mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)
|
||||
|
||||
Reference in New Issue
Block a user