From 295ce9db094ba6934e3882347b3a929accc2772c Mon Sep 17 00:00:00 2001 From: Asaf Zorea Date: Tue, 2 Jan 2024 16:08:53 +0200 Subject: [PATCH] Feature expand nn linear (#315) * Added an identity and bilinear layers Added a reset_parameters option Added normal init for bias * pre-commit run * add type hints for parameters and the return type change Bilinear math to x_1 and x_2 change __call__ arguments to x and y instead of input and output add explanation to the Initialization * Remove unnecessary reshape * Added 'i' to bilinear formula * Changed bilinear computation to two matrix multiplications * avoid saving intermediate results, kept y in bilinear for better clarity (can be replaced with x1) * Changed math formula in Linear Added more explanation to math formulas Changed x1, x2 reshape to support all inputs sizes --- python/mlx/nn/layers/__init__.py | 2 +- python/mlx/nn/layers/linear.py | 97 +++++++++++++++++++++++++++++--- python/tests/test_nn.py | 13 +++++ 3 files changed, 103 insertions(+), 9 deletions(-) diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index 31bcc59dc..29787a3cc 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -45,7 +45,7 @@ from mlx.nn.layers.containers import Sequential from mlx.nn.layers.convolution import Conv1d, Conv2d from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d from mlx.nn.layers.embedding import Embedding -from mlx.nn.layers.linear import Linear +from mlx.nn.layers.linear import Bilinear, Identity, Linear from mlx.nn.layers.normalization import BatchNorm, GroupNorm, LayerNorm, RMSNorm from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding from mlx.nn.layers.quantized import QuantizedLinear diff --git a/python/mlx/nn/layers/linear.py b/python/mlx/nn/layers/linear.py index 0c7a1b907..9c4ca6781 100644 --- a/python/mlx/nn/layers/linear.py +++ b/python/mlx/nn/layers/linear.py @@ -1,11 +1,27 @@ # Copyright © 2023 Apple Inc. import math +from typing import Any import mlx.core as mx from mlx.nn.layers.base import Module +class Identity(Module): + r"""A placeholder identity operator that is argument-insensitive. + + Args: + args: any argument (unused) + kwargs: any keyword argument (unused) + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__() + + def __call__(self, x: mx.array) -> mx.array: + return x + + class Linear(Module): r"""Applies an affine transformation to the input. @@ -13,33 +29,98 @@ class Linear(Module): .. math:: - y = W^\top x + b + y = x W^\top + b - where :math:`W` has shape ``[output_dims, input_dims]``. + where: + where :math:`W` has shape ``[output_dims, input_dims]`` and :math:`b` has shape ``[output_dims]``. + + The values are initialized from the uniform distribution :math:`\mathcal{U}(-{k}, {k})`, + where :math:`k = \frac{1}{\sqrt{D_i}}` and :math:`D_i` is equal to ``input_dims``. Args: input_dims (int): The dimensionality of the input features output_dims (int): The dimensionality of the output features bias (bool, optional): If set to ``False`` then the layer will - not use a bias. Default ``True``. + not use a bias. Default is ``True``. """ - def __init__(self, input_dims: int, output_dims: int, bias: bool = True): + def __init__(self, input_dims: int, output_dims: int, bias: bool = True) -> None: super().__init__() - scale = math.sqrt(1 / input_dims) + scale = math.sqrt(1.0 / input_dims) self.weight = mx.random.uniform( low=-scale, high=scale, shape=(output_dims, input_dims), ) if bias: - self.bias = mx.zeros((output_dims,)) + self.bias = mx.random.uniform( + low=-scale, + high=scale, + shape=(output_dims,), + ) - def _extra_repr(self): + def _extra_repr(self) -> str: return f"input_dims={self.weight.shape[1]}, output_dims={self.weight.shape[0]}, bias={'bias' in self}" - def __call__(self, x): + def __call__(self, x: mx.array) -> mx.array: x = x @ self.weight.T if "bias" in self: x = x + self.bias return x + + +class Bilinear(Module): + r"""Applies a bilinear transformation to the inputs. + + Concretely: + + .. math:: + + y_i = x_1^\top W_i x_2 + b_i + + where: + :math:`W` has shape ``[output_dims, input1_dims, input2_dims]``, :math:`b` has shape ``[output_dims ]``, + and :math:`i` indexes the output dimension. + + The values are initialized from the uniform distribution :math:`\mathcal{U}(-{k}, {k})`, + where :math:`k = \frac{1}{\sqrt{D_1}}` and :math:`D_1` is ``input1_dims``. + + Args: + input1_dims (int): The dimensionality of the input1 features + input2_dims (int): The dimensionality of the input2 features + output_dims (int): The dimensionality of the output features + bias (bool, optional): If set to ``False`` then the layer will + not use a bias. Default is ``True``. + """ + + def __init__( + self, input1_dims: int, input2_dims: int, output_dims: int, bias: bool = True + ) -> None: + super().__init__() + scale = math.sqrt(1.0 / input1_dims) + self.weight = mx.random.uniform( + low=-scale, + high=scale, + shape=(1, output_dims, input1_dims, input2_dims), + ) + if bias: + self.bias = mx.random.uniform( + low=-scale, + high=scale, + shape=(output_dims,), + ) + + def _extra_repr(self) -> str: + return ( + f"input1_dims={self.weight.shape[2]}, input2_dims={self.weight.shape[3]}, " + f"output_dims={self.weight.shape[1]}, bias={'bias' in self}" + ) + + def __call__(self, x1: mx.array, x2: mx.array) -> mx.array: + x1 = mx.expand_dims(x1, axis=(-2, -3)) + x2 = mx.expand_dims(x2, axis=(-2)) + y = mx.squeeze(x1 @ self.weight, -2).swapaxes(-1, -2) + y = mx.squeeze(x2 @ y, -2) + if "bias" in self: + y = y + self.bias + return y diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index e620ad831..a6ce87f34 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -12,12 +12,25 @@ from mlx.utils import tree_flatten, tree_map, tree_unflatten class TestNN(mlx_tests.MLXTestCase): + def test_identity(self): + inputs = mx.zeros((10, 4)) + layer = nn.Identity() + outputs = layer(inputs) + self.assertEqual(tuple(inputs.shape), tuple(outputs.shape)) + def test_linear(self): inputs = mx.zeros((10, 4)) layer = nn.Linear(input_dims=4, output_dims=8) outputs = layer(inputs) self.assertEqual(tuple(outputs.shape), (10, 8)) + def test_bilinear(self): + inputs1 = mx.zeros((10, 2)) + inputs2 = mx.zeros((10, 4)) + layer = nn.Bilinear(input1_dims=2, input2_dims=4, output_dims=6) + outputs = layer(inputs1, inputs2) + self.assertEqual(tuple(outputs.shape), (10, 6)) + def test_cross_entropy(self): logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]]) targets = mx.array([0, 1])