From 550d4bf7c01b842677a5daaecf22fb6913c46c3d Mon Sep 17 00:00:00 2001 From: AtomicVar Date: Fri, 19 Jan 2024 11:22:23 +0800 Subject: [PATCH] Update binary_cross_entropy function to handle both logits and probabilities (#492) --- python/mlx/nn/losses.py | 35 ++++++++++++++++++----- python/tests/test_losses.py | 55 +++++++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+), 7 deletions(-) diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index 0299e0e38..0f0021050 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -74,29 +74,50 @@ def cross_entropy( def binary_cross_entropy( - logits: mx.array, targets: mx.array, reduction: Reduction = "none" + inputs: mx.array, + targets: mx.array, + with_logits: bool = True, + reduction: Reduction = "mean", ) -> mx.array: """ Computes the binary cross entropy loss. Args: - logits (array): The unnormalized (pre-sigmoid) predicted logits. + inputs (array): The predicted values. If ``with_logits`` is ``True``, then + ``inputs`` are unnormalized logits. Otherwise, ``inputs`` are probabilities. targets (array): The binary target values in {0, 1}. + with_logits (bool, optional): Whether ``inputs`` are logits. Default: ``True``. reduction (str, optional): Specifies the reduction to apply to the output: - ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. + ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``. Returns: array: The computed binary cross entropy loss. Examples: >>> import mlx.core as mx >>> import mlx.nn as nn - >>> inputs = mx.array([0.105361, 0.223144, 1.20397, 0.916291]) + + >>> logits = mx.array([0.105361, 0.223144, 1.20397, 0.916291]) >>> targets = mx.array([0, 0, 1, 1]) - >>> loss = nn.losses.binary_cross_entropy(inputs, targets, "mean") + >>> loss = nn.losses.binary_cross_entropy(logits, targets, reduction="mean") >>> loss - array([0.612192], dtype=float32) + array(0.539245, dtype=float32) + + >>> probs = mx.array([0.1, 0.1, 0.4, 0.4]) + >>> targets = mx.array([0, 0, 1, 1]) + >>> loss = nn.losses.binary_cross_entropy(probs, targets, with_logits=False, reduction="mean") + >>> loss + array(0.510826, dtype=float32) """ - loss = mx.logaddexp(0.0, logits) - targets * logits + if inputs.shape != targets.shape: + raise ValueError( + f"Inputs shape {inputs.shape} does not match targets shape {targets.shape}." + ) + + if with_logits: + loss = mx.logaddexp(0.0, inputs) - inputs * targets + else: + loss = -(targets * mx.log(inputs) + (1 - targets) * mx.log(1 - inputs)) + return _reduce(loss, reduction) diff --git a/python/tests/test_losses.py b/python/tests/test_losses.py index 2db0ebb58..a23f26454 100644 --- a/python/tests/test_losses.py +++ b/python/tests/test_losses.py @@ -105,6 +105,61 @@ class TestLosses(mlx_tests.MLXTestCase): "Test case failed for cross_entropy --label_smoothing=0.3 --reduction='sum'", ) + def test_binary_cross_entropy(self): + def _test_logits_as_inputs(): + logits = mx.array([0.105361, 0.223144, 1.20397, 0.916291]) + targets = mx.array([0, 0, 1, 1]) + + # Test with reduction 'none' + losses_none = nn.losses.binary_cross_entropy( + logits, targets, reduction="none" + ) + expected_none = mx.array([0.747215, 0.810930, 0.262365, 0.336472]) + self.assertTrue(mx.allclose(losses_none, expected_none)) + + # Test with reduction 'mean' + losses_mean = nn.losses.binary_cross_entropy( + logits, targets, reduction="mean" + ) + expected_mean = mx.mean(expected_none) + self.assertEqual(losses_mean, expected_mean) + + # Test with reduction 'sum' + losses_sum = nn.losses.binary_cross_entropy( + logits, targets, reduction="sum" + ) + expected_sum = mx.sum(expected_none) + self.assertEqual(losses_sum, expected_sum) + + def _test_probs_as_inputs(): + probs = mx.array([0.5, 0.6, 0.7, 0.8]) + targets = mx.array([0, 0, 1, 1]) + + # Test with reduction 'none' + losses_none = nn.losses.binary_cross_entropy( + probs, targets, with_logits=False, reduction="none" + ) + expected_none = mx.array([0.693147, 0.916291, 0.356675, 0.223144]) + print(losses_none, expected_none) + self.assertTrue(mx.allclose(losses_none, expected_none)) + + # Test with reduction 'mean' + losses_mean = nn.losses.binary_cross_entropy( + probs, targets, with_logits=False, reduction="mean" + ) + expected_mean = mx.mean(expected_none) + self.assertTrue(mx.allclose(losses_mean, expected_mean)) + + # Test with reduction 'sum' + losses_sum = nn.losses.binary_cross_entropy( + probs, targets, with_logits=False, reduction="sum" + ) + expected_sum = mx.sum(expected_none) + self.assertTrue(mx.allclose(losses_sum, expected_sum)) + + _test_logits_as_inputs() + _test_probs_as_inputs() + def test_l1_loss(self): predictions = mx.array([0.5, 0.2, 0.9, 0.0]) targets = mx.array([0.5, 0.2, 0.9, 0.0])