mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
add type hints for parameters and the return type
change Bilinear math to x_1 and x_2 change __call__ arguments to x and y instead of input and output add explanation to the Initialization
This commit is contained in:
parent
85da6e2626
commit
a4b43db9af
@ -1,6 +1,7 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import math
|
||||
from typing import Any
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
@ -14,11 +15,11 @@ class Identity(Module):
|
||||
kwargs: any keyword argument (unused)
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, input):
|
||||
return input
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return x
|
||||
|
||||
|
||||
class Linear(Module):
|
||||
@ -30,103 +31,94 @@ class Linear(Module):
|
||||
|
||||
y = W^\top x + b
|
||||
|
||||
where :math:`W` has shape ``[output_dims, input_dims]``.
|
||||
where:
|
||||
:math:`W` has shape ``[output_dims, input_dims]``.
|
||||
:math:`b` has shape ``[output_dims, ]``.
|
||||
|
||||
The values are initialized from :math:`\mathcal{U}(-{k}, {k})`, where
|
||||
:math:`k = \frac{1}{\sqrt{input\_dims}}`.
|
||||
|
||||
Args:
|
||||
input_dims (int): The dimensionality of the input features
|
||||
output_dims (int): The dimensionality of the output features
|
||||
bias (bool, optional): If set to ``False`` then the layer will
|
||||
not use a bias. Default ``True``.
|
||||
not use a bias. Default is ``True``.
|
||||
"""
|
||||
|
||||
def __init__(self, input_dims: int, output_dims: int, bias: bool = True):
|
||||
def __init__(self, input_dims: int, output_dims: int, bias: bool = True) -> None:
|
||||
super().__init__()
|
||||
self.input_dims = input_dims
|
||||
self.output_dims = output_dims
|
||||
self.weight = mx.zeros((output_dims, input_dims))
|
||||
if bias:
|
||||
self.bias = mx.zeros((output_dims,))
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
scale = math.sqrt(1.0 / self.input_dims)
|
||||
scale = math.sqrt(1.0 / input_dims)
|
||||
self.weight = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(self.output_dims, self.input_dims),
|
||||
shape=(output_dims, input_dims),
|
||||
)
|
||||
if "bias" in self:
|
||||
if bias:
|
||||
self.bias = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(self.output_dims,),
|
||||
shape=(output_dims,),
|
||||
)
|
||||
|
||||
def _extra_repr(self):
|
||||
return f"input_dims={self.input_dims}, output_dims={self.output_dims}, bias={'bias' in self}"
|
||||
def _extra_repr(self) -> str:
|
||||
return f"input_dims={self.weight.shape[1]}, output_dims={self.weight.shape[0]}, bias={'bias' in self}"
|
||||
|
||||
def __call__(self, input):
|
||||
output = input @ self.weight.T
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
y = x @ self.weight.T
|
||||
if "bias" in self:
|
||||
output = output + self.bias
|
||||
return output
|
||||
y = y + self.bias
|
||||
return y
|
||||
|
||||
|
||||
class Bilinear(Module):
|
||||
r"""Applies a bilinear transformation to the input.
|
||||
r"""Applies a bilinear transformation to the inputs.
|
||||
|
||||
Concretely:
|
||||
|
||||
.. math::
|
||||
|
||||
y = input1^\top W input2 + b
|
||||
y = x_1^\top W x_2 + b
|
||||
|
||||
where :math:`W` has shape ``[output_dims, input1_dims, input2_dims]``.
|
||||
where
|
||||
:math:`W` has shape ``[output_dims, input1_dims, input2_dims]``.
|
||||
:math:`b` has shape ``[output_dims, ]``.
|
||||
|
||||
The values are initialized from :math:`\mathcal{U}(-{k}, {k})`, where
|
||||
:math:`k = \frac{1}{\sqrt{input1\_dims}}`.
|
||||
|
||||
Args:
|
||||
input1_dims (int): The dimensionality of the input1 features
|
||||
input2_dims (int): The dimensionality of the input2 features
|
||||
output_dims (int): The dimensionality of the output features
|
||||
bias (bool, optional): If set to ``False`` then the layer will
|
||||
not use a bias. Default ``True``.
|
||||
not use a bias. Default is ``True``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, input1_dims: int, input2_dims: int, output_dims: int, bias: bool = True
|
||||
):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.input1_dims = input1_dims
|
||||
self.input2_dims = input2_dims
|
||||
self.output_dims = output_dims
|
||||
self.weight = mx.zeros((output_dims, input1_dims, input2_dims))
|
||||
if bias:
|
||||
self.bias = mx.zeros((output_dims,))
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
scale = math.sqrt(1.0 / self.input1_dims)
|
||||
scale = math.sqrt(1.0 / input1_dims)
|
||||
self.weight = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(self.output_dims, self.input1_dims, self.input2_dims),
|
||||
shape=(output_dims, input1_dims, input2_dims),
|
||||
)
|
||||
if "bias" in self:
|
||||
if bias:
|
||||
self.bias = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(self.output_dims,),
|
||||
shape=(output_dims,),
|
||||
)
|
||||
|
||||
def _extra_repr(self):
|
||||
def _extra_repr(self) -> str:
|
||||
return (
|
||||
f"input1_dims={self.input1_dims}, input2_dims={self.input2_dims}, output_dims={self.output_dims}, "
|
||||
f"bias={'bias' in self}"
|
||||
f"input1_dims={self.weight.shape[1]}, input2_dims={self.weight.shape[2]}, "
|
||||
f"output_dims={self.weight.shape[0]}, bias={'bias' in self}"
|
||||
)
|
||||
|
||||
def __call__(self, input1, input2):
|
||||
output = (input1 @ self.weight * input2.reshape(1, *input2.shape)).sum(-1).T
|
||||
def __call__(self, x1: mx.array, x2: mx.array) -> mx.array:
|
||||
y = (x1 @ self.weight * x2.reshape(1, *x2.shape)).sum(-1).T
|
||||
if "bias" in self:
|
||||
output = output + self.bias
|
||||
return output
|
||||
y = y + self.bias
|
||||
return y
|
||||
|
Loading…
Reference in New Issue
Block a user