From 0b283996386bfe1bbe6c644bab8d3e3251f71ef0 Mon Sep 17 00:00:00 2001 From: Enoch Kan Date: Sat, 9 Dec 2023 22:25:03 +0000 Subject: [PATCH] added mse_loss, nll_loss and kl_div_loss (#98) * added mse_loss, nll_loss and kl_div_loss * fixed axis not defined error in nll_loss * fixed axis not defined in kl_div_loss * added tests for mse, nll and kl_div * modified docstrings and added reduce helper func * updated docstring in kl_div_loss and moved helper func * added new kl divergence implementation * added reduction to test * updated docstring of kl_div_loss with correct spelling * added losses to nn.rst in docs --- docs/src/python/nn.rst | 3 ++ python/mlx/nn/losses.py | 95 ++++++++++++++++++++++++++++++++++------- python/tests/test_nn.py | 61 +++++++++++++++++++++++++- 3 files changed, 142 insertions(+), 17 deletions(-) diff --git a/docs/src/python/nn.rst b/docs/src/python/nn.rst index a0aa0bfad..fe3924593 100644 --- a/docs/src/python/nn.rst +++ b/docs/src/python/nn.rst @@ -180,3 +180,6 @@ Loss Functions losses.cross_entropy losses.l1_loss + losses.mse_loss + losses.nll_loss + losses.kl_div_loss diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index 3445b686e..068d2db78 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -2,10 +2,9 @@ import mlx.core as mx - def cross_entropy( logits: mx.array, targets: mx.array, axis: int = -1, reduction: str = "none" -): +) -> mx.array: """ Computes the cross entropy loss between logits and targets. @@ -22,6 +21,84 @@ def cross_entropy( score = mx.take_along_axis(logits, targets[..., None], axis).squeeze(-1) loss = mx.logsumexp(logits, axis=axis) - score + return _reduce(loss, reduction) + + +def l1_loss(predictions: mx.array, targets: mx.array, reduction: str = "none") -> mx.array: + """ + Computes the L1 loss between predictions and targets. + + Args: + predictions (mx.array): The predicted values. + targets (mx.array): The target values. + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. + + Returns: + mx.array: The computed L1 loss. + """ + loss = mx.mean(mx.abs(predictions - targets)) + + return _reduce(loss, reduction) + + +def mse_loss(predictions: mx.array, targets: mx.array, axis: int = -1, reduction: str = "none") -> mx.array: + """ + Computes the mean squared error loss between predictions and targets. + + Args: + predictions (mx.array): The predicted values. + targets (mx.array): The target values. + axis (int, optional): The axis over which to compute softmax. Default: ``-1``. + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. + + Returns: + mx.array: The computed mean squared error loss. + """ + loss = mx.mean(mx.square(predictions - targets), axis) + + return _reduce(loss, reduction) + + +def nll_loss(logits: mx.array, targets: mx.array, axis: int = -1, reduction: str = "none") -> mx.array: + """ + Computes the negative log likelihood loss between logits and targets. + + Args: + logits (mx.array): The predicted logits. + targets (mx.array): The target values. + axis (int, optional): The axis over which to compute softmax. Default: ``-1``. + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. + + Returns: + mx.array: The computed NLL loss. + """ + loss = -mx.take_along_axis(logits, targets[..., None], axis).squeeze(-1) + + return _reduce(loss, reduction) + + +def kl_div_loss(logits: mx.array, targets: mx.array, axis: int = -1, reduction: str = "none") -> mx.array: + """ + Computes the Kullback-Leibler divergence loss between logits and targets. + + Args: + logits (mx.array): Logits for the distribution p. + targets (mx.array): Log probabilities for the distribution q. + axis (int, optional): The axis over which to compute softmax. Default: ``-1``. + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. + + Returns: + mx.array: The computed Kullback-Leibler divergence loss. + """ + loss = mx.sum(mx.exp(targets) * (targets - logits), axis) + + return _reduce(loss, reduction) + +def _reduce(loss: mx.array, reduction: str = 'none'): if reduction == "mean": return mx.mean(loss) elif reduction == "sum": @@ -30,17 +107,3 @@ def cross_entropy( return loss else: raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.") - - -def l1_loss(predictions: mx.array, targets: mx.array): - """ - Computes the L1 loss between predictions and targets. - - Args: - predictions (mx.array): The predicted values. - targets (mx.array): The target values. - - Returns: - mx.array: The computed L1 loss. - """ - return mx.mean(mx.abs(predictions - targets)) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index ede96aa32..19ef2eddd 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -40,8 +40,67 @@ class TestNN(mlx_tests.MLXTestCase): def test_l1_loss(self): predictions = mx.array([0.5, 0.2, 0.9, 0.0]) targets = mx.array([0.5, 0.2, 0.9, 0.0]) - losses = nn.losses.l1_loss(predictions, targets) + losses = nn.losses.l1_loss(predictions, targets, reduction="none") self.assertEqual(losses, 0.0) + + def test_mse_loss(self): + predictions = mx.array([0.5, 0.2, 0.9, 0.0]) + targets = mx.array([0.7, 0.1, 0.8, 0.2]) + + # Test with reduction 'none' + losses_none = nn.losses.mse_loss(predictions, targets, reduction="none") + expected_none = mx.array([0.04, 0.01, 0.01, 0.04]) + self.assertTrue(mx.array_equal(losses_none, expected_none)) + + # Test with reduction 'mean' + losses_mean = nn.losses.mse_loss(predictions, targets, reduction="mean") + expected_mean = mx.mean(expected_none) + self.assertEqual(losses_mean, expected_mean) + + # Test with reduction 'sum' + losses_sum = nn.losses.mse_loss(predictions, targets, reduction="sum") + expected_sum = mx.sum(expected_none) + self.assertEqual(losses_sum, expected_sum) + + + def test_nll_loss(self): + logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]]) + targets = mx.array([0, 1]) + + # Test with reduction 'none' + losses_none = nn.losses.nll_loss(logits, targets, reduction="none") + expected_none = mx.array([0.0, 0.0]) + self.assertTrue(mx.array_equal(losses_none, expected_none)) + + # Test with reduction 'mean' + losses_mean = nn.losses.nll_loss(logits, targets, reduction="mean") + expected_mean = mx.mean(expected_none) + self.assertEqual(losses_mean, expected_mean) + + # Test with reduction 'sum' + losses_sum = nn.losses.nll_loss(logits, targets, reduction="sum") + expected_sum = mx.sum(expected_none) + self.assertEqual(losses_sum, expected_sum) + + + def test_kl_div_loss(self): + p_logits = mx.array([[1.0, 2.0], [0.5, 1.0]]) + q_logits = mx.array([[0.8, 1.5], [0.4, 1.2]]) + + # Test with reduction 'none' + losses_none = nn.losses.kl_div_loss(p_logits, q_logits, reduction="none") + expected_none = mx.array([0.22314353, 0.09966799]) + self.assertTrue(mx.array_equal(losses_none, expected_none)) + + # Test with reduction 'mean' + losses_mean = nn.losses.kl_div_loss(p_logits, q_logits, reduction="mean") + expected_mean = mx.mean(expected_none) + self.assertEqual(losses_mean, expected_mean) + + # Test with reduction 'sum' + losses_sum = nn.losses.kl_div_loss(p_logits, q_logits, reduction="sum") + expected_sum = mx.sum(expected_none) + self.assertEqual(losses_sum, expected_sum) def test_gelu(self): inputs = [1.15286231, -0.81037411, 0.35816911, 0.77484438, 0.66276414]