Add Gaussian NLL loss function (#477)

* Add Gaussian NLL loss function

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
AtomicVar
2024-01-18 22:44:44 +08:00
committed by GitHub
parent 9c111f176d
commit d1fef34138
3 changed files with 126 additions and 17 deletions

View File

@@ -211,6 +211,51 @@ class TestLosses(mlx_tests.MLXTestCase):
expected_sum = mx.sum(expected_none)
self.assertEqual(losses_sum, expected_sum)
def test_gaussian_nll_loss(self):
inputs = mx.array([[0.1, 0.2], [0.3, 0.4]])
targets = mx.array([[0.2, 0.1], [0.1, 0.2]])
vars = mx.array([[0.1, 0.2], [0.3, 0.4]])
# Test with reduction 'none', full=False
losses_none = nn.losses.gaussian_nll_loss(
inputs, targets, vars, reduction="none"
)
expected_none = mx.array([[-1.101293, -0.779719], [-0.535320, -0.408145]])
self.assertTrue(mx.allclose(losses_none, expected_none))
# Test with reduction 'mean', full=False
losses_mean = nn.losses.gaussian_nll_loss(
inputs, targets, vars, reduction="mean"
)
expected_mean = mx.mean(expected_none)
self.assertTrue(mx.allclose(losses_mean, expected_mean))
# Test with reduction 'sum', full=False
losses_sum = nn.losses.gaussian_nll_loss(inputs, targets, vars, reduction="sum")
expected_sum = mx.sum(expected_none)
self.assertTrue(mx.allclose(losses_sum, expected_sum))
# Test with reduction='none', full=True
losses_none_full = nn.losses.gaussian_nll_loss(
inputs, targets, vars, full=True, reduction="none"
)
expected_none_full = mx.array([[-0.182354, 0.139220], [0.383619, 0.510793]])
self.assertTrue(mx.allclose(losses_none_full, expected_none_full))
# Test with reduction='mean', full=True
losses_mean_full = nn.losses.gaussian_nll_loss(
inputs, targets, vars, full=True, reduction="mean"
)
expected_mean_full = mx.mean(expected_none_full)
self.assertTrue(mx.allclose(losses_mean_full, expected_mean_full))
# Test with reduction='sum', full=True
losses_sum_full = nn.losses.gaussian_nll_loss(
inputs, targets, vars, full=True, reduction="sum"
)
expected_sum_full = mx.sum(expected_none_full)
self.assertTrue(mx.allclose(losses_sum_full, expected_sum_full))
def test_kl_div_loss(self):
p_logits = mx.log(mx.array([[0.5, 0.5], [0.8, 0.2]]))
q_logits = mx.log(mx.array([[0.5, 0.5], [0.2, 0.8]]))