diff --git a/docs/src/python/nn.rst b/docs/src/python/nn.rst index 496c27823..2a253ab25 100644 --- a/docs/src/python/nn.rst +++ b/docs/src/python/nn.rst @@ -180,3 +180,4 @@ In detail: nn/layers nn/functions nn/losses + nn/init diff --git a/docs/src/python/nn/init.rst b/docs/src/python/nn/init.rst new file mode 100644 index 000000000..610d767d4 --- /dev/null +++ b/docs/src/python/nn/init.rst @@ -0,0 +1,45 @@ +.. _init: + +.. currentmodule:: mlx.nn.init + +Initializers +------------ + +The ``mlx.nn.init`` package contains commonly used initializers for neural +network parameters. Initializers return a function which can be applied to any +input :obj:`mlx.core.array` to produce an initialized output. + +For example: + +.. code:: python + + import mlx.core as mx + import mlx.nn as nn + + init_fn = nn.init.uniform() + + # Produces a [2, 2] uniform matrix + param = init_fn(mx.zeros((2, 2))) + +To re-initialize all the parameter in an :obj:`mlx.nn.Module` from say a uniform +distribution, you can do: + +.. code:: python + + import mlx.nn as nn + model = nn.Sequential(nn.Linear(5, 10), nn.ReLU(), nn.Linear(10, 5)) + init_fn = nn.init.uniform(low=-0.1, high=0.1) + model.apply(init_fn) + + +.. autosummary:: + :toctree: _autosummary + + constant + normal + uniform + identity + glorot_normal + glorot_uniform + he_normal + he_uniform diff --git a/python/mlx/nn/__init__.py b/python/mlx/nn/__init__.py index 9bb7cc63d..b2cb9e0f4 100644 --- a/python/mlx/nn/__init__.py +++ b/python/mlx/nn/__init__.py @@ -1,5 +1,5 @@ # Copyright © 2023 Apple Inc. -from mlx.nn import losses +from mlx.nn import init, losses from mlx.nn.layers import * from mlx.nn.utils import value_and_grad diff --git a/python/mlx/nn/init.py b/python/mlx/nn/init.py new file mode 100644 index 000000000..5afc6170e --- /dev/null +++ b/python/mlx/nn/init.py @@ -0,0 +1,350 @@ +# Copyright © 2023-2024 Apple Inc. + +import math +from typing import Callable, Literal + +import mlx.core as mx + + +def constant( + value: float, dtype: mx.Dtype = mx.float32 +) -> Callable[[mx.array], mx.array]: + r"""An initializer that returns an array filled with ``value``. + + Args: + value (float): The value to fill the array with. + dtype (Dtype, optional): The data type of the array. Default: + ``float32``. + + Returns: + Callable[[array], array]: An initializer that returns an array with the + same shape as the input, filled with ``value``. + + Example: + + >>> init_fn = nn.init.constant(0.5) + >>> init_fn(mx.zeros((2, 2))) + array([[0.5, 0.5], + [0.5, 0.5]], dtype=float32) + """ + + def initializer(a: mx.array) -> mx.array: + return mx.full(a.shape, value, dtype=dtype) + + return initializer + + +def normal( + mean: float = 0.0, std: float = 1.0, dtype: mx.Dtype = mx.float32 +) -> Callable[[mx.array], mx.array]: + r"""An initializer that returns samples from a normal distribution. + + Args: + mean (float, optional): Mean of the normal distribution. Default: + ``0.0``. + std (float, optional): Standard deviation of the normal distribution. + Default: ``1.0``. + dtype (Dtype, optional): The data type of the array. Default: + ``float32``. + + Returns: + Callable[[array], array]: An initializer that returns an array with the + same shape as the input, filled with samples from a normal distribution. + + Example: + + >>> init_fn = nn.init.normal() + >>> init_fn(mx.zeros((2, 2))) + array([[-0.982273, -0.534422], + [0.380709, 0.0645099]], dtype=float32) + """ + + def initializer(a: mx.array) -> mx.array: + return std * mx.random.normal(shape=a.shape, dtype=dtype) + mean + + return initializer + + +def uniform( + low: float = 0.0, high: float = 1.0, dtype: mx.Dtype = mx.float32 +) -> Callable[[mx.array], mx.array]: + r"""An initializer that returns samples from a uniform distribution. + + Args: + low (float, optional): The lower bound of the uniform distribution. + Default: ``0.0``. + high (float, optional): The upper bound of the uniform distribution. + Default: ``1.0`` + dtype (Dtype, optional): The data type of the array. Default: ``float32``. + + Returns: + Callable[[array], array]: An initializer that returns an array + with the same shape as the input, filled with samples from a uniform + distribution + + Example: + + >>> init_fn = nn.init.uniform(low=0, high=1) + >>> init_fn(mx.zeros((2, 2))) + array([[0.883935, 0.863726], + [0.617261, 0.417497]], dtype=float32) + """ + + def initializer(a: mx.array) -> mx.array: + return mx.random.uniform(low, high, a.shape, dtype=dtype) + + return initializer + + +def identity(dtype: mx.Dtype = mx.float32) -> Callable[[mx.array], mx.array]: + r"""An initializer that returns an identity matrix. + + Args: + dtype (Dtype, optional): The data type of the array. Defaults: + ``float32``. + + Returns: + Callable[[array], array]: An initializer that returns an identity + matrix with the same shape as the input. + + Example: + + >>> init_fn = nn.init.identity() + >>> init_fn(mx.zeros((2, 2))) + array([[1, 0], + [0, 1]], dtype=float32) + """ + + def initializer(arr: mx.array) -> mx.array: + if arr.ndim != 2 or arr.shape[0] != arr.shape[1]: + raise ValueError( + f"The input array must be a square matrix but got shape {arr.shape}." + ) + return mx.eye(n=arr.shape[0], dtype=dtype) + + return initializer + + +def _calculate_fan_in_fan_out(x): + if x.ndim < 2: + raise ValueError( + "Glorot / He initialization requires at least 2 dimensional input" + f" but input with {x.ndim} dimensions." + ) + + fan_in = x.shape[-1] + fan_out = x.shape[0] + + if x.ndim > 2: + receptive_field = 1 + for d in x.shape[1:-1]: + receptive_field *= d + + fan_in = fan_in * receptive_field + fan_out = fan_out * receptive_field + + return fan_in, fan_out + + +def glorot_normal( + dtype: mx.Dtype = mx.float32, +) -> Callable[[mx.array, float], mx.array]: + r"""A Glorot normal initializer. + + This initializer samples from a normal distribution with a standard + deviation computed from the number of input (``fan_in``) and output + (``fan_out``) units according to: + + .. math:: + \sigma = \gamma \sqrt{\frac{2.0}{\text{fan_in} + \text{fan_out}}} + + For more details see the original reference: `Understanding the difficulty + of training deep feedforward neural networks + `_ + + Args: + dtype (Dtype, optional): The data type of the array. Default: ``float32``. + + Returns: + Callable[[array, float], array]: An initializer that returns an array + with the same shape as the input, filled with samples from the Glorot + normal distribution. + + Example: + + >>> init_fn = nn.init.glorot_normal() + >>> init_fn(mx.zeros((2, 2))) + array([[0.191107, 1.61278], + [-0.150594, -0.363207]], dtype=float32) + >>> init_fn(mx.zeros((2, 2)), gain=4.0) + array([[1.89613, -4.53947], + [4.48095, 0.995016]], dtype=float32) + """ + + def initializer(a: mx.array, gain: float = 1.0) -> mx.array: + fan_in, fan_out = _calculate_fan_in_fan_out(a) + std = gain * math.sqrt(2.0 / (fan_in + fan_out)) + return mx.random.normal(shape=a.shape, dtype=dtype) * std + + return initializer + + +def glorot_uniform( + dtype: mx.Dtype = mx.float32, +) -> Callable[[mx.array, float], mx.array]: + r"""A Glorot uniform initializer. + + This initializer samples from a uniform distribution with a range + computed from the number of input (``fan_in``) and output (``fan_out``) + units according to: + + .. math:: + \sigma = \gamma \sqrt{\frac{6.0}{\text{fan_in} + \text{fan_out}}} + + For more details see the original reference: `Understanding the difficulty + of training deep feedforward neural networks + `_ + + Args: + dtype (Dtype, optional): The data type of the array. Default: ``float32``. + + Returns: + Callable[[array, float], array]: An initializer that returns an array + with the same shape as the input, filled with samples from the Glorot + uniform distribution. + + Example: + + >>> init_fn = nn.init.glorot_uniform() + >>> init_fn(mx.zeros((2, 2))) + array([[0.223404, -0.890597], + [-0.379159, -0.776856]], dtype=float32) + >>> init_fn(mx.zeros((2, 2)), gain=4.0) + array([[-1.90041, 3.02264], + [-0.912766, 4.12451]], dtype=float32) + """ + + def initializer(a: mx.array, gain: float = 1.0) -> mx.array: + fan_in, fan_out = _calculate_fan_in_fan_out(a) + limit = gain * math.sqrt(6.0 / (fan_in + fan_out)) + return mx.random.uniform(-limit, limit, a.shape, dtype=dtype) + + return initializer + + +def he_normal( + dtype: mx.Dtype = mx.float32, +) -> Callable[[mx.array, str, float], mx.array]: + r"""Build a He normal initializer. + + This initializer samples from a normal distribution with a standard + deviation computed from the number of input (``fan_in``) or output + (``fan_out``) units according to: + + .. math:: + \sigma = \gamma \frac{1}{\sqrt{\text{fan}}} + + where :math:`\text{fan}` is either the number of input units when the + ``mode`` is ``"fan_in"`` or output units when the ``mode`` is + ``"fan_out"``. + + For more details see the original reference: `Delving Deep into Rectifiers: + Surpassing Human-Level Performance on ImageNet Classification + `_ + + Args: + dtype (Dtype, optional): The data type of the array. Defaults to mx.float32. + + Returns: + Callable[[array, str, float], array]: An initializer that returns an + array with the same shape as the input, filled with samples from the He + normal distribution. + + Example: + + >>> init_fn = nn.init.he_normal() + >>> init_fn(mx.zeros((2, 2))) # uses fan_in + array([[-1.25211, 0.458835], + [-0.177208, -0.0137595]], dtype=float32) + >>> init_fn(mx.zeros((2, 2)), mode="fan_out", gain=5) + array([[5.6967, 4.02765], + [-4.15268, -2.75787]], dtype=float32) + """ + + def initializer( + a: mx.array, + mode: Literal["fan_in", "fan_out"] = "fan_in", + gain: float = 1.0, + ) -> mx.array: + fan_in, fan_out = _calculate_fan_in_fan_out(a) + if mode == "fan_in": + fan = fan_in + elif mode == "fan_out": + fan = fan_out + else: + raise ValueError(f"Invalid mode: {mode}. Valid modes are: fan_in, fan_out") + + std = gain / math.sqrt(fan) + return mx.random.normal(shape=a.shape, dtype=dtype) * std + + return initializer + + +def he_uniform( + dtype: mx.Dtype = mx.float32, +) -> Callable[[mx.array, str, float], mx.array]: + r"""A He uniform (Kaiming uniform) initializer. + + This initializer samples from a uniform distribution with a range + computed from the number of input (``fan_in``) or output (``fan_out``) + units according to: + + .. math:: + + \sigma = \gamma \sqrt{\frac{3.0}{\text{fan}}} + + where :math:`\text{fan}` is either the number of input units when the + ``mode`` is ``"fan_in"`` or output units when the ``mode`` is + ``"fan_out"``. + + For more details see the original reference: `Delving Deep into Rectifiers: + Surpassing Human-Level Performance on ImageNet Classification + `_ + + + Args: + dtype (Dtype, optional): The data type of the array. Default: ``float32``. + + Returns: + Callable[[array, str, float], array]: An initializer that returns an + array with the same shape as the input, filled with samples from the + He uniform distribution. + + Example: + + >>> init_fn = nn.init.he_uniform() + >>> init_fn(mx.zeros((2, 2))) # uses fan_in + array([[0.0300242, -0.0184009], + [0.793615, 0.666329]], dtype=float32) + >>> init_fn(mx.zeros((2, 2)), mode="fan_out", gain=5) + array([[-1.64331, -2.16506], + [1.08619, 5.79854]], dtype=float32) + """ + + def initializer( + a: mx.array, + mode: Literal["fan_in", "fan_out"] = "fan_in", + gain: float = 1.0, + ) -> mx.array: + fan_in, fan_out = _calculate_fan_in_fan_out(a) + if mode == "fan_in": + fan = fan_in + elif mode == "fan_out": + fan = fan_out + else: + raise ValueError(f"Invalid mode: {mode}. Valid modes are: fan_in, fan_out") + + limit = gain * math.sqrt(3.0 / fan) + return mx.random.uniform(-limit, limit, a.shape, dtype=dtype) + + return initializer diff --git a/python/tests/test_init.py b/python/tests/test_init.py new file mode 100644 index 000000000..06211a14e --- /dev/null +++ b/python/tests/test_init.py @@ -0,0 +1,94 @@ +# Copyright © 2023 Apple Inc. +import unittest + +import mlx.core as mx +import mlx.nn.init as init +import mlx_tests +import numpy as np + + +class TestInit(mlx_tests.MLXTestCase): + def test_constant(self): + value = 5.0 + + for dtype in [mx.float32, mx.float16]: + initializer = init.constant(value, dtype) + for shape in [[3], [3, 3], [3, 3, 3]]: + result = initializer(mx.array(mx.zeros(shape))) + with self.subTest(shape=shape): + self.assertEqual(result.shape, shape) + self.assertEqual(result.dtype, dtype) + + def test_normal(self): + mean = 0.0 + std = 1.0 + for dtype in [mx.float32, mx.float16]: + initializer = init.normal(mean, std, dtype=dtype) + for shape in [[3], [3, 3], [3, 3, 3]]: + result = initializer(mx.array(np.empty(shape))) + with self.subTest(shape=shape): + self.assertEqual(result.shape, shape) + self.assertEqual(result.dtype, dtype) + + def test_uniform(self): + low = -1.0 + high = 1.0 + + for dtype in [mx.float32, mx.float16]: + initializer = init.uniform(low, high, dtype) + for shape in [[3], [3, 3], [3, 3, 3]]: + result = initializer(mx.array(np.empty(shape))) + with self.subTest(shape=shape): + self.assertEqual(result.shape, shape) + self.assertEqual(result.dtype, dtype) + self.assertTrue(mx.all(result >= low) and mx.all(result <= high)) + + def test_identity(self): + for dtype in [mx.float32, mx.float16]: + initializer = init.identity(dtype) + for shape in [[3], [3, 3], [3, 3, 3]]: + result = initializer(mx.zeros((3, 3))) + self.assertTrue(mx.array_equal(result, mx.eye(3))) + self.assertEqual(result.dtype, dtype) + with self.assertRaises(ValueError): + result = initializer(mx.zeros((3, 2))) + + def test_glorot_normal(self): + for dtype in [mx.float32, mx.float16]: + initializer = init.glorot_normal(dtype) + for shape in [[3, 3], [3, 3, 3]]: + result = initializer(mx.array(np.empty(shape))) + with self.subTest(shape=shape): + self.assertEqual(result.shape, shape) + self.assertEqual(result.dtype, dtype) + + def test_glorot_uniform(self): + for dtype in [mx.float32, mx.float16]: + initializer = init.glorot_uniform(dtype) + for shape in [[3, 3], [3, 3, 3]]: + result = initializer(mx.array(np.empty(shape))) + with self.subTest(shape=shape): + self.assertEqual(result.shape, shape) + self.assertEqual(result.dtype, dtype) + + def test_he_normal(self): + for dtype in [mx.float32, mx.float16]: + initializer = init.he_normal(dtype) + for shape in [[3, 3], [3, 3, 3]]: + result = initializer(mx.array(np.empty(shape))) + with self.subTest(shape=shape): + self.assertEqual(result.shape, shape) + self.assertEqual(result.dtype, dtype) + + def test_he_uniform(self): + for dtype in [mx.float32, mx.float16]: + initializer = init.he_uniform(dtype) + for shape in [[3, 3], [3, 3, 3]]: + result = initializer(mx.array(np.empty(shape))) + with self.subTest(shape=shape): + self.assertEqual(result.shape, shape) + self.assertEqual(result.dtype, dtype) + + +if __name__ == "__main__": + unittest.main()