From 0abc7f0d1005caa4802eb2ab5e5b891e53c14010 Mon Sep 17 00:00:00 2001 From: zorea Date: Fri, 29 Dec 2023 17:15:22 +0200 Subject: [PATCH] Added an identity and bilinear layers Added a reset_parameters option Added normal init for bias --- python/mlx/nn/layers/__init__.py | 2 +- python/mlx/nn/layers/linear.py | 107 +++++++++++++++++++++++++++---- python/tests/test_nn.py | 13 ++++ 3 files changed, 109 insertions(+), 13 deletions(-) diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index 4dbe96eb6..d001d98b6 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -37,7 +37,7 @@ from mlx.nn.layers.containers import Sequential from mlx.nn.layers.convolution import Conv1d, Conv2d from mlx.nn.layers.dropout import Dropout, Dropout2d from mlx.nn.layers.embedding import Embedding -from mlx.nn.layers.linear import Linear +from mlx.nn.layers.linear import Identity, Linear, Bilinear from mlx.nn.layers.normalization import BatchNorm, GroupNorm, LayerNorm, RMSNorm from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding from mlx.nn.layers.quantized import QuantizedLinear diff --git a/python/mlx/nn/layers/linear.py b/python/mlx/nn/layers/linear.py index 0c7a1b907..9877aa77c 100644 --- a/python/mlx/nn/layers/linear.py +++ b/python/mlx/nn/layers/linear.py @@ -6,6 +6,21 @@ import mlx.core as mx from mlx.nn.layers.base import Module +class Identity(Module): + r"""A placeholder identity operator that is argument-insensitive. + + Args: + args: any argument (unused) + kwargs: any keyword argument (unused) + """ + + def __init__(self, *args, **kwargs): + super().__init__() + + def __call__(self, input): + return input + + class Linear(Module): r"""Applies an affine transformation to the input. @@ -26,20 +41,88 @@ class Linear(Module): def __init__(self, input_dims: int, output_dims: int, bias: bool = True): super().__init__() - scale = math.sqrt(1 / input_dims) - self.weight = mx.random.uniform( - low=-scale, - high=scale, - shape=(output_dims, input_dims), - ) + self.input_dims = input_dims + self.output_dims = output_dims + self.weight = mx.zeros((output_dims, input_dims)) if bias: self.bias = mx.zeros((output_dims,)) - def _extra_repr(self): - return f"input_dims={self.weight.shape[1]}, output_dims={self.weight.shape[0]}, bias={'bias' in self}" + self.reset_parameters() - def __call__(self, x): - x = x @ self.weight.T + def reset_parameters(self): + scale = math.sqrt(1. / self.input_dims) + self.weight = mx.random.uniform( + low=-scale, + high=scale, + shape=(self.output_dims, self.input_dims), + ) if "bias" in self: - x = x + self.bias - return x + self.bias = mx.random.uniform( + low=-scale, + high=scale, + shape=(self.output_dims,), + ) + + def _extra_repr(self): + return f"input_dims={self.input_dims}, output_dims={self.output_dims}, bias={'bias' in self}" + + def __call__(self, input): + output = input @ self.weight.T + if "bias" in self: + output = output + self.bias + return output + + +class Bilinear(Module): + r"""Applies a bilinear transformation to the input. + + Concretely: + + .. math:: + + y = input1^\top W input2 + b + + where :math:`W` has shape ``[output_dims, input1_dims, input2_dims]``. + + Args: + input1_dims (int): The dimensionality of the input1 features + input2_dims (int): The dimensionality of the input2 features + output_dims (int): The dimensionality of the output features + bias (bool, optional): If set to ``False`` then the layer will + not use a bias. Default ``True``. + """ + + def __init__(self, input1_dims: int, input2_dims: int, output_dims: int, bias: bool = True): + super().__init__() + self.input1_dims = input1_dims + self.input2_dims = input2_dims + self.output_dims = output_dims + self.weight = mx.zeros((output_dims, input1_dims, input2_dims)) + if bias: + self.bias = mx.zeros((output_dims,)) + + self.reset_parameters() + + def reset_parameters(self): + scale = math.sqrt(1. / self.input1_dims) + self.weight = mx.random.uniform( + low=-scale, + high=scale, + shape=(self.output_dims, self.input1_dims, self.input2_dims), + ) + if "bias" in self: + self.bias = mx.random.uniform( + low=-scale, + high=scale, + shape=(self.output_dims,), + ) + + def _extra_repr(self): + return (f"input1_dims={self.input1_dims}, input2_dims={self.input2_dims}, output_dims={self.output_dims}, " + f"bias={'bias' in self}") + + def __call__(self, input1, input2): + output = (input1 @ self.weight * input2.reshape(1, *input2.shape)).sum(-1).T + if "bias" in self: + output = output + self.bias + return output diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 2cfac4475..fe19b83fb 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -12,12 +12,25 @@ from mlx.utils import tree_flatten, tree_map, tree_unflatten class TestNN(mlx_tests.MLXTestCase): + def test_identity(self): + inputs = mx.zeros((10, 4)) + layer = nn.Identity() + outputs = layer(inputs) + self.assertEqual(tuple(inputs.shape), tuple(outputs.shape)) + def test_linear(self): inputs = mx.zeros((10, 4)) layer = nn.Linear(input_dims=4, output_dims=8) outputs = layer(inputs) self.assertEqual(tuple(outputs.shape), (10, 8)) + def test_bilinear(self): + inputs1 = mx.zeros((10, 2)) + inputs2 = mx.zeros((10, 4)) + layer = nn.Bilinear(input1_dims=2, input2_dims=4, output_dims=6) + outputs = layer(inputs1, inputs2) + self.assertEqual(tuple(outputs.shape), (10, 6)) + def test_cross_entropy(self): logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]]) targets = mx.array([0, 1])