add type hints for parameters and the return type

change Bilinear math to x_1 and x_2
change __call__ arguments to x and y instead of input and output
add explanation to the Initialization
This commit is contained in:
zorea 2023-12-30 23:53:09 +02:00
parent 85da6e2626
commit a4b43db9af

View File

@ -1,6 +1,7 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import math import math
from typing import Any
import mlx.core as mx import mlx.core as mx
from mlx.nn.layers.base import Module from mlx.nn.layers.base import Module
@ -14,11 +15,11 @@ class Identity(Module):
kwargs: any keyword argument (unused) kwargs: any keyword argument (unused)
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__() super().__init__()
def __call__(self, input): def __call__(self, x: mx.array) -> mx.array:
return input return x
class Linear(Module): class Linear(Module):
@ -30,103 +31,94 @@ class Linear(Module):
y = W^\top x + b y = W^\top x + b
where :math:`W` has shape ``[output_dims, input_dims]``. where:
:math:`W` has shape ``[output_dims, input_dims]``.
:math:`b` has shape ``[output_dims, ]``.
The values are initialized from :math:`\mathcal{U}(-{k}, {k})`, where
:math:`k = \frac{1}{\sqrt{input\_dims}}`.
Args: Args:
input_dims (int): The dimensionality of the input features input_dims (int): The dimensionality of the input features
output_dims (int): The dimensionality of the output features output_dims (int): The dimensionality of the output features
bias (bool, optional): If set to ``False`` then the layer will bias (bool, optional): If set to ``False`` then the layer will
not use a bias. Default ``True``. not use a bias. Default is ``True``.
""" """
def __init__(self, input_dims: int, output_dims: int, bias: bool = True): def __init__(self, input_dims: int, output_dims: int, bias: bool = True) -> None:
super().__init__() super().__init__()
self.input_dims = input_dims scale = math.sqrt(1.0 / input_dims)
self.output_dims = output_dims
self.weight = mx.zeros((output_dims, input_dims))
if bias:
self.bias = mx.zeros((output_dims,))
self.reset_parameters()
def reset_parameters(self):
scale = math.sqrt(1.0 / self.input_dims)
self.weight = mx.random.uniform( self.weight = mx.random.uniform(
low=-scale, low=-scale,
high=scale, high=scale,
shape=(self.output_dims, self.input_dims), shape=(output_dims, input_dims),
) )
if "bias" in self: if bias:
self.bias = mx.random.uniform( self.bias = mx.random.uniform(
low=-scale, low=-scale,
high=scale, high=scale,
shape=(self.output_dims,), shape=(output_dims,),
) )
def _extra_repr(self): def _extra_repr(self) -> str:
return f"input_dims={self.input_dims}, output_dims={self.output_dims}, bias={'bias' in self}" return f"input_dims={self.weight.shape[1]}, output_dims={self.weight.shape[0]}, bias={'bias' in self}"
def __call__(self, input): def __call__(self, x: mx.array) -> mx.array:
output = input @ self.weight.T y = x @ self.weight.T
if "bias" in self: if "bias" in self:
output = output + self.bias y = y + self.bias
return output return y
class Bilinear(Module): class Bilinear(Module):
r"""Applies a bilinear transformation to the input. r"""Applies a bilinear transformation to the inputs.
Concretely: Concretely:
.. math:: .. math::
y = input1^\top W input2 + b y = x_1^\top W x_2 + b
where :math:`W` has shape ``[output_dims, input1_dims, input2_dims]``. where
:math:`W` has shape ``[output_dims, input1_dims, input2_dims]``.
:math:`b` has shape ``[output_dims, ]``.
The values are initialized from :math:`\mathcal{U}(-{k}, {k})`, where
:math:`k = \frac{1}{\sqrt{input1\_dims}}`.
Args: Args:
input1_dims (int): The dimensionality of the input1 features input1_dims (int): The dimensionality of the input1 features
input2_dims (int): The dimensionality of the input2 features input2_dims (int): The dimensionality of the input2 features
output_dims (int): The dimensionality of the output features output_dims (int): The dimensionality of the output features
bias (bool, optional): If set to ``False`` then the layer will bias (bool, optional): If set to ``False`` then the layer will
not use a bias. Default ``True``. not use a bias. Default is ``True``.
""" """
def __init__( def __init__(
self, input1_dims: int, input2_dims: int, output_dims: int, bias: bool = True self, input1_dims: int, input2_dims: int, output_dims: int, bias: bool = True
): ) -> None:
super().__init__() super().__init__()
self.input1_dims = input1_dims scale = math.sqrt(1.0 / input1_dims)
self.input2_dims = input2_dims
self.output_dims = output_dims
self.weight = mx.zeros((output_dims, input1_dims, input2_dims))
if bias:
self.bias = mx.zeros((output_dims,))
self.reset_parameters()
def reset_parameters(self):
scale = math.sqrt(1.0 / self.input1_dims)
self.weight = mx.random.uniform( self.weight = mx.random.uniform(
low=-scale, low=-scale,
high=scale, high=scale,
shape=(self.output_dims, self.input1_dims, self.input2_dims), shape=(output_dims, input1_dims, input2_dims),
) )
if "bias" in self: if bias:
self.bias = mx.random.uniform( self.bias = mx.random.uniform(
low=-scale, low=-scale,
high=scale, high=scale,
shape=(self.output_dims,), shape=(output_dims,),
) )
def _extra_repr(self): def _extra_repr(self) -> str:
return ( return (
f"input1_dims={self.input1_dims}, input2_dims={self.input2_dims}, output_dims={self.output_dims}, " f"input1_dims={self.weight.shape[1]}, input2_dims={self.weight.shape[2]}, "
f"bias={'bias' in self}" f"output_dims={self.weight.shape[0]}, bias={'bias' in self}"
) )
def __call__(self, input1, input2): def __call__(self, x1: mx.array, x2: mx.array) -> mx.array:
output = (input1 @ self.weight * input2.reshape(1, *input2.shape)).sum(-1).T y = (x1 @ self.weight * x2.reshape(1, *x2.shape)).sum(-1).T
if "bias" in self: if "bias" in self:
output = output + self.bias y = y + self.bias
return output return y