mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Feature expand nn linear (#315)
* Added an identity and bilinear layers Added a reset_parameters option Added normal init for bias * pre-commit run * add type hints for parameters and the return type change Bilinear math to x_1 and x_2 change __call__ arguments to x and y instead of input and output add explanation to the Initialization * Remove unnecessary reshape * Added 'i' to bilinear formula * Changed bilinear computation to two matrix multiplications * avoid saving intermediate results, kept y in bilinear for better clarity (can be replaced with x1) * Changed math formula in Linear Added more explanation to math formulas Changed x1, x2 reshape to support all inputs sizes
This commit is contained in:
parent
44c1ce5e6a
commit
295ce9db09
@ -45,7 +45,7 @@ from mlx.nn.layers.containers import Sequential
|
||||
from mlx.nn.layers.convolution import Conv1d, Conv2d
|
||||
from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d
|
||||
from mlx.nn.layers.embedding import Embedding
|
||||
from mlx.nn.layers.linear import Linear
|
||||
from mlx.nn.layers.linear import Bilinear, Identity, Linear
|
||||
from mlx.nn.layers.normalization import BatchNorm, GroupNorm, LayerNorm, RMSNorm
|
||||
from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
|
||||
from mlx.nn.layers.quantized import QuantizedLinear
|
||||
|
@ -1,11 +1,27 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import math
|
||||
from typing import Any
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
|
||||
|
||||
class Identity(Module):
|
||||
r"""A placeholder identity operator that is argument-insensitive.
|
||||
|
||||
Args:
|
||||
args: any argument (unused)
|
||||
kwargs: any keyword argument (unused)
|
||||
"""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return x
|
||||
|
||||
|
||||
class Linear(Module):
|
||||
r"""Applies an affine transformation to the input.
|
||||
|
||||
@ -13,33 +29,98 @@ class Linear(Module):
|
||||
|
||||
.. math::
|
||||
|
||||
y = W^\top x + b
|
||||
y = x W^\top + b
|
||||
|
||||
where :math:`W` has shape ``[output_dims, input_dims]``.
|
||||
where:
|
||||
where :math:`W` has shape ``[output_dims, input_dims]`` and :math:`b` has shape ``[output_dims]``.
|
||||
|
||||
The values are initialized from the uniform distribution :math:`\mathcal{U}(-{k}, {k})`,
|
||||
where :math:`k = \frac{1}{\sqrt{D_i}}` and :math:`D_i` is equal to ``input_dims``.
|
||||
|
||||
Args:
|
||||
input_dims (int): The dimensionality of the input features
|
||||
output_dims (int): The dimensionality of the output features
|
||||
bias (bool, optional): If set to ``False`` then the layer will
|
||||
not use a bias. Default ``True``.
|
||||
not use a bias. Default is ``True``.
|
||||
"""
|
||||
|
||||
def __init__(self, input_dims: int, output_dims: int, bias: bool = True):
|
||||
def __init__(self, input_dims: int, output_dims: int, bias: bool = True) -> None:
|
||||
super().__init__()
|
||||
scale = math.sqrt(1 / input_dims)
|
||||
scale = math.sqrt(1.0 / input_dims)
|
||||
self.weight = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(output_dims, input_dims),
|
||||
)
|
||||
if bias:
|
||||
self.bias = mx.zeros((output_dims,))
|
||||
self.bias = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(output_dims,),
|
||||
)
|
||||
|
||||
def _extra_repr(self):
|
||||
def _extra_repr(self) -> str:
|
||||
return f"input_dims={self.weight.shape[1]}, output_dims={self.weight.shape[0]}, bias={'bias' in self}"
|
||||
|
||||
def __call__(self, x):
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
x = x @ self.weight.T
|
||||
if "bias" in self:
|
||||
x = x + self.bias
|
||||
return x
|
||||
|
||||
|
||||
class Bilinear(Module):
|
||||
r"""Applies a bilinear transformation to the inputs.
|
||||
|
||||
Concretely:
|
||||
|
||||
.. math::
|
||||
|
||||
y_i = x_1^\top W_i x_2 + b_i
|
||||
|
||||
where:
|
||||
:math:`W` has shape ``[output_dims, input1_dims, input2_dims]``, :math:`b` has shape ``[output_dims ]``,
|
||||
and :math:`i` indexes the output dimension.
|
||||
|
||||
The values are initialized from the uniform distribution :math:`\mathcal{U}(-{k}, {k})`,
|
||||
where :math:`k = \frac{1}{\sqrt{D_1}}` and :math:`D_1` is ``input1_dims``.
|
||||
|
||||
Args:
|
||||
input1_dims (int): The dimensionality of the input1 features
|
||||
input2_dims (int): The dimensionality of the input2 features
|
||||
output_dims (int): The dimensionality of the output features
|
||||
bias (bool, optional): If set to ``False`` then the layer will
|
||||
not use a bias. Default is ``True``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, input1_dims: int, input2_dims: int, output_dims: int, bias: bool = True
|
||||
) -> None:
|
||||
super().__init__()
|
||||
scale = math.sqrt(1.0 / input1_dims)
|
||||
self.weight = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(1, output_dims, input1_dims, input2_dims),
|
||||
)
|
||||
if bias:
|
||||
self.bias = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(output_dims,),
|
||||
)
|
||||
|
||||
def _extra_repr(self) -> str:
|
||||
return (
|
||||
f"input1_dims={self.weight.shape[2]}, input2_dims={self.weight.shape[3]}, "
|
||||
f"output_dims={self.weight.shape[1]}, bias={'bias' in self}"
|
||||
)
|
||||
|
||||
def __call__(self, x1: mx.array, x2: mx.array) -> mx.array:
|
||||
x1 = mx.expand_dims(x1, axis=(-2, -3))
|
||||
x2 = mx.expand_dims(x2, axis=(-2))
|
||||
y = mx.squeeze(x1 @ self.weight, -2).swapaxes(-1, -2)
|
||||
y = mx.squeeze(x2 @ y, -2)
|
||||
if "bias" in self:
|
||||
y = y + self.bias
|
||||
return y
|
||||
|
@ -12,12 +12,25 @@ from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||
|
||||
|
||||
class TestNN(mlx_tests.MLXTestCase):
|
||||
def test_identity(self):
|
||||
inputs = mx.zeros((10, 4))
|
||||
layer = nn.Identity()
|
||||
outputs = layer(inputs)
|
||||
self.assertEqual(tuple(inputs.shape), tuple(outputs.shape))
|
||||
|
||||
def test_linear(self):
|
||||
inputs = mx.zeros((10, 4))
|
||||
layer = nn.Linear(input_dims=4, output_dims=8)
|
||||
outputs = layer(inputs)
|
||||
self.assertEqual(tuple(outputs.shape), (10, 8))
|
||||
|
||||
def test_bilinear(self):
|
||||
inputs1 = mx.zeros((10, 2))
|
||||
inputs2 = mx.zeros((10, 4))
|
||||
layer = nn.Bilinear(input1_dims=2, input2_dims=4, output_dims=6)
|
||||
outputs = layer(inputs1, inputs2)
|
||||
self.assertEqual(tuple(outputs.shape), (10, 6))
|
||||
|
||||
def test_cross_entropy(self):
|
||||
logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]])
|
||||
targets = mx.array([0, 1])
|
||||
|
Loading…
Reference in New Issue
Block a user