From f4ddd7dc44aef7eb4d41b7f372d985e513ff808f Mon Sep 17 00:00:00 2001 From: __mo_san__ <50895527+m0saan@users.noreply.github.com> Date: Mon, 11 Dec 2023 16:55:18 +0100 Subject: [PATCH] Add Binary Cross Entropy loss (#122) * update BCE added tests for it ... * added binary cross entropy loss to docs * resolving conflicts for merge --- docs/src/python/nn.rst | 1 + python/mlx/nn/losses.py | 83 ++++++++++++++++++++++++++++++++++++++++- python/tests/test_nn.py | 19 ++++++++++ 3 files changed, 102 insertions(+), 1 deletion(-) diff --git a/docs/src/python/nn.rst b/docs/src/python/nn.rst index fe3924593..15497206a 100644 --- a/docs/src/python/nn.rst +++ b/docs/src/python/nn.rst @@ -179,6 +179,7 @@ Loss Functions :template: nn-module-template.rst losses.cross_entropy + losses.binary_cross_entropy losses.l1_loss losses.mse_loss losses.nll_loss diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index c6ea53981..c5795574c 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -1,6 +1,17 @@ # Copyright © 2023 Apple Inc. import mlx.core as mx +from mlx.nn.layers.base import Module + + +def _make_loss_module(f): + def decorator(klass): + klass.__call__ = lambda self, inputs, targets: f( + inputs, targets, self.reduction + ) + return klass + + return decorator def cross_entropy( @@ -25,6 +36,76 @@ def cross_entropy( return _reduce(loss, reduction) +def binary_cross_entropy( + inputs: mx.array, targets: mx.array, reduction: str = "none" +) -> mx.array: + """ + Computes the binary cross entropy loss between inputs and targets. + + Args: + inputs (mx.array): The predicted inputs (post-sigmoid probabilities). + targets (mx.array): The target values (binary labels). + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. + + Returns: + mx.array: The computed binary cross entropy loss. + Examples: + >>> import mlx.core as mx + >>> import mlx.nn as nn + >>> inputs = mx.array([0.1, 0.2, 0.3, 0.4]) + >>> targets = mx.array([0, 0, 1, 1]) + >>> loss = nn.losses.binary_cross_entropy(inputs, targets) + >>> loss + array([0.612192]) + """ + loss = -targets * mx.log(inputs) - (1 - targets) * mx.log(1 - inputs) + return _reduce(loss, reduction) + + +@_make_loss_module(binary_cross_entropy) +class BCELoss(Module): + """ + Binary Cross Entropy Loss module. + It computes the binary cross entropy loss between predicted probabilities (post-sigmoid inputs) and target binary labels. + + Args: + reduction (str, optional): Specifies the reduction to apply to the output: + - 'none': no reduction (default) + - 'mean': compute the mean loss + - 'sum': compute the sum of the loss + + Examples: + >>> import mlx.core as mx + >>> from mlx.nn.losses import BCELoss + >>> + >>> # Create BCELoss module with default reduction ('none') + >>> loss_module_none = BCELoss() + >>> inputs = mx.array([0.5, 0.7, 0.3]) + >>> targets = mx.array([1, 0, 1]) + >>> loss_none = loss_module_none(inputs, targets) + >>> print(loss_none) + array([0.693147, 1.20397, 1.20397], dtype=float32) + + >>> # Create BCELoss module with reduction 'mean' + >>> loss_module_mean = BCELoss(reduction='mean') + >>> loss_mean = loss_module_mean(inputs, targets) + >>> print(loss_mean) + array(1.0337, dtype=float32) + + >>> # Create BCELoss module with reduction 'sum' + >>> loss_module_sum = BCELoss(reduction='sum') + >>> loss_sum = loss_module_sum(inputs, targets) + >>> print(loss_sum) + array(3.10109, dtype=float32) + """ + + def __init__(self, reduction: str = "none"): + super().__init__() + + self.reduction = reduction + + def l1_loss( predictions: mx.array, targets: mx.array, reduction: str = "none" ) -> mx.array: @@ -122,4 +203,4 @@ def _reduce(loss: mx.array, reduction: str = "none"): elif reduction == "none": return loss else: - raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.") + raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.") \ No newline at end of file diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 707015814..54ef9b32c 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -100,6 +100,25 @@ class TestNN(mlx_tests.MLXTestCase): expected_sum = mx.sum(expected_none) self.assertTrue(mx.allclose(losses_sum, expected_sum)) + def test_binary_cross_entropy(self): + inputs = mx.array([[0.5, 0.5], [0.5, 0.5]]) + targets = mx.array([[0.0, 1.0], [1.0, 0.0]]) + + # Test with reduction 'none' + losses_none = nn.losses.binary_cross_entropy(inputs, targets, reduction="none") + expected_none = mx.array([[0.693147, 0.693147], [0.693147, 0.693147]]) + self.assertTrue(mx.array_equal(losses_none, expected_none)) + + # Test with reduction 'mean' + losses_mean = nn.losses.binary_cross_entropy(inputs, targets, reduction="mean") + expected_mean = mx.mean(expected_none) + self.assertEqual(losses_mean, expected_mean) + + # Test with reduction 'sum' + losses_sum = nn.losses.binary_cross_entropy(inputs, targets, reduction="sum") + expected_sum = mx.sum(expected_none) + self.assertEqual(losses_sum, expected_sum) + def test_gelu(self): inputs = [1.15286231, -0.81037411, 0.35816911, 0.77484438, 0.66276414]