mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Feature expand nn linear (#315)
* Added an identity and bilinear layers Added a reset_parameters option Added normal init for bias * pre-commit run * add type hints for parameters and the return type change Bilinear math to x_1 and x_2 change __call__ arguments to x and y instead of input and output add explanation to the Initialization * Remove unnecessary reshape * Added 'i' to bilinear formula * Changed bilinear computation to two matrix multiplications * avoid saving intermediate results, kept y in bilinear for better clarity (can be replaced with x1) * Changed math formula in Linear Added more explanation to math formulas Changed x1, x2 reshape to support all inputs sizes
This commit is contained in:
		| @@ -45,7 +45,7 @@ from mlx.nn.layers.containers import Sequential | ||||
| from mlx.nn.layers.convolution import Conv1d, Conv2d | ||||
| from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d | ||||
| from mlx.nn.layers.embedding import Embedding | ||||
| from mlx.nn.layers.linear import Linear | ||||
| from mlx.nn.layers.linear import Bilinear, Identity, Linear | ||||
| from mlx.nn.layers.normalization import BatchNorm, GroupNorm, LayerNorm, RMSNorm | ||||
| from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding | ||||
| from mlx.nn.layers.quantized import QuantizedLinear | ||||
|   | ||||
| @@ -1,11 +1,27 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
|  | ||||
| import math | ||||
| from typing import Any | ||||
|  | ||||
| import mlx.core as mx | ||||
| from mlx.nn.layers.base import Module | ||||
|  | ||||
|  | ||||
| class Identity(Module): | ||||
|     r"""A placeholder identity operator that is argument-insensitive. | ||||
|  | ||||
|     Args: | ||||
|         args: any argument (unused) | ||||
|         kwargs: any keyword argument (unused) | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, *args: Any, **kwargs: Any) -> None: | ||||
|         super().__init__() | ||||
|  | ||||
|     def __call__(self, x: mx.array) -> mx.array: | ||||
|         return x | ||||
|  | ||||
|  | ||||
| class Linear(Module): | ||||
|     r"""Applies an affine transformation to the input. | ||||
|  | ||||
| @@ -13,33 +29,98 @@ class Linear(Module): | ||||
|  | ||||
|     .. math:: | ||||
|  | ||||
|         y = W^\top x + b | ||||
|         y = x W^\top + b | ||||
|  | ||||
|     where :math:`W` has shape ``[output_dims, input_dims]``. | ||||
|     where: | ||||
|     where :math:`W` has shape ``[output_dims, input_dims]`` and :math:`b` has shape ``[output_dims]``. | ||||
|  | ||||
|     The values are initialized from the uniform distribution :math:`\mathcal{U}(-{k}, {k})`, | ||||
|     where :math:`k = \frac{1}{\sqrt{D_i}}` and :math:`D_i` is equal to ``input_dims``. | ||||
|  | ||||
|     Args: | ||||
|         input_dims (int): The dimensionality of the input features | ||||
|         output_dims (int): The dimensionality of the output features | ||||
|         bias (bool, optional): If set to ``False`` then the layer will | ||||
|           not use a bias. Default ``True``. | ||||
|           not use a bias. Default is ``True``. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, input_dims: int, output_dims: int, bias: bool = True): | ||||
|     def __init__(self, input_dims: int, output_dims: int, bias: bool = True) -> None: | ||||
|         super().__init__() | ||||
|         scale = math.sqrt(1 / input_dims) | ||||
|         scale = math.sqrt(1.0 / input_dims) | ||||
|         self.weight = mx.random.uniform( | ||||
|             low=-scale, | ||||
|             high=scale, | ||||
|             shape=(output_dims, input_dims), | ||||
|         ) | ||||
|         if bias: | ||||
|             self.bias = mx.zeros((output_dims,)) | ||||
|             self.bias = mx.random.uniform( | ||||
|                 low=-scale, | ||||
|                 high=scale, | ||||
|                 shape=(output_dims,), | ||||
|             ) | ||||
|  | ||||
|     def _extra_repr(self): | ||||
|     def _extra_repr(self) -> str: | ||||
|         return f"input_dims={self.weight.shape[1]}, output_dims={self.weight.shape[0]}, bias={'bias' in self}" | ||||
|  | ||||
|     def __call__(self, x): | ||||
|     def __call__(self, x: mx.array) -> mx.array: | ||||
|         x = x @ self.weight.T | ||||
|         if "bias" in self: | ||||
|             x = x + self.bias | ||||
|         return x | ||||
|  | ||||
|  | ||||
| class Bilinear(Module): | ||||
|     r"""Applies a bilinear transformation to the inputs. | ||||
|  | ||||
|     Concretely: | ||||
|  | ||||
|     .. math:: | ||||
|  | ||||
|         y_i = x_1^\top W_i x_2 + b_i | ||||
|  | ||||
|     where: | ||||
|     :math:`W` has shape ``[output_dims, input1_dims, input2_dims]``, :math:`b` has shape ``[output_dims ]``, | ||||
|     and :math:`i` indexes the output dimension. | ||||
|  | ||||
|     The values are initialized from the uniform distribution :math:`\mathcal{U}(-{k}, {k})`, | ||||
|     where :math:`k = \frac{1}{\sqrt{D_1}}` and :math:`D_1` is ``input1_dims``. | ||||
|  | ||||
|     Args: | ||||
|         input1_dims (int): The dimensionality of the input1 features | ||||
|         input2_dims (int): The dimensionality of the input2 features | ||||
|         output_dims (int): The dimensionality of the output features | ||||
|         bias (bool, optional): If set to ``False`` then the layer will | ||||
|           not use a bias. Default is ``True``. | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|         self, input1_dims: int, input2_dims: int, output_dims: int, bias: bool = True | ||||
|     ) -> None: | ||||
|         super().__init__() | ||||
|         scale = math.sqrt(1.0 / input1_dims) | ||||
|         self.weight = mx.random.uniform( | ||||
|             low=-scale, | ||||
|             high=scale, | ||||
|             shape=(1, output_dims, input1_dims, input2_dims), | ||||
|         ) | ||||
|         if bias: | ||||
|             self.bias = mx.random.uniform( | ||||
|                 low=-scale, | ||||
|                 high=scale, | ||||
|                 shape=(output_dims,), | ||||
|             ) | ||||
|  | ||||
|     def _extra_repr(self) -> str: | ||||
|         return ( | ||||
|             f"input1_dims={self.weight.shape[2]}, input2_dims={self.weight.shape[3]}, " | ||||
|             f"output_dims={self.weight.shape[1]}, bias={'bias' in self}" | ||||
|         ) | ||||
|  | ||||
|     def __call__(self, x1: mx.array, x2: mx.array) -> mx.array: | ||||
|         x1 = mx.expand_dims(x1, axis=(-2, -3)) | ||||
|         x2 = mx.expand_dims(x2, axis=(-2)) | ||||
|         y = mx.squeeze(x1 @ self.weight, -2).swapaxes(-1, -2) | ||||
|         y = mx.squeeze(x2 @ y, -2) | ||||
|         if "bias" in self: | ||||
|             y = y + self.bias | ||||
|         return y | ||||
|   | ||||
| @@ -12,12 +12,25 @@ from mlx.utils import tree_flatten, tree_map, tree_unflatten | ||||
|  | ||||
|  | ||||
| class TestNN(mlx_tests.MLXTestCase): | ||||
|     def test_identity(self): | ||||
|         inputs = mx.zeros((10, 4)) | ||||
|         layer = nn.Identity() | ||||
|         outputs = layer(inputs) | ||||
|         self.assertEqual(tuple(inputs.shape), tuple(outputs.shape)) | ||||
|  | ||||
|     def test_linear(self): | ||||
|         inputs = mx.zeros((10, 4)) | ||||
|         layer = nn.Linear(input_dims=4, output_dims=8) | ||||
|         outputs = layer(inputs) | ||||
|         self.assertEqual(tuple(outputs.shape), (10, 8)) | ||||
|  | ||||
|     def test_bilinear(self): | ||||
|         inputs1 = mx.zeros((10, 2)) | ||||
|         inputs2 = mx.zeros((10, 4)) | ||||
|         layer = nn.Bilinear(input1_dims=2, input2_dims=4, output_dims=6) | ||||
|         outputs = layer(inputs1, inputs2) | ||||
|         self.assertEqual(tuple(outputs.shape), (10, 6)) | ||||
|  | ||||
|     def test_cross_entropy(self): | ||||
|         logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]]) | ||||
|         targets = mx.array([0, 1]) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Asaf Zorea
					Asaf Zorea