mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-31 15:21:19 +08:00
add losses to the docs, fix black failur (#92)
This commit is contained in:
parent
430bfb4944
commit
2520dbcf0a
@ -170,3 +170,13 @@ simple functions.
|
||||
gelu_fast_approx
|
||||
relu
|
||||
silu
|
||||
|
||||
Loss Functions
|
||||
--------------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary_functions
|
||||
:template: nn-module-template.rst
|
||||
|
||||
losses.cross_entropy
|
||||
losses.l1_loss
|
||||
|
@ -12,12 +12,9 @@ def cross_entropy(
|
||||
Args:
|
||||
logits (mx.array): The predicted logits.
|
||||
targets (mx.array): The target values.
|
||||
axis (int, optional): The axis over which to compute softmax. Defaults to -1.
|
||||
reduction (str, optional): Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'.
|
||||
'none': no reduction will be applied.
|
||||
'mean': the sum of the output will be divided by the number of elements in the output.
|
||||
'sum': the output will be summed.
|
||||
Defaults to 'none'.
|
||||
axis (int, optional): The axis over which to compute softmax. Default: ``-1``.
|
||||
reduction (str, optional): Specifies the reduction to apply to the output:
|
||||
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
||||
|
||||
Returns:
|
||||
mx.array: The computed cross entropy loss.
|
||||
|
Loading…
Reference in New Issue
Block a user