mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
Update normalization.py
This commit is contained in:
parent
a43b853194
commit
b444a6a693
@ -208,7 +208,7 @@ class BatchNorm(Module):
|
|||||||
>>> mx.random.seed(42)
|
>>> mx.random.seed(42)
|
||||||
>>> input = mx.random.normal((5, 4), dtype=mx.float32)
|
>>> input = mx.random.normal((5, 4), dtype=mx.float32)
|
||||||
>>> # Batch norm
|
>>> # Batch norm
|
||||||
>>> bn = nn.BatchNorm1d(num_features=4, affine=True)
|
>>> bn = nn.BatchNorm(num_features=4, affine=True)
|
||||||
>>> output = bn(x)
|
>>> output = bn(x)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user