Feature expand nn linear (#315)

* Added an identity and bilinear layers
Added a reset_parameters option
Added normal init for bias

* pre-commit run

* add type hints for parameters and the return type
change Bilinear math to x_1 and x_2
change __call__ arguments to x and y instead of input and output
add explanation to the Initialization

* Remove unnecessary reshape

* Added 'i' to bilinear formula

* Changed bilinear computation to two matrix multiplications

* avoid saving intermediate results, kept y in bilinear for better clarity (can be replaced with x1)

* Changed math formula in Linear
Added more explanation to math formulas
Changed x1, x2 reshape to support all inputs sizes
This commit is contained in:
Asaf Zorea 2024-01-02 16:08:53 +02:00 committed by GitHub
parent 44c1ce5e6a
commit 295ce9db09
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 103 additions and 9 deletions

View File

@ -45,7 +45,7 @@ from mlx.nn.layers.containers import Sequential
from mlx.nn.layers.convolution import Conv1d, Conv2d
from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d
from mlx.nn.layers.embedding import Embedding
from mlx.nn.layers.linear import Linear
from mlx.nn.layers.linear import Bilinear, Identity, Linear
from mlx.nn.layers.normalization import BatchNorm, GroupNorm, LayerNorm, RMSNorm
from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
from mlx.nn.layers.quantized import QuantizedLinear

View File

@ -1,11 +1,27 @@
# Copyright © 2023 Apple Inc.
import math
from typing import Any
import mlx.core as mx
from mlx.nn.layers.base import Module
class Identity(Module):
r"""A placeholder identity operator that is argument-insensitive.
Args:
args: any argument (unused)
kwargs: any keyword argument (unused)
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__()
def __call__(self, x: mx.array) -> mx.array:
return x
class Linear(Module):
r"""Applies an affine transformation to the input.
@ -13,33 +29,98 @@ class Linear(Module):
.. math::
y = W^\top x + b
y = x W^\top + b
where :math:`W` has shape ``[output_dims, input_dims]``.
where:
where :math:`W` has shape ``[output_dims, input_dims]`` and :math:`b` has shape ``[output_dims]``.
The values are initialized from the uniform distribution :math:`\mathcal{U}(-{k}, {k})`,
where :math:`k = \frac{1}{\sqrt{D_i}}` and :math:`D_i` is equal to ``input_dims``.
Args:
input_dims (int): The dimensionality of the input features
output_dims (int): The dimensionality of the output features
bias (bool, optional): If set to ``False`` then the layer will
not use a bias. Default ``True``.
not use a bias. Default is ``True``.
"""
def __init__(self, input_dims: int, output_dims: int, bias: bool = True):
def __init__(self, input_dims: int, output_dims: int, bias: bool = True) -> None:
super().__init__()
scale = math.sqrt(1 / input_dims)
scale = math.sqrt(1.0 / input_dims)
self.weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(output_dims, input_dims),
)
if bias:
self.bias = mx.zeros((output_dims,))
self.bias = mx.random.uniform(
low=-scale,
high=scale,
shape=(output_dims,),
)
def _extra_repr(self):
def _extra_repr(self) -> str:
return f"input_dims={self.weight.shape[1]}, output_dims={self.weight.shape[0]}, bias={'bias' in self}"
def __call__(self, x):
def __call__(self, x: mx.array) -> mx.array:
x = x @ self.weight.T
if "bias" in self:
x = x + self.bias
return x
class Bilinear(Module):
r"""Applies a bilinear transformation to the inputs.
Concretely:
.. math::
y_i = x_1^\top W_i x_2 + b_i
where:
:math:`W` has shape ``[output_dims, input1_dims, input2_dims]``, :math:`b` has shape ``[output_dims ]``,
and :math:`i` indexes the output dimension.
The values are initialized from the uniform distribution :math:`\mathcal{U}(-{k}, {k})`,
where :math:`k = \frac{1}{\sqrt{D_1}}` and :math:`D_1` is ``input1_dims``.
Args:
input1_dims (int): The dimensionality of the input1 features
input2_dims (int): The dimensionality of the input2 features
output_dims (int): The dimensionality of the output features
bias (bool, optional): If set to ``False`` then the layer will
not use a bias. Default is ``True``.
"""
def __init__(
self, input1_dims: int, input2_dims: int, output_dims: int, bias: bool = True
) -> None:
super().__init__()
scale = math.sqrt(1.0 / input1_dims)
self.weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(1, output_dims, input1_dims, input2_dims),
)
if bias:
self.bias = mx.random.uniform(
low=-scale,
high=scale,
shape=(output_dims,),
)
def _extra_repr(self) -> str:
return (
f"input1_dims={self.weight.shape[2]}, input2_dims={self.weight.shape[3]}, "
f"output_dims={self.weight.shape[1]}, bias={'bias' in self}"
)
def __call__(self, x1: mx.array, x2: mx.array) -> mx.array:
x1 = mx.expand_dims(x1, axis=(-2, -3))
x2 = mx.expand_dims(x2, axis=(-2))
y = mx.squeeze(x1 @ self.weight, -2).swapaxes(-1, -2)
y = mx.squeeze(x2 @ y, -2)
if "bias" in self:
y = y + self.bias
return y

View File

@ -12,12 +12,25 @@ from mlx.utils import tree_flatten, tree_map, tree_unflatten
class TestNN(mlx_tests.MLXTestCase):
def test_identity(self):
inputs = mx.zeros((10, 4))
layer = nn.Identity()
outputs = layer(inputs)
self.assertEqual(tuple(inputs.shape), tuple(outputs.shape))
def test_linear(self):
inputs = mx.zeros((10, 4))
layer = nn.Linear(input_dims=4, output_dims=8)
outputs = layer(inputs)
self.assertEqual(tuple(outputs.shape), (10, 8))
def test_bilinear(self):
inputs1 = mx.zeros((10, 2))
inputs2 = mx.zeros((10, 4))
layer = nn.Bilinear(input1_dims=2, input2_dims=4, output_dims=6)
outputs = layer(inputs1, inputs2)
self.assertEqual(tuple(outputs.shape), (10, 6))
def test_cross_entropy(self):
logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]])
targets = mx.array([0, 1])