mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Add Gaussian NLL loss function (#477)
* Add Gaussian NLL loss function --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
9c111f176d
commit
d1fef34138
@ -12,6 +12,7 @@ Loss Functions
|
|||||||
binary_cross_entropy
|
binary_cross_entropy
|
||||||
cosine_similarity_loss
|
cosine_similarity_loss
|
||||||
cross_entropy
|
cross_entropy
|
||||||
|
gaussian_nll_loss
|
||||||
hinge_loss
|
hinge_loss
|
||||||
huber_loss
|
huber_loss
|
||||||
kl_div_loss
|
kl_div_loss
|
||||||
|
@ -1,12 +1,14 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx.nn.layers.base import Module
|
|
||||||
|
Reduction = Literal["none", "mean", "sum"]
|
||||||
|
|
||||||
|
|
||||||
def _reduce(loss: mx.array, reduction: str = "none"):
|
def _reduce(loss: mx.array, reduction: Reduction = "none"):
|
||||||
if reduction == "mean":
|
if reduction == "mean":
|
||||||
return mx.mean(loss)
|
return mx.mean(loss)
|
||||||
elif reduction == "sum":
|
elif reduction == "sum":
|
||||||
@ -23,7 +25,7 @@ def cross_entropy(
|
|||||||
weights: mx.array = None,
|
weights: mx.array = None,
|
||||||
axis: int = -1,
|
axis: int = -1,
|
||||||
label_smoothing: float = 0.0,
|
label_smoothing: float = 0.0,
|
||||||
reduction: str = "none",
|
reduction: Reduction = "none",
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
"""
|
"""
|
||||||
Computes the cross entropy loss.
|
Computes the cross entropy loss.
|
||||||
@ -72,7 +74,7 @@ def cross_entropy(
|
|||||||
|
|
||||||
|
|
||||||
def binary_cross_entropy(
|
def binary_cross_entropy(
|
||||||
logits: mx.array, targets: mx.array, reduction: str = "none"
|
logits: mx.array, targets: mx.array, reduction: Reduction = "none"
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
"""
|
"""
|
||||||
Computes the binary cross entropy loss.
|
Computes the binary cross entropy loss.
|
||||||
@ -99,7 +101,7 @@ def binary_cross_entropy(
|
|||||||
|
|
||||||
|
|
||||||
def l1_loss(
|
def l1_loss(
|
||||||
predictions: mx.array, targets: mx.array, reduction: str = "mean"
|
predictions: mx.array, targets: mx.array, reduction: Reduction = "mean"
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
"""
|
"""
|
||||||
Computes the L1 loss.
|
Computes the L1 loss.
|
||||||
@ -124,7 +126,7 @@ def l1_loss(
|
|||||||
|
|
||||||
|
|
||||||
def mse_loss(
|
def mse_loss(
|
||||||
predictions: mx.array, targets: mx.array, reduction: str = "mean"
|
predictions: mx.array, targets: mx.array, reduction: Reduction = "mean"
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
"""
|
"""
|
||||||
Computes the mean squared error loss.
|
Computes the mean squared error loss.
|
||||||
@ -149,7 +151,7 @@ def mse_loss(
|
|||||||
|
|
||||||
|
|
||||||
def nll_loss(
|
def nll_loss(
|
||||||
inputs: mx.array, targets: mx.array, axis: int = -1, reduction: str = "none"
|
inputs: mx.array, targets: mx.array, axis: int = -1, reduction: Reduction = "none"
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
"""
|
"""
|
||||||
Computes the negative log likelihood loss.
|
Computes the negative log likelihood loss.
|
||||||
@ -169,8 +171,63 @@ def nll_loss(
|
|||||||
return _reduce(loss, reduction)
|
return _reduce(loss, reduction)
|
||||||
|
|
||||||
|
|
||||||
|
def gaussian_nll_loss(
|
||||||
|
inputs: mx.array,
|
||||||
|
targets: mx.array,
|
||||||
|
vars: mx.array,
|
||||||
|
full: bool = False,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
reduction: Reduction = "mean",
|
||||||
|
) -> mx.array:
|
||||||
|
r"""
|
||||||
|
Computes the negative log likelihood loss for a Gaussian distribution.
|
||||||
|
|
||||||
|
The loss is given by:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\frac{1}{2}\left(\log\left(\max\left(\text{vars},
|
||||||
|
\ \epsilon\right)\right) + \frac{\left(\text{inputs} - \text{targets} \right)^2}
|
||||||
|
{\max\left(\text{vars}, \ \epsilon \right)}\right) + \text{const.}
|
||||||
|
|
||||||
|
where ``inputs`` are the predicted means and ``vars`` are the the
|
||||||
|
predicted variances.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (array): The predicted expectation of the Gaussian distribution.
|
||||||
|
targets (array): The target values (samples from the Gaussian distribution).
|
||||||
|
vars (array): The predicted variance of the Gaussian distribution.
|
||||||
|
full (bool, optional): Whether to include the constant term in the loss calculation.
|
||||||
|
Default: ``False``.
|
||||||
|
eps (float, optional): Small positive constant for numerical stability.
|
||||||
|
Default: ``1e-6``.
|
||||||
|
reduction (str, optional): Specifies the reduction to apply to the output:
|
||||||
|
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The Gaussian NLL loss.
|
||||||
|
"""
|
||||||
|
if inputs.shape != targets.shape:
|
||||||
|
raise ValueError(
|
||||||
|
f"Inputs shape {inputs.shape} does not match targets shape {targets.shape}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if inputs.shape != vars.shape:
|
||||||
|
raise ValueError(
|
||||||
|
f"Inputs shape {inputs.shape} does not match vars shape {vars.shape}."
|
||||||
|
)
|
||||||
|
|
||||||
|
# For stability
|
||||||
|
vars = mx.maximum(vars, eps)
|
||||||
|
loss = 0.5 * (mx.log(vars) + mx.square(targets - inputs) / vars)
|
||||||
|
|
||||||
|
if full:
|
||||||
|
loss += 0.5 * math.log(2 * math.pi)
|
||||||
|
|
||||||
|
return _reduce(loss, reduction)
|
||||||
|
|
||||||
|
|
||||||
def kl_div_loss(
|
def kl_div_loss(
|
||||||
inputs: mx.array, targets: mx.array, axis: int = -1, reduction: str = "none"
|
inputs: mx.array, targets: mx.array, axis: int = -1, reduction: Reduction = "none"
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
"""
|
"""
|
||||||
Computes the Kullback-Leibler divergence loss.
|
Computes the Kullback-Leibler divergence loss.
|
||||||
@ -197,7 +254,10 @@ def kl_div_loss(
|
|||||||
|
|
||||||
|
|
||||||
def smooth_l1_loss(
|
def smooth_l1_loss(
|
||||||
predictions: mx.array, targets: mx.array, beta: float = 1.0, reduction: str = "mean"
|
predictions: mx.array,
|
||||||
|
targets: mx.array,
|
||||||
|
beta: float = 1.0,
|
||||||
|
reduction: Reduction = "mean",
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r"""
|
r"""
|
||||||
Computes the smooth L1 loss.
|
Computes the smooth L1 loss.
|
||||||
@ -210,7 +270,7 @@ def smooth_l1_loss(
|
|||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
|
|
||||||
l =
|
l =
|
||||||
\begin{cases}
|
\begin{cases}
|
||||||
0.5 (x - y)^2, & \text{ if } & (x - y) < \beta \\
|
0.5 (x - y)^2, & \text{ if } & (x - y) < \beta \\
|
||||||
|x - y| - 0.5 \beta, & & \text{otherwise}
|
|x - y| - 0.5 \beta, & & \text{otherwise}
|
||||||
@ -249,7 +309,7 @@ def triplet_loss(
|
|||||||
p: int = 2,
|
p: int = 2,
|
||||||
margin: float = 1.0,
|
margin: float = 1.0,
|
||||||
eps: float = 1e-6,
|
eps: float = 1e-6,
|
||||||
reduction: str = "none",
|
reduction: Reduction = "none",
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r"""
|
r"""
|
||||||
Computes the triplet loss for a set of anchor, positive, and negative samples.
|
Computes the triplet loss for a set of anchor, positive, and negative samples.
|
||||||
@ -257,7 +317,7 @@ def triplet_loss(
|
|||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
|
|
||||||
L_{\text{triplet}} = \max\left(\|A - P\|_p - \|A - N\|_p + \alpha, 0\right)
|
\max\left(\|A - P\|_p - \|A - N\|_p + \alpha, 0\right)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
anchors (array): The anchor samples.
|
anchors (array): The anchor samples.
|
||||||
@ -284,7 +344,7 @@ def triplet_loss(
|
|||||||
|
|
||||||
|
|
||||||
def hinge_loss(
|
def hinge_loss(
|
||||||
inputs: mx.array, targets: mx.array, reduction: str = "none"
|
inputs: mx.array, targets: mx.array, reduction: Reduction = "none"
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r"""
|
r"""
|
||||||
Computes the hinge loss between inputs and targets.
|
Computes the hinge loss between inputs and targets.
|
||||||
@ -309,14 +369,17 @@ def hinge_loss(
|
|||||||
|
|
||||||
|
|
||||||
def huber_loss(
|
def huber_loss(
|
||||||
inputs: mx.array, targets: mx.array, delta: float = 1.0, reduction: str = "none"
|
inputs: mx.array,
|
||||||
|
targets: mx.array,
|
||||||
|
delta: float = 1.0,
|
||||||
|
reduction: Reduction = "none",
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r"""
|
r"""
|
||||||
Computes the Huber loss between inputs and targets.
|
Computes the Huber loss between inputs and targets.
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
|
|
||||||
L_{\delta}(a) =
|
l_{\delta}(a) =
|
||||||
\left\{ \begin{array}{ll}
|
\left\{ \begin{array}{ll}
|
||||||
\frac{1}{2} a^2 & \text{for } |a| \leq \delta, \\
|
\frac{1}{2} a^2 & \text{for } |a| \leq \delta, \\
|
||||||
\delta \left( |a| - \frac{1}{2} \delta \right) & \text{otherwise.}
|
\delta \left( |a| - \frac{1}{2} \delta \right) & \text{otherwise.}
|
||||||
@ -343,7 +406,7 @@ def huber_loss(
|
|||||||
|
|
||||||
|
|
||||||
def log_cosh_loss(
|
def log_cosh_loss(
|
||||||
inputs: mx.array, targets: mx.array, reduction: str = "none"
|
inputs: mx.array, targets: mx.array, reduction: Reduction = "none"
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r"""
|
r"""
|
||||||
Computes the log cosh loss between inputs and targets.
|
Computes the log cosh loss between inputs and targets.
|
||||||
@ -379,7 +442,7 @@ def cosine_similarity_loss(
|
|||||||
x2: mx.array,
|
x2: mx.array,
|
||||||
axis: int = 1,
|
axis: int = 1,
|
||||||
eps: float = 1e-8,
|
eps: float = 1e-8,
|
||||||
reduction: str = "none",
|
reduction: Reduction = "none",
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r"""
|
r"""
|
||||||
Computes the cosine similarity between the two inputs.
|
Computes the cosine similarity between the two inputs.
|
||||||
|
@ -211,6 +211,51 @@ class TestLosses(mlx_tests.MLXTestCase):
|
|||||||
expected_sum = mx.sum(expected_none)
|
expected_sum = mx.sum(expected_none)
|
||||||
self.assertEqual(losses_sum, expected_sum)
|
self.assertEqual(losses_sum, expected_sum)
|
||||||
|
|
||||||
|
def test_gaussian_nll_loss(self):
|
||||||
|
inputs = mx.array([[0.1, 0.2], [0.3, 0.4]])
|
||||||
|
targets = mx.array([[0.2, 0.1], [0.1, 0.2]])
|
||||||
|
vars = mx.array([[0.1, 0.2], [0.3, 0.4]])
|
||||||
|
|
||||||
|
# Test with reduction 'none', full=False
|
||||||
|
losses_none = nn.losses.gaussian_nll_loss(
|
||||||
|
inputs, targets, vars, reduction="none"
|
||||||
|
)
|
||||||
|
expected_none = mx.array([[-1.101293, -0.779719], [-0.535320, -0.408145]])
|
||||||
|
self.assertTrue(mx.allclose(losses_none, expected_none))
|
||||||
|
|
||||||
|
# Test with reduction 'mean', full=False
|
||||||
|
losses_mean = nn.losses.gaussian_nll_loss(
|
||||||
|
inputs, targets, vars, reduction="mean"
|
||||||
|
)
|
||||||
|
expected_mean = mx.mean(expected_none)
|
||||||
|
self.assertTrue(mx.allclose(losses_mean, expected_mean))
|
||||||
|
|
||||||
|
# Test with reduction 'sum', full=False
|
||||||
|
losses_sum = nn.losses.gaussian_nll_loss(inputs, targets, vars, reduction="sum")
|
||||||
|
expected_sum = mx.sum(expected_none)
|
||||||
|
self.assertTrue(mx.allclose(losses_sum, expected_sum))
|
||||||
|
|
||||||
|
# Test with reduction='none', full=True
|
||||||
|
losses_none_full = nn.losses.gaussian_nll_loss(
|
||||||
|
inputs, targets, vars, full=True, reduction="none"
|
||||||
|
)
|
||||||
|
expected_none_full = mx.array([[-0.182354, 0.139220], [0.383619, 0.510793]])
|
||||||
|
self.assertTrue(mx.allclose(losses_none_full, expected_none_full))
|
||||||
|
|
||||||
|
# Test with reduction='mean', full=True
|
||||||
|
losses_mean_full = nn.losses.gaussian_nll_loss(
|
||||||
|
inputs, targets, vars, full=True, reduction="mean"
|
||||||
|
)
|
||||||
|
expected_mean_full = mx.mean(expected_none_full)
|
||||||
|
self.assertTrue(mx.allclose(losses_mean_full, expected_mean_full))
|
||||||
|
|
||||||
|
# Test with reduction='sum', full=True
|
||||||
|
losses_sum_full = nn.losses.gaussian_nll_loss(
|
||||||
|
inputs, targets, vars, full=True, reduction="sum"
|
||||||
|
)
|
||||||
|
expected_sum_full = mx.sum(expected_none_full)
|
||||||
|
self.assertTrue(mx.allclose(losses_sum_full, expected_sum_full))
|
||||||
|
|
||||||
def test_kl_div_loss(self):
|
def test_kl_div_loss(self):
|
||||||
p_logits = mx.log(mx.array([[0.5, 0.5], [0.8, 0.2]]))
|
p_logits = mx.log(mx.array([[0.5, 0.5], [0.8, 0.2]]))
|
||||||
q_logits = mx.log(mx.array([[0.5, 0.5], [0.2, 0.8]]))
|
q_logits = mx.log(mx.array([[0.5, 0.5], [0.2, 0.8]]))
|
||||||
|
Loading…
Reference in New Issue
Block a user