mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-03 01:06:43 +08:00
add losses to the docs, fix black failur (#92)
This commit is contained in:
parent
430bfb4944
commit
2520dbcf0a
@ -170,3 +170,13 @@ simple functions.
|
|||||||
gelu_fast_approx
|
gelu_fast_approx
|
||||||
relu
|
relu
|
||||||
silu
|
silu
|
||||||
|
|
||||||
|
Loss Functions
|
||||||
|
--------------
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary_functions
|
||||||
|
:template: nn-module-template.rst
|
||||||
|
|
||||||
|
losses.cross_entropy
|
||||||
|
losses.l1_loss
|
||||||
|
@ -12,12 +12,9 @@ def cross_entropy(
|
|||||||
Args:
|
Args:
|
||||||
logits (mx.array): The predicted logits.
|
logits (mx.array): The predicted logits.
|
||||||
targets (mx.array): The target values.
|
targets (mx.array): The target values.
|
||||||
axis (int, optional): The axis over which to compute softmax. Defaults to -1.
|
axis (int, optional): The axis over which to compute softmax. Default: ``-1``.
|
||||||
reduction (str, optional): Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'.
|
reduction (str, optional): Specifies the reduction to apply to the output:
|
||||||
'none': no reduction will be applied.
|
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
||||||
'mean': the sum of the output will be divided by the number of elements in the output.
|
|
||||||
'sum': the output will be summed.
|
|
||||||
Defaults to 'none'.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
mx.array: The computed cross entropy loss.
|
mx.array: The computed cross entropy loss.
|
||||||
|
Loading…
Reference in New Issue
Block a user