mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Feat: Add weights argument in BCE Loss and tests (#620)
This commit is contained in:
@@ -92,6 +92,14 @@ class TestLosses(mlx_tests.MLXTestCase):
|
||||
expected_sum = mx.sum(expected_none)
|
||||
self.assertEqual(losses_sum, expected_sum)
|
||||
|
||||
# With weights, no label smoothing
|
||||
weights = mx.array([1.0, 2.0, 1.0, 2.0])
|
||||
expected = mx.array([0.747215, 1.62186, 0.262365, 0.672944])
|
||||
loss = nn.losses.binary_cross_entropy(
|
||||
logits, targets, weights=weights, reduction="none"
|
||||
)
|
||||
self.assertTrue(mx.allclose(loss, expected))
|
||||
|
||||
def _test_probs_as_inputs():
|
||||
probs = mx.array([0.5, 0.6, 0.7, 0.8])
|
||||
targets = mx.array([0, 0, 1, 1])
|
||||
|
Reference in New Issue
Block a user