precommit changes

This commit is contained in:
Jyun1998 2023-12-15 03:27:17 +08:00
parent 4bd3c02c00
commit 584a105c09

View File

@ -5,51 +5,9 @@ from typing import Any, Optional
import mlx.core as mx
from mlx.nn.layers.base import Module
from mlx.nn.layers.dropout import Dropout
from mlx.nn.layers.linear import Linear
from mlx.nn.layers.normalization import LayerNorm
from mlx.nn.layers.dropout import Dropout
from mlx.nn.layers.positional_encoding import SinusoidalPositionalEncoding
class MyPosEncoding(SinusoidalPositionalEncoding):
def __init__(
self,
dims: int,
min_freq: float = 0.0001,
max_freq: float = 1,
scale: Optional[float] = None,
cos_first: bool = False,
full_turns: bool = False,
):
super().__init__(
dims,
min_freq=min_freq,
max_freq=max_freq,
scale=scale,
cos_first=cos_first,
full_turns=full_turns
)
self.dims = dims
def __call__(self, x):
seq_length = x.shape[1] # Assuming x.shape [batch_size, sequence_length, embedding_dim]
position = mx.arange(seq_length)[..., None] * self._sigmas
# Generate positional encodings
div_term = mx.exp(mx.arange(0, self.dims, 2) * -(math.log(10000.0) / self.dims))
sinusoid_inp = position * div_term
y = mx.zeros((seq_length, self.dims))
if self.cos_first:
y[:, 0::2] = mx.cos(sinusoid_inp)
y[:, 1::2] = mx.sin(sinusoid_inp)
else:
y[:, 0::2] = mx.sin(sinusoid_inp)
y[:, 1::2] = mx.cos(sinusoid_inp)
if self.scale != 1:
y = y * self.scale
return x + y
class MultiHeadAttention(Module):
"""Implements the scaled dot product attention with multiple heads.
@ -151,7 +109,7 @@ class TransformerEncoderLayer(Module):
def __call__(self, x, mask):
y = self.attention(x, x, x, mask)
y = self.ln1(x + y)
y = self.linear1(y)
y = mx.maximum(y, 0)
y = self.linear2(y)
@ -161,7 +119,13 @@ class TransformerEncoderLayer(Module):
class TransformerEncoderLayerWithDropout(Module):
def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None, dropout_rate: float = 0.1):
def __init__(
self,
dims: int,
num_heads: int,
mlp_dims: Optional[int] = None,
dropout_rate: float = 0.1,
):
super().__init__()
mlp_dims = mlp_dims or dims * 4
self.attention = MultiHeadAttention(dims, num_heads)
@ -176,7 +140,7 @@ class TransformerEncoderLayerWithDropout(Module):
y = self.attention(x, x, x, mask)
y = self.dropout1(y)
y = self.ln1(x + y)
y = self.linear1(y)
y = mx.maximum(y, 0)
y = self.linear2(y)
@ -224,7 +188,6 @@ class TransformerEncoderWithDropout(Module):
return x
class TransformerDecoderLayer(Module):
def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None):
super().__init__()
@ -240,7 +203,7 @@ class TransformerDecoderLayer(Module):
def __call__(self, x, memory, x_mask, memory_mask):
y = self.self_attention(x, x, x, x_mask)
x = self.ln1(x + y)
y = self.cross_attention(y, memory, memory, memory_mask)
x = self.ln1(x + y)
@ -253,7 +216,13 @@ class TransformerDecoderLayer(Module):
class TransformerDecoderLayerWithDropout(Module):
def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None, dropout_rate: float = 0.1):
def __init__(
self,
dims: int,
num_heads: int,
mlp_dims: Optional[int] = None,
dropout_rate: float = 0.1,
):
super().__init__()
mlp_dims = mlp_dims or dims * 4
self.self_attention = MultiHeadAttention(dims, num_heads)
@ -271,7 +240,7 @@ class TransformerDecoderLayerWithDropout(Module):
y = self.self_attention(x, x, x, x_mask)
y = self.dropout1(y)
x = self.ln1(x + y)
y = self.cross_attention(y, memory, memory, memory_mask)
y = self.dropout2(y)
x = self.ln1(x + y)