mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
add type hints for parameters and the return type
change Bilinear math to x_1 and x_2 change __call__ arguments to x and y instead of input and output add explanation to the Initialization
This commit is contained in:
parent
85da6e2626
commit
a4b43db9af
@ -1,6 +1,7 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx.nn.layers.base import Module
|
from mlx.nn.layers.base import Module
|
||||||
@ -14,11 +15,11 @@ class Identity(Module):
|
|||||||
kwargs: any keyword argument (unused)
|
kwargs: any keyword argument (unused)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def __call__(self, input):
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
return input
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Linear(Module):
|
class Linear(Module):
|
||||||
@ -30,103 +31,94 @@ class Linear(Module):
|
|||||||
|
|
||||||
y = W^\top x + b
|
y = W^\top x + b
|
||||||
|
|
||||||
where :math:`W` has shape ``[output_dims, input_dims]``.
|
where:
|
||||||
|
:math:`W` has shape ``[output_dims, input_dims]``.
|
||||||
|
:math:`b` has shape ``[output_dims, ]``.
|
||||||
|
|
||||||
|
The values are initialized from :math:`\mathcal{U}(-{k}, {k})`, where
|
||||||
|
:math:`k = \frac{1}{\sqrt{input\_dims}}`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_dims (int): The dimensionality of the input features
|
input_dims (int): The dimensionality of the input features
|
||||||
output_dims (int): The dimensionality of the output features
|
output_dims (int): The dimensionality of the output features
|
||||||
bias (bool, optional): If set to ``False`` then the layer will
|
bias (bool, optional): If set to ``False`` then the layer will
|
||||||
not use a bias. Default ``True``.
|
not use a bias. Default is ``True``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, input_dims: int, output_dims: int, bias: bool = True):
|
def __init__(self, input_dims: int, output_dims: int, bias: bool = True) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.input_dims = input_dims
|
scale = math.sqrt(1.0 / input_dims)
|
||||||
self.output_dims = output_dims
|
|
||||||
self.weight = mx.zeros((output_dims, input_dims))
|
|
||||||
if bias:
|
|
||||||
self.bias = mx.zeros((output_dims,))
|
|
||||||
|
|
||||||
self.reset_parameters()
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
|
||||||
scale = math.sqrt(1.0 / self.input_dims)
|
|
||||||
self.weight = mx.random.uniform(
|
self.weight = mx.random.uniform(
|
||||||
low=-scale,
|
low=-scale,
|
||||||
high=scale,
|
high=scale,
|
||||||
shape=(self.output_dims, self.input_dims),
|
shape=(output_dims, input_dims),
|
||||||
)
|
)
|
||||||
if "bias" in self:
|
if bias:
|
||||||
self.bias = mx.random.uniform(
|
self.bias = mx.random.uniform(
|
||||||
low=-scale,
|
low=-scale,
|
||||||
high=scale,
|
high=scale,
|
||||||
shape=(self.output_dims,),
|
shape=(output_dims,),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _extra_repr(self):
|
def _extra_repr(self) -> str:
|
||||||
return f"input_dims={self.input_dims}, output_dims={self.output_dims}, bias={'bias' in self}"
|
return f"input_dims={self.weight.shape[1]}, output_dims={self.weight.shape[0]}, bias={'bias' in self}"
|
||||||
|
|
||||||
def __call__(self, input):
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
output = input @ self.weight.T
|
y = x @ self.weight.T
|
||||||
if "bias" in self:
|
if "bias" in self:
|
||||||
output = output + self.bias
|
y = y + self.bias
|
||||||
return output
|
return y
|
||||||
|
|
||||||
|
|
||||||
class Bilinear(Module):
|
class Bilinear(Module):
|
||||||
r"""Applies a bilinear transformation to the input.
|
r"""Applies a bilinear transformation to the inputs.
|
||||||
|
|
||||||
Concretely:
|
Concretely:
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
|
|
||||||
y = input1^\top W input2 + b
|
y = x_1^\top W x_2 + b
|
||||||
|
|
||||||
where :math:`W` has shape ``[output_dims, input1_dims, input2_dims]``.
|
where
|
||||||
|
:math:`W` has shape ``[output_dims, input1_dims, input2_dims]``.
|
||||||
|
:math:`b` has shape ``[output_dims, ]``.
|
||||||
|
|
||||||
|
The values are initialized from :math:`\mathcal{U}(-{k}, {k})`, where
|
||||||
|
:math:`k = \frac{1}{\sqrt{input1\_dims}}`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input1_dims (int): The dimensionality of the input1 features
|
input1_dims (int): The dimensionality of the input1 features
|
||||||
input2_dims (int): The dimensionality of the input2 features
|
input2_dims (int): The dimensionality of the input2 features
|
||||||
output_dims (int): The dimensionality of the output features
|
output_dims (int): The dimensionality of the output features
|
||||||
bias (bool, optional): If set to ``False`` then the layer will
|
bias (bool, optional): If set to ``False`` then the layer will
|
||||||
not use a bias. Default ``True``.
|
not use a bias. Default is ``True``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, input1_dims: int, input2_dims: int, output_dims: int, bias: bool = True
|
self, input1_dims: int, input2_dims: int, output_dims: int, bias: bool = True
|
||||||
):
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.input1_dims = input1_dims
|
scale = math.sqrt(1.0 / input1_dims)
|
||||||
self.input2_dims = input2_dims
|
|
||||||
self.output_dims = output_dims
|
|
||||||
self.weight = mx.zeros((output_dims, input1_dims, input2_dims))
|
|
||||||
if bias:
|
|
||||||
self.bias = mx.zeros((output_dims,))
|
|
||||||
|
|
||||||
self.reset_parameters()
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
|
||||||
scale = math.sqrt(1.0 / self.input1_dims)
|
|
||||||
self.weight = mx.random.uniform(
|
self.weight = mx.random.uniform(
|
||||||
low=-scale,
|
low=-scale,
|
||||||
high=scale,
|
high=scale,
|
||||||
shape=(self.output_dims, self.input1_dims, self.input2_dims),
|
shape=(output_dims, input1_dims, input2_dims),
|
||||||
)
|
)
|
||||||
if "bias" in self:
|
if bias:
|
||||||
self.bias = mx.random.uniform(
|
self.bias = mx.random.uniform(
|
||||||
low=-scale,
|
low=-scale,
|
||||||
high=scale,
|
high=scale,
|
||||||
shape=(self.output_dims,),
|
shape=(output_dims,),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _extra_repr(self):
|
def _extra_repr(self) -> str:
|
||||||
return (
|
return (
|
||||||
f"input1_dims={self.input1_dims}, input2_dims={self.input2_dims}, output_dims={self.output_dims}, "
|
f"input1_dims={self.weight.shape[1]}, input2_dims={self.weight.shape[2]}, "
|
||||||
f"bias={'bias' in self}"
|
f"output_dims={self.weight.shape[0]}, bias={'bias' in self}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, input1, input2):
|
def __call__(self, x1: mx.array, x2: mx.array) -> mx.array:
|
||||||
output = (input1 @ self.weight * input2.reshape(1, *input2.shape)).sum(-1).T
|
y = (x1 @ self.weight * x2.reshape(1, *x2.shape)).sum(-1).T
|
||||||
if "bias" in self:
|
if "bias" in self:
|
||||||
output = output + self.bias
|
y = y + self.bias
|
||||||
return output
|
return y
|
||||||
|
Loading…
Reference in New Issue
Block a user