diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index bccf45b16..58232363a 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -373,7 +373,7 @@ def smooth_l1_loss( f"targets shape {targets.shape}." ) - diff = predictions - targets + diff = mx.abs(predictions - targets) loss = mx.where( diff < beta, 0.5 * mx.square(diff) / beta, mx.abs(diff) - 0.5 * beta )