Add Gaussian NLL loss function (#477)

* Add Gaussian NLL loss function

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
AtomicVar
2024-01-18 22:44:44 +08:00
committed by GitHub
parent 9c111f176d
commit d1fef34138
3 changed files with 126 additions and 17 deletions

View File

@@ -1,12 +1,14 @@
# Copyright © 2023 Apple Inc.
import math
from typing import Literal
import mlx.core as mx
from mlx.nn.layers.base import Module
Reduction = Literal["none", "mean", "sum"]
def _reduce(loss: mx.array, reduction: str = "none"):
def _reduce(loss: mx.array, reduction: Reduction = "none"):
if reduction == "mean":
return mx.mean(loss)
elif reduction == "sum":
@@ -23,7 +25,7 @@ def cross_entropy(
weights: mx.array = None,
axis: int = -1,
label_smoothing: float = 0.0,
reduction: str = "none",
reduction: Reduction = "none",
) -> mx.array:
"""
Computes the cross entropy loss.
@@ -72,7 +74,7 @@ def cross_entropy(
def binary_cross_entropy(
logits: mx.array, targets: mx.array, reduction: str = "none"
logits: mx.array, targets: mx.array, reduction: Reduction = "none"
) -> mx.array:
"""
Computes the binary cross entropy loss.
@@ -99,7 +101,7 @@ def binary_cross_entropy(
def l1_loss(
predictions: mx.array, targets: mx.array, reduction: str = "mean"
predictions: mx.array, targets: mx.array, reduction: Reduction = "mean"
) -> mx.array:
"""
Computes the L1 loss.
@@ -124,7 +126,7 @@ def l1_loss(
def mse_loss(
predictions: mx.array, targets: mx.array, reduction: str = "mean"
predictions: mx.array, targets: mx.array, reduction: Reduction = "mean"
) -> mx.array:
"""
Computes the mean squared error loss.
@@ -149,7 +151,7 @@ def mse_loss(
def nll_loss(
inputs: mx.array, targets: mx.array, axis: int = -1, reduction: str = "none"
inputs: mx.array, targets: mx.array, axis: int = -1, reduction: Reduction = "none"
) -> mx.array:
"""
Computes the negative log likelihood loss.
@@ -169,8 +171,63 @@ def nll_loss(
return _reduce(loss, reduction)
def gaussian_nll_loss(
inputs: mx.array,
targets: mx.array,
vars: mx.array,
full: bool = False,
eps: float = 1e-6,
reduction: Reduction = "mean",
) -> mx.array:
r"""
Computes the negative log likelihood loss for a Gaussian distribution.
The loss is given by:
.. math::
\frac{1}{2}\left(\log\left(\max\left(\text{vars},
\ \epsilon\right)\right) + \frac{\left(\text{inputs} - \text{targets} \right)^2}
{\max\left(\text{vars}, \ \epsilon \right)}\right) + \text{const.}
where ``inputs`` are the predicted means and ``vars`` are the the
predicted variances.
Args:
inputs (array): The predicted expectation of the Gaussian distribution.
targets (array): The target values (samples from the Gaussian distribution).
vars (array): The predicted variance of the Gaussian distribution.
full (bool, optional): Whether to include the constant term in the loss calculation.
Default: ``False``.
eps (float, optional): Small positive constant for numerical stability.
Default: ``1e-6``.
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
Returns:
array: The Gaussian NLL loss.
"""
if inputs.shape != targets.shape:
raise ValueError(
f"Inputs shape {inputs.shape} does not match targets shape {targets.shape}."
)
if inputs.shape != vars.shape:
raise ValueError(
f"Inputs shape {inputs.shape} does not match vars shape {vars.shape}."
)
# For stability
vars = mx.maximum(vars, eps)
loss = 0.5 * (mx.log(vars) + mx.square(targets - inputs) / vars)
if full:
loss += 0.5 * math.log(2 * math.pi)
return _reduce(loss, reduction)
def kl_div_loss(
inputs: mx.array, targets: mx.array, axis: int = -1, reduction: str = "none"
inputs: mx.array, targets: mx.array, axis: int = -1, reduction: Reduction = "none"
) -> mx.array:
"""
Computes the Kullback-Leibler divergence loss.
@@ -197,7 +254,10 @@ def kl_div_loss(
def smooth_l1_loss(
predictions: mx.array, targets: mx.array, beta: float = 1.0, reduction: str = "mean"
predictions: mx.array,
targets: mx.array,
beta: float = 1.0,
reduction: Reduction = "mean",
) -> mx.array:
r"""
Computes the smooth L1 loss.
@@ -210,7 +270,7 @@ def smooth_l1_loss(
.. math::
l =
l =
\begin{cases}
0.5 (x - y)^2, & \text{ if } & (x - y) < \beta \\
|x - y| - 0.5 \beta, & & \text{otherwise}
@@ -249,7 +309,7 @@ def triplet_loss(
p: int = 2,
margin: float = 1.0,
eps: float = 1e-6,
reduction: str = "none",
reduction: Reduction = "none",
) -> mx.array:
r"""
Computes the triplet loss for a set of anchor, positive, and negative samples.
@@ -257,7 +317,7 @@ def triplet_loss(
.. math::
L_{\text{triplet}} = \max\left(\|A - P\|_p - \|A - N\|_p + \alpha, 0\right)
\max\left(\|A - P\|_p - \|A - N\|_p + \alpha, 0\right)
Args:
anchors (array): The anchor samples.
@@ -284,7 +344,7 @@ def triplet_loss(
def hinge_loss(
inputs: mx.array, targets: mx.array, reduction: str = "none"
inputs: mx.array, targets: mx.array, reduction: Reduction = "none"
) -> mx.array:
r"""
Computes the hinge loss between inputs and targets.
@@ -309,14 +369,17 @@ def hinge_loss(
def huber_loss(
inputs: mx.array, targets: mx.array, delta: float = 1.0, reduction: str = "none"
inputs: mx.array,
targets: mx.array,
delta: float = 1.0,
reduction: Reduction = "none",
) -> mx.array:
r"""
Computes the Huber loss between inputs and targets.
.. math::
L_{\delta}(a) =
l_{\delta}(a) =
\left\{ \begin{array}{ll}
\frac{1}{2} a^2 & \text{for } |a| \leq \delta, \\
\delta \left( |a| - \frac{1}{2} \delta \right) & \text{otherwise.}
@@ -343,7 +406,7 @@ def huber_loss(
def log_cosh_loss(
inputs: mx.array, targets: mx.array, reduction: str = "none"
inputs: mx.array, targets: mx.array, reduction: Reduction = "none"
) -> mx.array:
r"""
Computes the log cosh loss between inputs and targets.
@@ -379,7 +442,7 @@ def cosine_similarity_loss(
x2: mx.array,
axis: int = 1,
eps: float = 1e-8,
reduction: str = "none",
reduction: Reduction = "none",
) -> mx.array:
r"""
Computes the cosine similarity between the two inputs.