mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
pre-commit run
This commit is contained in:
parent
0abc7f0d10
commit
7bf769a292
@ -37,7 +37,7 @@ from mlx.nn.layers.containers import Sequential
|
|||||||
from mlx.nn.layers.convolution import Conv1d, Conv2d
|
from mlx.nn.layers.convolution import Conv1d, Conv2d
|
||||||
from mlx.nn.layers.dropout import Dropout, Dropout2d
|
from mlx.nn.layers.dropout import Dropout, Dropout2d
|
||||||
from mlx.nn.layers.embedding import Embedding
|
from mlx.nn.layers.embedding import Embedding
|
||||||
from mlx.nn.layers.linear import Identity, Linear, Bilinear
|
from mlx.nn.layers.linear import Bilinear, Identity, Linear
|
||||||
from mlx.nn.layers.normalization import BatchNorm, GroupNorm, LayerNorm, RMSNorm
|
from mlx.nn.layers.normalization import BatchNorm, GroupNorm, LayerNorm, RMSNorm
|
||||||
from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
|
from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
|
||||||
from mlx.nn.layers.quantized import QuantizedLinear
|
from mlx.nn.layers.quantized import QuantizedLinear
|
||||||
|
@ -50,7 +50,7 @@ class Linear(Module):
|
|||||||
self.reset_parameters()
|
self.reset_parameters()
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
scale = math.sqrt(1. / self.input_dims)
|
scale = math.sqrt(1.0 / self.input_dims)
|
||||||
self.weight = mx.random.uniform(
|
self.weight = mx.random.uniform(
|
||||||
low=-scale,
|
low=-scale,
|
||||||
high=scale,
|
high=scale,
|
||||||
@ -92,7 +92,9 @@ class Bilinear(Module):
|
|||||||
not use a bias. Default ``True``.
|
not use a bias. Default ``True``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, input1_dims: int, input2_dims: int, output_dims: int, bias: bool = True):
|
def __init__(
|
||||||
|
self, input1_dims: int, input2_dims: int, output_dims: int, bias: bool = True
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.input1_dims = input1_dims
|
self.input1_dims = input1_dims
|
||||||
self.input2_dims = input2_dims
|
self.input2_dims = input2_dims
|
||||||
@ -104,7 +106,7 @@ class Bilinear(Module):
|
|||||||
self.reset_parameters()
|
self.reset_parameters()
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
scale = math.sqrt(1. / self.input1_dims)
|
scale = math.sqrt(1.0 / self.input1_dims)
|
||||||
self.weight = mx.random.uniform(
|
self.weight = mx.random.uniform(
|
||||||
low=-scale,
|
low=-scale,
|
||||||
high=scale,
|
high=scale,
|
||||||
@ -118,8 +120,10 @@ class Bilinear(Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _extra_repr(self):
|
def _extra_repr(self):
|
||||||
return (f"input1_dims={self.input1_dims}, input2_dims={self.input2_dims}, output_dims={self.output_dims}, "
|
return (
|
||||||
f"bias={'bias' in self}")
|
f"input1_dims={self.input1_dims}, input2_dims={self.input2_dims}, output_dims={self.output_dims}, "
|
||||||
|
f"bias={'bias' in self}"
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, input1, input2):
|
def __call__(self, input1, input2):
|
||||||
output = (input1 @ self.weight * input2.reshape(1, *input2.shape)).sum(-1).T
|
output = (input1 @ self.weight * input2.reshape(1, *input2.shape)).sum(-1).T
|
||||||
|
Loading…
Reference in New Issue
Block a user