Added an identity and bilinear layers

Added a reset_parameters option
Added normal init for bias
This commit is contained in:
zorea 2023-12-29 17:15:22 +02:00
parent 040c3bafab
commit 0abc7f0d10
3 changed files with 109 additions and 13 deletions

View File

@ -37,7 +37,7 @@ from mlx.nn.layers.containers import Sequential
from mlx.nn.layers.convolution import Conv1d, Conv2d from mlx.nn.layers.convolution import Conv1d, Conv2d
from mlx.nn.layers.dropout import Dropout, Dropout2d from mlx.nn.layers.dropout import Dropout, Dropout2d
from mlx.nn.layers.embedding import Embedding from mlx.nn.layers.embedding import Embedding
from mlx.nn.layers.linear import Linear from mlx.nn.layers.linear import Identity, Linear, Bilinear
from mlx.nn.layers.normalization import BatchNorm, GroupNorm, LayerNorm, RMSNorm from mlx.nn.layers.normalization import BatchNorm, GroupNorm, LayerNorm, RMSNorm
from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
from mlx.nn.layers.quantized import QuantizedLinear from mlx.nn.layers.quantized import QuantizedLinear

View File

@ -6,6 +6,21 @@ import mlx.core as mx
from mlx.nn.layers.base import Module from mlx.nn.layers.base import Module
class Identity(Module):
r"""A placeholder identity operator that is argument-insensitive.
Args:
args: any argument (unused)
kwargs: any keyword argument (unused)
"""
def __init__(self, *args, **kwargs):
super().__init__()
def __call__(self, input):
return input
class Linear(Module): class Linear(Module):
r"""Applies an affine transformation to the input. r"""Applies an affine transformation to the input.
@ -26,20 +41,88 @@ class Linear(Module):
def __init__(self, input_dims: int, output_dims: int, bias: bool = True): def __init__(self, input_dims: int, output_dims: int, bias: bool = True):
super().__init__() super().__init__()
scale = math.sqrt(1 / input_dims) self.input_dims = input_dims
self.weight = mx.random.uniform( self.output_dims = output_dims
low=-scale, self.weight = mx.zeros((output_dims, input_dims))
high=scale,
shape=(output_dims, input_dims),
)
if bias: if bias:
self.bias = mx.zeros((output_dims,)) self.bias = mx.zeros((output_dims,))
def _extra_repr(self): self.reset_parameters()
return f"input_dims={self.weight.shape[1]}, output_dims={self.weight.shape[0]}, bias={'bias' in self}"
def __call__(self, x): def reset_parameters(self):
x = x @ self.weight.T scale = math.sqrt(1. / self.input_dims)
self.weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(self.output_dims, self.input_dims),
)
if "bias" in self: if "bias" in self:
x = x + self.bias self.bias = mx.random.uniform(
return x low=-scale,
high=scale,
shape=(self.output_dims,),
)
def _extra_repr(self):
return f"input_dims={self.input_dims}, output_dims={self.output_dims}, bias={'bias' in self}"
def __call__(self, input):
output = input @ self.weight.T
if "bias" in self:
output = output + self.bias
return output
class Bilinear(Module):
r"""Applies a bilinear transformation to the input.
Concretely:
.. math::
y = input1^\top W input2 + b
where :math:`W` has shape ``[output_dims, input1_dims, input2_dims]``.
Args:
input1_dims (int): The dimensionality of the input1 features
input2_dims (int): The dimensionality of the input2 features
output_dims (int): The dimensionality of the output features
bias (bool, optional): If set to ``False`` then the layer will
not use a bias. Default ``True``.
"""
def __init__(self, input1_dims: int, input2_dims: int, output_dims: int, bias: bool = True):
super().__init__()
self.input1_dims = input1_dims
self.input2_dims = input2_dims
self.output_dims = output_dims
self.weight = mx.zeros((output_dims, input1_dims, input2_dims))
if bias:
self.bias = mx.zeros((output_dims,))
self.reset_parameters()
def reset_parameters(self):
scale = math.sqrt(1. / self.input1_dims)
self.weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(self.output_dims, self.input1_dims, self.input2_dims),
)
if "bias" in self:
self.bias = mx.random.uniform(
low=-scale,
high=scale,
shape=(self.output_dims,),
)
def _extra_repr(self):
return (f"input1_dims={self.input1_dims}, input2_dims={self.input2_dims}, output_dims={self.output_dims}, "
f"bias={'bias' in self}")
def __call__(self, input1, input2):
output = (input1 @ self.weight * input2.reshape(1, *input2.shape)).sum(-1).T
if "bias" in self:
output = output + self.bias
return output

View File

@ -12,12 +12,25 @@ from mlx.utils import tree_flatten, tree_map, tree_unflatten
class TestNN(mlx_tests.MLXTestCase): class TestNN(mlx_tests.MLXTestCase):
def test_identity(self):
inputs = mx.zeros((10, 4))
layer = nn.Identity()
outputs = layer(inputs)
self.assertEqual(tuple(inputs.shape), tuple(outputs.shape))
def test_linear(self): def test_linear(self):
inputs = mx.zeros((10, 4)) inputs = mx.zeros((10, 4))
layer = nn.Linear(input_dims=4, output_dims=8) layer = nn.Linear(input_dims=4, output_dims=8)
outputs = layer(inputs) outputs = layer(inputs)
self.assertEqual(tuple(outputs.shape), (10, 8)) self.assertEqual(tuple(outputs.shape), (10, 8))
def test_bilinear(self):
inputs1 = mx.zeros((10, 2))
inputs2 = mx.zeros((10, 4))
layer = nn.Bilinear(input1_dims=2, input2_dims=4, output_dims=6)
outputs = layer(inputs1, inputs2)
self.assertEqual(tuple(outputs.shape), (10, 6))
def test_cross_entropy(self): def test_cross_entropy(self):
logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]]) logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]])
targets = mx.array([0, 1]) targets = mx.array([0, 1])