From ef73393a19bfc0f005f585656cbd3fa0ab865ec4 Mon Sep 17 00:00:00 2001 From: Aryan Gupta <97878444+guptaaryan16@users.noreply.github.com> Date: Wed, 7 Feb 2024 23:09:52 +0530 Subject: [PATCH] Feat: Add weights argument in BCE Loss and tests (#620) --- python/mlx/nn/losses.py | 11 +++++++++++ python/tests/test_losses.py | 8 ++++++++ 2 files changed, 19 insertions(+) diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index a466c10ed..ee33fde3e 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -117,6 +117,7 @@ def cross_entropy( def binary_cross_entropy( inputs: mx.array, targets: mx.array, + weights: mx.array = None, with_logits: bool = True, reduction: Reduction = "mean", ) -> mx.array: @@ -128,6 +129,7 @@ def binary_cross_entropy( ``inputs`` are unnormalized logits. Otherwise, ``inputs`` are probabilities. targets (array): The binary target values in {0, 1}. with_logits (bool, optional): Whether ``inputs`` are logits. Default: ``True``. + weights (array, optional): Optional weights for each target. Default: ``None``. reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``. @@ -159,6 +161,15 @@ def binary_cross_entropy( else: loss = -(targets * mx.log(inputs) + (1 - targets) * mx.log(1 - inputs)) + # Apply weights if provided + if weights is not None: + if weights.shape != loss.shape: + raise ValueError( + f"Weights with shape {weights.shape} is not the same as " + f"output loss with shape {loss.shape}." + ) + loss *= weights + return _reduce(loss, reduction) diff --git a/python/tests/test_losses.py b/python/tests/test_losses.py index 2160b0a6e..3a430be21 100644 --- a/python/tests/test_losses.py +++ b/python/tests/test_losses.py @@ -92,6 +92,14 @@ class TestLosses(mlx_tests.MLXTestCase): expected_sum = mx.sum(expected_none) self.assertEqual(losses_sum, expected_sum) + # With weights, no label smoothing + weights = mx.array([1.0, 2.0, 1.0, 2.0]) + expected = mx.array([0.747215, 1.62186, 0.262365, 0.672944]) + loss = nn.losses.binary_cross_entropy( + logits, targets, weights=weights, reduction="none" + ) + self.assertTrue(mx.allclose(loss, expected)) + def _test_probs_as_inputs(): probs = mx.array([0.5, 0.6, 0.7, 0.8]) targets = mx.array([0, 0, 1, 1])