add losses to the docs, fix black failur (#92)

This commit is contained in:
Awni Hannun 2023-12-09 06:06:52 -08:00 committed by GitHub
parent 430bfb4944
commit 2520dbcf0a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 6 deletions

View File

@ -170,3 +170,13 @@ simple functions.
gelu_fast_approx
relu
silu
Loss Functions
--------------
.. autosummary::
:toctree: _autosummary_functions
:template: nn-module-template.rst
losses.cross_entropy
losses.l1_loss

View File

@ -12,12 +12,9 @@ def cross_entropy(
Args:
logits (mx.array): The predicted logits.
targets (mx.array): The target values.
axis (int, optional): The axis over which to compute softmax. Defaults to -1.
reduction (str, optional): Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'.
'none': no reduction will be applied.
'mean': the sum of the output will be divided by the number of elements in the output.
'sum': the output will be summed.
Defaults to 'none'.
axis (int, optional): The axis over which to compute softmax. Default: ``-1``.
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
Returns:
mx.array: The computed cross entropy loss.