Add Gaussian NLL loss function (#477)

* Add Gaussian NLL loss function

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
AtomicVar 2024-01-18 22:44:44 +08:00 committed by GitHub
parent 9c111f176d
commit d1fef34138
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 126 additions and 17 deletions

View File

@ -12,6 +12,7 @@ Loss Functions
binary_cross_entropy binary_cross_entropy
cosine_similarity_loss cosine_similarity_loss
cross_entropy cross_entropy
gaussian_nll_loss
hinge_loss hinge_loss
huber_loss huber_loss
kl_div_loss kl_div_loss

View File

@ -1,12 +1,14 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import math import math
from typing import Literal
import mlx.core as mx import mlx.core as mx
from mlx.nn.layers.base import Module
Reduction = Literal["none", "mean", "sum"]
def _reduce(loss: mx.array, reduction: str = "none"): def _reduce(loss: mx.array, reduction: Reduction = "none"):
if reduction == "mean": if reduction == "mean":
return mx.mean(loss) return mx.mean(loss)
elif reduction == "sum": elif reduction == "sum":
@ -23,7 +25,7 @@ def cross_entropy(
weights: mx.array = None, weights: mx.array = None,
axis: int = -1, axis: int = -1,
label_smoothing: float = 0.0, label_smoothing: float = 0.0,
reduction: str = "none", reduction: Reduction = "none",
) -> mx.array: ) -> mx.array:
""" """
Computes the cross entropy loss. Computes the cross entropy loss.
@ -72,7 +74,7 @@ def cross_entropy(
def binary_cross_entropy( def binary_cross_entropy(
logits: mx.array, targets: mx.array, reduction: str = "none" logits: mx.array, targets: mx.array, reduction: Reduction = "none"
) -> mx.array: ) -> mx.array:
""" """
Computes the binary cross entropy loss. Computes the binary cross entropy loss.
@ -99,7 +101,7 @@ def binary_cross_entropy(
def l1_loss( def l1_loss(
predictions: mx.array, targets: mx.array, reduction: str = "mean" predictions: mx.array, targets: mx.array, reduction: Reduction = "mean"
) -> mx.array: ) -> mx.array:
""" """
Computes the L1 loss. Computes the L1 loss.
@ -124,7 +126,7 @@ def l1_loss(
def mse_loss( def mse_loss(
predictions: mx.array, targets: mx.array, reduction: str = "mean" predictions: mx.array, targets: mx.array, reduction: Reduction = "mean"
) -> mx.array: ) -> mx.array:
""" """
Computes the mean squared error loss. Computes the mean squared error loss.
@ -149,7 +151,7 @@ def mse_loss(
def nll_loss( def nll_loss(
inputs: mx.array, targets: mx.array, axis: int = -1, reduction: str = "none" inputs: mx.array, targets: mx.array, axis: int = -1, reduction: Reduction = "none"
) -> mx.array: ) -> mx.array:
""" """
Computes the negative log likelihood loss. Computes the negative log likelihood loss.
@ -169,8 +171,63 @@ def nll_loss(
return _reduce(loss, reduction) return _reduce(loss, reduction)
def gaussian_nll_loss(
inputs: mx.array,
targets: mx.array,
vars: mx.array,
full: bool = False,
eps: float = 1e-6,
reduction: Reduction = "mean",
) -> mx.array:
r"""
Computes the negative log likelihood loss for a Gaussian distribution.
The loss is given by:
.. math::
\frac{1}{2}\left(\log\left(\max\left(\text{vars},
\ \epsilon\right)\right) + \frac{\left(\text{inputs} - \text{targets} \right)^2}
{\max\left(\text{vars}, \ \epsilon \right)}\right) + \text{const.}
where ``inputs`` are the predicted means and ``vars`` are the the
predicted variances.
Args:
inputs (array): The predicted expectation of the Gaussian distribution.
targets (array): The target values (samples from the Gaussian distribution).
vars (array): The predicted variance of the Gaussian distribution.
full (bool, optional): Whether to include the constant term in the loss calculation.
Default: ``False``.
eps (float, optional): Small positive constant for numerical stability.
Default: ``1e-6``.
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
Returns:
array: The Gaussian NLL loss.
"""
if inputs.shape != targets.shape:
raise ValueError(
f"Inputs shape {inputs.shape} does not match targets shape {targets.shape}."
)
if inputs.shape != vars.shape:
raise ValueError(
f"Inputs shape {inputs.shape} does not match vars shape {vars.shape}."
)
# For stability
vars = mx.maximum(vars, eps)
loss = 0.5 * (mx.log(vars) + mx.square(targets - inputs) / vars)
if full:
loss += 0.5 * math.log(2 * math.pi)
return _reduce(loss, reduction)
def kl_div_loss( def kl_div_loss(
inputs: mx.array, targets: mx.array, axis: int = -1, reduction: str = "none" inputs: mx.array, targets: mx.array, axis: int = -1, reduction: Reduction = "none"
) -> mx.array: ) -> mx.array:
""" """
Computes the Kullback-Leibler divergence loss. Computes the Kullback-Leibler divergence loss.
@ -197,7 +254,10 @@ def kl_div_loss(
def smooth_l1_loss( def smooth_l1_loss(
predictions: mx.array, targets: mx.array, beta: float = 1.0, reduction: str = "mean" predictions: mx.array,
targets: mx.array,
beta: float = 1.0,
reduction: Reduction = "mean",
) -> mx.array: ) -> mx.array:
r""" r"""
Computes the smooth L1 loss. Computes the smooth L1 loss.
@ -249,7 +309,7 @@ def triplet_loss(
p: int = 2, p: int = 2,
margin: float = 1.0, margin: float = 1.0,
eps: float = 1e-6, eps: float = 1e-6,
reduction: str = "none", reduction: Reduction = "none",
) -> mx.array: ) -> mx.array:
r""" r"""
Computes the triplet loss for a set of anchor, positive, and negative samples. Computes the triplet loss for a set of anchor, positive, and negative samples.
@ -257,7 +317,7 @@ def triplet_loss(
.. math:: .. math::
L_{\text{triplet}} = \max\left(\|A - P\|_p - \|A - N\|_p + \alpha, 0\right) \max\left(\|A - P\|_p - \|A - N\|_p + \alpha, 0\right)
Args: Args:
anchors (array): The anchor samples. anchors (array): The anchor samples.
@ -284,7 +344,7 @@ def triplet_loss(
def hinge_loss( def hinge_loss(
inputs: mx.array, targets: mx.array, reduction: str = "none" inputs: mx.array, targets: mx.array, reduction: Reduction = "none"
) -> mx.array: ) -> mx.array:
r""" r"""
Computes the hinge loss between inputs and targets. Computes the hinge loss between inputs and targets.
@ -309,14 +369,17 @@ def hinge_loss(
def huber_loss( def huber_loss(
inputs: mx.array, targets: mx.array, delta: float = 1.0, reduction: str = "none" inputs: mx.array,
targets: mx.array,
delta: float = 1.0,
reduction: Reduction = "none",
) -> mx.array: ) -> mx.array:
r""" r"""
Computes the Huber loss between inputs and targets. Computes the Huber loss between inputs and targets.
.. math:: .. math::
L_{\delta}(a) = l_{\delta}(a) =
\left\{ \begin{array}{ll} \left\{ \begin{array}{ll}
\frac{1}{2} a^2 & \text{for } |a| \leq \delta, \\ \frac{1}{2} a^2 & \text{for } |a| \leq \delta, \\
\delta \left( |a| - \frac{1}{2} \delta \right) & \text{otherwise.} \delta \left( |a| - \frac{1}{2} \delta \right) & \text{otherwise.}
@ -343,7 +406,7 @@ def huber_loss(
def log_cosh_loss( def log_cosh_loss(
inputs: mx.array, targets: mx.array, reduction: str = "none" inputs: mx.array, targets: mx.array, reduction: Reduction = "none"
) -> mx.array: ) -> mx.array:
r""" r"""
Computes the log cosh loss between inputs and targets. Computes the log cosh loss between inputs and targets.
@ -379,7 +442,7 @@ def cosine_similarity_loss(
x2: mx.array, x2: mx.array,
axis: int = 1, axis: int = 1,
eps: float = 1e-8, eps: float = 1e-8,
reduction: str = "none", reduction: Reduction = "none",
) -> mx.array: ) -> mx.array:
r""" r"""
Computes the cosine similarity between the two inputs. Computes the cosine similarity between the two inputs.

View File

@ -211,6 +211,51 @@ class TestLosses(mlx_tests.MLXTestCase):
expected_sum = mx.sum(expected_none) expected_sum = mx.sum(expected_none)
self.assertEqual(losses_sum, expected_sum) self.assertEqual(losses_sum, expected_sum)
def test_gaussian_nll_loss(self):
inputs = mx.array([[0.1, 0.2], [0.3, 0.4]])
targets = mx.array([[0.2, 0.1], [0.1, 0.2]])
vars = mx.array([[0.1, 0.2], [0.3, 0.4]])
# Test with reduction 'none', full=False
losses_none = nn.losses.gaussian_nll_loss(
inputs, targets, vars, reduction="none"
)
expected_none = mx.array([[-1.101293, -0.779719], [-0.535320, -0.408145]])
self.assertTrue(mx.allclose(losses_none, expected_none))
# Test with reduction 'mean', full=False
losses_mean = nn.losses.gaussian_nll_loss(
inputs, targets, vars, reduction="mean"
)
expected_mean = mx.mean(expected_none)
self.assertTrue(mx.allclose(losses_mean, expected_mean))
# Test with reduction 'sum', full=False
losses_sum = nn.losses.gaussian_nll_loss(inputs, targets, vars, reduction="sum")
expected_sum = mx.sum(expected_none)
self.assertTrue(mx.allclose(losses_sum, expected_sum))
# Test with reduction='none', full=True
losses_none_full = nn.losses.gaussian_nll_loss(
inputs, targets, vars, full=True, reduction="none"
)
expected_none_full = mx.array([[-0.182354, 0.139220], [0.383619, 0.510793]])
self.assertTrue(mx.allclose(losses_none_full, expected_none_full))
# Test with reduction='mean', full=True
losses_mean_full = nn.losses.gaussian_nll_loss(
inputs, targets, vars, full=True, reduction="mean"
)
expected_mean_full = mx.mean(expected_none_full)
self.assertTrue(mx.allclose(losses_mean_full, expected_mean_full))
# Test with reduction='sum', full=True
losses_sum_full = nn.losses.gaussian_nll_loss(
inputs, targets, vars, full=True, reduction="sum"
)
expected_sum_full = mx.sum(expected_none_full)
self.assertTrue(mx.allclose(losses_sum_full, expected_sum_full))
def test_kl_div_loss(self): def test_kl_div_loss(self):
p_logits = mx.log(mx.array([[0.5, 0.5], [0.8, 0.2]])) p_logits = mx.log(mx.array([[0.5, 0.5], [0.8, 0.2]]))
q_logits = mx.log(mx.array([[0.5, 0.5], [0.2, 0.8]])) q_logits = mx.log(mx.array([[0.5, 0.5], [0.2, 0.8]]))