mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Common neural network initializers nn.initializers
(#456)
* initial commit: constant, normal, uniform * identity, glorot and he initializers * docstrings * rm file * nits * nits * nits * testing suite * docs * nits in docs * more docs * remove unused template * rename packakge to nn.innit * docs, receptive field * more docs --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
86e0c79467
commit
6b4b30e3fc
@ -180,3 +180,4 @@ In detail:
|
|||||||
nn/layers
|
nn/layers
|
||||||
nn/functions
|
nn/functions
|
||||||
nn/losses
|
nn/losses
|
||||||
|
nn/init
|
||||||
|
45
docs/src/python/nn/init.rst
Normal file
45
docs/src/python/nn/init.rst
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
.. _init:
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.nn.init
|
||||||
|
|
||||||
|
Initializers
|
||||||
|
------------
|
||||||
|
|
||||||
|
The ``mlx.nn.init`` package contains commonly used initializers for neural
|
||||||
|
network parameters. Initializers return a function which can be applied to any
|
||||||
|
input :obj:`mlx.core.array` to produce an initialized output.
|
||||||
|
|
||||||
|
For example:
|
||||||
|
|
||||||
|
.. code:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
init_fn = nn.init.uniform()
|
||||||
|
|
||||||
|
# Produces a [2, 2] uniform matrix
|
||||||
|
param = init_fn(mx.zeros((2, 2)))
|
||||||
|
|
||||||
|
To re-initialize all the parameter in an :obj:`mlx.nn.Module` from say a uniform
|
||||||
|
distribution, you can do:
|
||||||
|
|
||||||
|
.. code:: python
|
||||||
|
|
||||||
|
import mlx.nn as nn
|
||||||
|
model = nn.Sequential(nn.Linear(5, 10), nn.ReLU(), nn.Linear(10, 5))
|
||||||
|
init_fn = nn.init.uniform(low=-0.1, high=0.1)
|
||||||
|
model.apply(init_fn)
|
||||||
|
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
constant
|
||||||
|
normal
|
||||||
|
uniform
|
||||||
|
identity
|
||||||
|
glorot_normal
|
||||||
|
glorot_uniform
|
||||||
|
he_normal
|
||||||
|
he_uniform
|
@ -1,5 +1,5 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
from mlx.nn import losses
|
from mlx.nn import init, losses
|
||||||
from mlx.nn.layers import *
|
from mlx.nn.layers import *
|
||||||
from mlx.nn.utils import value_and_grad
|
from mlx.nn.utils import value_and_grad
|
||||||
|
350
python/mlx/nn/init.py
Normal file
350
python/mlx/nn/init.py
Normal file
@ -0,0 +1,350 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Callable, Literal
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
|
||||||
|
def constant(
|
||||||
|
value: float, dtype: mx.Dtype = mx.float32
|
||||||
|
) -> Callable[[mx.array], mx.array]:
|
||||||
|
r"""An initializer that returns an array filled with ``value``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value (float): The value to fill the array with.
|
||||||
|
dtype (Dtype, optional): The data type of the array. Default:
|
||||||
|
``float32``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable[[array], array]: An initializer that returns an array with the
|
||||||
|
same shape as the input, filled with ``value``.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
>>> init_fn = nn.init.constant(0.5)
|
||||||
|
>>> init_fn(mx.zeros((2, 2)))
|
||||||
|
array([[0.5, 0.5],
|
||||||
|
[0.5, 0.5]], dtype=float32)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def initializer(a: mx.array) -> mx.array:
|
||||||
|
return mx.full(a.shape, value, dtype=dtype)
|
||||||
|
|
||||||
|
return initializer
|
||||||
|
|
||||||
|
|
||||||
|
def normal(
|
||||||
|
mean: float = 0.0, std: float = 1.0, dtype: mx.Dtype = mx.float32
|
||||||
|
) -> Callable[[mx.array], mx.array]:
|
||||||
|
r"""An initializer that returns samples from a normal distribution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mean (float, optional): Mean of the normal distribution. Default:
|
||||||
|
``0.0``.
|
||||||
|
std (float, optional): Standard deviation of the normal distribution.
|
||||||
|
Default: ``1.0``.
|
||||||
|
dtype (Dtype, optional): The data type of the array. Default:
|
||||||
|
``float32``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable[[array], array]: An initializer that returns an array with the
|
||||||
|
same shape as the input, filled with samples from a normal distribution.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
>>> init_fn = nn.init.normal()
|
||||||
|
>>> init_fn(mx.zeros((2, 2)))
|
||||||
|
array([[-0.982273, -0.534422],
|
||||||
|
[0.380709, 0.0645099]], dtype=float32)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def initializer(a: mx.array) -> mx.array:
|
||||||
|
return std * mx.random.normal(shape=a.shape, dtype=dtype) + mean
|
||||||
|
|
||||||
|
return initializer
|
||||||
|
|
||||||
|
|
||||||
|
def uniform(
|
||||||
|
low: float = 0.0, high: float = 1.0, dtype: mx.Dtype = mx.float32
|
||||||
|
) -> Callable[[mx.array], mx.array]:
|
||||||
|
r"""An initializer that returns samples from a uniform distribution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
low (float, optional): The lower bound of the uniform distribution.
|
||||||
|
Default: ``0.0``.
|
||||||
|
high (float, optional): The upper bound of the uniform distribution.
|
||||||
|
Default: ``1.0``
|
||||||
|
dtype (Dtype, optional): The data type of the array. Default: ``float32``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable[[array], array]: An initializer that returns an array
|
||||||
|
with the same shape as the input, filled with samples from a uniform
|
||||||
|
distribution
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
>>> init_fn = nn.init.uniform(low=0, high=1)
|
||||||
|
>>> init_fn(mx.zeros((2, 2)))
|
||||||
|
array([[0.883935, 0.863726],
|
||||||
|
[0.617261, 0.417497]], dtype=float32)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def initializer(a: mx.array) -> mx.array:
|
||||||
|
return mx.random.uniform(low, high, a.shape, dtype=dtype)
|
||||||
|
|
||||||
|
return initializer
|
||||||
|
|
||||||
|
|
||||||
|
def identity(dtype: mx.Dtype = mx.float32) -> Callable[[mx.array], mx.array]:
|
||||||
|
r"""An initializer that returns an identity matrix.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dtype (Dtype, optional): The data type of the array. Defaults:
|
||||||
|
``float32``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable[[array], array]: An initializer that returns an identity
|
||||||
|
matrix with the same shape as the input.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
>>> init_fn = nn.init.identity()
|
||||||
|
>>> init_fn(mx.zeros((2, 2)))
|
||||||
|
array([[1, 0],
|
||||||
|
[0, 1]], dtype=float32)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def initializer(arr: mx.array) -> mx.array:
|
||||||
|
if arr.ndim != 2 or arr.shape[0] != arr.shape[1]:
|
||||||
|
raise ValueError(
|
||||||
|
f"The input array must be a square matrix but got shape {arr.shape}."
|
||||||
|
)
|
||||||
|
return mx.eye(n=arr.shape[0], dtype=dtype)
|
||||||
|
|
||||||
|
return initializer
|
||||||
|
|
||||||
|
|
||||||
|
def _calculate_fan_in_fan_out(x):
|
||||||
|
if x.ndim < 2:
|
||||||
|
raise ValueError(
|
||||||
|
"Glorot / He initialization requires at least 2 dimensional input"
|
||||||
|
f" but input with {x.ndim} dimensions."
|
||||||
|
)
|
||||||
|
|
||||||
|
fan_in = x.shape[-1]
|
||||||
|
fan_out = x.shape[0]
|
||||||
|
|
||||||
|
if x.ndim > 2:
|
||||||
|
receptive_field = 1
|
||||||
|
for d in x.shape[1:-1]:
|
||||||
|
receptive_field *= d
|
||||||
|
|
||||||
|
fan_in = fan_in * receptive_field
|
||||||
|
fan_out = fan_out * receptive_field
|
||||||
|
|
||||||
|
return fan_in, fan_out
|
||||||
|
|
||||||
|
|
||||||
|
def glorot_normal(
|
||||||
|
dtype: mx.Dtype = mx.float32,
|
||||||
|
) -> Callable[[mx.array, float], mx.array]:
|
||||||
|
r"""A Glorot normal initializer.
|
||||||
|
|
||||||
|
This initializer samples from a normal distribution with a standard
|
||||||
|
deviation computed from the number of input (``fan_in``) and output
|
||||||
|
(``fan_out``) units according to:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\sigma = \gamma \sqrt{\frac{2.0}{\text{fan_in} + \text{fan_out}}}
|
||||||
|
|
||||||
|
For more details see the original reference: `Understanding the difficulty
|
||||||
|
of training deep feedforward neural networks
|
||||||
|
<https://proceedings.mlr.press/v9/glorot10a.html>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dtype (Dtype, optional): The data type of the array. Default: ``float32``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable[[array, float], array]: An initializer that returns an array
|
||||||
|
with the same shape as the input, filled with samples from the Glorot
|
||||||
|
normal distribution.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
>>> init_fn = nn.init.glorot_normal()
|
||||||
|
>>> init_fn(mx.zeros((2, 2)))
|
||||||
|
array([[0.191107, 1.61278],
|
||||||
|
[-0.150594, -0.363207]], dtype=float32)
|
||||||
|
>>> init_fn(mx.zeros((2, 2)), gain=4.0)
|
||||||
|
array([[1.89613, -4.53947],
|
||||||
|
[4.48095, 0.995016]], dtype=float32)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def initializer(a: mx.array, gain: float = 1.0) -> mx.array:
|
||||||
|
fan_in, fan_out = _calculate_fan_in_fan_out(a)
|
||||||
|
std = gain * math.sqrt(2.0 / (fan_in + fan_out))
|
||||||
|
return mx.random.normal(shape=a.shape, dtype=dtype) * std
|
||||||
|
|
||||||
|
return initializer
|
||||||
|
|
||||||
|
|
||||||
|
def glorot_uniform(
|
||||||
|
dtype: mx.Dtype = mx.float32,
|
||||||
|
) -> Callable[[mx.array, float], mx.array]:
|
||||||
|
r"""A Glorot uniform initializer.
|
||||||
|
|
||||||
|
This initializer samples from a uniform distribution with a range
|
||||||
|
computed from the number of input (``fan_in``) and output (``fan_out``)
|
||||||
|
units according to:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\sigma = \gamma \sqrt{\frac{6.0}{\text{fan_in} + \text{fan_out}}}
|
||||||
|
|
||||||
|
For more details see the original reference: `Understanding the difficulty
|
||||||
|
of training deep feedforward neural networks
|
||||||
|
<https://proceedings.mlr.press/v9/glorot10a.html>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dtype (Dtype, optional): The data type of the array. Default: ``float32``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable[[array, float], array]: An initializer that returns an array
|
||||||
|
with the same shape as the input, filled with samples from the Glorot
|
||||||
|
uniform distribution.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
>>> init_fn = nn.init.glorot_uniform()
|
||||||
|
>>> init_fn(mx.zeros((2, 2)))
|
||||||
|
array([[0.223404, -0.890597],
|
||||||
|
[-0.379159, -0.776856]], dtype=float32)
|
||||||
|
>>> init_fn(mx.zeros((2, 2)), gain=4.0)
|
||||||
|
array([[-1.90041, 3.02264],
|
||||||
|
[-0.912766, 4.12451]], dtype=float32)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def initializer(a: mx.array, gain: float = 1.0) -> mx.array:
|
||||||
|
fan_in, fan_out = _calculate_fan_in_fan_out(a)
|
||||||
|
limit = gain * math.sqrt(6.0 / (fan_in + fan_out))
|
||||||
|
return mx.random.uniform(-limit, limit, a.shape, dtype=dtype)
|
||||||
|
|
||||||
|
return initializer
|
||||||
|
|
||||||
|
|
||||||
|
def he_normal(
|
||||||
|
dtype: mx.Dtype = mx.float32,
|
||||||
|
) -> Callable[[mx.array, str, float], mx.array]:
|
||||||
|
r"""Build a He normal initializer.
|
||||||
|
|
||||||
|
This initializer samples from a normal distribution with a standard
|
||||||
|
deviation computed from the number of input (``fan_in``) or output
|
||||||
|
(``fan_out``) units according to:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\sigma = \gamma \frac{1}{\sqrt{\text{fan}}}
|
||||||
|
|
||||||
|
where :math:`\text{fan}` is either the number of input units when the
|
||||||
|
``mode`` is ``"fan_in"`` or output units when the ``mode`` is
|
||||||
|
``"fan_out"``.
|
||||||
|
|
||||||
|
For more details see the original reference: `Delving Deep into Rectifiers:
|
||||||
|
Surpassing Human-Level Performance on ImageNet Classification
|
||||||
|
<https://arxiv.org/abs/1502.01852>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dtype (Dtype, optional): The data type of the array. Defaults to mx.float32.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable[[array, str, float], array]: An initializer that returns an
|
||||||
|
array with the same shape as the input, filled with samples from the He
|
||||||
|
normal distribution.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
>>> init_fn = nn.init.he_normal()
|
||||||
|
>>> init_fn(mx.zeros((2, 2))) # uses fan_in
|
||||||
|
array([[-1.25211, 0.458835],
|
||||||
|
[-0.177208, -0.0137595]], dtype=float32)
|
||||||
|
>>> init_fn(mx.zeros((2, 2)), mode="fan_out", gain=5)
|
||||||
|
array([[5.6967, 4.02765],
|
||||||
|
[-4.15268, -2.75787]], dtype=float32)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def initializer(
|
||||||
|
a: mx.array,
|
||||||
|
mode: Literal["fan_in", "fan_out"] = "fan_in",
|
||||||
|
gain: float = 1.0,
|
||||||
|
) -> mx.array:
|
||||||
|
fan_in, fan_out = _calculate_fan_in_fan_out(a)
|
||||||
|
if mode == "fan_in":
|
||||||
|
fan = fan_in
|
||||||
|
elif mode == "fan_out":
|
||||||
|
fan = fan_out
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid mode: {mode}. Valid modes are: fan_in, fan_out")
|
||||||
|
|
||||||
|
std = gain / math.sqrt(fan)
|
||||||
|
return mx.random.normal(shape=a.shape, dtype=dtype) * std
|
||||||
|
|
||||||
|
return initializer
|
||||||
|
|
||||||
|
|
||||||
|
def he_uniform(
|
||||||
|
dtype: mx.Dtype = mx.float32,
|
||||||
|
) -> Callable[[mx.array, str, float], mx.array]:
|
||||||
|
r"""A He uniform (Kaiming uniform) initializer.
|
||||||
|
|
||||||
|
This initializer samples from a uniform distribution with a range
|
||||||
|
computed from the number of input (``fan_in``) or output (``fan_out``)
|
||||||
|
units according to:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
\sigma = \gamma \sqrt{\frac{3.0}{\text{fan}}}
|
||||||
|
|
||||||
|
where :math:`\text{fan}` is either the number of input units when the
|
||||||
|
``mode`` is ``"fan_in"`` or output units when the ``mode`` is
|
||||||
|
``"fan_out"``.
|
||||||
|
|
||||||
|
For more details see the original reference: `Delving Deep into Rectifiers:
|
||||||
|
Surpassing Human-Level Performance on ImageNet Classification
|
||||||
|
<https://arxiv.org/abs/1502.01852>`_
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dtype (Dtype, optional): The data type of the array. Default: ``float32``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable[[array, str, float], array]: An initializer that returns an
|
||||||
|
array with the same shape as the input, filled with samples from the
|
||||||
|
He uniform distribution.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
>>> init_fn = nn.init.he_uniform()
|
||||||
|
>>> init_fn(mx.zeros((2, 2))) # uses fan_in
|
||||||
|
array([[0.0300242, -0.0184009],
|
||||||
|
[0.793615, 0.666329]], dtype=float32)
|
||||||
|
>>> init_fn(mx.zeros((2, 2)), mode="fan_out", gain=5)
|
||||||
|
array([[-1.64331, -2.16506],
|
||||||
|
[1.08619, 5.79854]], dtype=float32)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def initializer(
|
||||||
|
a: mx.array,
|
||||||
|
mode: Literal["fan_in", "fan_out"] = "fan_in",
|
||||||
|
gain: float = 1.0,
|
||||||
|
) -> mx.array:
|
||||||
|
fan_in, fan_out = _calculate_fan_in_fan_out(a)
|
||||||
|
if mode == "fan_in":
|
||||||
|
fan = fan_in
|
||||||
|
elif mode == "fan_out":
|
||||||
|
fan = fan_out
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid mode: {mode}. Valid modes are: fan_in, fan_out")
|
||||||
|
|
||||||
|
limit = gain * math.sqrt(3.0 / fan)
|
||||||
|
return mx.random.uniform(-limit, limit, a.shape, dtype=dtype)
|
||||||
|
|
||||||
|
return initializer
|
94
python/tests/test_init.py
Normal file
94
python/tests/test_init.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
# Copyright © 2023 Apple Inc.
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn.init as init
|
||||||
|
import mlx_tests
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class TestInit(mlx_tests.MLXTestCase):
|
||||||
|
def test_constant(self):
|
||||||
|
value = 5.0
|
||||||
|
|
||||||
|
for dtype in [mx.float32, mx.float16]:
|
||||||
|
initializer = init.constant(value, dtype)
|
||||||
|
for shape in [[3], [3, 3], [3, 3, 3]]:
|
||||||
|
result = initializer(mx.array(mx.zeros(shape)))
|
||||||
|
with self.subTest(shape=shape):
|
||||||
|
self.assertEqual(result.shape, shape)
|
||||||
|
self.assertEqual(result.dtype, dtype)
|
||||||
|
|
||||||
|
def test_normal(self):
|
||||||
|
mean = 0.0
|
||||||
|
std = 1.0
|
||||||
|
for dtype in [mx.float32, mx.float16]:
|
||||||
|
initializer = init.normal(mean, std, dtype=dtype)
|
||||||
|
for shape in [[3], [3, 3], [3, 3, 3]]:
|
||||||
|
result = initializer(mx.array(np.empty(shape)))
|
||||||
|
with self.subTest(shape=shape):
|
||||||
|
self.assertEqual(result.shape, shape)
|
||||||
|
self.assertEqual(result.dtype, dtype)
|
||||||
|
|
||||||
|
def test_uniform(self):
|
||||||
|
low = -1.0
|
||||||
|
high = 1.0
|
||||||
|
|
||||||
|
for dtype in [mx.float32, mx.float16]:
|
||||||
|
initializer = init.uniform(low, high, dtype)
|
||||||
|
for shape in [[3], [3, 3], [3, 3, 3]]:
|
||||||
|
result = initializer(mx.array(np.empty(shape)))
|
||||||
|
with self.subTest(shape=shape):
|
||||||
|
self.assertEqual(result.shape, shape)
|
||||||
|
self.assertEqual(result.dtype, dtype)
|
||||||
|
self.assertTrue(mx.all(result >= low) and mx.all(result <= high))
|
||||||
|
|
||||||
|
def test_identity(self):
|
||||||
|
for dtype in [mx.float32, mx.float16]:
|
||||||
|
initializer = init.identity(dtype)
|
||||||
|
for shape in [[3], [3, 3], [3, 3, 3]]:
|
||||||
|
result = initializer(mx.zeros((3, 3)))
|
||||||
|
self.assertTrue(mx.array_equal(result, mx.eye(3)))
|
||||||
|
self.assertEqual(result.dtype, dtype)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
result = initializer(mx.zeros((3, 2)))
|
||||||
|
|
||||||
|
def test_glorot_normal(self):
|
||||||
|
for dtype in [mx.float32, mx.float16]:
|
||||||
|
initializer = init.glorot_normal(dtype)
|
||||||
|
for shape in [[3, 3], [3, 3, 3]]:
|
||||||
|
result = initializer(mx.array(np.empty(shape)))
|
||||||
|
with self.subTest(shape=shape):
|
||||||
|
self.assertEqual(result.shape, shape)
|
||||||
|
self.assertEqual(result.dtype, dtype)
|
||||||
|
|
||||||
|
def test_glorot_uniform(self):
|
||||||
|
for dtype in [mx.float32, mx.float16]:
|
||||||
|
initializer = init.glorot_uniform(dtype)
|
||||||
|
for shape in [[3, 3], [3, 3, 3]]:
|
||||||
|
result = initializer(mx.array(np.empty(shape)))
|
||||||
|
with self.subTest(shape=shape):
|
||||||
|
self.assertEqual(result.shape, shape)
|
||||||
|
self.assertEqual(result.dtype, dtype)
|
||||||
|
|
||||||
|
def test_he_normal(self):
|
||||||
|
for dtype in [mx.float32, mx.float16]:
|
||||||
|
initializer = init.he_normal(dtype)
|
||||||
|
for shape in [[3, 3], [3, 3, 3]]:
|
||||||
|
result = initializer(mx.array(np.empty(shape)))
|
||||||
|
with self.subTest(shape=shape):
|
||||||
|
self.assertEqual(result.shape, shape)
|
||||||
|
self.assertEqual(result.dtype, dtype)
|
||||||
|
|
||||||
|
def test_he_uniform(self):
|
||||||
|
for dtype in [mx.float32, mx.float16]:
|
||||||
|
initializer = init.he_uniform(dtype)
|
||||||
|
for shape in [[3, 3], [3, 3, 3]]:
|
||||||
|
result = initializer(mx.array(np.empty(shape)))
|
||||||
|
with self.subTest(shape=shape):
|
||||||
|
self.assertEqual(result.shape, shape)
|
||||||
|
self.assertEqual(result.dtype, dtype)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in New Issue
Block a user