mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 10:48:09 +08:00
Make shape a tuple (#591)
* shape tuple * also remove simplify from docs * rebase
This commit is contained in:
@@ -81,7 +81,7 @@ class Dropout2d(Module):
|
||||
# Dropout is applied on the whole channel
|
||||
# 3D input: (1, 1, C)
|
||||
# 4D input: (B, 1, 1, C)
|
||||
mask_shape = x.shape
|
||||
mask_shape = list(x.shape)
|
||||
mask_shape[-2] = mask_shape[-3] = 1
|
||||
|
||||
mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)
|
||||
|
@@ -70,8 +70,9 @@ def cross_entropy(
|
||||
targets_as_probs = targets.ndim == logits.ndim
|
||||
|
||||
def _drop_dim(shape, axis):
|
||||
shape = list(shape)
|
||||
shape.pop(axis)
|
||||
return shape
|
||||
return tuple(shape)
|
||||
|
||||
# Check shapes in two cases: targets as class indices and targets as probabilities
|
||||
if (targets_as_probs and targets.shape != logits.shape) or (
|
||||
|
Reference in New Issue
Block a user