precommit changes

This commit is contained in:
Jyun1998 2023-12-15 03:27:17 +08:00
parent 4bd3c02c00
commit 584a105c09

View File

@ -5,51 +5,9 @@ from typing import Any, Optional
import mlx.core as mx import mlx.core as mx
from mlx.nn.layers.base import Module from mlx.nn.layers.base import Module
from mlx.nn.layers.dropout import Dropout
from mlx.nn.layers.linear import Linear from mlx.nn.layers.linear import Linear
from mlx.nn.layers.normalization import LayerNorm from mlx.nn.layers.normalization import LayerNorm
from mlx.nn.layers.dropout import Dropout
from mlx.nn.layers.positional_encoding import SinusoidalPositionalEncoding
class MyPosEncoding(SinusoidalPositionalEncoding):
def __init__(
self,
dims: int,
min_freq: float = 0.0001,
max_freq: float = 1,
scale: Optional[float] = None,
cos_first: bool = False,
full_turns: bool = False,
):
super().__init__(
dims,
min_freq=min_freq,
max_freq=max_freq,
scale=scale,
cos_first=cos_first,
full_turns=full_turns
)
self.dims = dims
def __call__(self, x):
seq_length = x.shape[1] # Assuming x.shape [batch_size, sequence_length, embedding_dim]
position = mx.arange(seq_length)[..., None] * self._sigmas
# Generate positional encodings
div_term = mx.exp(mx.arange(0, self.dims, 2) * -(math.log(10000.0) / self.dims))
sinusoid_inp = position * div_term
y = mx.zeros((seq_length, self.dims))
if self.cos_first:
y[:, 0::2] = mx.cos(sinusoid_inp)
y[:, 1::2] = mx.sin(sinusoid_inp)
else:
y[:, 0::2] = mx.sin(sinusoid_inp)
y[:, 1::2] = mx.cos(sinusoid_inp)
if self.scale != 1:
y = y * self.scale
return x + y
class MultiHeadAttention(Module): class MultiHeadAttention(Module):
"""Implements the scaled dot product attention with multiple heads. """Implements the scaled dot product attention with multiple heads.
@ -161,7 +119,13 @@ class TransformerEncoderLayer(Module):
class TransformerEncoderLayerWithDropout(Module): class TransformerEncoderLayerWithDropout(Module):
def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None, dropout_rate: float = 0.1): def __init__(
self,
dims: int,
num_heads: int,
mlp_dims: Optional[int] = None,
dropout_rate: float = 0.1,
):
super().__init__() super().__init__()
mlp_dims = mlp_dims or dims * 4 mlp_dims = mlp_dims or dims * 4
self.attention = MultiHeadAttention(dims, num_heads) self.attention = MultiHeadAttention(dims, num_heads)
@ -224,7 +188,6 @@ class TransformerEncoderWithDropout(Module):
return x return x
class TransformerDecoderLayer(Module): class TransformerDecoderLayer(Module):
def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None): def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None):
super().__init__() super().__init__()
@ -253,7 +216,13 @@ class TransformerDecoderLayer(Module):
class TransformerDecoderLayerWithDropout(Module): class TransformerDecoderLayerWithDropout(Module):
def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None, dropout_rate: float = 0.1): def __init__(
self,
dims: int,
num_heads: int,
mlp_dims: Optional[int] = None,
dropout_rate: float = 0.1,
):
super().__init__() super().__init__()
mlp_dims = mlp_dims or dims * 4 mlp_dims = mlp_dims or dims * 4
self.self_attention = MultiHeadAttention(dims, num_heads) self.self_attention = MultiHeadAttention(dims, num_heads)