diff --git a/docs/src/python/nn/losses.rst b/docs/src/python/nn/losses.rst index 31f40fb1f..6c4327eb8 100644 --- a/docs/src/python/nn/losses.rst +++ b/docs/src/python/nn/losses.rst @@ -12,6 +12,7 @@ Loss Functions binary_cross_entropy cosine_similarity_loss cross_entropy + gaussian_nll_loss hinge_loss huber_loss kl_div_loss diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index 2a4c5bd9b..0299e0e38 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -1,12 +1,14 @@ # Copyright © 2023 Apple Inc. import math +from typing import Literal import mlx.core as mx -from mlx.nn.layers.base import Module + +Reduction = Literal["none", "mean", "sum"] -def _reduce(loss: mx.array, reduction: str = "none"): +def _reduce(loss: mx.array, reduction: Reduction = "none"): if reduction == "mean": return mx.mean(loss) elif reduction == "sum": @@ -23,7 +25,7 @@ def cross_entropy( weights: mx.array = None, axis: int = -1, label_smoothing: float = 0.0, - reduction: str = "none", + reduction: Reduction = "none", ) -> mx.array: """ Computes the cross entropy loss. @@ -72,7 +74,7 @@ def cross_entropy( def binary_cross_entropy( - logits: mx.array, targets: mx.array, reduction: str = "none" + logits: mx.array, targets: mx.array, reduction: Reduction = "none" ) -> mx.array: """ Computes the binary cross entropy loss. @@ -99,7 +101,7 @@ def binary_cross_entropy( def l1_loss( - predictions: mx.array, targets: mx.array, reduction: str = "mean" + predictions: mx.array, targets: mx.array, reduction: Reduction = "mean" ) -> mx.array: """ Computes the L1 loss. @@ -124,7 +126,7 @@ def l1_loss( def mse_loss( - predictions: mx.array, targets: mx.array, reduction: str = "mean" + predictions: mx.array, targets: mx.array, reduction: Reduction = "mean" ) -> mx.array: """ Computes the mean squared error loss. @@ -149,7 +151,7 @@ def mse_loss( def nll_loss( - inputs: mx.array, targets: mx.array, axis: int = -1, reduction: str = "none" + inputs: mx.array, targets: mx.array, axis: int = -1, reduction: Reduction = "none" ) -> mx.array: """ Computes the negative log likelihood loss. @@ -169,8 +171,63 @@ def nll_loss( return _reduce(loss, reduction) +def gaussian_nll_loss( + inputs: mx.array, + targets: mx.array, + vars: mx.array, + full: bool = False, + eps: float = 1e-6, + reduction: Reduction = "mean", +) -> mx.array: + r""" + Computes the negative log likelihood loss for a Gaussian distribution. + + The loss is given by: + + .. math:: + \frac{1}{2}\left(\log\left(\max\left(\text{vars}, + \ \epsilon\right)\right) + \frac{\left(\text{inputs} - \text{targets} \right)^2} + {\max\left(\text{vars}, \ \epsilon \right)}\right) + \text{const.} + + where ``inputs`` are the predicted means and ``vars`` are the the + predicted variances. + + Args: + inputs (array): The predicted expectation of the Gaussian distribution. + targets (array): The target values (samples from the Gaussian distribution). + vars (array): The predicted variance of the Gaussian distribution. + full (bool, optional): Whether to include the constant term in the loss calculation. + Default: ``False``. + eps (float, optional): Small positive constant for numerical stability. + Default: ``1e-6``. + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. + + Returns: + array: The Gaussian NLL loss. + """ + if inputs.shape != targets.shape: + raise ValueError( + f"Inputs shape {inputs.shape} does not match targets shape {targets.shape}." + ) + + if inputs.shape != vars.shape: + raise ValueError( + f"Inputs shape {inputs.shape} does not match vars shape {vars.shape}." + ) + + # For stability + vars = mx.maximum(vars, eps) + loss = 0.5 * (mx.log(vars) + mx.square(targets - inputs) / vars) + + if full: + loss += 0.5 * math.log(2 * math.pi) + + return _reduce(loss, reduction) + + def kl_div_loss( - inputs: mx.array, targets: mx.array, axis: int = -1, reduction: str = "none" + inputs: mx.array, targets: mx.array, axis: int = -1, reduction: Reduction = "none" ) -> mx.array: """ Computes the Kullback-Leibler divergence loss. @@ -197,7 +254,10 @@ def kl_div_loss( def smooth_l1_loss( - predictions: mx.array, targets: mx.array, beta: float = 1.0, reduction: str = "mean" + predictions: mx.array, + targets: mx.array, + beta: float = 1.0, + reduction: Reduction = "mean", ) -> mx.array: r""" Computes the smooth L1 loss. @@ -210,7 +270,7 @@ def smooth_l1_loss( .. math:: - l = + l = \begin{cases} 0.5 (x - y)^2, & \text{ if } & (x - y) < \beta \\ |x - y| - 0.5 \beta, & & \text{otherwise} @@ -249,7 +309,7 @@ def triplet_loss( p: int = 2, margin: float = 1.0, eps: float = 1e-6, - reduction: str = "none", + reduction: Reduction = "none", ) -> mx.array: r""" Computes the triplet loss for a set of anchor, positive, and negative samples. @@ -257,7 +317,7 @@ def triplet_loss( .. math:: - L_{\text{triplet}} = \max\left(\|A - P\|_p - \|A - N\|_p + \alpha, 0\right) + \max\left(\|A - P\|_p - \|A - N\|_p + \alpha, 0\right) Args: anchors (array): The anchor samples. @@ -284,7 +344,7 @@ def triplet_loss( def hinge_loss( - inputs: mx.array, targets: mx.array, reduction: str = "none" + inputs: mx.array, targets: mx.array, reduction: Reduction = "none" ) -> mx.array: r""" Computes the hinge loss between inputs and targets. @@ -309,14 +369,17 @@ def hinge_loss( def huber_loss( - inputs: mx.array, targets: mx.array, delta: float = 1.0, reduction: str = "none" + inputs: mx.array, + targets: mx.array, + delta: float = 1.0, + reduction: Reduction = "none", ) -> mx.array: r""" Computes the Huber loss between inputs and targets. .. math:: - L_{\delta}(a) = + l_{\delta}(a) = \left\{ \begin{array}{ll} \frac{1}{2} a^2 & \text{for } |a| \leq \delta, \\ \delta \left( |a| - \frac{1}{2} \delta \right) & \text{otherwise.} @@ -343,7 +406,7 @@ def huber_loss( def log_cosh_loss( - inputs: mx.array, targets: mx.array, reduction: str = "none" + inputs: mx.array, targets: mx.array, reduction: Reduction = "none" ) -> mx.array: r""" Computes the log cosh loss between inputs and targets. @@ -379,7 +442,7 @@ def cosine_similarity_loss( x2: mx.array, axis: int = 1, eps: float = 1e-8, - reduction: str = "none", + reduction: Reduction = "none", ) -> mx.array: r""" Computes the cosine similarity between the two inputs. diff --git a/python/tests/test_losses.py b/python/tests/test_losses.py index 63bd5a20e..2db0ebb58 100644 --- a/python/tests/test_losses.py +++ b/python/tests/test_losses.py @@ -211,6 +211,51 @@ class TestLosses(mlx_tests.MLXTestCase): expected_sum = mx.sum(expected_none) self.assertEqual(losses_sum, expected_sum) + def test_gaussian_nll_loss(self): + inputs = mx.array([[0.1, 0.2], [0.3, 0.4]]) + targets = mx.array([[0.2, 0.1], [0.1, 0.2]]) + vars = mx.array([[0.1, 0.2], [0.3, 0.4]]) + + # Test with reduction 'none', full=False + losses_none = nn.losses.gaussian_nll_loss( + inputs, targets, vars, reduction="none" + ) + expected_none = mx.array([[-1.101293, -0.779719], [-0.535320, -0.408145]]) + self.assertTrue(mx.allclose(losses_none, expected_none)) + + # Test with reduction 'mean', full=False + losses_mean = nn.losses.gaussian_nll_loss( + inputs, targets, vars, reduction="mean" + ) + expected_mean = mx.mean(expected_none) + self.assertTrue(mx.allclose(losses_mean, expected_mean)) + + # Test with reduction 'sum', full=False + losses_sum = nn.losses.gaussian_nll_loss(inputs, targets, vars, reduction="sum") + expected_sum = mx.sum(expected_none) + self.assertTrue(mx.allclose(losses_sum, expected_sum)) + + # Test with reduction='none', full=True + losses_none_full = nn.losses.gaussian_nll_loss( + inputs, targets, vars, full=True, reduction="none" + ) + expected_none_full = mx.array([[-0.182354, 0.139220], [0.383619, 0.510793]]) + self.assertTrue(mx.allclose(losses_none_full, expected_none_full)) + + # Test with reduction='mean', full=True + losses_mean_full = nn.losses.gaussian_nll_loss( + inputs, targets, vars, full=True, reduction="mean" + ) + expected_mean_full = mx.mean(expected_none_full) + self.assertTrue(mx.allclose(losses_mean_full, expected_mean_full)) + + # Test with reduction='sum', full=True + losses_sum_full = nn.losses.gaussian_nll_loss( + inputs, targets, vars, full=True, reduction="sum" + ) + expected_sum_full = mx.sum(expected_none_full) + self.assertTrue(mx.allclose(losses_sum_full, expected_sum_full)) + def test_kl_div_loss(self): p_logits = mx.log(mx.array([[0.5, 0.5], [0.8, 0.2]])) q_logits = mx.log(mx.array([[0.5, 0.5], [0.2, 0.8]]))