mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Add Gaussian NLL loss function (#477)
* Add Gaussian NLL loss function --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
		| @@ -211,6 +211,51 @@ class TestLosses(mlx_tests.MLXTestCase): | ||||
|         expected_sum = mx.sum(expected_none) | ||||
|         self.assertEqual(losses_sum, expected_sum) | ||||
|  | ||||
|     def test_gaussian_nll_loss(self): | ||||
|         inputs = mx.array([[0.1, 0.2], [0.3, 0.4]]) | ||||
|         targets = mx.array([[0.2, 0.1], [0.1, 0.2]]) | ||||
|         vars = mx.array([[0.1, 0.2], [0.3, 0.4]]) | ||||
|  | ||||
|         # Test with reduction 'none', full=False | ||||
|         losses_none = nn.losses.gaussian_nll_loss( | ||||
|             inputs, targets, vars, reduction="none" | ||||
|         ) | ||||
|         expected_none = mx.array([[-1.101293, -0.779719], [-0.535320, -0.408145]]) | ||||
|         self.assertTrue(mx.allclose(losses_none, expected_none)) | ||||
|  | ||||
|         # Test with reduction 'mean', full=False | ||||
|         losses_mean = nn.losses.gaussian_nll_loss( | ||||
|             inputs, targets, vars, reduction="mean" | ||||
|         ) | ||||
|         expected_mean = mx.mean(expected_none) | ||||
|         self.assertTrue(mx.allclose(losses_mean, expected_mean)) | ||||
|  | ||||
|         # Test with reduction 'sum', full=False | ||||
|         losses_sum = nn.losses.gaussian_nll_loss(inputs, targets, vars, reduction="sum") | ||||
|         expected_sum = mx.sum(expected_none) | ||||
|         self.assertTrue(mx.allclose(losses_sum, expected_sum)) | ||||
|  | ||||
|         # Test with reduction='none', full=True | ||||
|         losses_none_full = nn.losses.gaussian_nll_loss( | ||||
|             inputs, targets, vars, full=True, reduction="none" | ||||
|         ) | ||||
|         expected_none_full = mx.array([[-0.182354, 0.139220], [0.383619, 0.510793]]) | ||||
|         self.assertTrue(mx.allclose(losses_none_full, expected_none_full)) | ||||
|  | ||||
|         # Test with reduction='mean', full=True | ||||
|         losses_mean_full = nn.losses.gaussian_nll_loss( | ||||
|             inputs, targets, vars, full=True, reduction="mean" | ||||
|         ) | ||||
|         expected_mean_full = mx.mean(expected_none_full) | ||||
|         self.assertTrue(mx.allclose(losses_mean_full, expected_mean_full)) | ||||
|  | ||||
|         # Test with reduction='sum', full=True | ||||
|         losses_sum_full = nn.losses.gaussian_nll_loss( | ||||
|             inputs, targets, vars, full=True, reduction="sum" | ||||
|         ) | ||||
|         expected_sum_full = mx.sum(expected_none_full) | ||||
|         self.assertTrue(mx.allclose(losses_sum_full, expected_sum_full)) | ||||
|  | ||||
|     def test_kl_div_loss(self): | ||||
|         p_logits = mx.log(mx.array([[0.5, 0.5], [0.8, 0.2]])) | ||||
|         q_logits = mx.log(mx.array([[0.5, 0.5], [0.2, 0.8]])) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 AtomicVar
					AtomicVar