diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index 35aedf755..cfb6ffa15 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -133,10 +133,6 @@ def mse_loss( f"targets shape {targets.shape}." ) - assert ( - predictions.shape == targets.shape - ), f"Shape of predictions {predictions.shape} and targets {targets.shape} must match" - loss = mx.square(predictions - targets) return _reduce(loss, reduction)