diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index d001d98b6..66d0a7111 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -37,7 +37,7 @@ from mlx.nn.layers.containers import Sequential from mlx.nn.layers.convolution import Conv1d, Conv2d from mlx.nn.layers.dropout import Dropout, Dropout2d from mlx.nn.layers.embedding import Embedding -from mlx.nn.layers.linear import Identity, Linear, Bilinear +from mlx.nn.layers.linear import Bilinear, Identity, Linear from mlx.nn.layers.normalization import BatchNorm, GroupNorm, LayerNorm, RMSNorm from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding from mlx.nn.layers.quantized import QuantizedLinear diff --git a/python/mlx/nn/layers/linear.py b/python/mlx/nn/layers/linear.py index 9877aa77c..1ef7a0d5b 100644 --- a/python/mlx/nn/layers/linear.py +++ b/python/mlx/nn/layers/linear.py @@ -50,7 +50,7 @@ class Linear(Module): self.reset_parameters() def reset_parameters(self): - scale = math.sqrt(1. / self.input_dims) + scale = math.sqrt(1.0 / self.input_dims) self.weight = mx.random.uniform( low=-scale, high=scale, @@ -92,7 +92,9 @@ class Bilinear(Module): not use a bias. Default ``True``. """ - def __init__(self, input1_dims: int, input2_dims: int, output_dims: int, bias: bool = True): + def __init__( + self, input1_dims: int, input2_dims: int, output_dims: int, bias: bool = True + ): super().__init__() self.input1_dims = input1_dims self.input2_dims = input2_dims @@ -104,7 +106,7 @@ class Bilinear(Module): self.reset_parameters() def reset_parameters(self): - scale = math.sqrt(1. / self.input1_dims) + scale = math.sqrt(1.0 / self.input1_dims) self.weight = mx.random.uniform( low=-scale, high=scale, @@ -118,8 +120,10 @@ class Bilinear(Module): ) def _extra_repr(self): - return (f"input1_dims={self.input1_dims}, input2_dims={self.input2_dims}, output_dims={self.output_dims}, " - f"bias={'bias' in self}") + return ( + f"input1_dims={self.input1_dims}, input2_dims={self.input2_dims}, output_dims={self.output_dims}, " + f"bias={'bias' in self}" + ) def __call__(self, input1, input2): output = (input1 @ self.weight * input2.reshape(1, *input2.shape)).sum(-1).T