From 2d0130f80fe20484a9e4367dde80891194ece820 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sun, 10 Dec 2023 10:08:19 -0800 Subject: [PATCH] fix loss tests (#118) * fix loss tests * use none as default --- python/mlx/nn/losses.py | 34 ++++++++++++++++++++-------------- python/tests/test_nn.py | 14 +++++++------- 2 files changed, 27 insertions(+), 21 deletions(-) diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index 4e0d14ef7..c6ea53981 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -46,7 +46,7 @@ def l1_loss( def mse_loss( - predictions: mx.array, targets: mx.array, axis: int = -1, reduction: str = "none" + predictions: mx.array, targets: mx.array, reduction: str = "none" ) -> mx.array: """ Computes the mean squared error loss between predictions and targets. @@ -54,56 +54,62 @@ def mse_loss( Args: predictions (mx.array): The predicted values. targets (mx.array): The target values. - axis (int, optional): The axis over which to compute softmax. Default: ``-1``. reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. Returns: mx.array: The computed mean squared error loss. """ - loss = mx.mean(mx.square(predictions - targets), axis) + loss = mx.square(predictions - targets) return _reduce(loss, reduction) def nll_loss( - logits: mx.array, targets: mx.array, axis: int = -1, reduction: str = "none" + inputs: mx.array, targets: mx.array, axis: int = -1, reduction: str = "none" ) -> mx.array: """ - Computes the negative log likelihood loss between logits and targets. + Computes the negative log likelihood loss between inputs and targets. Args: - logits (mx.array): The predicted logits. + inputs (mx.array): The predicted distribution in log space. targets (mx.array): The target values. - axis (int, optional): The axis over which to compute softmax. Default: ``-1``. + axis (int, optional): The distribution axis. Default: ``-1``. reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. Returns: mx.array: The computed NLL loss. """ - loss = -mx.take_along_axis(logits, targets[..., None], axis).squeeze(-1) + loss = -mx.take_along_axis(inputs, targets[..., None], axis).squeeze(-1) return _reduce(loss, reduction) def kl_div_loss( - logits: mx.array, targets: mx.array, axis: int = -1, reduction: str = "none" + inputs: mx.array, targets: mx.array, axis: int = -1, reduction: str = "none" ) -> mx.array: """ - Computes the Kullback-Leibler divergence loss between logits and targets. + Computes the Kullback-Leibler divergence loss between targets and the + inputs. + + Computes the following when ``reduction == 'none'``: + + .. code-block:: python + + mx.exp(targets) * (targets - inputs).sum(axis) Args: - logits (mx.array): Logits for the distribution p. - targets (mx.array): Log probabilities for the distribution q. - axis (int, optional): The axis over which to compute softmax. Default: ``-1``. + inputs (mx.array): Log probabilities for the predicted distribution. + targets (mx.array): Log probabilities for the target distribution. + axis (int, optional): The distribution axis. Default: ``-1``. reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. Returns: mx.array: The computed Kullback-Leibler divergence loss. """ - loss = mx.sum(mx.exp(targets) * (targets - logits), axis) + loss = mx.sum(mx.exp(targets) * (targets - inputs), axis) return _reduce(loss, reduction) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index f9bb6a200..e434c3ae8 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -50,7 +50,7 @@ class TestNN(mlx_tests.MLXTestCase): # Test with reduction 'none' losses_none = nn.losses.mse_loss(predictions, targets, reduction="none") expected_none = mx.array([0.04, 0.01, 0.01, 0.04]) - self.assertTrue(mx.array_equal(losses_none, expected_none)) + self.assertTrue(mx.allclose(losses_none, expected_none)) # Test with reduction 'mean' losses_mean = nn.losses.mse_loss(predictions, targets, reduction="mean") @@ -82,23 +82,23 @@ class TestNN(mlx_tests.MLXTestCase): self.assertEqual(losses_sum, expected_sum) def test_kl_div_loss(self): - p_logits = mx.array([[1.0, 2.0], [0.5, 1.0]]) - q_logits = mx.array([[0.8, 1.5], [0.4, 1.2]]) + p_logits = mx.log(mx.array([[0.5, 0.5], [0.8, 0.2]])) + q_logits = mx.log(mx.array([[0.5, 0.5], [0.2, 0.8]])) # Test with reduction 'none' losses_none = nn.losses.kl_div_loss(p_logits, q_logits, reduction="none") - expected_none = mx.array([0.22314353, 0.09966799]) - self.assertTrue(mx.array_equal(losses_none, expected_none)) + expected_none = mx.array([0.0, 0.831777]) + self.assertTrue(mx.allclose(losses_none, expected_none)) # Test with reduction 'mean' losses_mean = nn.losses.kl_div_loss(p_logits, q_logits, reduction="mean") expected_mean = mx.mean(expected_none) - self.assertEqual(losses_mean, expected_mean) + self.assertTrue(mx.allclose(losses_mean, expected_mean)) # Test with reduction 'sum' losses_sum = nn.losses.kl_div_loss(p_logits, q_logits, reduction="sum") expected_sum = mx.sum(expected_none) - self.assertEqual(losses_sum, expected_sum) + self.assertTrue(mx.allclose(losses_sum, expected_sum)) def test_gelu(self): inputs = [1.15286231, -0.81037411, 0.35816911, 0.77484438, 0.66276414]