mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-22 04:56:41 +08:00
Added an identity and bilinear layers
Added a reset_parameters option Added normal init for bias
This commit is contained in:
parent
040c3bafab
commit
0abc7f0d10
@ -37,7 +37,7 @@ from mlx.nn.layers.containers import Sequential
|
|||||||
from mlx.nn.layers.convolution import Conv1d, Conv2d
|
from mlx.nn.layers.convolution import Conv1d, Conv2d
|
||||||
from mlx.nn.layers.dropout import Dropout, Dropout2d
|
from mlx.nn.layers.dropout import Dropout, Dropout2d
|
||||||
from mlx.nn.layers.embedding import Embedding
|
from mlx.nn.layers.embedding import Embedding
|
||||||
from mlx.nn.layers.linear import Linear
|
from mlx.nn.layers.linear import Identity, Linear, Bilinear
|
||||||
from mlx.nn.layers.normalization import BatchNorm, GroupNorm, LayerNorm, RMSNorm
|
from mlx.nn.layers.normalization import BatchNorm, GroupNorm, LayerNorm, RMSNorm
|
||||||
from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
|
from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
|
||||||
from mlx.nn.layers.quantized import QuantizedLinear
|
from mlx.nn.layers.quantized import QuantizedLinear
|
||||||
|
@ -6,6 +6,21 @@ import mlx.core as mx
|
|||||||
from mlx.nn.layers.base import Module
|
from mlx.nn.layers.base import Module
|
||||||
|
|
||||||
|
|
||||||
|
class Identity(Module):
|
||||||
|
r"""A placeholder identity operator that is argument-insensitive.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: any argument (unused)
|
||||||
|
kwargs: any keyword argument (unused)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def __call__(self, input):
|
||||||
|
return input
|
||||||
|
|
||||||
|
|
||||||
class Linear(Module):
|
class Linear(Module):
|
||||||
r"""Applies an affine transformation to the input.
|
r"""Applies an affine transformation to the input.
|
||||||
|
|
||||||
@ -26,20 +41,88 @@ class Linear(Module):
|
|||||||
|
|
||||||
def __init__(self, input_dims: int, output_dims: int, bias: bool = True):
|
def __init__(self, input_dims: int, output_dims: int, bias: bool = True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
scale = math.sqrt(1 / input_dims)
|
self.input_dims = input_dims
|
||||||
self.weight = mx.random.uniform(
|
self.output_dims = output_dims
|
||||||
low=-scale,
|
self.weight = mx.zeros((output_dims, input_dims))
|
||||||
high=scale,
|
|
||||||
shape=(output_dims, input_dims),
|
|
||||||
)
|
|
||||||
if bias:
|
if bias:
|
||||||
self.bias = mx.zeros((output_dims,))
|
self.bias = mx.zeros((output_dims,))
|
||||||
|
|
||||||
def _extra_repr(self):
|
self.reset_parameters()
|
||||||
return f"input_dims={self.weight.shape[1]}, output_dims={self.weight.shape[0]}, bias={'bias' in self}"
|
|
||||||
|
|
||||||
def __call__(self, x):
|
def reset_parameters(self):
|
||||||
x = x @ self.weight.T
|
scale = math.sqrt(1. / self.input_dims)
|
||||||
|
self.weight = mx.random.uniform(
|
||||||
|
low=-scale,
|
||||||
|
high=scale,
|
||||||
|
shape=(self.output_dims, self.input_dims),
|
||||||
|
)
|
||||||
if "bias" in self:
|
if "bias" in self:
|
||||||
x = x + self.bias
|
self.bias = mx.random.uniform(
|
||||||
return x
|
low=-scale,
|
||||||
|
high=scale,
|
||||||
|
shape=(self.output_dims,),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _extra_repr(self):
|
||||||
|
return f"input_dims={self.input_dims}, output_dims={self.output_dims}, bias={'bias' in self}"
|
||||||
|
|
||||||
|
def __call__(self, input):
|
||||||
|
output = input @ self.weight.T
|
||||||
|
if "bias" in self:
|
||||||
|
output = output + self.bias
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class Bilinear(Module):
|
||||||
|
r"""Applies a bilinear transformation to the input.
|
||||||
|
|
||||||
|
Concretely:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
y = input1^\top W input2 + b
|
||||||
|
|
||||||
|
where :math:`W` has shape ``[output_dims, input1_dims, input2_dims]``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input1_dims (int): The dimensionality of the input1 features
|
||||||
|
input2_dims (int): The dimensionality of the input2 features
|
||||||
|
output_dims (int): The dimensionality of the output features
|
||||||
|
bias (bool, optional): If set to ``False`` then the layer will
|
||||||
|
not use a bias. Default ``True``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, input1_dims: int, input2_dims: int, output_dims: int, bias: bool = True):
|
||||||
|
super().__init__()
|
||||||
|
self.input1_dims = input1_dims
|
||||||
|
self.input2_dims = input2_dims
|
||||||
|
self.output_dims = output_dims
|
||||||
|
self.weight = mx.zeros((output_dims, input1_dims, input2_dims))
|
||||||
|
if bias:
|
||||||
|
self.bias = mx.zeros((output_dims,))
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
scale = math.sqrt(1. / self.input1_dims)
|
||||||
|
self.weight = mx.random.uniform(
|
||||||
|
low=-scale,
|
||||||
|
high=scale,
|
||||||
|
shape=(self.output_dims, self.input1_dims, self.input2_dims),
|
||||||
|
)
|
||||||
|
if "bias" in self:
|
||||||
|
self.bias = mx.random.uniform(
|
||||||
|
low=-scale,
|
||||||
|
high=scale,
|
||||||
|
shape=(self.output_dims,),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _extra_repr(self):
|
||||||
|
return (f"input1_dims={self.input1_dims}, input2_dims={self.input2_dims}, output_dims={self.output_dims}, "
|
||||||
|
f"bias={'bias' in self}")
|
||||||
|
|
||||||
|
def __call__(self, input1, input2):
|
||||||
|
output = (input1 @ self.weight * input2.reshape(1, *input2.shape)).sum(-1).T
|
||||||
|
if "bias" in self:
|
||||||
|
output = output + self.bias
|
||||||
|
return output
|
||||||
|
@ -12,12 +12,25 @@ from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
|||||||
|
|
||||||
|
|
||||||
class TestNN(mlx_tests.MLXTestCase):
|
class TestNN(mlx_tests.MLXTestCase):
|
||||||
|
def test_identity(self):
|
||||||
|
inputs = mx.zeros((10, 4))
|
||||||
|
layer = nn.Identity()
|
||||||
|
outputs = layer(inputs)
|
||||||
|
self.assertEqual(tuple(inputs.shape), tuple(outputs.shape))
|
||||||
|
|
||||||
def test_linear(self):
|
def test_linear(self):
|
||||||
inputs = mx.zeros((10, 4))
|
inputs = mx.zeros((10, 4))
|
||||||
layer = nn.Linear(input_dims=4, output_dims=8)
|
layer = nn.Linear(input_dims=4, output_dims=8)
|
||||||
outputs = layer(inputs)
|
outputs = layer(inputs)
|
||||||
self.assertEqual(tuple(outputs.shape), (10, 8))
|
self.assertEqual(tuple(outputs.shape), (10, 8))
|
||||||
|
|
||||||
|
def test_bilinear(self):
|
||||||
|
inputs1 = mx.zeros((10, 2))
|
||||||
|
inputs2 = mx.zeros((10, 4))
|
||||||
|
layer = nn.Bilinear(input1_dims=2, input2_dims=4, output_dims=6)
|
||||||
|
outputs = layer(inputs1, inputs2)
|
||||||
|
self.assertEqual(tuple(outputs.shape), (10, 6))
|
||||||
|
|
||||||
def test_cross_entropy(self):
|
def test_cross_entropy(self):
|
||||||
logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]])
|
logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]])
|
||||||
targets = mx.array([0, 1])
|
targets = mx.array([0, 1])
|
||||||
|
Loading…
Reference in New Issue
Block a user