mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
run precommit
This commit is contained in:
parent
297e69017c
commit
0e0557b756
@ -1,7 +1,7 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Any, Optional, Callable
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx.nn.layers.activations import relu
|
from mlx.nn.layers.activations import relu
|
||||||
@ -150,11 +150,20 @@ class TransformerEncoderLayer(Module):
|
|||||||
|
|
||||||
class TransformerEncoder(Module):
|
class TransformerEncoder(Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None, dropout: float = 0.0, norm_first: bool = False, activation = relu
|
self,
|
||||||
|
num_layers: int,
|
||||||
|
dims: int,
|
||||||
|
num_heads: int,
|
||||||
|
mlp_dims: Optional[int] = None,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
norm_first: bool = False,
|
||||||
|
activation=relu,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layers = [
|
self.layers = [
|
||||||
TransformerEncoderLayer(dims, num_heads, mlp_dims, dropout, norm_first, activation)
|
TransformerEncoderLayer(
|
||||||
|
dims, num_heads, mlp_dims, dropout, norm_first, activation
|
||||||
|
)
|
||||||
for i in range(num_layers)
|
for i in range(num_layers)
|
||||||
]
|
]
|
||||||
self.ln = LayerNorm(dims)
|
self.ln = LayerNorm(dims)
|
||||||
@ -231,11 +240,20 @@ class TransformerDecoderLayer(Module):
|
|||||||
|
|
||||||
class TransformerDecoder(Module):
|
class TransformerDecoder(Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None, dropout: float = 0.0, norm_first: bool = False, activation = relu
|
self,
|
||||||
|
num_layers: int,
|
||||||
|
dims: int,
|
||||||
|
num_heads: int,
|
||||||
|
mlp_dims: Optional[int] = None,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
norm_first: bool = False,
|
||||||
|
activation=relu,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layers = [
|
self.layers = [
|
||||||
TransformerDecoderLayer(dims, num_heads, mlp_dims, dropout, norm_first, activation)
|
TransformerDecoderLayer(
|
||||||
|
dims, num_heads, mlp_dims, dropout, norm_first, activation
|
||||||
|
)
|
||||||
for i in range(num_layers)
|
for i in range(num_layers)
|
||||||
]
|
]
|
||||||
self.ln = LayerNorm(dims)
|
self.ln = LayerNorm(dims)
|
||||||
@ -270,6 +288,7 @@ class Transformer(Module):
|
|||||||
norm_first (bool): if ``True``, encoder and decoder layers will perform LayerNorms before
|
norm_first (bool): if ``True``, encoder and decoder layers will perform LayerNorms before
|
||||||
other attention and feedforward operations, otherwise after. Default is``False``.
|
other attention and feedforward operations, otherwise after. Default is``False``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dims: int = 512,
|
dims: int = 512,
|
||||||
@ -281,21 +300,33 @@ class Transformer(Module):
|
|||||||
activation: Callable[[Any], Any] = relu,
|
activation: Callable[[Any], Any] = relu,
|
||||||
custom_encoder: Optional[Any] = None,
|
custom_encoder: Optional[Any] = None,
|
||||||
custom_decoder: Optional[Any] = None,
|
custom_decoder: Optional[Any] = None,
|
||||||
norm_first: bool = False
|
norm_first: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if custom_encoder is not None:
|
if custom_encoder is not None:
|
||||||
self.encoder = custom_encoder
|
self.encoder = custom_encoder
|
||||||
else:
|
else:
|
||||||
self.encoder = TransformerEncoder(
|
self.encoder = TransformerEncoder(
|
||||||
num_encoder_layers, dims, num_heads, mlp_dims, dropout, activation, norm_first
|
num_encoder_layers,
|
||||||
|
dims,
|
||||||
|
num_heads,
|
||||||
|
mlp_dims,
|
||||||
|
dropout,
|
||||||
|
activation,
|
||||||
|
norm_first,
|
||||||
)
|
)
|
||||||
|
|
||||||
if custom_decoder is not None:
|
if custom_decoder is not None:
|
||||||
self.decoder = custom_decoder
|
self.decoder = custom_decoder
|
||||||
else:
|
else:
|
||||||
self.decoder = TransformerDecoder(
|
self.decoder = TransformerDecoder(
|
||||||
num_decoder_layers, dims, num_heads, mlp_dims, dropout, activation, norm_first
|
num_decoder_layers,
|
||||||
|
dims,
|
||||||
|
num_heads,
|
||||||
|
mlp_dims,
|
||||||
|
dropout,
|
||||||
|
activation,
|
||||||
|
norm_first,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask):
|
def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask):
|
||||||
|
Loading…
Reference in New Issue
Block a user