Update binary_cross_entropy function to handle both logits and probabilities (#492)

This commit is contained in:
AtomicVar
2024-01-19 11:22:23 +08:00
committed by GitHub
parent f6e911ced0
commit 550d4bf7c0
2 changed files with 83 additions and 7 deletions

View File

@@ -105,6 +105,61 @@ class TestLosses(mlx_tests.MLXTestCase):
"Test case failed for cross_entropy --label_smoothing=0.3 --reduction='sum'",
)
def test_binary_cross_entropy(self):
def _test_logits_as_inputs():
logits = mx.array([0.105361, 0.223144, 1.20397, 0.916291])
targets = mx.array([0, 0, 1, 1])
# Test with reduction 'none'
losses_none = nn.losses.binary_cross_entropy(
logits, targets, reduction="none"
)
expected_none = mx.array([0.747215, 0.810930, 0.262365, 0.336472])
self.assertTrue(mx.allclose(losses_none, expected_none))
# Test with reduction 'mean'
losses_mean = nn.losses.binary_cross_entropy(
logits, targets, reduction="mean"
)
expected_mean = mx.mean(expected_none)
self.assertEqual(losses_mean, expected_mean)
# Test with reduction 'sum'
losses_sum = nn.losses.binary_cross_entropy(
logits, targets, reduction="sum"
)
expected_sum = mx.sum(expected_none)
self.assertEqual(losses_sum, expected_sum)
def _test_probs_as_inputs():
probs = mx.array([0.5, 0.6, 0.7, 0.8])
targets = mx.array([0, 0, 1, 1])
# Test with reduction 'none'
losses_none = nn.losses.binary_cross_entropy(
probs, targets, with_logits=False, reduction="none"
)
expected_none = mx.array([0.693147, 0.916291, 0.356675, 0.223144])
print(losses_none, expected_none)
self.assertTrue(mx.allclose(losses_none, expected_none))
# Test with reduction 'mean'
losses_mean = nn.losses.binary_cross_entropy(
probs, targets, with_logits=False, reduction="mean"
)
expected_mean = mx.mean(expected_none)
self.assertTrue(mx.allclose(losses_mean, expected_mean))
# Test with reduction 'sum'
losses_sum = nn.losses.binary_cross_entropy(
probs, targets, with_logits=False, reduction="sum"
)
expected_sum = mx.sum(expected_none)
self.assertTrue(mx.allclose(losses_sum, expected_sum))
_test_logits_as_inputs()
_test_probs_as_inputs()
def test_l1_loss(self):
predictions = mx.array([0.5, 0.2, 0.9, 0.0])
targets = mx.array([0.5, 0.2, 0.9, 0.0])