Add Margin Ranking Loss (#536)

This commit is contained in:
AtomicVar
2024-02-03 02:57:31 +08:00
committed by GitHub
parent cb6156d35d
commit 83f63f2184
3 changed files with 75 additions and 0 deletions

View File

@@ -359,6 +359,25 @@ class TestLosses(mlx_tests.MLXTestCase):
expected_sum = mx.sum(expected_none)
self.assertTrue(mx.allclose(losses_sum, expected_sum))
def test_margin_ranking_loss(self):
inputs1 = mx.array([-0.573409, -0.765166, -0.0638])
inputs2 = mx.array([0.75596, 0.225763, 0.256995])
targets = mx.array([1, 1, -1])
# Test with no margin
losses = nn.losses.margin_ranking_loss(
inputs1, inputs2, targets, reduction="none"
)
expected = mx.array([1.329369, 0.990929, 0.0])
self.assertTrue(mx.allclose(losses, expected))
# Test with margin
losses = nn.losses.margin_ranking_loss(
inputs1, inputs2, targets, margin=0.5, reduction="none"
)
expected = mx.array([1.829369, 1.490929, 0.179205])
self.assertTrue(mx.allclose(losses, expected))
if __name__ == "__main__":
unittest.main()