mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Add Margin Ranking Loss (#536)
This commit is contained in:
@@ -359,6 +359,25 @@ class TestLosses(mlx_tests.MLXTestCase):
|
||||
expected_sum = mx.sum(expected_none)
|
||||
self.assertTrue(mx.allclose(losses_sum, expected_sum))
|
||||
|
||||
def test_margin_ranking_loss(self):
|
||||
inputs1 = mx.array([-0.573409, -0.765166, -0.0638])
|
||||
inputs2 = mx.array([0.75596, 0.225763, 0.256995])
|
||||
targets = mx.array([1, 1, -1])
|
||||
|
||||
# Test with no margin
|
||||
losses = nn.losses.margin_ranking_loss(
|
||||
inputs1, inputs2, targets, reduction="none"
|
||||
)
|
||||
expected = mx.array([1.329369, 0.990929, 0.0])
|
||||
self.assertTrue(mx.allclose(losses, expected))
|
||||
|
||||
# Test with margin
|
||||
losses = nn.losses.margin_ranking_loss(
|
||||
inputs1, inputs2, targets, margin=0.5, reduction="none"
|
||||
)
|
||||
expected = mx.array([1.829369, 1.490929, 0.179205])
|
||||
self.assertTrue(mx.allclose(losses, expected))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user