mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-22 04:56:41 +08:00
doc nits
This commit is contained in:
parent
15577cb727
commit
865e53fcab
@ -5,7 +5,7 @@ from mlx.nn.layers.base import Module
|
||||
|
||||
|
||||
class Dropout(Module):
|
||||
"""Randomly zero a portion of the elements during training.
|
||||
r"""Randomly zero a portion of the elements during training.
|
||||
|
||||
The remaining elements are multiplied with :math:`\frac{1}{1-p}` where
|
||||
:math:`p` is the probability of zeroing an element. This is done so the
|
||||
@ -36,15 +36,13 @@ class Dropout(Module):
|
||||
|
||||
|
||||
class Dropout2d(Module):
|
||||
"""Apply 2D channel-wise dropout during training.
|
||||
r"""Apply 2D channel-wise dropout during training.
|
||||
|
||||
Randomly zero out entire channels independently with probability :math:`p`.
|
||||
This layer expects the channels to be last, i.e. the input shape should be
|
||||
``NWHC`` or ``WHC`` where:
|
||||
- ``N`` is the batch dimension
|
||||
- ``H`` is the input image height
|
||||
- ``W`` is the input image width
|
||||
- ``C`` is the number of input channels
|
||||
``NWHC`` or ``WHC`` where:``N`` is the batch dimension,``H`` is the input
|
||||
image height,``W`` is the input image width, and``C`` is the number of
|
||||
input channels
|
||||
|
||||
The remaining channels are scaled by :math:`\frac{1}{1-p}` to
|
||||
maintain the expected value of each element. Unlike traditional dropout,
|
||||
|
@ -194,24 +194,31 @@ class BatchNorm(Module):
|
||||
where :math:`\gamma` and :math:`\beta` are learned per feature dimension
|
||||
parameters initialized at 1 and 0 respectively.
|
||||
|
||||
[1]: https://arxiv.org/abs/1502.03167
|
||||
The input shape is specified as ``NC`` or ``NLC``, where ``N`` is the
|
||||
batch, ``C`` is the number of features or channels, and ``L`` is the
|
||||
sequence length. The output has the same shape as the input. For
|
||||
four-dimensional arrays, the shape is ``NHWC``, where ``H`` and ``W`` are
|
||||
the height and width respecitvely.
|
||||
|
||||
The input tensor shape is specified as (N, C) or (N, L, C), representing the batch size (N), the number of features or channels (C), and optionally, the sequence length (L). The output tensor maintains the same shape as the input, adhering to (N, C) or (N, L, C).
|
||||
For three-dimensional tensors, the shape is denoted as (N, H, W, C), where N signifies the batch size, C represents the number of channels, H corresponds to the height, and W denotes the width.
|
||||
For more information on Batch Normalization, see the original paper `Batch
|
||||
Normalization: Accelerating Deep Network Training by Reducing Internal
|
||||
Covariate Shift <https://arxiv.org/abs/1502.03167>`_.
|
||||
|
||||
Args:
|
||||
num_features (int): The feature dimension of the input to normalize over.
|
||||
eps (float, optional): A small additive constant for numerical stability. Default is 1e-5.
|
||||
momentum (float, optional): The momentum for updating the running mean and variance. Default is 0.1.
|
||||
affine (bool, optional): If True, learn an affine transform to apply after the normalization. Default is True.
|
||||
track_running_stats (bool, optional): If True, track the running mean and variance. Default is True.
|
||||
num_features (int): The feature dimension to normalize over.
|
||||
eps (float, optional): A small additive constant for numerical
|
||||
stability. Default: ``1e-5``.
|
||||
momentum (float, optional): The momentum for updating the running
|
||||
mean and variance. Default: ``0.1``.
|
||||
affine (bool, optional): If ``True``, apply a learned affine
|
||||
transformation after the normalization. Default: ``True``.
|
||||
track_running_stats (bool, optional): If ``True``, track the
|
||||
running mean and variance. Default: ``True``.
|
||||
|
||||
Examples:
|
||||
>>> import mlx.core as mx
|
||||
>>> import mlx.nn as nn
|
||||
>>> mx.random.seed(42)
|
||||
>>> input = mx.random.normal((5, 4), dtype=mx.float32)
|
||||
>>> # Batch norm
|
||||
>>> x = mx.random.normal((5, 4))
|
||||
>>> bn = nn.BatchNorm(num_features=4, affine=True)
|
||||
>>> output = bn(x)
|
||||
"""
|
||||
@ -229,10 +236,9 @@ class BatchNorm(Module):
|
||||
self.num_features = num_features
|
||||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
self.affine = affine
|
||||
self.track_running_stats = track_running_stats
|
||||
|
||||
if self.affine:
|
||||
if affine:
|
||||
self.weight = mx.ones((num_features,))
|
||||
self.bias = mx.zeros((num_features,))
|
||||
|
||||
@ -241,7 +247,11 @@ class BatchNorm(Module):
|
||||
self._running_var = mx.ones((num_features,))
|
||||
|
||||
def _extra_repr(self):
|
||||
return f"{self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={'weight' in self}, track_running_stats={self.track_running_stats}"
|
||||
return (
|
||||
f"{self.num_features}, eps={self.eps}, "
|
||||
f"momentum={self.momentum}, affine={'weight' in self}, "
|
||||
f"track_running_stats={self.track_running_stats}"
|
||||
)
|
||||
|
||||
def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]:
|
||||
"""
|
||||
@ -253,9 +263,7 @@ class BatchNorm(Module):
|
||||
Returns:
|
||||
tuple: Tuple containing mean and variance.
|
||||
"""
|
||||
reduction_axes = (
|
||||
(0,) if len(x.shape) == 2 else (0, 1) if len(x.shape) == 3 else (0, 1, 2)
|
||||
)
|
||||
reduction_axes = tuple(range(0, x.ndim - 1))
|
||||
means = mx.mean(x, axis=reduction_axes, keepdims=True)
|
||||
var = mx.var(x, axis=reduction_axes, keepdims=True)
|
||||
|
||||
@ -279,7 +287,7 @@ class BatchNorm(Module):
|
||||
mx.array: Output tensor.
|
||||
"""
|
||||
|
||||
if x.ndim not in [2, 3, 4]:
|
||||
if x.ndim < 2 or x.ndim > 4:
|
||||
raise ValueError(
|
||||
f"Expected input tensor to have 2, 3 or 4 dimensions, but got {x.ndim}"
|
||||
)
|
||||
|
@ -286,7 +286,7 @@ def _reduce(loss: mx.array, reduction: str = "none"):
|
||||
def hinge_loss(
|
||||
inputs: mx.array, targets: mx.array, reduction: str = "none"
|
||||
) -> mx.array:
|
||||
"""
|
||||
r"""
|
||||
Computes the hinge loss between inputs and targets.
|
||||
|
||||
.. math::
|
||||
@ -311,7 +311,7 @@ def hinge_loss(
|
||||
def huber_loss(
|
||||
inputs: mx.array, targets: mx.array, delta: float = 1.0, reduction: str = "none"
|
||||
) -> mx.array:
|
||||
"""
|
||||
r"""
|
||||
Computes the Huber loss between inputs and targets.
|
||||
|
||||
.. math::
|
||||
@ -345,7 +345,7 @@ def huber_loss(
|
||||
def log_cosh_loss(
|
||||
inputs: mx.array, targets: mx.array, reduction: str = "none"
|
||||
) -> mx.array:
|
||||
"""
|
||||
r"""
|
||||
Computes the log cosh loss between inputs and targets.
|
||||
|
||||
Logcosh acts like L2 loss for small errors, ensuring stable gradients,
|
||||
|
@ -3,7 +3,6 @@
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -342,9 +341,9 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
expected_mean = mx.array([0.008929, 0.005680, -0.016092, 0.027778])
|
||||
expected_var = mx.array([0.928435, 1.00455, 1.04117, 0.94258])
|
||||
self.assertTrue(x.shape == y.shape)
|
||||
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
|
||||
self.assertTrue(np.allclose(bn._running_mean, expected_mean, atol=1e-5))
|
||||
self.assertTrue(np.allclose(bn._running_var, expected_var, atol=1e-5))
|
||||
self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))
|
||||
self.assertTrue(mx.allclose(bn._running_mean, expected_mean, atol=1e-5))
|
||||
self.assertTrue(mx.allclose(bn._running_var, expected_var, atol=1e-5))
|
||||
|
||||
# test eval mode
|
||||
bn.eval()
|
||||
@ -360,7 +359,7 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
)
|
||||
|
||||
self.assertTrue(x.shape == y.shape)
|
||||
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
|
||||
self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))
|
||||
|
||||
# test_no_affine
|
||||
bn = nn.BatchNorm(num_features=4, affine=False)
|
||||
@ -406,13 +405,13 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
],
|
||||
]
|
||||
)
|
||||
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
|
||||
self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))
|
||||
expected_mean = mx.array(
|
||||
[[[0.00207845, -5.3259e-05, 0.04755, -0.0697296, 0.0236228]]]
|
||||
)
|
||||
expected_var = mx.array([[[0.968415, 1.05322, 0.96913, 0.932305, 0.967224]]])
|
||||
self.assertTrue(np.allclose(bn._running_mean, expected_mean, atol=1e-5))
|
||||
self.assertTrue(np.allclose(bn._running_var, expected_var, atol=1e-5))
|
||||
self.assertTrue(mx.allclose(bn._running_mean, expected_mean, atol=1e-5))
|
||||
self.assertTrue(mx.allclose(bn._running_var, expected_var, atol=1e-5))
|
||||
|
||||
x = mx.random.normal((N, L, C, L, C), dtype=mx.float32)
|
||||
with self.assertRaises(ValueError):
|
||||
|
Loading…
Reference in New Issue
Block a user