mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	MLE and L1 loss functions (#88)
* MLE and L1 loss functions * logsoftmax change and tests * subtract max logit for numerical stability * l1 name change * cross entropy reduction + unit tests * docstrings * l1 test name change * old loss impl + default none
This commit is contained in:
		| @@ -2,7 +2,45 @@ | |||||||
|  |  | ||||||
| import mlx.core as mx | import mlx.core as mx | ||||||
|  |  | ||||||
|  | def cross_entropy(logits: mx.array, targets: mx.array, axis: int = -1, reduction: str = 'none'): | ||||||
|  |     """ | ||||||
|  |     Computes the cross entropy loss between logits and targets. | ||||||
|  |  | ||||||
| def cross_entropy(logits: mx.array, targets: mx.array, axis: int = -1): |     Args: | ||||||
|  |         logits (mx.array): The predicted logits. | ||||||
|  |         targets (mx.array): The target values. | ||||||
|  |         axis (int, optional): The axis over which to compute softmax. Defaults to -1. | ||||||
|  |         reduction (str, optional): Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'.  | ||||||
|  |                                     'none': no reduction will be applied.  | ||||||
|  |                                     'mean': the sum of the output will be divided by the number of elements in the output. | ||||||
|  |                                     'sum': the output will be summed.  | ||||||
|  |                                     Defaults to 'none'. | ||||||
|  |  | ||||||
|  |     Returns: | ||||||
|  |         mx.array: The computed cross entropy loss. | ||||||
|  |     """ | ||||||
|     score = mx.take_along_axis(logits, targets[..., None], axis).squeeze(-1) |     score = mx.take_along_axis(logits, targets[..., None], axis).squeeze(-1) | ||||||
|     return mx.logsumexp(logits, axis=axis) - score |     loss = mx.logsumexp(logits, axis=axis) - score | ||||||
|  |  | ||||||
|  |     if reduction == 'mean': | ||||||
|  |         return mx.mean(loss) | ||||||
|  |     elif reduction == 'sum': | ||||||
|  |         return mx.sum(loss) | ||||||
|  |     elif reduction == 'none': | ||||||
|  |         return loss | ||||||
|  |     else: | ||||||
|  |         raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.") | ||||||
|  |  | ||||||
|  | def l1_loss(predictions: mx.array, targets: mx.array): | ||||||
|  |     """ | ||||||
|  |     Computes the L1 loss between predictions and targets. | ||||||
|  |  | ||||||
|  |     Args: | ||||||
|  |         predictions (mx.array): The predicted values. | ||||||
|  |         targets (mx.array): The target values. | ||||||
|  |  | ||||||
|  |     Returns: | ||||||
|  |         mx.array: The computed L1 loss. | ||||||
|  |     """ | ||||||
|  |     return mx.mean(mx.abs(predictions - targets)) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -10,7 +10,6 @@ import mlx_tests | |||||||
| import numpy as np | import numpy as np | ||||||
| from mlx.utils import tree_flatten, tree_map, tree_unflatten | from mlx.utils import tree_flatten, tree_map, tree_unflatten | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestNN(mlx_tests.MLXTestCase): | class TestNN(mlx_tests.MLXTestCase): | ||||||
|     def test_linear(self): |     def test_linear(self): | ||||||
|         inputs = mx.zeros((10, 4)) |         inputs = mx.zeros((10, 4)) | ||||||
| @@ -21,8 +20,27 @@ class TestNN(mlx_tests.MLXTestCase): | |||||||
|     def test_cross_entropy(self): |     def test_cross_entropy(self): | ||||||
|         logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]]) |         logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]]) | ||||||
|         targets = mx.array([0, 1]) |         targets = mx.array([0, 1]) | ||||||
|         losses = nn.losses.cross_entropy(logits, targets) |  | ||||||
|         self.assertTrue(mx.array_equal(losses, mx.zeros((2,)))) |         # Test with reduction 'none' | ||||||
|  |         losses_none = nn.losses.cross_entropy(logits, targets, reduction='none') | ||||||
|  |         expected_none = mx.array([0.0, 0.0]) | ||||||
|  |         self.assertTrue(mx.array_equal(losses_none, expected_none)) | ||||||
|  |  | ||||||
|  |         # Test with reduction 'mean' | ||||||
|  |         losses_mean = nn.losses.cross_entropy(logits, targets, reduction='mean') | ||||||
|  |         expected_mean = mx.mean(expected_none) | ||||||
|  |         self.assertEqual(losses_mean, expected_mean) | ||||||
|  |  | ||||||
|  |         # Test with reduction 'sum' | ||||||
|  |         losses_sum = nn.losses.cross_entropy(logits, targets, reduction='sum') | ||||||
|  |         expected_sum = mx.sum(expected_none) | ||||||
|  |         self.assertEqual(losses_sum, expected_sum) | ||||||
|  |  | ||||||
|  |     def test_l1_loss(self): | ||||||
|  |         predictions = mx.array([0.5, 0.2, 0.9, 0.0]) | ||||||
|  |         targets = mx.array([0.5, 0.2, 0.9, 0.0]) | ||||||
|  |         losses = nn.losses.l1_loss(predictions, targets) | ||||||
|  |         self.assertEqual(losses, 0.0) | ||||||
|  |  | ||||||
|     def test_gelu(self): |     def test_gelu(self): | ||||||
|         inputs = [1.15286231, -0.81037411, 0.35816911, 0.77484438, 0.66276414] |         inputs = [1.15286231, -0.81037411, 0.35816911, 0.77484438, 0.66276414] | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Kai Ma
					Kai Ma