Make shape a tuple (#591)

* shape tuple

* also remove simplify from docs

* rebase
This commit is contained in:
Awni Hannun
2024-01-30 13:11:01 -08:00
committed by GitHub
parent d3a9005454
commit 09b9275027
13 changed files with 141 additions and 140 deletions

View File

@@ -81,7 +81,7 @@ class Dropout2d(Module):
# Dropout is applied on the whole channel
# 3D input: (1, 1, C)
# 4D input: (B, 1, 1, C)
mask_shape = x.shape
mask_shape = list(x.shape)
mask_shape[-2] = mask_shape[-3] = 1
mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)

View File

@@ -70,8 +70,9 @@ def cross_entropy(
targets_as_probs = targets.ndim == logits.ndim
def _drop_dim(shape, axis):
shape = list(shape)
shape.pop(axis)
return shape
return tuple(shape)
# Check shapes in two cases: targets as class indices and targets as probabilities
if (targets_as_probs and targets.shape != logits.shape) or (