Make shape a tuple (#591)

* shape tuple

* also remove simplify from docs

* rebase
This commit is contained in:
Awni Hannun
2024-01-30 13:11:01 -08:00
committed by GitHub
parent d3a9005454
commit 09b9275027
13 changed files with 141 additions and 140 deletions

View File

@@ -81,7 +81,7 @@ class Dropout2d(Module):
# Dropout is applied on the whole channel
# 3D input: (1, 1, C)
# 4D input: (B, 1, 1, C)
mask_shape = x.shape
mask_shape = list(x.shape)
mask_shape[-2] = mask_shape[-3] = 1
mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)