mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Feat: Add weights argument in BCE Loss and tests (#620)
This commit is contained in:
parent
ea406d5e33
commit
ef73393a19
@ -117,6 +117,7 @@ def cross_entropy(
|
|||||||
def binary_cross_entropy(
|
def binary_cross_entropy(
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
targets: mx.array,
|
targets: mx.array,
|
||||||
|
weights: mx.array = None,
|
||||||
with_logits: bool = True,
|
with_logits: bool = True,
|
||||||
reduction: Reduction = "mean",
|
reduction: Reduction = "mean",
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
@ -128,6 +129,7 @@ def binary_cross_entropy(
|
|||||||
``inputs`` are unnormalized logits. Otherwise, ``inputs`` are probabilities.
|
``inputs`` are unnormalized logits. Otherwise, ``inputs`` are probabilities.
|
||||||
targets (array): The binary target values in {0, 1}.
|
targets (array): The binary target values in {0, 1}.
|
||||||
with_logits (bool, optional): Whether ``inputs`` are logits. Default: ``True``.
|
with_logits (bool, optional): Whether ``inputs`` are logits. Default: ``True``.
|
||||||
|
weights (array, optional): Optional weights for each target. Default: ``None``.
|
||||||
reduction (str, optional): Specifies the reduction to apply to the output:
|
reduction (str, optional): Specifies the reduction to apply to the output:
|
||||||
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``.
|
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``.
|
||||||
|
|
||||||
@ -159,6 +161,15 @@ def binary_cross_entropy(
|
|||||||
else:
|
else:
|
||||||
loss = -(targets * mx.log(inputs) + (1 - targets) * mx.log(1 - inputs))
|
loss = -(targets * mx.log(inputs) + (1 - targets) * mx.log(1 - inputs))
|
||||||
|
|
||||||
|
# Apply weights if provided
|
||||||
|
if weights is not None:
|
||||||
|
if weights.shape != loss.shape:
|
||||||
|
raise ValueError(
|
||||||
|
f"Weights with shape {weights.shape} is not the same as "
|
||||||
|
f"output loss with shape {loss.shape}."
|
||||||
|
)
|
||||||
|
loss *= weights
|
||||||
|
|
||||||
return _reduce(loss, reduction)
|
return _reduce(loss, reduction)
|
||||||
|
|
||||||
|
|
||||||
|
@ -92,6 +92,14 @@ class TestLosses(mlx_tests.MLXTestCase):
|
|||||||
expected_sum = mx.sum(expected_none)
|
expected_sum = mx.sum(expected_none)
|
||||||
self.assertEqual(losses_sum, expected_sum)
|
self.assertEqual(losses_sum, expected_sum)
|
||||||
|
|
||||||
|
# With weights, no label smoothing
|
||||||
|
weights = mx.array([1.0, 2.0, 1.0, 2.0])
|
||||||
|
expected = mx.array([0.747215, 1.62186, 0.262365, 0.672944])
|
||||||
|
loss = nn.losses.binary_cross_entropy(
|
||||||
|
logits, targets, weights=weights, reduction="none"
|
||||||
|
)
|
||||||
|
self.assertTrue(mx.allclose(loss, expected))
|
||||||
|
|
||||||
def _test_probs_as_inputs():
|
def _test_probs_as_inputs():
|
||||||
probs = mx.array([0.5, 0.6, 0.7, 0.8])
|
probs = mx.array([0.5, 0.6, 0.7, 0.8])
|
||||||
targets = mx.array([0, 0, 1, 1])
|
targets = mx.array([0, 0, 1, 1])
|
||||||
|
Loading…
Reference in New Issue
Block a user