mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Feat: Add weights argument in BCE Loss and tests (#620)
This commit is contained in:
		| @@ -117,6 +117,7 @@ def cross_entropy( | ||||
| def binary_cross_entropy( | ||||
|     inputs: mx.array, | ||||
|     targets: mx.array, | ||||
|     weights: mx.array = None, | ||||
|     with_logits: bool = True, | ||||
|     reduction: Reduction = "mean", | ||||
| ) -> mx.array: | ||||
| @@ -128,6 +129,7 @@ def binary_cross_entropy( | ||||
|             ``inputs`` are unnormalized logits. Otherwise, ``inputs`` are probabilities. | ||||
|         targets (array): The binary target values in {0, 1}. | ||||
|         with_logits (bool, optional): Whether ``inputs`` are logits. Default: ``True``. | ||||
|         weights (array, optional): Optional weights for each target. Default: ``None``. | ||||
|         reduction (str, optional): Specifies the reduction to apply to the output: | ||||
|           ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``. | ||||
|  | ||||
| @@ -159,6 +161,15 @@ def binary_cross_entropy( | ||||
|     else: | ||||
|         loss = -(targets * mx.log(inputs) + (1 - targets) * mx.log(1 - inputs)) | ||||
|  | ||||
|     # Apply weights if provided | ||||
|     if weights is not None: | ||||
|         if weights.shape != loss.shape: | ||||
|             raise ValueError( | ||||
|                 f"Weights with shape {weights.shape} is not the same as " | ||||
|                 f"output loss with shape {loss.shape}." | ||||
|             ) | ||||
|         loss *= weights | ||||
|  | ||||
|     return _reduce(loss, reduction) | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -92,6 +92,14 @@ class TestLosses(mlx_tests.MLXTestCase): | ||||
|             expected_sum = mx.sum(expected_none) | ||||
|             self.assertEqual(losses_sum, expected_sum) | ||||
|  | ||||
|             # With weights, no label smoothing | ||||
|             weights = mx.array([1.0, 2.0, 1.0, 2.0]) | ||||
|             expected = mx.array([0.747215, 1.62186, 0.262365, 0.672944]) | ||||
|             loss = nn.losses.binary_cross_entropy( | ||||
|                 logits, targets, weights=weights, reduction="none" | ||||
|             ) | ||||
|             self.assertTrue(mx.allclose(loss, expected)) | ||||
|  | ||||
|         def _test_probs_as_inputs(): | ||||
|             probs = mx.array([0.5, 0.6, 0.7, 0.8]) | ||||
|             targets = mx.array([0, 0, 1, 1]) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Aryan Gupta
					Aryan Gupta