diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index 58232363a3..aceb1f98ab 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -352,7 +352,7 @@ def smooth_l1_loss( .. math:: l = \begin{cases} - 0.5 (x - y)^2, & \text{if } (x - y) < \beta \\ + 0.5 (x - y)^2 / \beta, & \text{if } |x - y| < \beta \\ |x - y| - 0.5 \beta, & \text{otherwise} \end{cases}