diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index 0a9367ec4b..067dcd6ddd 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -2,7 +2,10 @@ import mlx.core as mx -def cross_entropy(logits: mx.array, targets: mx.array, axis: int = -1, reduction: str = 'none'): + +def cross_entropy( + logits: mx.array, targets: mx.array, axis: int = -1, reduction: str = "none" +): """ Computes the cross entropy loss between logits and targets. @@ -10,10 +13,10 @@ def cross_entropy(logits: mx.array, targets: mx.array, axis: int = -1, reduction logits (mx.array): The predicted logits. targets (mx.array): The target values. axis (int, optional): The axis over which to compute softmax. Defaults to -1. - reduction (str, optional): Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. - 'none': no reduction will be applied. + reduction (str, optional): Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. + 'none': no reduction will be applied. 'mean': the sum of the output will be divided by the number of elements in the output. - 'sum': the output will be summed. + 'sum': the output will be summed. Defaults to 'none'. Returns: @@ -22,15 +25,16 @@ def cross_entropy(logits: mx.array, targets: mx.array, axis: int = -1, reduction score = mx.take_along_axis(logits, targets[..., None], axis).squeeze(-1) loss = mx.logsumexp(logits, axis=axis) - score - if reduction == 'mean': + if reduction == "mean": return mx.mean(loss) - elif reduction == 'sum': + elif reduction == "sum": return mx.sum(loss) - elif reduction == 'none': + elif reduction == "none": return loss else: raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.") + def l1_loss(predictions: mx.array, targets: mx.array): """ Computes the L1 loss between predictions and targets. @@ -43,4 +47,3 @@ def l1_loss(predictions: mx.array, targets: mx.array): mx.array: The computed L1 loss. """ return mx.mean(mx.abs(predictions - targets)) - diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 3c1fbdc6d8..ede96aa325 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -10,6 +10,7 @@ import mlx_tests import numpy as np from mlx.utils import tree_flatten, tree_map, tree_unflatten + class TestNN(mlx_tests.MLXTestCase): def test_linear(self): inputs = mx.zeros((10, 4)) @@ -22,17 +23,17 @@ class TestNN(mlx_tests.MLXTestCase): targets = mx.array([0, 1]) # Test with reduction 'none' - losses_none = nn.losses.cross_entropy(logits, targets, reduction='none') + losses_none = nn.losses.cross_entropy(logits, targets, reduction="none") expected_none = mx.array([0.0, 0.0]) self.assertTrue(mx.array_equal(losses_none, expected_none)) # Test with reduction 'mean' - losses_mean = nn.losses.cross_entropy(logits, targets, reduction='mean') + losses_mean = nn.losses.cross_entropy(logits, targets, reduction="mean") expected_mean = mx.mean(expected_none) self.assertEqual(losses_mean, expected_mean) # Test with reduction 'sum' - losses_sum = nn.losses.cross_entropy(logits, targets, reduction='sum') + losses_sum = nn.losses.cross_entropy(logits, targets, reduction="sum") expected_sum = mx.sum(expected_none) self.assertEqual(losses_sum, expected_sum)