Feat: Add weights argument in BCE Loss and tests (#620)

This commit is contained in:
Aryan Gupta 2024-02-07 23:09:52 +05:30 committed by GitHub
parent ea406d5e33
commit ef73393a19
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 19 additions and 0 deletions

View File

@ -117,6 +117,7 @@ def cross_entropy(
def binary_cross_entropy( def binary_cross_entropy(
inputs: mx.array, inputs: mx.array,
targets: mx.array, targets: mx.array,
weights: mx.array = None,
with_logits: bool = True, with_logits: bool = True,
reduction: Reduction = "mean", reduction: Reduction = "mean",
) -> mx.array: ) -> mx.array:
@ -128,6 +129,7 @@ def binary_cross_entropy(
``inputs`` are unnormalized logits. Otherwise, ``inputs`` are probabilities. ``inputs`` are unnormalized logits. Otherwise, ``inputs`` are probabilities.
targets (array): The binary target values in {0, 1}. targets (array): The binary target values in {0, 1}.
with_logits (bool, optional): Whether ``inputs`` are logits. Default: ``True``. with_logits (bool, optional): Whether ``inputs`` are logits. Default: ``True``.
weights (array, optional): Optional weights for each target. Default: ``None``.
reduction (str, optional): Specifies the reduction to apply to the output: reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``. ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``.
@ -159,6 +161,15 @@ def binary_cross_entropy(
else: else:
loss = -(targets * mx.log(inputs) + (1 - targets) * mx.log(1 - inputs)) loss = -(targets * mx.log(inputs) + (1 - targets) * mx.log(1 - inputs))
# Apply weights if provided
if weights is not None:
if weights.shape != loss.shape:
raise ValueError(
f"Weights with shape {weights.shape} is not the same as "
f"output loss with shape {loss.shape}."
)
loss *= weights
return _reduce(loss, reduction) return _reduce(loss, reduction)

View File

@ -92,6 +92,14 @@ class TestLosses(mlx_tests.MLXTestCase):
expected_sum = mx.sum(expected_none) expected_sum = mx.sum(expected_none)
self.assertEqual(losses_sum, expected_sum) self.assertEqual(losses_sum, expected_sum)
# With weights, no label smoothing
weights = mx.array([1.0, 2.0, 1.0, 2.0])
expected = mx.array([0.747215, 1.62186, 0.262365, 0.672944])
loss = nn.losses.binary_cross_entropy(
logits, targets, weights=weights, reduction="none"
)
self.assertTrue(mx.allclose(loss, expected))
def _test_probs_as_inputs(): def _test_probs_as_inputs():
probs = mx.array([0.5, 0.6, 0.7, 0.8]) probs = mx.array([0.5, 0.6, 0.7, 0.8])
targets = mx.array([0, 0, 1, 1]) targets = mx.array([0, 0, 1, 1])