From a4b43db9af5a00285213d7c9ae710166663aca58 Mon Sep 17 00:00:00 2001 From: zorea Date: Sat, 30 Dec 2023 23:53:09 +0200 Subject: [PATCH] add type hints for parameters and the return type change Bilinear math to x_1 and x_2 change __call__ arguments to x and y instead of input and output add explanation to the Initialization --- python/mlx/nn/layers/linear.py | 94 ++++++++++++++++------------------ 1 file changed, 43 insertions(+), 51 deletions(-) diff --git a/python/mlx/nn/layers/linear.py b/python/mlx/nn/layers/linear.py index 1ef7a0d5b..9bb3e1cd5 100644 --- a/python/mlx/nn/layers/linear.py +++ b/python/mlx/nn/layers/linear.py @@ -1,6 +1,7 @@ # Copyright © 2023 Apple Inc. import math +from typing import Any import mlx.core as mx from mlx.nn.layers.base import Module @@ -14,11 +15,11 @@ class Identity(Module): kwargs: any keyword argument (unused) """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__() - def __call__(self, input): - return input + def __call__(self, x: mx.array) -> mx.array: + return x class Linear(Module): @@ -30,103 +31,94 @@ class Linear(Module): y = W^\top x + b - where :math:`W` has shape ``[output_dims, input_dims]``. + where: + :math:`W` has shape ``[output_dims, input_dims]``. + :math:`b` has shape ``[output_dims, ]``. + + The values are initialized from :math:`\mathcal{U}(-{k}, {k})`, where + :math:`k = \frac{1}{\sqrt{input\_dims}}`. Args: input_dims (int): The dimensionality of the input features output_dims (int): The dimensionality of the output features bias (bool, optional): If set to ``False`` then the layer will - not use a bias. Default ``True``. + not use a bias. Default is ``True``. """ - def __init__(self, input_dims: int, output_dims: int, bias: bool = True): + def __init__(self, input_dims: int, output_dims: int, bias: bool = True) -> None: super().__init__() - self.input_dims = input_dims - self.output_dims = output_dims - self.weight = mx.zeros((output_dims, input_dims)) - if bias: - self.bias = mx.zeros((output_dims,)) - - self.reset_parameters() - - def reset_parameters(self): - scale = math.sqrt(1.0 / self.input_dims) + scale = math.sqrt(1.0 / input_dims) self.weight = mx.random.uniform( low=-scale, high=scale, - shape=(self.output_dims, self.input_dims), + shape=(output_dims, input_dims), ) - if "bias" in self: + if bias: self.bias = mx.random.uniform( low=-scale, high=scale, - shape=(self.output_dims,), + shape=(output_dims,), ) - def _extra_repr(self): - return f"input_dims={self.input_dims}, output_dims={self.output_dims}, bias={'bias' in self}" + def _extra_repr(self) -> str: + return f"input_dims={self.weight.shape[1]}, output_dims={self.weight.shape[0]}, bias={'bias' in self}" - def __call__(self, input): - output = input @ self.weight.T + def __call__(self, x: mx.array) -> mx.array: + y = x @ self.weight.T if "bias" in self: - output = output + self.bias - return output + y = y + self.bias + return y class Bilinear(Module): - r"""Applies a bilinear transformation to the input. + r"""Applies a bilinear transformation to the inputs. Concretely: .. math:: - y = input1^\top W input2 + b + y = x_1^\top W x_2 + b - where :math:`W` has shape ``[output_dims, input1_dims, input2_dims]``. + where + :math:`W` has shape ``[output_dims, input1_dims, input2_dims]``. + :math:`b` has shape ``[output_dims, ]``. + + The values are initialized from :math:`\mathcal{U}(-{k}, {k})`, where + :math:`k = \frac{1}{\sqrt{input1\_dims}}`. Args: input1_dims (int): The dimensionality of the input1 features input2_dims (int): The dimensionality of the input2 features output_dims (int): The dimensionality of the output features bias (bool, optional): If set to ``False`` then the layer will - not use a bias. Default ``True``. + not use a bias. Default is ``True``. """ def __init__( self, input1_dims: int, input2_dims: int, output_dims: int, bias: bool = True - ): + ) -> None: super().__init__() - self.input1_dims = input1_dims - self.input2_dims = input2_dims - self.output_dims = output_dims - self.weight = mx.zeros((output_dims, input1_dims, input2_dims)) - if bias: - self.bias = mx.zeros((output_dims,)) - - self.reset_parameters() - - def reset_parameters(self): - scale = math.sqrt(1.0 / self.input1_dims) + scale = math.sqrt(1.0 / input1_dims) self.weight = mx.random.uniform( low=-scale, high=scale, - shape=(self.output_dims, self.input1_dims, self.input2_dims), + shape=(output_dims, input1_dims, input2_dims), ) - if "bias" in self: + if bias: self.bias = mx.random.uniform( low=-scale, high=scale, - shape=(self.output_dims,), + shape=(output_dims,), ) - def _extra_repr(self): + def _extra_repr(self) -> str: return ( - f"input1_dims={self.input1_dims}, input2_dims={self.input2_dims}, output_dims={self.output_dims}, " - f"bias={'bias' in self}" + f"input1_dims={self.weight.shape[1]}, input2_dims={self.weight.shape[2]}, " + f"output_dims={self.weight.shape[0]}, bias={'bias' in self}" ) - def __call__(self, input1, input2): - output = (input1 @ self.weight * input2.reshape(1, *input2.shape)).sum(-1).T + def __call__(self, x1: mx.array, x2: mx.array) -> mx.array: + y = (x1 @ self.weight * x2.reshape(1, *x2.shape)).sum(-1).T if "bias" in self: - output = output + self.bias - return output + y = y + self.bias + return y