Feat: Add weights argument in BCE Loss and tests (#620)

This commit is contained in:
Aryan Gupta
2024-02-07 23:09:52 +05:30
committed by GitHub
parent ea406d5e33
commit ef73393a19
2 changed files with 19 additions and 0 deletions

View File

@@ -92,6 +92,14 @@ class TestLosses(mlx_tests.MLXTestCase):
expected_sum = mx.sum(expected_none)
self.assertEqual(losses_sum, expected_sum)
# With weights, no label smoothing
weights = mx.array([1.0, 2.0, 1.0, 2.0])
expected = mx.array([0.747215, 1.62186, 0.262365, 0.672944])
loss = nn.losses.binary_cross_entropy(
logits, targets, weights=weights, reduction="none"
)
self.assertTrue(mx.allclose(loss, expected))
def _test_probs_as_inputs():
probs = mx.array([0.5, 0.6, 0.7, 0.8])
targets = mx.array([0, 0, 1, 1])