From 641d316484115bea9b58047fb9f15ce7749abd47 Mon Sep 17 00:00:00 2001 From: Kai Ma <77073162+k78ma@users.noreply.github.com> Date: Fri, 8 Dec 2023 23:21:37 -0500 Subject: [PATCH] MLE and L1 loss functions (#88) * MLE and L1 loss functions * logsoftmax change and tests * subtract max logit for numerical stability * l1 name change * cross entropy reduction + unit tests * docstrings * l1 test name change * old loss impl + default none --- python/mlx/nn/losses.py | 42 +++++++++++++++++++++++++++++++++++++++-- python/tests/test_nn.py | 24 ++++++++++++++++++++--- 2 files changed, 61 insertions(+), 5 deletions(-) diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index dfa806c34..0a9367ec4 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -2,7 +2,45 @@ import mlx.core as mx +def cross_entropy(logits: mx.array, targets: mx.array, axis: int = -1, reduction: str = 'none'): + """ + Computes the cross entropy loss between logits and targets. -def cross_entropy(logits: mx.array, targets: mx.array, axis: int = -1): + Args: + logits (mx.array): The predicted logits. + targets (mx.array): The target values. + axis (int, optional): The axis over which to compute softmax. Defaults to -1. + reduction (str, optional): Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. + 'none': no reduction will be applied. + 'mean': the sum of the output will be divided by the number of elements in the output. + 'sum': the output will be summed. + Defaults to 'none'. + + Returns: + mx.array: The computed cross entropy loss. + """ score = mx.take_along_axis(logits, targets[..., None], axis).squeeze(-1) - return mx.logsumexp(logits, axis=axis) - score + loss = mx.logsumexp(logits, axis=axis) - score + + if reduction == 'mean': + return mx.mean(loss) + elif reduction == 'sum': + return mx.sum(loss) + elif reduction == 'none': + return loss + else: + raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.") + +def l1_loss(predictions: mx.array, targets: mx.array): + """ + Computes the L1 loss between predictions and targets. + + Args: + predictions (mx.array): The predicted values. + targets (mx.array): The target values. + + Returns: + mx.array: The computed L1 loss. + """ + return mx.mean(mx.abs(predictions - targets)) + diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index e8bae0588..3c1fbdc6d 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -10,7 +10,6 @@ import mlx_tests import numpy as np from mlx.utils import tree_flatten, tree_map, tree_unflatten - class TestNN(mlx_tests.MLXTestCase): def test_linear(self): inputs = mx.zeros((10, 4)) @@ -21,8 +20,27 @@ class TestNN(mlx_tests.MLXTestCase): def test_cross_entropy(self): logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]]) targets = mx.array([0, 1]) - losses = nn.losses.cross_entropy(logits, targets) - self.assertTrue(mx.array_equal(losses, mx.zeros((2,)))) + + # Test with reduction 'none' + losses_none = nn.losses.cross_entropy(logits, targets, reduction='none') + expected_none = mx.array([0.0, 0.0]) + self.assertTrue(mx.array_equal(losses_none, expected_none)) + + # Test with reduction 'mean' + losses_mean = nn.losses.cross_entropy(logits, targets, reduction='mean') + expected_mean = mx.mean(expected_none) + self.assertEqual(losses_mean, expected_mean) + + # Test with reduction 'sum' + losses_sum = nn.losses.cross_entropy(logits, targets, reduction='sum') + expected_sum = mx.sum(expected_none) + self.assertEqual(losses_sum, expected_sum) + + def test_l1_loss(self): + predictions = mx.array([0.5, 0.2, 0.9, 0.0]) + targets = mx.array([0.5, 0.2, 0.9, 0.0]) + losses = nn.losses.l1_loss(predictions, targets) + self.assertEqual(losses, 0.0) def test_gelu(self): inputs = [1.15286231, -0.81037411, 0.35816911, 0.77484438, 0.66276414]